未验证 提交 1826c083 编写于 作者: 七年期限 提交者: GitHub

add stgan_bald module

上级 96d997c9
# stgan_bald
基于PaddleHub的秃头生成器
# 模型概述
秃头生成器(stgan_bald),该模型可自动根据图像生成1年、3年、5年的秃头效果。
# 模型效果:
详情请查看此链接:https://aistudio.baidu.com/aistudio/projectdetail/1145381
本模型为大家提供了小程序,欢迎大家体验
![image](https://github.com/1084667371/stgan_bald/blob/main/images/code.jpg)
# 选择模型版本进行安装
$ hub install stgan_bald==1.0.0
# Module API说明
def bald(self,
images=None,
paths=None,
use_gpu=False,
visualization=False):
秃头生成器API预测接口,预测输入一张人像,输出三张秃头效果(1年、3年、5年)
## 参数
images (list(numpy.ndarray)): 图像数据,每个图像的形状为[H,W,C],颜色空间为BGR。
paths (list[str]): 图像的路径。
use_gpu (bool): 是否使用gpu。
visualization (bool): 是否保存图像。
## 返回
data_0 ([numpy.ndarray]):秃头一年的预测结果图。
data_1 ([numpy.ndarray]):秃头三年的预测结果图。
data_2 ([numpy.ndarray]):秃头五年的预测结果图。
# API预测代码示例
import paddlehub as hub
import cv2
stgan_bald = hub.Module('stgan_bald')
im = cv2.imread('/PATH/TO/IMAGE')
res = stgan_bald.bald(images=[im],visualization=True)
# 服务部署
## 第一步:启动PaddleHub Serving
$ hub serving start -m stgan_bald
## 第二步:发送预测请求
import requests
import json
import base64
import cv2
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {'images':[cv2_to_base64(org_im)]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/stgan_bald"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 保存图片 1年 3年 5年
one_year =cv2.cvtColor(base64_to_cv2(r.json()["results"]['data_0']), cv2.COLOR_RGB2BGR)
three_year =cv2.cvtColor(base64_to_cv2(r.json()["results"]['data_1']), cv2.COLOR_RGB2BGR)
five_year =cv2.cvtColor(base64_to_cv2(r.json()["results"]['data_2']), cv2.COLOR_RGB2BGR)
cv2.imwrite("stgan_bald_server.png", one_year)
# 贡献者
刘炫、彭兆帅、郑博培
# 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
# 查看代码
[基于PaddleHub的秃头生成器](https://github.com/PaddlePaddle/PaddleHub/tree/release/v1.8/hub_module/modules/image/gan/stgan_bald)
# -*- coding:utf-8 -*-
import os
import time
from collections import OrderedDict
from PIL import Image, ImageOps
import numpy as np
from PIL import Image
import cv2
__all__ = ['reader']
def reader(images=None, paths=None, org_labels=None, target_labels=None):
"""
Preprocess to yield image.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
paths (list[str]): paths to images.
Yield:
each (collections.OrderedDict): info of original image, preprocessed image.
"""
component = list()
if paths:
for i, im_path in enumerate(paths):
each = OrderedDict()
assert os.path.isfile(
im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path)
each['org_im'] = im
each['org_im_path'] = im_path
each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32')
else:
each['target_label'] = np.array(target_labels[i]).astype('float32')
component.append(each)
if images is not None:
assert type(images) is list, "images should be a list."
for i, im in enumerate(images):
each = OrderedDict()
each['org_im'] = im
each['org_im_path'] = 'ndarray_time={}'.format(round(time.time(), 6) * 1e6)
each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32')
else:
each['target_label'] = np.array(target_labels[i]).astype('float32')
component.append(each)
for element in component:
img = cv2.cvtColor(element['org_im'], cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 128), interpolation=cv2.INTER_LINEAR)
img = (img.astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
element['img'] = img[np.newaxis, :, :, :]
yield element
# -*- coding:utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import os
import argparse
import copy
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from stgan_bald.data_feed import reader
from stgan_bald.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
def check_attribute_conflict(label_batch):
''' Based on https://github.com/LynnHo/AttGAN-Tensorflow'''
attrs = "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young".split(
',')
def _set(label, value, attr):
if attr in attrs:
label[attrs.index(attr)] = value
attr_id = attrs.index('Bald')
for label in label_batch:
if attrs[attr_id] != 0:
_set(label, 0, 'Bangs')
return label_batch
@moduleinfo(
name="stgan_bald",
version="1.0.0",
summary="Baldness generator",
author="Arrow, 七年期限,Mr.郑先生_",
author_email="1084667371@qq.com,2733821739@qq.com",
type="image/gan")
class StganBald(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "module")
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
self.model_file_path = os.path.join(self.default_pretrained_model_path,
'__model__')
self.params_file_path = os.path.join(self.default_pretrained_model_path,
'__params__')
cpu_config = AnalysisConfig(self.model_file_path, self.params_file_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
self.place = fluid.CUDAPlace(0)
except:
use_gpu = False
self.place = fluid.CPUPlace()
if use_gpu:
gpu_config =AnalysisConfig(self.model_file_path, self.params_file_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(
memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
def bald(self,
images=None,
paths=None,
data=None,
use_gpu=False,
org_labels=[[0.,0.,1.,0.,0.,1.,1.,1.,0.,0.,0.,0.,1.]],
target_labels=None,
visualization=True,
output_dir="bald_output"):
"""
API for super resolution.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C], the color space is BGR.
paths (list[str]): The paths of images.
data (dict): key is 'image', the corresponding value is the path to image.
use_gpu (bool): Whether to use gpu.
visualization (bool): Whether to save image or not.
output_dir (str): The path to store output images.
Returns:
res (list[dict]): each element in the list is a dict, the keys and values are:
save_path (str, optional): the path to save images. (Exists only if visualization is True)
data (numpy.ndarray): data of post processed image.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
if data and 'image' in data:
if paths is None:
paths = list()
paths += data['image']
all_data = list()
for yield_data in reader(images, paths, org_labels, target_labels):
all_data.append(yield_data)
total_num = len(all_data)
res = list()
outputs = []
for i in range(total_num):
image_np = all_data[i]['img']
org_label_np = [all_data[i]['org_label']]
target_label_np = [all_data[i]['target_label']]
for j in range(5):
if j % 2 == 0:
label_trg_tmp = copy.deepcopy(target_label_np)
new_i = 0
label_trg_tmp[0][new_i] = 1.0 - label_trg_tmp[0][new_i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp)
change_num = j * 0.02 + 0.3
label_org_tmp = list(
map(lambda x: ((x * 2) - 1) * change_num, org_label_np))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * change_num, label_trg_tmp))
image = PaddleTensor(image_np.copy())
org_label = PaddleTensor(np.array(label_org_tmp).astype('float32'))
target_label = PaddleTensor(np.array(label_trg_tmp).astype('float32'))
output = self.gpu_predictor.run(
[image, target_label, org_label]
) if use_gpu else self.cpu_predictor.run([image, org_label, target_label])
outputs.append(output)
out = postprocess(
data_out=outputs,
org_im=all_data[i]['org_im'],
org_im_path=all_data[i]['org_im_path'],
output_dir=output_dir,
visualization=visualization)
res.append(out)
return res
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.bald(images=images_decode, **kwargs)
output = {}
for key, value in results[0].items():
output[key] = cv2_to_base64(value)
return output
# -*- coding:utf-8 -*-
import os
import time
import base64
import cv2
from PIL import Image
import numpy as np
__all__ = ['cv2_to_base64', 'base64_to_cv2', 'postprocess']
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def postprocess(data_out,
org_im,
org_im_path,
output_dir,
visualization,
thresh=120):
"""
Postprocess output of network. one image at a time.
Args:
data_out (numpy.ndarray): output of network.
org_im (numpy.ndarray): original image.
org_im_shape (list): shape pf original image.
org_im_path (list): path of riginal image.
output_dir (str): output directory to store image.
visualization (bool): whether to save image or not.
thresh (float): threshold.
Returns:
result (dict): The data of processed image.
"""
result = dict()
for i, img in enumerate(data_out):
img = np.squeeze(img[0].as_ndarray(), 0).transpose((1,2,0))
img = ((img + 1) * 127.5).astype(np.uint8)
img = cv2.resize(img, (256, 341), cv2.INTER_CUBIC)
fake_image = Image.fromarray(img)
if visualization:
check_dir(output_dir)
save_im_path = get_save_image_name(org_im_path, output_dir, i)
img_name = '{}.png'.format(i)
fake_image.save(os.path.join(output_dir, img_name))
result['data_{}'.format(i)] = img
return result
def check_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
elif os.path.isfile(dir_path):
os.remove(dir_path)
os.makedirs(dir_path)
def get_save_image_name( org_im_path, output_dir, num):
"""
Get save image name from source image path.
"""
# name prefix of orginal image
org_im_name = os.path.split(org_im_path)[-1]
im_prefix = os.path.splitext(org_im_name)[0]
ext = '.png'
# save image path
save_im_path = os.path.join(output_dir, im_prefix + ext)
if os.path.exists(save_im_path):
save_im_path = os.path.join(
output_dir, im_prefix + str(num) + ext)
return save_im_path
paddlepaddle>=1.8.4
paddlehub>=1.8.0
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册