You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

14 lines
413 B

import torch.nn as nn
from models.Stripformer import Stripformer
def get_generator(model_config):
generator_name = model_config['g_name']
if generator_name == 'Stripformer':
model_g = Stripformer()
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)
return nn.DataParallel(model_g)
def get_nets(model_config):
return get_generator(model_config)