未验证 提交 01cb542f 编写于 作者: W wangna11BD 提交者: GitHub

Support @to_static traing for edvr pix2pix and esrgan (#750)

上级 461bc8cd
......@@ -24,6 +24,8 @@ model:
w_TSA: False
pixel_criterion:
name: CharbonnierLoss
# training model under @to_static
to_static: False
export_model:
- {name: 'generator', inputs_num: 1}
......
......@@ -14,6 +14,8 @@ model:
nb: 23
pixel_criterion:
name: L1Loss
# training model under @to_static
to_static: False
export_model:
- {name: 'generator', inputs_num: 1}
......
......@@ -24,6 +24,8 @@ model:
gan_criterion:
name: GANLoss
gan_mode: vanilla
# training model under @to_static
to_static: False
dataset:
train:
......
......@@ -15,6 +15,7 @@
import paddle
import paddle.nn as nn
from .base_model import apply_to_static
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.edvr import ResidualBlockNoBN, DCNPack
......@@ -28,7 +29,8 @@ class EDVRModel(BaseSRModel):
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
"""
def __init__(self, generator, tsa_iter, pixel_criterion=None):
def __init__(self, generator, tsa_iter, pixel_criterion=None, to_static=False,
image_shape=None):
"""Initialize the EDVR class.
Args:
......@@ -36,7 +38,9 @@ class EDVRModel(BaseSRModel):
tsa_iter (dict): config of tsa_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(EDVRModel, self).__init__(generator, pixel_criterion)
super(EDVRModel, self).__init__(generator, pixel_criterion,
to_static=to_static,
image_shape=image_shape)
self.tsa_iter = tsa_iter
self.current_iter = 1
init_edvr_weight(self.nets['generator'])
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import paddle
from .base_model import BaseModel
from .base_model import BaseModel, apply_to_static
from .builder import MODELS
from .generators.builder import build_generator
......@@ -36,7 +36,9 @@ class Pix2PixModel(BaseModel):
discriminator=None,
pixel_criterion=None,
gan_criterion=None,
direction='a2b'):
direction='a2b',
to_static=False,
image_shape=None):
"""Initialize the pix2pix class.
Args:
......@@ -51,11 +53,15 @@ class Pix2PixModel(BaseModel):
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'])
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if discriminator:
self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['netD'])
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
......
......@@ -17,7 +17,7 @@ import paddle.nn as nn
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from .base_model import BaseModel
from .base_model import BaseModel, apply_to_static
from .builder import MODELS
from ..utils.visual import tensor2img
from ..modules.init import reset_parameters
......@@ -28,7 +28,8 @@ class BaseSRModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
def __init__(self, generator, pixel_criterion=None, use_init_weight=False, to_static=False,
image_shape=None):
"""
Args:
generator (dict): config of generator.
......@@ -37,6 +38,8 @@ class BaseSRModel(BaseModel):
super(BaseSRModel, self).__init__()
self.nets['generator'] = build_generator(generator)
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['generator'])
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
......
......@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o log_confi
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
......
......@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
......
......@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册