You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
14 lines
413 B
14 lines
413 B
import torch.nn as nn
|
|
from models.Stripformer import Stripformer
|
|
|
|
def get_generator(model_config):
|
|
generator_name = model_config['g_name']
|
|
if generator_name == 'Stripformer':
|
|
model_g = Stripformer()
|
|
else:
|
|
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
|
return nn.DataParallel(model_g)
|
|
|
|
def get_nets(model_config):
|
|
return get_generator(model_config)
|