@ -50,7 +50,7 @@ add_arg('--adversary-weight', default=1e2, type=float, help='Weight
add_arg ( ' --generator-start ' , default = 0 , type = int , help = ' Epoch count to start training generator. ' )
add_arg ( ' --generator-start ' , default = 0 , type = int , help = ' Epoch count to start training generator. ' )
add_arg ( ' --discriminator-start ' , default = 1 , type = int , help = ' Epoch count to update the discriminator. ' )
add_arg ( ' --discriminator-start ' , default = 1 , type = int , help = ' Epoch count to update the discriminator. ' )
add_arg ( ' --adversarial-start ' , default = 2 , type = int , help = ' Epoch for generator to use discriminator. ' )
add_arg ( ' --adversarial-start ' , default = 2 , type = int , help = ' Epoch for generator to use discriminator. ' )
add_arg ( ' --device ' , default = ' cpu ' , type = str , help = ' Name of the CPU/GPU to use, for Theano. ' )
add_arg ( ' --device ' , default = ' gpu0 ' , type = str , help = ' Name of the CPU/GPU to use, for Theano. ' )
args = parser . parse_args ( )
args = parser . parse_args ( )
@ -132,7 +132,7 @@ class DataLoader(threading.Thread):
filename = os . path . join ( self . cwd , f )
filename = os . path . join ( self . cwd , f )
try :
try :
if f not in cache :
if f not in cache :
if len ( cache ) > 317 2:
if len ( cache ) > 2048 :
del cache [ random . choice ( list ( cache . keys ( ) ) ) ]
del cache [ random . choice ( list ( cache . keys ( ) ) ) ]
img = scipy . ndimage . imread ( filename , mode = ' RGB ' )
img = scipy . ndimage . imread ( filename , mode = ' RGB ' )
@ -200,14 +200,16 @@ class Model(object):
else :
else :
self . network [ ' img ' ] = InputLayer ( ( None , 3 , None , None ) )
self . network [ ' img ' ] = InputLayer ( ( None , 3 , None , None ) )
self . network [ ' seed ' ] = self . network [ ' img ' ]
self . network [ ' seed ' ] = self . network [ ' img ' ]
self . setup_generator ( self . last_layer ( ) )
config , params = self . load_model ( )
self . setup_generator ( self . last_layer ( ) , config )
if args . train :
if args . train :
concatenated = lasagne . layers . ConcatLayer ( [ self . network [ ' img ' ] , self . network [ ' out ' ] ] , axis = 0 )
concatenated = lasagne . layers . ConcatLayer ( [ self . network [ ' img ' ] , self . network [ ' out ' ] ] , axis = 0 )
self . setup_perceptual ( concatenated )
self . setup_perceptual ( concatenated )
self . load_perceptual ( )
self . load_perceptual ( )
self . setup_discriminator ( )
self . setup_discriminator ( )
self . load_generator ( )
self . load_generator ( params )
self . compile ( )
self . compile ( )
@ -230,7 +232,8 @@ class Model(object):
self . make_layer ( name + ' -B ' , self . last_layer ( ) , units , alpha = 1.0 )
self . make_layer ( name + ' -B ' , self . last_layer ( ) , units , alpha = 1.0 )
return ElemwiseSumLayer ( [ input , self . last_layer ( ) ] ) if args . generator_residual else self . last_layer ( )
return ElemwiseSumLayer ( [ input , self . last_layer ( ) ] ) if args . generator_residual else self . last_layer ( )
def setup_generator ( self , input ) :
def setup_generator ( self , input , config ) :
for k , v in config . items ( ) : setattr ( args , k , v )
units = args . generator_filters
units = args . generator_filters
self . make_layer ( ' iter.0 ' , input , units , filter_size = ( 5 , 5 ) , pad = ( 2 , 2 ) )
self . make_layer ( ' iter.0 ' , input , units , filter_size = ( 5 , 5 ) , pad = ( 2 , 2 ) )
@ -312,21 +315,25 @@ class Model(object):
def save_generator ( self ) :
def save_generator ( self ) :
def cast ( p ) : return p . get_value ( ) . astype ( np . float16 )
def cast ( p ) : return p . get_value ( ) . astype ( np . float16 )
params = { k : [ cast ( p ) for p in l . get_params ( ) ] for ( k , l ) in self . list_generator_layers ( ) }
params = { k : [ cast ( p ) for p in l . get_params ( ) ] for ( k , l ) in self . list_generator_layers ( ) }
config = { k : getattr ( args , k ) for k in [ ' generator_blocks ' , ' generator_residual ' , ' generator_filters ' ] }
filename = args . model % 2 * * args . scales
filename = args . model % 2 * * args . scales
pickle . dump ( params , bz2 . open ( filename , ' wb ' ) )
pickle . dump ( ( config , params ) , bz2 . open ( filename , ' wb ' ) )
print ( ' - Saved model as ` {} ` after training. ' . format ( filename ) )
print ( ' - Saved model as ` {} ` after training. ' . format ( filename ) )
def load_ generator ( self ) :
def load_ model ( self ) :
filename = args . model % 2 * * args . scales
filename = args . model % 2 * * args . scales
if not os . path . exists ( filename ) : return
if not os . path . exists ( filename ) : return { } , { }
params = pickle . load ( bz2 . open ( filename , ' rb ' ) )
print ( ' - Loaded file ` {} ` with trained model. ' . format ( filename ) )
return pickle . load ( bz2 . open ( filename , ' rb ' ) )
def load_generator ( self , params ) :
if len ( params ) == 0 : return
for k , l in self . list_generator_layers ( ) :
for k , l in self . list_generator_layers ( ) :
assert k in params , " Couldn ' t find layer ` %s ` in loaded model. ' "
assert k in params , " Couldn ' t find layer ` %s ` in loaded model. ' "
assert len ( l . get_params ( ) ) == len ( params [ k ] ) , " Mismatch in types of layers. "
assert len ( l . get_params ( ) ) == len ( params [ k ] ) , " Mismatch in types of layers. "
for p , v in zip ( l . get_params ( ) , params [ k ] ) :
for p , v in zip ( l . get_params ( ) , params [ k ] ) :
assert v . shape == p . get_value ( ) . shape , " Mismatch in number of parameters. "
assert v . shape == p . get_value ( ) . shape , " Mismatch in number of parameters. "
p . set_value ( v . astype ( np . float32 ) )
p . set_value ( v . astype ( np . float32 ) )
print ( ' - Loaded file ` {} ` with trained model. ' . format ( filename ) )
#------------------------------------------------------------------------------------------------------------------
#------------------------------------------------------------------------------------------------------------------
# Training & Loss Functions
# Training & Loss Functions
@ -401,29 +408,31 @@ class NeuralEnhancer(object):
self . imsave ( ' valid/ %03i _pixels.png ' % i , scald [ i ] )
self . imsave ( ' valid/ %03i _pixels.png ' % i , scald [ i ] )
self . imsave ( ' valid/ %03i _reprod.png ' % i , repro [ i ] )
self . imsave ( ' valid/ %03i _reprod.png ' % i , repro [ i ] )
def decay_with_restart ( self ) :
l_min , l_max , l_mult = 1E-7 , 1E-3 , 0.5
t_cur , t_i , t_mult = 10 , 10 , 1
while True :
yield l_min + 0.5 * ( l_max - l_min ) * ( 1.0 + math . cos ( t_cur / t_i * math . pi ) )
t_cur + = 1
if t_cur > t_i :
t_cur , t_i = 0 , int ( t_i * t_mult )
l_max = max ( l_max * l_mult , 1e-12 )
l_min = max ( l_min * l_mult , 1e-8 )
def train ( self ) :
def train ( self ) :
images = np . zeros ( ( args . batch_size , 3 , args . batch_resolution , args . batch_resolution ) , dtype = np . float32 )
images = np . zeros ( ( args . batch_size , 3 , args . batch_resolution , args . batch_resolution ) , dtype = np . float32 )
l_min , l_max , l_mult = 1E-7 , 1E-3 , 0.2
learning_rate = self . decay_with_restart ( )
t_cur , t_i , t_mult = 120 , 150 , 1
try :
try :
i , running , start = 0 , None , time . time ( )
running , start = None , time . time ( )
for epoch in range ( args . epochs ) :
for epoch in range ( args . epochs ) :
total , stats = None , None
total , stats = None , None
for _ in range ( args . epoch_size ) :
l_r = next ( learning_rate )
i + = 1
if epoch > = args . generator_start : self . model . gen_lr . set_value ( l_r )
l_r = l_min + 0.5 * ( l_max - l_min ) * ( 1.0 + math . cos ( t_cur / t_i * math . pi ) )
if epoch > = args . discriminator_start : self . model . disc_lr . set_value ( l_r )
t_cur + = 1
l_r = 1E-4
if epoch > = args . generator_start : self . model . gen_lr . set_value ( l_r )
if epoch > = args . discriminator_start : self . model . disc_lr . set_value ( l_r )
if t_cur > = t_i :
t_cur , t_i = 0 , int ( t_i * t_mult )
l_max = max ( l_max * l_mult , 1e-11 )
l_min = max ( l_min * l_mult , 1e-7 )
for _ in range ( args . epoch_size ) :
self . thread . copy ( images )
self . thread . copy ( images )
output = self . model . fit ( images )
output = self . model . fit ( images )
losses = np . array ( output [ : 3 ] , dtype = np . float32 )
losses = np . array ( output [ : 3 ] , dtype = np . float32 )
@ -440,12 +449,12 @@ class NeuralEnhancer(object):
stats / = args . epoch_size
stats / = args . epoch_size
totals , labels = [ sum ( total ) ] + list ( total ) , [ ' total ' , ' prcpt ' , ' smthn ' , ' advrs ' ]
totals , labels = [ sum ( total ) ] + list ( total ) , [ ' total ' , ' prcpt ' , ' smthn ' , ' advrs ' ]
gen_info = [ ' {} {} {} = {:4.2e} ' . format ( ansi . WHITE_B , k , ansi . ENDC , v ) for k , v in zip ( labels , totals ) ]
gen_info = [ ' {} {} {} = {:4.2e} ' . format ( ansi . WHITE_B , k , ansi . ENDC , v ) for k , v in zip ( labels , totals ) ]
print ( ' \r Epoch # {} at {:4.1f} s {} ' . format ( epoch + 1 , time . time ( ) - start , ' ' * args . epoch_size ) )
print ( ' \r Epoch # {} at {:4.1f} s , lr={:4.2e} {} ' . format ( epoch + 1 , time . time ( ) - start , l_r , ' ' * args . epoch_size ) )
print ( ' - generator {} ' . format ( ' ' . join ( gen_info ) ) )
print ( ' - generator {} ' . format ( ' ' . join ( gen_info ) ) )
real , fake = stats [ : args . batch_size ] , stats [ args . batch_size : ]
real , fake = stats [ : args . batch_size ] , stats [ args . batch_size : ]
print ( ' - discriminator ' , real . mean ( ) , len ( np . where ( real > 0.5 ) [ 0 ] ) , fake . mean ( ) , len ( np . where ( fake < 0.5 ) [ 0 ] ) )
print ( ' - discriminator ' , real . mean ( ) , len ( np . where ( real > 0.5 ) [ 0 ] ) , fake . mean ( ) , len ( np . where ( fake < 0.5 ) [ 0 ] ) )
if epoch == args . adversar y _start- 1 :
if epoch == args . adversar ial _start- 1 :
print ( ' - adversary mode: generator engaging discriminator. ' )
print ( ' - adversary mode: generator engaging discriminator. ' )
self . model . adversary_weight . set_value ( args . adversary_weight )
self . model . adversary_weight . set_value ( args . adversary_weight )
running = None
running = None