提交 345b30e2 编写于 作者: Q qingqing01

Update CycleGAN

上级 df33864a
...@@ -18,9 +18,8 @@ from __future__ import print_function ...@@ -18,9 +18,8 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.model import Model
from paddle.incubate.hapi.loss import Loss
from layers import ConvBN, DeConvBN from layers import ConvBN, DeConvBN
...@@ -133,7 +132,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer): ...@@ -133,7 +132,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer):
return y return y
class Generator(Model): class Generator(paddle.nn.Layer):
def __init__(self, input_channel=3): def __init__(self, input_channel=3):
super(Generator, self).__init__() super(Generator, self).__init__()
self.g = ResnetGenerator(input_channel) self.g = ResnetGenerator(input_channel)
...@@ -143,7 +142,7 @@ class Generator(Model): ...@@ -143,7 +142,7 @@ class Generator(Model):
return fake return fake
class GeneratorCombine(Model): class GeneratorCombine(paddle.nn.Layer):
def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None, def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None,
is_train=True): is_train=True):
super(GeneratorCombine, self).__init__() super(GeneratorCombine, self).__init__()
...@@ -177,16 +176,15 @@ class GeneratorCombine(Model): ...@@ -177,16 +176,15 @@ class GeneratorCombine(Model):
return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B
class GLoss(Loss): class GLoss(paddle.nn.Layer):
def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5): def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5):
super(GLoss, self).__init__() super(GLoss, self).__init__()
self.lambda_A = lambda_A self.lambda_A = lambda_A
self.lambda_B = lambda_B self.lambda_B = lambda_B
self.lambda_identity = lambda_identity self.lambda_identity = lambda_identity
def forward(self, outputs, labels=None): def forward(self, input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A,
input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B = outputs idt_B, valid_A, valid_B):
def mse(a, b): def mse(a, b):
return fluid.layers.reduce_mean(fluid.layers.square(a - b)) return fluid.layers.reduce_mean(fluid.layers.square(a - b))
...@@ -211,7 +209,7 @@ class GLoss(Loss): ...@@ -211,7 +209,7 @@ class GLoss(Loss):
return loss return loss
class Discriminator(Model): class Discriminator(paddle.nn.Layer):
def __init__(self, input_channel=3): def __init__(self, input_channel=3):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.d = NLayerDiscriminator(input_channel) self.d = NLayerDiscriminator(input_channel)
...@@ -222,13 +220,11 @@ class Discriminator(Model): ...@@ -222,13 +220,11 @@ class Discriminator(Model):
return pred_real, pred_fake return pred_real, pred_fake
class DLoss(Loss): class DLoss(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(DLoss, self).__init__() super(DLoss, self).__init__()
def forward(self, inputs, labels=None): def forward(self, real, fake):
pred_real, pred_fake = inputs loss = fluid.layers.square(fake) + fluid.layers.square(real - 1.)
loss = fluid.layers.square(pred_fake) + fluid.layers.square(pred_real -
1.)
loss = fluid.layers.reduce_mean(loss / 2.0) loss = fluid.layers.reduce_mean(loss / 2.0)
return loss return loss
...@@ -24,26 +24,32 @@ import argparse ...@@ -24,26 +24,32 @@ import argparse
from PIL import Image from PIL import Image
from scipy.misc import imsave from scipy.misc import imsave
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.model import Model, Input, set_device from paddle.static import InputSpec as Input
from check import check_gpu, check_version from check import check_gpu, check_version
from cyclegan import Generator, GeneratorCombine from cyclegan import Generator, GeneratorCombine
def main(): def main():
place = set_device(FLAGS.device) place = paddle.set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None fluid.enable_dygraph(place) if FLAGS.dynamic else None
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
# Generators # Generators
g_AB = Generator() g_AB = Generator()
g_BA = Generator() g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False)
im_shape = [-1, 3, 256, 256] g = paddle.Model(
input_A = Input(im_shape, 'float32', 'input_A') GeneratorCombine(
input_B = Input(im_shape, 'float32', 'input_B') g_AB, g_BA, is_train=False),
g.prepare(inputs=[input_A, input_B], device=FLAGS.device) inputs=[input_A, input_B])
g.prepare()
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
out_path = FLAGS.output + "/single" out_path = FLAGS.output + "/single"
......
...@@ -21,8 +21,9 @@ import argparse ...@@ -21,8 +21,9 @@ import argparse
import numpy as np import numpy as np
from scipy.misc import imsave from scipy.misc import imsave
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.model import Model, Input, set_device from paddle.static import InputSpec as Input
from check import check_gpu, check_version from check import check_gpu, check_version
from cyclegan import Generator, GeneratorCombine from cyclegan import Generator, GeneratorCombine
...@@ -30,18 +31,22 @@ import data as data ...@@ -30,18 +31,22 @@ import data as data
def main(): def main():
place = set_device(FLAGS.device) place = paddle.set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None fluid.enable_dygraph(place) if FLAGS.dynamic else None
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
# Generators # Generators
g_AB = Generator() g_AB = Generator()
g_BA = Generator() g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False) g = paddle.Model(
GeneratorCombine(
g_AB, g_BA, is_train=False),
inputs=[input_A, input_B])
im_shape = [-1, 3, 256, 256] g.prepare()
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B], device=FLAGS.device)
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True) g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
if not os.path.exists(FLAGS.output): if not os.path.exists(FLAGS.output):
......
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.model import Model, Input, set_device from paddle.static import InputSpec as Input
from check import check_gpu, check_version from check import check_gpu, check_version
from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
...@@ -48,18 +48,29 @@ def opt(parameters): ...@@ -48,18 +48,29 @@ def opt(parameters):
def main(): def main():
place = set_device(FLAGS.device) place = paddle.set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None fluid.enable_dygraph(place) if FLAGS.dynamic else None
im_shape = [None, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
fake_A = Input(im_shape, 'float32', 'fake_A')
fake_B = Input(im_shape, 'float32', 'fake_B')
# Generators # Generators
g_AB = Generator() g_AB = Generator()
g_BA = Generator() g_BA = Generator()
# Discriminators
d_A = Discriminator() d_A = Discriminator()
d_B = Discriminator() d_B = Discriminator()
g = GeneratorCombine(g_AB, g_BA, d_A, d_B) g = paddle.Model(
GeneratorCombine(g_AB, g_BA, d_A, d_B), inputs=[input_A, input_B])
g_AB = paddle.Model(g_AB, [input_A])
g_BA = paddle.Model(g_BA, [input_B])
# Discriminators
d_A = paddle.Model(d_A, [input_B, fake_B])
d_B = paddle.Model(d_B, [input_A, fake_A])
da_params = d_A.parameters() da_params = d_A.parameters()
db_params = d_B.parameters() db_params = d_B.parameters()
...@@ -69,21 +80,12 @@ def main(): ...@@ -69,21 +80,12 @@ def main():
db_optimizer = opt(db_params) db_optimizer = opt(db_params)
g_optimizer = opt(g_params) g_optimizer = opt(g_params)
im_shape = [None, 3, 256, 256] g_AB.prepare()
input_A = Input(im_shape, 'float32', 'input_A') g_BA.prepare()
input_B = Input(im_shape, 'float32', 'input_B')
fake_A = Input(im_shape, 'float32', 'fake_A')
fake_B = Input(im_shape, 'float32', 'fake_B')
g_AB.prepare(inputs=[input_A], device=FLAGS.device)
g_BA.prepare(inputs=[input_B], device=FLAGS.device)
g.prepare( g.prepare(g_optimizer, GLoss())
g_optimizer, GLoss(), inputs=[input_A, input_B], device=FLAGS.device) d_A.prepare(da_optimizer, DLoss())
d_A.prepare( d_B.prepare(db_optimizer, DLoss())
da_optimizer, DLoss(), inputs=[input_B, fake_B], device=FLAGS.device)
d_B.prepare(
db_optimizer, DLoss(), inputs=[input_A, fake_A], device=FLAGS.device)
if FLAGS.resume: if FLAGS.resume:
g.load(FLAGS.resume) g.load(FLAGS.resume)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册