未验证 提交 3c9dd7b3 编写于 作者: L lzzyzlbb 提交者: GitHub

1.add compare result, 2.add seed for paddlegan (#493)

* 1.add compare result, 2.add seed for paddlegan

* 1.add compare result, 2.add seed for paddlegan

* 1.add compare result, 2.add seed for paddlegan

* 1.add compare result, 2.add seed for paddlegan
上级 3d4fa146
......@@ -123,7 +123,7 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'):
except FileExistsError:
pass
for idx, img in enumerate(video_array_reshape):
cv2.imwrite(sub_dir.joinpath('%i.png' % idx), img)
cv2.imwrite(str(sub_dir.joinpath('%i.png' % idx)), img[:,:,[2,1,0]])
name.unlink()
return video_array_reshape
else:
......@@ -207,7 +207,6 @@ class FramesDataset(Dataset):
num_frames, replace=True,
size=2)) if self.is_train else range(num_frames)
video_array = [video_array[i] for i in frame_idx]
# convert to 3-channel image
if video_array[0].shape[-1] == 4:
video_array = [i[..., :3] for i in video_array]
......@@ -218,7 +217,6 @@ class FramesDataset(Dataset):
np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video_array
]
out = {}
if self.is_train:
if self.transform is not None: #modify
t = self.transform(tuple(video_array))
......
......@@ -67,6 +67,12 @@ def parse_args():
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
# fix random numbers by setting seed
parser.add_argument('--seed',
type=int,
default=None,
help='fix random numbers by setting seed\".'
)
args = parser.parse_args()
return args
......@@ -15,10 +15,10 @@
import os
import time
import paddle
import numpy as np
import random
from .logger import setup_logger
def setup(args, cfg):
if args.evaluate_only:
cfg.is_train = False
......@@ -44,3 +44,10 @@ def setup(args, cfg):
paddle.set_device('gpu')
else:
paddle.set_device('cpu')
if args.seed:
paddle.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
paddle.framework.random._manual_program_seed(args.seed)
import numpy as np
import os
import subprocess
import json
import argparse
import glob
def init_args():
parser = argparse.ArgumentParser()
# params for testing assert allclose
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--rtol", type=float, default=1e-3)
parser.add_argument("--gt_file", type=str, default="")
parser.add_argument("--log_file", type=str, default="")
parser.add_argument("--precision", type=str, default="fp32")
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def load_from_file(gt_file):
if not os.path.exists(gt_file):
raise ValueError("The log file {} does not exists!".format(gt_file))
with open(gt_file, 'r') as f:
data = f.readlines()
f.close()
parser_gt = {}
for line in data:
metric_name, result = line.strip("\n").split(":")
parser_gt[metric_name] = float(result)
return parser_gt
if __name__ == "__main__":
# Usage:
# python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/*.txt --log_file=./test_tipc/output/*/*.txt
args = parse_args()
gt_collection = load_from_file(args.gt_file)
pre_collection = load_from_file(args.log_file)
for metric in pre_collection.keys():
try:
np.testing.assert_allclose(
np.array(pre_collection[metric]), np.array(gt_collection[metric]), atol=args.atol, rtol=args.rtol)
print(
"Assert allclose passed! The results of {} are consistent!".
format(metric))
except Exception as E:
print(E)
raise ValueError(
"The results of {} are inconsistent!".
format(metric))
\ No newline at end of file
......@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=5|whole_train_whole_infer=200
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
......@@ -13,7 +13,7 @@ train_infer_img_dir:./data/basicvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/basicvsr_reds.yaml -o dataset.train.dataset.num_clips=2
norm_train:tools/main.py -c configs/basicvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
......@@ -37,7 +37,7 @@ inference_dir:basicvsrmodel_generator
train_model:./inference/basicvsr/basicvsrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type basicvsr -c configs/basicvsr_reds.yaml -o dataset.test.num_clips=2 dataset.test.number_frames=6
inference:tools/inference.py --model_type basicvsr -c configs/basicvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=6 --output_path test_tipc/output/
--device:gpu
null:null
null:null
......
......@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
##
auto_cast:null
epochs:lite_train_lite_infer=5|whole_train_whole_infer=200
epochs:lite_train_lite_infer=1|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
......@@ -13,7 +13,7 @@ train_infer_img_dir:./data/horse2zebra/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/cyclegan_horse2zebra.yaml -o
norm_train:tools/main.py -c configs/cyclegan_horse2zebra.yaml --seed 123 -o log_config.interval=10 snapshot_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
......@@ -37,7 +37,7 @@ inference_dir:cycleganmodel_netG_A
train_model:./inference/cyclegan_horse2zebra/cycleganmodel_netG_A
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type cyclegan -c configs/cyclegan_horse2zebra.yaml
inference:tools/inference.py --model_type cyclegan --seed 123 -c configs/cyclegan_horse2zebra.yaml --output_path test_tipc/output/
--device:gpu
null:null
null:null
......
......@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0
##
auto_cast:null
epochs:lite_train_lite_infer=10|whole_train_whole_infer=100
epochs:lite_train_lite_infer=1|whole_train_whole_infer=100
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=8|whole_train_whole_infer=8
pretrained_model:null
......@@ -13,7 +13,7 @@ train_infer_img_dir:./data/firstorder_vox_256/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/firstorder_vox_256.yaml -o
norm_train:tools/main.py -c configs/firstorder_vox_256.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=1 dataset.train.num_repeats=1 dataset.train.id_sampling=False
pact_train:null
fpgm_train:null
distill_train:null
......@@ -37,7 +37,7 @@ inference_dir:fom_dy2st
train_model:./inference/fom_dy2st/
infer_export:null
infer_quant:False
inference:tools/fom_infer.py --driving_path data/first_order/Voxceleb/test --output_path infer_output/fom
inference:tools/fom_infer.py --driving_path data/first_order/Voxceleb/test --output_path test_tipc/output/fom/
--device:gpu
null:null
null:null
......
===========================train_params===========================
model_name:pix2pix
python:python3.7
gpu_list:0|0,1
gpu_list:0
##
auto_cast:null
epochs:lite_train_lite_infer=5|whole_train_whole_infer=200
epochs:lite_train_lite_infer=10|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
......@@ -13,7 +13,7 @@ train_infer_img_dir:./data/facades/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/pix2pix_facades.yaml -o
norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
......@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/pix2pix_facades.yaml --inputs_size="-1,3,-1,-1" --load
norm_export:tools/export_model.py -c configs/pix2pix_facades.yaml --inputs_size="-1,3,-1,-1" --load
quant_export:null
fpgm_export:null
distill_export:null
......@@ -37,7 +37,7 @@ inference_dir:pix2pixmodel_netG
train_model:./inference/pix2pix_facade/pix2pixmodel_netG
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type pix2pix -c configs/pix2pix_facades.yaml
inference:tools/inference.py --model_type pix2pix --seed 123 -c configs/pix2pix_facades.yaml --output_path test_tipc/output/
--device:cpu
null:null
null:null
......
......@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0
##
auto_cast:null
total_iters::lite_train_lite_infer=10|whole_train_whole_infer=800
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=800
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=3|whole_train_whole_infer=3
pretrained_model:null
......@@ -13,7 +13,7 @@ train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/stylegan_v2_256_ffhq.yaml -o
norm_train:tools/main.py -c configs/stylegan_v2_256_ffhq.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=10
pact_train:null
fpgm_train:null
distill_train:null
......@@ -37,7 +37,7 @@ inference_dir:stylegan2model_gen
train_model:./inference/stylegan2/stylegan2model_gen
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type stylegan2 -c configs/stylegan_v2_256_ffhq.yaml
inference:tools/inference.py --model_type stylegan2 --seed 123 -c configs/stylegan_v2_256_ffhq.yaml --output_path test_tipc/output/
--device:gpu
null:null
null:null
......
......@@ -108,7 +108,7 @@ Run failed with command - python3.7 tools/export_model.py -c configs/basicvsr_re
#### 使用方式
运行命令:
```shell
python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/python_*.txt --log_file=./test_tipc/output/python_*.log --atol=1e-3 --rtol=1e-3
python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/*.txt --log_file=./test_tipc/output/*/*.txt --atol=1e-3 --rtol=1e-3
```
参数介绍:
......
......@@ -37,12 +37,11 @@ MODE=$2
if [ ${MODE} = "lite_train_lite_infer" ];then
if [ ${model_name} == "pix2pix" ]; then
rm -rf ./data/facades*
rm -rf ./data/pix2pix*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/pix2pix_facade_lite.tar --no-check-certificate
cd ./data/ && tar xf pix2pix_facade_lite.tar && cd ../
elif [ ${model_name} == "cyclegan" ]; then
rm -rf ./data/horse2zebra*
rm -rf ./data/cyclegan*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/cyclegan_horse2zebra_lite.tar --no-check-certificate
cd ./data/ && tar xf cyclegan_horse2zebra_lite.tar && cd ../
elif [ ${model_name} == "stylegan2" ]; then
......@@ -50,11 +49,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../
elif [ ${model_name} == "fom" ]; then
rm -rf ./data/first_order*
rm -rf ./data/fom_lite*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/fom_lite.tar --no-check-certificate --no-check-certificate
cd ./data/ && tar xf fom_lite.tar && cd ../
elif [ ${model_name} == "basicvsr" ]; then
rm -rf ./data/REDS*
rm -rf ./data/basicvsr*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate
cd ./data/ && tar xf basicvsr_lite.tar && cd ../
fi
......@@ -94,18 +93,21 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
elif [ ${MODE} = "whole_infer" ];then
if [ ${model_name} = "pix2pix" ]; then
rm -rf ./data/facades*
rm -rf ./inference/pix2pix*
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/pix2pix_facade.tar --no-check-certificate
wget -nc -P ./data https://paddlegan.bj.bcebos.com/datasets/facades_test.tar --no-check-certificate
cd ./data && tar xf facades_test.tar && mv facades_test facades && cd ../
cd ./inference && tar xf pix2pix_facade.tar && cd ../
elif [ ${model_name} = "cyclegan" ]; then
rm -rf ./data/horse2zebra*
rm -rf ./data/cyclegan*
rm -rf ./inference/cyclegan*
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/cyclegan_horse2zebra.tar --no-check-certificate
wget -nc -P ./data https://paddlegan.bj.bcebos.com/datasets/cyclegan_horse2zebra_test.tar --no-check-certificate
cd ./data && tar xf cyclegan_horse2zebra_test.tar && mv cyclegan_test horse2zebra && cd ../
cd ./inference && tar xf cyclegan_horse2zebra.tar && cd ../
elif [ ${model_name} == "fom" ]; then
rm -rf ./data/first_order*
rm -rf ./inference/fom_dy2st*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/fom_lite_test.tar --no-check-certificate
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/fom_dy2st.tar --no-check-certificate
cd ./data/ && tar xf fom_lite_test.tar && cd ../
......
......@@ -43,10 +43,10 @@ test_tipc/
├── ...
├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对
├── python_basicvsr_results_fp32.txt # 预存的basicvsr模型python预测fp32精度的结果
├── python_cyclegan_results_fp32.txt # 预存的cyclegan模型python预测fp32精度的结果
├── python_pix2pix_results_fp32.txt # 预存的pix2pix模型python预测的fp32精度的结果
├── python_stylegan_results_fp32.txt # 预存的stylegan模型python预测的fp32精度的结果
├── python_basicvsr_results_fp32.txt # 预存的basicvsr模型python预测fp32精度的结果
├── python_cyclegan_results_fp32.txt # 预存的cyclegan模型python预测fp32精度的结果
├── python_pix2pix_results_fp32.txt # 预存的pix2pix模型python预测的fp32精度的结果
├── python_stylegan2_results_fp32.txt # 预存的stylegan2模型python预测的fp32精度的结果
├── ...
├── prepare.sh # 完成test_*.sh运行所需要的数据和模型下载
├── test_train_inference_python.sh # 测试python训练预测的主程序
......
Metric psnr: 28.2953
Metric ssim: 0.8334
......@@ -95,12 +95,6 @@ function func_inference(){
for threads in ${cpu_threads_list[*]}; do
for batch_size in ${batch_size_list[*]}; do
for precision in ${precision_list[*]}; do
if [ ${use_mkldnn} = "False" ] && [ ${precision} = "fp16" ]; then
continue
fi # skip when enable fp16 but disable mkldnn
if [ ${_flag_quant} = "True" ] && [ ${precision} != "int8" ]; then
continue
fi # skip when quant model inference but precision is not int8
set_precision=$(func_set_params "${precision_key}" "${precision}")
_save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
......
......@@ -100,11 +100,11 @@ def main():
# 创建 config
kp_detector_config = paddle_infer.Config(os.path.join(
args.model_path, "/kp_detector.pdmodel"),
os.path.join(args.model_path, "/kp_detector.pdiparams"))
args.model_path, "kp_detector.pdmodel"),
os.path.join(args.model_path, "kp_detector.pdiparams"))
generator_config = paddle_infer.Config(os.path.join(
args.model_path, "/generator.pdmodel"),
os.path.join(args.model_path, "/generator.pdiparams"))
args.model_path, "generator.pdmodel"),
os.path.join(args.model_path, "generator.pdiparams"))
if args.device == "gpu":
kp_detector_config.enable_use_gpu(100, 0)
generator_config.enable_use_gpu(100, 0)
......@@ -120,7 +120,7 @@ def main():
# 根据 config 创建 predictor
kp_detector_predictor = paddle_infer.create_predictor(kp_detector_config)
generator_predictor = paddle_infer.create_predictor(generator_config)
test_loss = []
for k in range(len(driving_paths)):
driving_path = driving_paths[k]
driving_video, fps = read_video(driving_path)
......@@ -194,8 +194,11 @@ def main():
generator_output_handle = generator_predictor.get_output_handle(
generator_output_names[0])
output_data = generator_output_handle.copy_to_cpu()
loss = paddle.abs(paddle.to_tensor(output_data) -
paddle.to_tensor(driving_video[i])).mean().cpu().numpy()
test_loss.append(loss)
output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0
#Todo:add blazeface static model
#frame = source_img.copy()
#frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA)
......@@ -205,6 +208,11 @@ def main():
"result_" + str(k) + ".mp4"),
[frame for frame in results],
fps=fps)
metric_file = os.path.join(args.output_path, "metric.txt")
log_file = open(metric_file, 'a')
loss_string = "Metric {}: {:.4f}".format(
"l1 loss", np.mean(test_loss))
log_file.close()
def parse_args():
......
import paddle
import argparse
import numpy as np
import random
import os
from collections import OrderedDict
from ppgan.utils.config import get_config
from ppgan.datasets.builder import build_dataloader
......@@ -8,8 +11,10 @@ from ppgan.engine.trainer import IterLoader
from ppgan.utils.visual import save_image
from ppgan.utils.visual import tensor2img
from ppgan.utils.filesystem import makedirs
from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr"]
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr", "fom", "stylegan2", "basicvsr"]
def parse_args():
......@@ -37,11 +42,21 @@ def parse_args():
'--config-file',
metavar="FILE",
help='config file path')
parser.add_argument("--output_path",
type=str,
default="infer_output",
help="output_path")
# config options
parser.add_argument("-o",
"--opt",
nargs='+',
help="set configuration options")
# fix random numbers by setting seed
parser.add_argument('--seed',
type=int,
default=None,
help='fix random numbers by setting seed\".'
)
args = parser.parse_args()
return args
......@@ -61,9 +76,23 @@ def create_predictor(model_path, device="gpu"):
predictor = paddle.inference.create_predictor(config)
return predictor
def setup_metrics(cfg):
metrics = OrderedDict()
if isinstance(list(cfg.values())[0], dict):
for metric_name, cfg_ in cfg.items():
metrics[metric_name] = build_metric(cfg_)
else:
metric = build_metric(cfg)
metrics[metric.__class__.__name__] = metric
return metrics
def main():
args = parse_args()
if args.seed:
paddle.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path, args.device)
input_handles = [
......@@ -83,7 +112,15 @@ def main():
min_max = (-1., 1.)
model_type = args.model_type
makedirs("infer_output/" + model_type)
makedirs(os.path.join(args.output_path, model_type))
validate_cfg = cfg.get('validate', None)
metrics = None
if validate_cfg and 'metrics' in validate_cfg:
metrics = setup_metrics(validate_cfg['metrics'])
for metric in metrics.values():
metric.reset()
for i in range(max_eval_steps):
data = next(iter_loader)
if model_type == "pix2pix":
......@@ -91,17 +128,27 @@ def main():
input_handles[0].copy_from_cpu(real_A)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/pix2pix/{}.png".format(i))
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "pix2pix/metric.txt")
real_B = paddle.to_tensor(data['A'])
for metric in metrics.values():
metric.update(prediction, real_B)
elif model_type == "cyclegan":
real_A = data['A'].numpy()
input_handles[0].copy_from_cpu(real_A)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/cyclegan/{}.png".format(i))
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
real_B = paddle.to_tensor(data['B'])
for metric in metrics.values():
metric.update(prediction, real_B)
elif model_type == "wav2lip":
indiv_mels, x = data['indiv_mels'].numpy()[0], data['x'].numpy()[0]
x = x.transpose([1, 0, 2, 3])
......@@ -115,6 +162,7 @@ def main():
image_numpy = tensor2img(image_numpy, (0, 1))
save_image(image_numpy,
"infer_output/wav2lip/{}_{}.png".format(i, j))
elif model_type == "esrgan":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
......@@ -137,18 +185,42 @@ def main():
input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32'))
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/stylegan2/{}.png".format(i))
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "stylegan2/metric.txt")
real_img = paddle.to_tensor(data['A'])
for metric in metrics.values():
metric.update(prediction, real_img)
elif model_type == "basicvsr":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/basicvsr/{}.png".format(i))
prediction = paddle.to_tensor(prediction)
_, t, _, _, _ = prediction.shape
out_img = []
gt_img = []
for ti in range(t):
out_tensor = prediction[0, ti]
gt_tensor = data['gt'][0, ti]
out_img.append(tensor2img(out_tensor, (0.,1.)))
gt_img.append(tensor2img(gt_tensor, (0.,1.)))
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "basicvsr/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "basicvsr/metric.txt")
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
if metrics:
log_file = open(metric_file, 'a')
for metric_name, metric in metrics.items():
loss_string = "Metric {}: {:.4f}".format(
metric_name, metric.accumulate())
print(loss_string, file=log_file)
log_file.close()
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册