未验证 提交 61d4939d 编写于 作者: W wangna11BD 提交者: GitHub

add EDVR predictor dynamic (#315)

上级 d81d9cc5
......@@ -119,10 +119,8 @@ if __name__ == "__main__":
weight_path=args.RealSR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'EDVR':
paddle.enable_static()
predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path)
paddle.disable_static()
print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
......
......@@ -19,12 +19,15 @@ import glob
import numpy as np
from tqdm import tqdm
import paddle
from paddle.io import Dataset, DataLoader
from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators import EDVRNet
from .base_predictor import BasePredictor
EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams'
def get_img(pred):
......@@ -110,7 +113,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
return return_l
class EDVRDataset:
class EDVRDataset(Dataset):
def __init__(self, frame_paths):
self.frames = frame_paths
......@@ -133,16 +136,15 @@ class EDVRDataset:
class EDVRPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None):
def __init__(self, output='output', weight_path=None, bs=1):
self.input = input
self.output = os.path.join(output, 'EDVR')
self.bs = bs
self.model = EDVRNet(nf=128, back_RBs=40)
if weight_path is None:
weight_path = get_path_from_url(EDVR_WEIGHT_URL)
self.weight_path = weight_path
self.build_inference_model()
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
def run(self, video_path):
vid = video_path
......@@ -163,23 +165,23 @@ class EDVRPredictor(BasePredictor):
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
dataset = EDVRDataset(frames)
test_dataset = EDVRDataset(frames)
dataset = DataLoader(test_dataset, batch_size=self.bs, num_workers=2)
periods = []
cur_time = time.time()
for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = [data[0]]
outs = self.base_forward(np.array(data_feed_in))
infer_result_list = [item for item in outs]
data_feed_in = paddle.to_tensor(data[0])
with paddle.no_grad():
outs = self.model(data_feed_in).numpy()
infer_result_list = [outs[i, :, :, :] for i in range(self.bs)]
frame_path = data[1]
img_i = get_img(infer_result_list[0])
save_img(
img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path)))
for i in range(self.bs):
img_i = get_img(infer_result_list[i])
save_img(
img_i,
os.path.join(pred_frame_path,
os.path.basename(frame_path[i])))
prev_time = cur_time
cur_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册