@ -130,7 +130,10 @@ class DataLoader(threading.Thread):
self . data_copied = threading . Event ( )
self . resolution = args . batch_resolution
self . seed_resolution = int ( args . batch_resolution / 2 * * args . scales )
self . buffer = np . zeros ( ( args . buffer_size , 3 , self . resolution , self . resolution ) , dtype = np . float32 )
self . seed_buffer = np . zeros ( ( args . buffer_size , 3 , self . seed_resolution , self . seed_resolution ) , dtype = np . float32 )
self . files = glob . glob ( args . train )
if len ( self . files ) == 0 :
error ( " There were no files found to train from searching for ` {} ` " . format ( args . train ) ,
@ -168,17 +171,23 @@ class DataLoader(threading.Thread):
i = self . available . pop ( )
self . buffer [ i ] = np . transpose ( copy / 255.0 - 0.5 , ( 2 , 0 , 1 ) )
seed_copy = scipy . misc . imresize ( copy ,
size = ( self . seed_resolution , self . seed_resolution ) ,
interp = ' bilinear ' )
self . seed_buffer [ i ] = np . transpose ( seed_copy / 255.0 - 0.5 , ( 2 , 0 , 1 ) )
self . ready . add ( i )
if len ( self . ready ) > = args . batch_size :
self . data_ready . set ( )
def copy ( self , outp ut) :
def copy ( self , images_out, seeds_o ut) :
self . data_ready . wait ( )
self . data_ready . clear ( )
for i , j in enumerate ( random . sample ( self . ready , args . batch_size ) ) :
output [ i ] = self . buffer [ j ]
images_out [ i ] = self . buffer [ j ]
seeds_out [ i ] = self . seed_buffer [ j ]
self . available . add ( j )
self . data_copied . set ( )
@ -214,7 +223,7 @@ class Model(object):
self . network = collections . OrderedDict ( )
if args . train :
self . network [ ' img ' ] = InputLayer ( ( None , 3 , None , None ) )
self . network [ ' seed ' ] = PoolLayer( self . network [ ' img ' ] , pool_size = 2 * * args . scales , mode = ' average_exc_pad ' )
self . network [ ' seed ' ] = InputLayer( ( None , 3 , None , None ) )
else :
self . network [ ' img ' ] = InputLayer ( ( None , 3 , None , None ) )
self . network [ ' seed ' ] = self . network [ ' img ' ]
@ -380,9 +389,10 @@ class Model(object):
def compile ( self ) :
# Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T . tensor4 ( )
input_layers = { self . network [ ' img ' ] : input_tensor }
seed_tensor = T . tensor4 ( )
input_layers = { self . network [ ' img ' ] : input_tensor , self . network [ ' seed ' ] : seed_tensor }
output = lasagne . layers . get_output ( [ self . network [ k ] for k in [ ' img ' , ' seed ' , ' out ' ] ] , input_layers , deterministic = True )
self . predict = theano . function ( [ input_tensor ], output )
self . predict = theano . function ( [ input_tensor , seed_tensor ], output )
if not args . train : return
@ -408,7 +418,7 @@ class Model(object):
# Combined Theano function for updating both generator and discriminator at the same time.
updates = collections . OrderedDict ( list ( gen_updates . items ( ) ) + list ( disc_updates . items ( ) ) )
self . fit = theano . function ( [ input_tensor ], gen_losses + [ disc_out . mean ( axis = ( 1 , 2 , 3 ) ) ] , updates = updates )
self . fit = theano . function ( [ input_tensor , seed_tensor ], gen_losses + [ disc_out . mean ( axis = ( 1 , 2 , 3 ) ) ] , updates = updates )
@ -448,7 +458,9 @@ class NeuralEnhancer(object):
if t_cur % args . learning_period == 0 : l_r * = args . learning_decay
def train ( self ) :
seed_size = int ( args . batch_resolution / 2 * * args . scales )
images = np . zeros ( ( args . batch_size , 3 , args . batch_resolution , args . batch_resolution ) , dtype = np . float32 )
seeds = np . zeros ( ( args . batch_size , 3 , seed_size , seed_size ) , dtype = np . float32 )
learning_rate = self . decay_learning_rate ( )
try :
running , start = None , time . time ( )
@ -459,8 +471,8 @@ class NeuralEnhancer(object):
if epoch > = args . discriminator_start : self . model . disc_lr . set_value ( l_r )
for _ in range ( args . epoch_size ) :
self . thread . copy ( images )
output = self . model . fit ( images )
self . thread . copy ( images , seeds )
output = self . model . fit ( images , seeds )
losses = np . array ( output [ : 3 ] , dtype = np . float32 )
stats = ( stats + output [ 3 ] ) if stats is not None else output [ 3 ]
total = total + losses if total is not None else losses
@ -469,7 +481,7 @@ class NeuralEnhancer(object):
running = l if running is None else running * 0.95 + 0.05 * l
print ( ' ↑ ' if l > running else ' ↓ ' , end = ' ' , flush = True )
orign , scald , repro = self . model . predict ( images )
orign , scald , repro = self . model . predict ( images , seeds )
self . show_progress ( orign , scald , repro )
total / = args . epoch_size
stats / = args . epoch_size