未验证 提交 96551765 编写于 作者: B Birdylx 提交者: GitHub

Support amp for esrgan (#712)

上级 af681065
......@@ -133,7 +133,7 @@ class Trainer:
cfg.optimizer)
# setup amp train
self.scaler = self.setup_amp_train() if self.cfg.amp else None
self.scalers = self.setup_amp_train() if self.cfg.amp else None
# multiple gpus prepare
if ParallelEnv().nranks > 1:
......@@ -164,11 +164,10 @@ class Trainer:
self.profiler_options = cfg.profiler_options
def setup_amp_train(self):
""" decerate model, optimizer and return a GradScaler """
""" decerate model, optimizer and return a list of GradScaler """
self.logger.info('use AMP to train. AMP level = {}'.format(
self.cfg.amp_level))
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# need to decorate model and optim if amp_level == 'O2'
if self.cfg.amp_level == 'O2':
nets, optimizers = list(self.model.nets.values()), list(
......@@ -181,7 +180,13 @@ class Trainer:
self.model.nets[k] = nets[i]
for i, (k, _) in enumerate(self.optimizers.items()):
self.optimizers[k] = optimizers[i]
return scaler
scalers = [
paddle.amp.GradScaler(init_loss_scaling=1024)
for i in range(len(self.optimizers))
]
return scalers
def distributed_data_parallel(self):
paddle.distributed.init_parallel_env()
......@@ -223,7 +228,7 @@ class Trainer:
self.model.setup_input(data)
if self.cfg.amp:
self.model.train_iter_amp(self.optimizers, self.scaler,
self.model.train_iter_amp(self.optimizers, self.scalers,
self.cfg.amp_level) # amp train
else:
self.model.train_iter(self.optimizers) # norm train
......
......@@ -76,7 +76,7 @@ class EDVRModel(BaseSRModel):
self.current_iter += 1
# amp train with brute force implementation
def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):
def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'):
optims['optim'].clear_grad()
if self.tsa_iter:
if self.current_iter == 1:
......@@ -97,9 +97,9 @@ class EDVRModel(BaseSRModel):
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
scaled_loss = scaler.scale(loss_pixel)
scaled_loss = scalers[0].scale(loss_pixel)
scaled_loss.backward()
scaler.minimize(optims['optim'], scaled_loss)
scalers[0].minimize(optims['optim'], scaled_loss)
self.current_iter += 1
......
......@@ -29,6 +29,7 @@ class ESRGAN(BaseSRModel):
ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(self,
generator,
discriminator=None,
......@@ -127,3 +128,87 @@ class ESRGAN(BaseSRModel):
else:
l_total.backward()
optimizers['optimG'].step()
# amp training
def train_iter_amp(self, optimizers=None, scalers=None, amp_level='O1'):
optimizers['optimG'].clear_grad()
l_total = 0
# put loss computation in amp context
with paddle.amp.auto_cast(enable=True, level=amp_level):
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
if self.pixel_criterion:
l_pix = self.pixel_criterion(self.output, self.gt)
l_total += l_pix
self.losses['loss_pix'] = l_pix
if self.perceptual_criterion:
l_g_percep, l_g_style = self.perceptual_criterion(
self.output, self.gt)
# l_total += l_pix
if l_g_percep is not None:
l_total += l_g_percep
self.losses['loss_percep'] = l_g_percep
if l_g_style is not None:
l_total += l_g_style
self.losses['loss_style'] = l_g_style
# gan loss (relativistic gan)
if hasattr(self, 'gan_criterion'):
self.set_requires_grad(self.nets['discriminator'], False)
# put fwd and loss computation in amp context
with paddle.amp.auto_cast(enable=True, level=amp_level):
real_d_pred = self.nets['discriminator'](self.gt).detach()
fake_g_pred = self.nets['discriminator'](self.output)
l_g_real = self.gan_criterion(real_d_pred -
paddle.mean(fake_g_pred),
False,
is_disc=False)
l_g_fake = self.gan_criterion(fake_g_pred -
paddle.mean(real_d_pred),
True,
is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2
l_total += l_g_gan
self.losses['l_g_gan'] = l_g_gan
scaled_l_total = scalers[0].scale(l_total)
scaled_l_total.backward()
optimizers['optimG'].step()
scalers[0].minimize(optimizers['optimG'], scaled_l_total)
self.set_requires_grad(self.nets['discriminator'], True)
optimizers['optimD'].clear_grad()
with paddle.amp.auto_cast(enable=True, level=amp_level):
# real
fake_d_pred = self.nets['discriminator'](self.output).detach()
real_d_pred = self.nets['discriminator'](self.gt)
l_d_real = self.gan_criterion(
real_d_pred - paddle.mean(fake_d_pred), True,
is_disc=True) * 0.5
# fake
fake_d_pred = self.nets['discriminator'](self.output.detach())
l_d_fake = self.gan_criterion(
fake_d_pred - paddle.mean(real_d_pred.detach()),
False,
is_disc=True) * 0.5
l_temp = l_d_real + l_d_fake
scaled_l_temp = scalers[1].scale(l_temp)
scaled_l_temp.backward()
scalers[0].minimize(optimizers['optimD'], scaled_l_temp)
self.losses['l_d_real'] = l_d_real
self.losses['l_d_fake'] = l_d_fake
self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
else:
scaled_l_total = scalers[0].scale(l_total)
scaled_l_total.backward()
optimizers['optimG'].step()
scalers[0].minimize(optimizers['optimG'], scaled_l_total)
......@@ -98,7 +98,7 @@ class MultiStageVSRModel(BaseSRModel):
self.current_iter += 1
# amp train with brute force implementation
def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):
def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'):
optims['optim'].clear_grad()
if self.fix_iter:
if self.current_iter == 1:
......@@ -133,9 +133,9 @@ class MultiStageVSRModel(BaseSRModel):
if 'loss_pix' in _key)
self.losses['loss'] = self.loss
scaled_loss = scaler.scale(self.loss)
scaled_loss = scalers[0].scale(self.loss)
scaled_loss.backward()
scaler.minimize(optims['optim'], scaled_loss)
scalers[0].minimize(optims['optim'], scaled_loss)
self.current_iter += 1
......
......@@ -27,6 +27,7 @@ from ..modules.init import reset_parameters
class BaseSRModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
"""
Args:
......@@ -65,6 +66,22 @@ class BaseSRModel(BaseModel):
loss_pixel.backward()
optims['optim'].step()
# amp training
def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'):
optims['optim'].clear_grad()
# put fwd and loss computation in amp context
with paddle.amp.auto_cast(enable=True, level=amp_level):
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
scaled_loss_pixel = scalers[0].scale(loss_pixel)
scaled_loss_pixel.backward()
scalers[0].minimize(optims['optim'], scaled_loss_pixel)
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
......@@ -84,6 +101,7 @@ class BaseSRModel(BaseModel):
def init_sr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (not isinstance(
m, (nn.BatchNorm, nn.BatchNorm2D))):
......
......@@ -51,7 +51,7 @@ null:null
null:null
===========================train_benchmark_params==========================
batch_size:64
fp_items:fp32
fp_items:fp32|fp16
total_iters:100
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
......@@ -51,7 +51,7 @@ null:null
null:null
===========================train_benchmark_params==========================
batch_size:32|64
fp_items:fp32
fp_items:fp32|fp16
total_iters:500
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
......@@ -51,7 +51,7 @@ null:null
null:null
===========================train_benchmark_params==========================
batch_size:2|4
fp_items:fp32
fp_items:fp32|fp16
total_iters:60
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
......
......@@ -197,5 +197,6 @@ elif [ ${MODE} = "cpp_infer" ]; then
rm -rf ./inference/msvsr*
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar --no-check-certificate
cd ./inference && tar xf msvsr.tar && cd ../
wget -nc -P ./data https://paddlegan.bj.bcebos.com/datasets/low_res.mp4 --no-check-certificate
fi
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册