未验证 提交 130bd7f9 编写于 作者: L LielinJiang 提交者: GitHub

Fix 2.0-beta bugs (#183)

* fix 2.0-beta bugs

* update pretreained path

* add extract_weight.py
上级 60066eb2
...@@ -10,6 +10,7 @@ model: ...@@ -10,6 +10,7 @@ model:
gan_criterion: gan_criterion:
name: GANLoss name: GANLoss
gan_mode: lsgan gan_mode: lsgan
# use your trained path
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300. g_adv_weight: 300.
d_adv_weight: 300. d_adv_weight: 300.
...@@ -47,21 +48,21 @@ dataset: ...@@ -47,21 +48,21 @@ dataset:
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/animedataset/test/HR_photo dataroot: data/animedataset/test/HR_photo
max_dataset_size: inf preprocess:
direction: BtoA - name: LoadImageFromFile
input_nc: 3 key: A
output_nc: 3 - name: Transforms
serial_batches: False input_keys: [A]
pool_size: 50 pipeline:
transforms: - name: ResizeToScale
- name: ResizeToScale size: [256, 256]
size: [256, 256] scale: 32
scale: 32 interpolation: bilinear
interpolation: bilinear - name: Transpose
- name: Transpose - name: Normalize
- name: Normalize mean: [127.5, 127.5, 127.5]
mean: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] keys: [image, image]
lr_scheduler: lr_scheduler:
name: LinearDecay name: LinearDecay
......
...@@ -21,44 +21,33 @@ model: ...@@ -21,44 +21,33 @@ model:
dataset: dataset:
train: train:
name: SingleDataset name: CommonVisionDataset
dataroot: data/mnist/train dataset_name: MNIST
num_workers: 0
batch_size: 128 batch_size: 128
preprocess: return_label: False
- name: LoadImageFromFile transforms:
key: A - name: Resize
- name: Transfroms size: [64, 64]
input_keys: [A] interpolation: 'bicubic' #cv2.INTER_CUBIC
pipeline: - name: Normalize
- name: Resize mean: [127.5]
size: [64, 64] std: [127.5]
interpolation: 'bicubic' #cv2.INTER_CUBIC keys: [image]
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: SingleDataset name: CommonVisionDataset
dataroot: data/mnist/test dataset_name: MNIST
preprocess: num_workers: 0
- name: LoadImageFromFile batch_size: 128
key: A return_label: False
- name: Transforms transforms:
input_keys: [A] - name: Resize
pipeline: size: [64, 64]
- name: Resize interpolation: 'bicubic' #cv2.INTER_CUBIC
size: [64, 64] - name: Normalize
interpolation: 'bicubic' #cv2.INTER_CUBIC mean: [127.5]
keys: [image, image] std: [127.5]
- name: Transpose keys: [image]
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler: lr_scheduler:
name: LinearDecay name: LinearDecay
......
...@@ -64,7 +64,7 @@ dataset: ...@@ -64,7 +64,7 @@ dataset:
preprocess: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: pair key: pair
- name: Transforms - name: Transforms
input_keys: [A, B] input_keys: [A, B]
pipeline: pipeline:
- name: Resize - name: Resize
......
...@@ -92,6 +92,21 @@ train model ...@@ -92,6 +92,21 @@ train model
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
``` ```
### Inference
When you finish training, you need to use ``tools/extract_weight.py`` to extract the corresponding weights.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output YOUR_WEIGHT_PATH.pdparams
```
Then use ``applications/tools/styleganv2.py`` to get results
```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```
Note: ``--size`` should be same with your config file.
## Results ## Results
Random Samples: Random Samples:
......
...@@ -54,9 +54,56 @@ python -u tools/styleganv2.py \ ...@@ -54,9 +54,56 @@ python -u tools/styleganv2.py \
- n_col: 采样的图片的列数 - n_col: 采样的图片的列数
- cpu: 是否使用cpu推理,若不使用,请在命令中去除 - cpu: 是否使用cpu推理,若不使用,请在命令中去除
### 训练(TODO) ### 训练
#### 准备数据集
你可以从[这里](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)下载对应的数据集
为了方便,我们提供了[images256x256.tar](https://paddlegan.bj.bcebos.com/datasets/images256x256.tar)
目前的配置文件默认数据集的结构如下:
```
PaddleGAN
├── data
├── ffhq
├──images1024x1024
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──images256x256
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──custom_data
├── img0.png
├── img1.png
├── img2.png
├── img3.png
├── img4.png
...
```
启动训练
```
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
```
### 推理
训练结束后,需要使用 ``tools/extract_weight.py`` 来提取对应的权重给``applications/tools/styleganv2.py``来进行推理.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output stylegan_config_f.pdparams
```
```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```
未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。 注意: ``--size`` 这个参数要和配置文件中的参数保持一致.
## 生成结果展示 ## 生成结果展示
......
...@@ -20,7 +20,7 @@ from .base_dataset import BaseDataset ...@@ -20,7 +20,7 @@ from .base_dataset import BaseDataset
from .image_folder import ImageFolder from .image_folder import ImageFolder
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms from .preprocess.builder import build_transforms
@DATASETS.register() @DATASETS.register()
......
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from .builder import DATASETS from .builder import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .transforms.builder import build_transforms from .preprocess.builder import build_transforms
@DATASETS.register() @DATASETS.register()
......
...@@ -62,3 +62,15 @@ def build_preprocess(cfg): ...@@ -62,3 +62,15 @@ def build_preprocess(cfg):
preproccess = Compose(preproccess) preproccess = Compose(preproccess)
return preproccess return preproccess
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
...@@ -264,3 +264,74 @@ class SRNoise(T.BaseTransform): ...@@ -264,3 +264,74 @@ class SRNoise(T.BaseTransform):
image = image + normed_noise image = image + normed_noise
image = np.clip(image, 0., 1.) image = np.clip(image, 0., 1.)
return image return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value
def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params
def _apply_image(self, image):
return np.clip(image + self.params['value'], 0, 255).astype('uint8')
# return custom_F.add(image, self.params['value'])
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params
@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x
if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from . import functional_cv2 as F_cv2
from paddle.vision.transforms.functional import _is_numpy_image, _is_pil_image
__all__ = ['add']
def add(pic, value):
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(
type(pic)))
if _is_pil_image(pic):
raise NotImplementedError('add not support pil image')
else:
return F_cv2.add(pic, value)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
import numpy as np
def add(image, value):
return np.clip(image + value, 0, 255).astype('uint8')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numbers
import collections
import numpy as np
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
from . import functional as custom_F
from .builder import TRANSFORMS
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
TRANSFORMS.register(T.Resize)
TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
TRANSFORMS.register(T.Grayscale)
@TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def _apply_image(self, img):
i, j, h, w = self.params['crop_prams']
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value
def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params
def _apply_image(self, image):
return custom_F.add(image, self.params['value'])
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params
@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x
if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
...@@ -165,6 +165,8 @@ class Trainer: ...@@ -165,6 +165,8 @@ class Trainer:
iter_loader = IterLoader(self.train_dataloader) iter_loader = IterLoader(self.train_dataloader)
# set model.is_train = True
self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1): while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch self.inner_iter = self.current_iter % self.iters_per_epoch
...@@ -219,6 +221,9 @@ class Trainer: ...@@ -219,6 +221,9 @@ class Trainer:
for metric in self.metrics.values(): for metric in self.metrics.values():
metric.reset() metric.reset()
# set model.is_train = False
self.model.setup_train_mode(is_train=False)
for i in range(self.max_eval_steps): for i in range(self.max_eval_steps):
data = next(iter_loader) data = next(iter_loader)
self.model.setup_input(data) self.model.setup_input(data)
...@@ -289,7 +294,9 @@ class Trainer: ...@@ -289,7 +294,9 @@ class Trainer:
message += 'ips: %.5f images/s ' % self.ips message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'): if hasattr(self, 'step_time'):
eta = self.step_time * (self.total_iters - self.current_iter - 1) eta = self.step_time * (self.total_iters - self.current_iter)
eta = eta if eta > 0 else 0
eta_str = str(datetime.timedelta(seconds=int(eta))) eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}' message += f'eta: {eta_str}'
......
...@@ -83,7 +83,7 @@ class AnimeGANV2Model(BaseModel): ...@@ -83,7 +83,7 @@ class AnimeGANV2Model(BaseModel):
self.smooth_gray = paddle.to_tensor(input['smooth_gray']) self.smooth_gray = paddle.to_tensor(input['smooth_gray'])
else: else:
self.real = paddle.to_tensor(input['A']) self.real = paddle.to_tensor(input['A'])
self.image_paths = input['A_paths'] self.image_paths = input['A_path']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
......
...@@ -56,8 +56,9 @@ class DCGANModel(BaseModel): ...@@ -56,8 +56,9 @@ class DCGANModel(BaseModel):
input (dict): include the data itself and its metadata information. input (dict): include the data itself and its metadata information.
""" """
# get 1-channel gray image, or 3-channel color image # get 1-channel gray image, or 3-channel color image
self.real = paddle.to_tensor(input['A']) self.real = paddle.to_tensor(input['img'])
self.image_paths = input['A_path'] if 'img_path' in input:
self.image_paths = input['A_path']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <train_iter> and <test_iter>.""" """Run forward pass; called by both functions <train_iter> and <test_iter>."""
......
...@@ -74,10 +74,8 @@ class Pix2PixModel(BaseModel): ...@@ -74,10 +74,8 @@ class Pix2PixModel(BaseModel):
AtoB = self.direction == 'AtoB' AtoB = self.direction == 'AtoB'
self.real_A = paddle.to_tensor( self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
input['A' if AtoB else 'B']) self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
self.real_B = paddle.to_tensor(
input['B' if AtoB else 'A'])
self.image_paths = input['A_path' if AtoB else 'B_path'] self.image_paths = input['A_path' if AtoB else 'B_path']
...@@ -141,3 +139,7 @@ class Pix2PixModel(BaseModel): ...@@ -141,3 +139,7 @@ class Pix2PixModel(BaseModel):
optimizers['optimG'].clear_grad() optimizers['optimG'].clear_grad()
self.backward_G() self.backward_G()
optimizers['optimG'].step() optimizers['optimG'].step()
def test_iter(self, metrics=None):
with paddle.no_grad():
self.forward()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,47 +12,29 @@ ...@@ -12,47 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import traceback
import paddle import paddle
from ...utils.registry import Registry import argparse
TRANSFORMS = Registry("TRANSFORMS")
def parse_args():
parser = argparse.ArgumentParser(
description='This script extracts weights from a checkpoint')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--net-name',
type=str,
help='net name in checkpoint dict')
parser.add_argument('--output', type=str, help='destination file name')
args = parser.parse_args()
return args
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data): def main():
for f in self.transforms: args = parse_args()
try: assert args.output.endswith(".pdparams")
data = f(data) ckpt = paddle.load(args.checkpoint)
except Exception as e: state_dict = ckpt[args.net_name]
print(f) paddle.save(state_dict, args.output)
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def build_transforms(cfg): if __name__ == '__main__':
transforms = [] main()
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册