You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
7.0 KiB
171 lines
7.0 KiB
import logging
|
|
from functools import partial
|
|
import os
|
|
import cv2
|
|
import torch
|
|
import torch.optim as optim
|
|
import tqdm
|
|
import yaml
|
|
from joblib import cpu_count
|
|
from torch.utils.data import DataLoader
|
|
import random
|
|
from dataset import PairedDataset
|
|
from metric_counter import MetricCounter
|
|
from models.losses import get_loss
|
|
from models.models import get_model
|
|
from models.networks import get_nets
|
|
import numpy as np
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
cv2.setNumThreads(0)
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, config, train: DataLoader, val: DataLoader):
|
|
self.config = config
|
|
self.train_dataset = train
|
|
self.val_dataset = val
|
|
self.metric_counter = MetricCounter(config['experiment_desc'])
|
|
|
|
def train(self):
|
|
self._init_params()
|
|
start_epoch = 0
|
|
if os.path.exists('last_Stripformer_gopro.pth'):
|
|
print('load_pretrained')
|
|
training_state = (torch.load('last_Stripformer_gopro.pth'))
|
|
start_epoch = training_state['epoch']
|
|
new_weight = self.netG.state_dict()
|
|
new_weight.update(training_state['model_state'])
|
|
self.netG.load_state_dict(new_weight)
|
|
new_optimizer = self.optimizer_G.state_dict()
|
|
new_optimizer.update(training_state['optimizer_state'])
|
|
self.optimizer_G.load_state_dict(new_optimizer)
|
|
new_scheduler = self.scheduler_G.state_dict()
|
|
new_scheduler.update(training_state['scheduler_state'])
|
|
self.scheduler_G.load_state_dict(new_scheduler)
|
|
else:
|
|
print('load_GoPro_pretrained')
|
|
training_state = (torch.load('final_Stripformer_pretrained.pth'))
|
|
new_weight = self.netG.state_dict()
|
|
new_weight.update(training_state)
|
|
self.netG.load_state_dict(new_weight)
|
|
|
|
|
|
for epoch in range(start_epoch, config['num_epochs']):
|
|
self._run_epoch(epoch)
|
|
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
|
self._validate(epoch)
|
|
self.scheduler_G.step()
|
|
|
|
scheduler_state = self.scheduler_G.state_dict()
|
|
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
|
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
|
if self.metric_counter.update_best_model():
|
|
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
|
|
|
|
if epoch % 200 == 0:
|
|
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
|
|
|
|
if epoch == (config['num_epochs']-1):
|
|
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
|
|
|
|
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
|
|
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
|
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
|
|
|
def _run_epoch(self, epoch):
|
|
self.metric_counter.clear()
|
|
for param_group in self.optimizer_G.param_groups:
|
|
lr = param_group['lr']
|
|
|
|
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
|
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
|
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
|
i = 0
|
|
for data in tq:
|
|
inputs, targets = self.model.get_input(data)
|
|
outputs = self.netG(inputs)
|
|
self.optimizer_G.zero_grad()
|
|
loss_G = self.criterionG(outputs, targets, inputs)
|
|
loss_G.backward()
|
|
self.optimizer_G.step()
|
|
self.metric_counter.add_losses(loss_G.item())
|
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
|
tq.set_postfix(loss=self.metric_counter.loss_message())
|
|
if not i:
|
|
self.metric_counter.add_image(img_for_vis, tag='train')
|
|
i += 1
|
|
if i > epoch_size:
|
|
break
|
|
tq.close()
|
|
self.metric_counter.write_to_tensorboard(epoch)
|
|
|
|
def _validate(self, epoch):
|
|
self.metric_counter.clear()
|
|
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
|
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
|
tq.set_description('Validation')
|
|
i = 0
|
|
for data in tq:
|
|
with torch.no_grad():
|
|
inputs, targets = self.model.get_input(data)
|
|
outputs = self.netG(inputs)
|
|
loss_G = self.criterionG(outputs, targets, inputs)
|
|
self.metric_counter.add_losses(loss_G.item())
|
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
|
if not i:
|
|
self.metric_counter.add_image(img_for_vis, tag='val')
|
|
i += 1
|
|
if i > epoch_size:
|
|
break
|
|
tq.close()
|
|
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
|
|
|
|
|
def _get_optim(self, params):
|
|
if self.config['optimizer']['name'] == 'adam':
|
|
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
|
else:
|
|
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
|
return optimizer
|
|
|
|
def _get_scheduler(self, optimizer):
|
|
if self.config['scheduler']['name'] == 'cosine':
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
|
else:
|
|
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
|
return scheduler
|
|
|
|
def _init_params(self):
|
|
self.criterionG = get_loss(self.config['model'])
|
|
self.netG = get_nets(self.config['model'])
|
|
self.netG.cuda()
|
|
self.model = get_model(self.config['model'])
|
|
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
|
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
with open('config/config_Stripformer_gopro.yaml', 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
# setup
|
|
torch.backends.cudnn.enabled = True
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# set random seed
|
|
seed = 666
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
batch_size = config.pop('batch_size')
|
|
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
|
|
|
datasets = map(config.pop, ('train', 'val'))
|
|
datasets = map(PairedDataset.from_config, datasets)
|
|
train, val = map(get_dataloader, datasets)
|
|
trainer = Trainer(config, train=train, val=val)
|
|
trainer.train()
|