提交 36b62146 编写于 作者: D Dario Pavllo

Add support for trajectory in inference in the wild

上级 c4675a16
......@@ -209,6 +209,18 @@ if args.resume or args.evaluate:
model_pos_train.load_state_dict(checkpoint['model_pos'])
model_pos.load_state_dict(checkpoint['model_pos'])
if args.evaluate and 'model_traj' in checkpoint:
# Load trajectory model if it contained in the checkpoint (e.g. for inference in the wild)
model_traj = TemporalModel(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], 1,
filter_widths=filter_widths, causal=args.causal, dropout=args.dropout, channels=args.channels,
dense=args.dense)
if torch.cuda.is_available():
model_traj = model_traj.cuda()
model_traj.load_state_dict(checkpoint['model_traj'])
else:
model_traj = None
test_generator = UnchunkedGenerator(cameras_valid, poses_valid, poses_valid_2d,
pad=pad, causal_shift=causal_shift, augment=False,
kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
......@@ -637,13 +649,16 @@ if not args.evaluate:
plt.close('all')
# Evaluate
def evaluate(test_generator, action=None, return_predictions=False):
def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False):
epoch_loss_3d_pos = 0
epoch_loss_3d_pos_procrustes = 0
epoch_loss_3d_pos_scale = 0
epoch_loss_3d_vel = 0
with torch.no_grad():
model_pos.eval()
if not use_trajectory_model:
model_pos.eval()
else:
model_traj.eval()
N = 0
for _, batch, batch_2d in test_generator.next_epoch():
inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
......@@ -651,13 +666,17 @@ def evaluate(test_generator, action=None, return_predictions=False):
inputs_2d = inputs_2d.cuda()
# Positional model
predicted_3d_pos = model_pos(inputs_2d)
if not use_trajectory_model:
predicted_3d_pos = model_pos(inputs_2d)
else:
predicted_3d_pos = model_traj(inputs_2d)
# Test-time augmentation (if enabled)
if test_generator.augment_enabled():
# Undo flipping and take average with non-flipped version
predicted_3d_pos[1, :, :, 0] *= -1
predicted_3d_pos[1, :, joints_left + joints_right] = predicted_3d_pos[1, :, joints_right + joints_left]
if not use_trajectory_model:
predicted_3d_pos[1, :, joints_left + joints_right] = predicted_3d_pos[1, :, joints_right + joints_left]
predicted_3d_pos = torch.mean(predicted_3d_pos, dim=0, keepdim=True)
if return_predictions:
......@@ -717,6 +736,9 @@ if args.render:
pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
prediction = evaluate(gen, return_predictions=True)
if model_traj is not None and ground_truth is None:
prediction_traj = evaluate(gen, return_predictions=True, use_trajectory_model=True)
prediction += prediction_traj
if args.viz_export is not None:
print('Exporting joint positions to', args.viz_export)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册