deoldify_predictor.py 5.5 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
L
LielinJiang 已提交
14

L
LielinJiang 已提交
15
import os
L
LielinJiang 已提交
16 17 18 19 20
import cv2
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
L
LielinJiang 已提交
21 22

import paddle
L
LielinJiang 已提交
23
from ppgan.utils.download import get_path_from_url
L
LielinJiang 已提交
24 25
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model
L
LielinJiang 已提交
26
from ppgan.utils.logger import get_logger
L
LielinJiang 已提交
27

L
LielinJiang 已提交
28
from .base_predictor import BasePredictor
L
LielinJiang 已提交
29

L
LielinJiang 已提交
30 31
DEOLDIFY_STABLE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DEOLDIFY_ART_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_art.pdparams'
L
LielinJiang 已提交
32

L
LielinJiang 已提交
33

L
LielinJiang 已提交
34
class DeOldifyPredictor(BasePredictor):
L
LielinJiang 已提交
35 36 37 38 39
    def __init__(self,
                 output='output',
                 weight_path=None,
                 artistic=False,
                 render_factor=32):
L
LielinJiang 已提交
40
        self.output = os.path.join(output, 'DeOldify')
L
LielinJiang 已提交
41 42
        if not os.path.exists(self.output):
            os.makedirs(self.output)
L
LielinJiang 已提交
43
        self.render_factor = render_factor
L
LielinJiang 已提交
44 45
        self.model = build_model(
            model_type='artistic' if artistic else 'stable')
L
LielinJiang 已提交
46
        if weight_path is None:
L
LielinJiang 已提交
47 48 49 50
            if artistic:
                weight_path = get_path_from_url(DEOLDIFY_ART_WEIGHT_URL)
            else:
                weight_path = get_path_from_url(DEOLDIFY_STABLE_WEIGHT_URL)
L
LielinJiang 已提交
51

L
LielinJiang 已提交
52
        state_dict = paddle.load(weight_path)
L
LielinJiang 已提交
53 54 55 56 57 58
        self.model.load_dict(state_dict)
        self.model.eval()

    def norm(self, img, render_factor=32, render_base=16):
        target_size = render_factor * render_base
        img = img.resize((target_size, target_size), resample=Image.BILINEAR)
L
LielinJiang 已提交
59

L
LielinJiang 已提交
60
        img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
L
LielinJiang 已提交
61

L
LielinJiang 已提交
62 63 64 65 66 67 68 69 70 71
        img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
        img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

        img -= img_mean
        img /= img_std
        return img.astype('float32')

    def denorm(self, img):
        img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
        img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
L
LielinJiang 已提交
72

L
LielinJiang 已提交
73 74 75 76
        img *= img_std
        img += img_mean
        img = img.transpose((1, 2, 0))

L
LielinJiang 已提交
77
        return (img * 255).clip(0, 255).astype('uint8')
L
LielinJiang 已提交
78

L
LielinJiang 已提交
79 80 81 82 83 84 85 86 87 88
    def post_process(self, raw_color, orig):
        color_np = np.asarray(raw_color)
        orig_np = np.asarray(orig)
        color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
        orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
        hires = np.copy(orig_yuv)
        hires[:, :, 1:3] = color_yuv[:, :, 1:3]
        final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
        final = Image.fromarray(final)
        return final
L
LielinJiang 已提交
89

L
LielinJiang 已提交
90 91 92 93 94 95 96 97
    def run_image(self, img):
        if isinstance(img, str):
            ori_img = Image.open(img).convert('LA').convert('RGB')
        elif isinstance(img, np.ndarray):
            ori_img = Image.fromarray(img).convert('LA').convert('RGB')
        elif isinstance(img, Image.Image):
            ori_img = img

L
LielinJiang 已提交
98
        img = self.norm(ori_img, self.render_factor)
L
LielinJiang 已提交
99
        x = paddle.to_tensor(img[np.newaxis, ...])
L
LielinJiang 已提交
100 101 102 103 104 105 106 107
        out = self.model(x)

        pred_img = self.denorm(out.numpy()[0])
        pred_img = Image.fromarray(pred_img)
        pred_img = pred_img.resize(ori_img.size, resample=Image.BILINEAR)
        pred_img = self.post_process(pred_img, ori_img)
        return pred_img

L
LielinJiang 已提交
108 109
    def run_video(self, video):
        base_name = os.path.basename(video).split('.')[0]
L
LielinJiang 已提交
110 111 112 113 114 115 116 117 118
        output_path = os.path.join(self.output, base_name)
        pred_frame_path = os.path.join(output_path, 'frames_pred')

        if not os.path.exists(output_path):
            os.makedirs(output_path)

        if not os.path.exists(pred_frame_path):
            os.makedirs(pred_frame_path)

L
LielinJiang 已提交
119
        cap = cv2.VideoCapture(video)
L
LielinJiang 已提交
120 121
        fps = cap.get(cv2.CAP_PROP_FPS)

L
LielinJiang 已提交
122
        out_path = video2frames(video, output_path)
L
LielinJiang 已提交
123 124 125 126

        frames = sorted(glob.glob(os.path.join(out_path, '*.png')))

        for frame in tqdm(frames):
L
LielinJiang 已提交
127
            pred_img = self.run_image(frame)
L
LielinJiang 已提交
128 129 130

            frame_name = os.path.basename(frame)
            pred_img.save(os.path.join(pred_frame_path, frame_name))
L
LielinJiang 已提交
131

L
LielinJiang 已提交
132 133
        frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')

L
LielinJiang 已提交
134 135
        vid_out_path = os.path.join(output_path,
                                    '{}_deoldify_out.mp4'.format(base_name))
L
LielinJiang 已提交
136
        frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
L
LielinJiang 已提交
137

L
LielinJiang 已提交
138
        return frame_pattern_combined, vid_out_path
L
LielinJiang 已提交
139 140

    def run(self, input):
L
LielinJiang 已提交
141
        if not self.is_image(input):
L
LielinJiang 已提交
142 143 144 145
            return self.run_video(input)
        else:
            pred_img = self.run_image(input)

L
LielinJiang 已提交
146
            out_path = None
L
LielinJiang 已提交
147
            if self.output:
L
LielinJiang 已提交
148 149 150 151
                try:
                    base_name = os.path.splitext(os.path.basename(input))[0]
                except:
                    base_name = 'result'
L
LielinJiang 已提交
152 153
                out_path = os.path.join(self.output, base_name + '.png')
                pred_img.save(out_path)
L
LielinJiang 已提交
154 155
                logger = get_logger()
                logger.info('Image saved to {}'.format(out_path))
L
LielinJiang 已提交
156

L
LielinJiang 已提交
157
            return pred_img, out_path