未验证 提交 083ea332 编写于 作者: J Jintao Lin 提交者: GitHub

Move `start_index` from `SampleFrames` to dataset level (#89)

上级 a4e100e1
......@@ -31,12 +31,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
start_index=0),
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
......@@ -59,7 +54,6 @@ val_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=1,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -77,7 +71,6 @@ test_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -30,7 +30,6 @@ test_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=1,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -38,12 +38,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=8,
frame_interval=8,
num_clips=1,
start_index=0),
dict(type='SampleFrames', clip_len=8, frame_interval=8, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
......@@ -61,7 +56,6 @@ val_pipeline = [
clip_len=8,
frame_interval=8,
num_clips=1,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -79,7 +73,6 @@ test_pipeline = [
clip_len=8,
frame_interval=8,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -38,7 +38,6 @@ test_pipeline = [
clip_len=8,
frame_interval=8,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -45,12 +45,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
start_index=0),
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
......@@ -68,7 +63,6 @@ val_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=1,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -86,7 +80,6 @@ test_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -48,7 +48,6 @@ test_pipeline = [
clip_len=32,
frame_interval=2,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -28,12 +28,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=4,
frame_interval=16,
num_clips=1,
start_index=0),
dict(type='SampleFrames', clip_len=4, frame_interval=16, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
......@@ -51,7 +46,6 @@ val_pipeline = [
clip_len=4,
frame_interval=16,
num_clips=1,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -69,7 +63,6 @@ test_pipeline = [
clip_len=4,
frame_interval=16,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -28,7 +28,6 @@ test_pipeline = [
clip_len=4,
frame_interval=16,
num_clips=10,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -30,12 +30,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
......@@ -59,7 +54,6 @@ val_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -77,7 +71,6 @@ test_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -30,7 +30,6 @@ test_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -28,12 +28,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(
type='MultiScaleCrop',
......@@ -55,7 +50,6 @@ val_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -73,7 +67,6 @@ test_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=25,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -28,12 +28,7 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(
type='DenseSampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0),
dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(
type='MultiScaleCrop',
......@@ -55,7 +50,6 @@ val_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......@@ -73,7 +67,6 @@ test_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -27,7 +27,6 @@ test_pipeline = [
clip_len=1,
frame_interval=1,
num_clips=25,
start_index=0,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
......
......@@ -88,15 +88,22 @@ def inference_recognizer(model, video_path, label_path, use_frames=False):
if use_frames:
filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
modality = cfg.data.test.get('modality', 'RGB')
start_index = cfg.data.test.get('start_index', 1)
data = dict(
frame_dir=video_path,
total_frames=len(os.listdir(video_path)),
# assuming files in ``video_path`` are all named with ``filename_tmpl`` # noqa: E501
label=-1,
start_index=start_index,
filename_tmpl=filename_tmpl,
modality=modality)
else:
data = dict(filename=video_path, label=-1, modality='RGB')
start_index = cfg.data.test.get('start_index', 0)
data = dict(
filename=video_path,
label=-1,
start_index=start_index,
modality='RGB')
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
......
......@@ -32,6 +32,10 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
dataset. Default: False.
num_classes (int): Number of classes of the dataset, used in
multi-class datasets. Default: None.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 1.
modality (str): Modality of data. Support 'RGB', 'Flow'.
Default: 'RGB'.
"""
......@@ -43,6 +47,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
test_mode=False,
multi_class=False,
num_classes=None,
start_index=1,
modality='RGB'):
super().__init__()
......@@ -52,6 +57,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
self.test_mode = test_mode
self.multi_class = multi_class
self.num_classes = num_classes
self.start_index = start_index
self.modality = modality
self.pipeline = Compose(pipeline)
self.video_infos = self.load_annotations()
......@@ -83,12 +89,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
"""Prepare the frames for training given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
def prepare_test_frames(self, idx):
"""Prepare the frames for testing given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
def __len__(self):
......
......@@ -17,18 +17,14 @@ from ..registry import PIPELINES
class SampleFrames(object):
"""Sample frames from the video.
Required keys are "filename", "total_frames", added or modified keys are
"frame_inds", "frame_interval" and "num_clips".
Required keys are "filename", "total_frames", "start_index" , added or
modified keys are "frame_inds", "frame_interval" and "num_clips".
Args:
clip_len (int): Frames of each sampled output clip.
frame_interval (int): Temporal interval of adjacent sampled frames.
Default: 1.
num_clips (int): Number of clips to be sampled. Default: 1.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 1.
temporal_jitter (bool): Whether to apply temporal jittering.
Default: False.
twice_sample (bool): Whether to use twice sample when testing.
......@@ -39,28 +35,35 @@ class SampleFrames(object):
Default: 'loop'.
test_mode (bool): Store True when building test or validation dataset.
Default: False.
start_index (None): This argument is deprecated and moved to dataset
class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc),
see this: https://github.com/open-mmlab/mmaction2/pull/89.
"""
def __init__(self,
clip_len,
frame_interval=1,
num_clips=1,
start_index=1,
temporal_jitter=False,
twice_sample=False,
out_of_bound_opt='loop',
test_mode=False):
test_mode=False,
start_index=None):
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
self.start_index = start_index
self.temporal_jitter = temporal_jitter
self.twice_sample = twice_sample
self.out_of_bound_opt = out_of_bound_opt
self.test_mode = test_mode
assert self.out_of_bound_opt in ['loop', 'repeat_last']
if start_index is not None:
warnings.warn('No longer support "start_index" in "SampleFrames", '
'it should be set in dataset class, see this pr: '
'https://github.com/open-mmlab/mmaction2/pull/89')
def _get_train_clips(self, num_frames):
"""Get clip offsets in train mode.
......@@ -165,7 +168,9 @@ class SampleFrames(object):
frame_inds = new_inds
else:
raise ValueError('Illegal out_of_bound option.')
frame_inds = np.concatenate(frame_inds) + self.start_index
start_index = results['start_index']
frame_inds = np.concatenate(frame_inds) + start_index
results['frame_inds'] = frame_inds.astype(np.int)
results['clip_len'] = self.clip_len
results['frame_interval'] = self.frame_interval
......@@ -185,8 +190,6 @@ class DenseSampleFrames(SampleFrames):
frame_interval (int): Temporal interval of adjacent sampled frames.
Default: 1.
num_clips (int): Number of clips to be sampled. Default: 1.
start_index (int): Specify a start index for frames in consideration of
different filename format. Default: 1.
sample_range (int): Total sample range for dense sample.
Default: 64.
num_sample_positions (int): Number of sample start positions, Which is
......@@ -201,7 +204,6 @@ class DenseSampleFrames(SampleFrames):
clip_len,
frame_interval=1,
num_clips=1,
start_index=1,
sample_range=64,
num_sample_positions=10,
temporal_jitter=False,
......@@ -211,7 +213,6 @@ class DenseSampleFrames(SampleFrames):
clip_len,
frame_interval,
num_clips,
start_index,
temporal_jitter,
out_of_bound_opt=out_of_bound_opt,
test_mode=test_mode)
......@@ -285,10 +286,6 @@ class SampleProposalFrames(SampleFrames):
Default: 1.
test_interval (int): Temporal interval of adjacent sampled frames
in test mode. Default: 6.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 1.
temporal_jitter (bool): Whether to apply temporal jittering.
Default: False.
mode (str): Choose 'train', 'val' or 'test' mode.
......@@ -302,13 +299,11 @@ class SampleProposalFrames(SampleFrames):
aug_ratio,
frame_interval=1,
test_interval=6,
start_index=1,
temporal_jitter=False,
mode='train'):
super().__init__(
clip_len,
frame_interval=frame_interval,
start_index=start_index,
temporal_jitter=temporal_jitter)
self.body_segments = body_segments
self.aug_segments = aug_segments
......@@ -500,7 +495,8 @@ class SampleProposalFrames(SampleFrames):
self.frame_interval, size=len(frame_inds))
frame_inds += perframe_offsets
frame_inds = np.mod(frame_inds, total_frames) + self.start_index
start_index = results['start_index']
frame_inds = np.mod(frame_inds, total_frames) + start_index
results['frame_inds'] = np.array(frame_inds).astype(np.int)
results['clip_len'] = self.clip_len
......
......@@ -85,11 +85,12 @@ class RawframeDataset(BaseDataset):
with_offset=False,
multi_class=False,
num_classes=None,
start_index=1,
modality='RGB'):
self.filename_tmpl = filename_tmpl
self.with_offset = with_offset
super().__init__(ann_file, pipeline, data_prefix, test_mode,
multi_class, num_classes, modality)
multi_class, num_classes, start_index, modality)
def load_annotations(self):
"""Load annotation file to get video information."""
......@@ -134,6 +135,7 @@ class RawframeDataset(BaseDataset):
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
def prepare_test_frames(self, idx):
......@@ -141,6 +143,7 @@ class RawframeDataset(BaseDataset):
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
def evaluate(self,
......
......@@ -27,8 +27,21 @@ class VideoDataset(BaseDataset):
some/path/003.mp4 2
some/path/004.mp4 3
some/path/005.mp4 3
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 0.
**kwargs: Keyword arguments for ``BaseDataset``.
"""
def __init__(self, ann_file, pipeline, start_index=0, **kwargs):
super().__init__(ann_file, pipeline, start_index=start_index, **kwargs)
def load_annotations(self):
"""Load annotation file to get video information."""
video_infos = []
......
......@@ -59,6 +59,7 @@ class TestDataset(object):
assert rawframe_infos == [
dict(frame_dir=frame_dir, total_frames=5, label=127)
] * 2
assert rawframe_dataset.start_index == 1
def test_rawframe_dataset_with_offset(self):
rawframe_dataset = RawframeDataset(
......@@ -71,6 +72,7 @@ class TestDataset(object):
assert rawframe_infos == [
dict(frame_dir=frame_dir, offset=2, total_frames=5, label=127)
] * 2
assert rawframe_dataset.start_index == 1
def test_rawframe_dataset_multi_label(self):
rawframe_dataset = RawframeDataset(
......@@ -90,6 +92,7 @@ class TestDataset(object):
assert info['frame_dir'] == frame_dir
assert info['total_frames'] == 5
assert torch.all(info['label'] == label)
assert rawframe_dataset.start_index == 1
def test_dataset_realpath(self):
dataset = RawframeDataset(self.frame_ann_file, self.frame_pipeline,
......@@ -100,14 +103,20 @@ class TestDataset(object):
assert dataset.data_prefix == 's3://good'
def test_video_dataset(self):
video_dataset = VideoDataset(self.video_ann_file, self.video_pipeline,
self.data_prefix)
video_dataset = VideoDataset(
self.video_ann_file,
self.video_pipeline,
data_prefix=self.data_prefix)
video_infos = video_dataset.video_infos
video_filename = osp.join(self.data_prefix, 'test.mp4')
assert video_infos == [dict(filename=video_filename, label=0)] * 2
assert video_dataset.start_index == 0
def test_rawframe_pipeline(self):
target_keys = ['frame_dir', 'total_frames', 'label', 'filename_tmpl']
target_keys = [
'frame_dir', 'total_frames', 'label', 'filename_tmpl',
'start_index', 'modality'
]
# RawframeDataset not in test mode
rawframe_dataset = RawframeDataset(
......@@ -129,6 +138,17 @@ class TestDataset(object):
result = rawframe_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
# RawframeDataset with offset
rawframe_dataset = RawframeDataset(
self.frame_ann_file_with_offset,
self.frame_pipeline,
self.data_prefix,
with_offset=True,
num_classes=400,
test_mode=False)
result = rawframe_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys + ['offset'])
# RawframeDataset in test mode
rawframe_dataset = RawframeDataset(
self.frame_ann_file,
......@@ -149,14 +169,25 @@ class TestDataset(object):
result = rawframe_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
# RawframeDataset with offset
rawframe_dataset = RawframeDataset(
self.frame_ann_file_with_offset,
self.frame_pipeline,
self.data_prefix,
with_offset=True,
num_classes=400,
test_mode=True)
result = rawframe_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys + ['offset'])
def test_video_pipeline(self):
target_keys = ['filename', 'label']
target_keys = ['filename', 'label', 'start_index', 'modality']
# VideoDataset not in test mode
video_dataset = VideoDataset(
self.video_ann_file,
self.video_pipeline,
self.data_prefix,
data_prefix=self.data_prefix,
test_mode=False)
result = video_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
......@@ -165,7 +196,7 @@ class TestDataset(object):
video_dataset = VideoDataset(
self.video_ann_file,
self.video_pipeline,
self.data_prefix,
data_prefix=self.data_prefix,
test_mode=True)
result = video_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
......@@ -221,8 +252,10 @@ class TestDataset(object):
['top1_acc', 'top5_acc', 'mean_class_accuracy'])
def test_video_evaluate(self):
video_dataset = VideoDataset(self.video_ann_file, self.video_pipeline,
self.data_prefix)
video_dataset = VideoDataset(
self.video_ann_file,
self.video_pipeline,
data_prefix=self.data_prefix)
with pytest.raises(TypeError):
# results must be a list
......@@ -248,10 +281,14 @@ class TestDataset(object):
['top1_acc', 'top5_acc', 'mean_class_accuracy'])
def test_base_dataset(self):
video_dataset = VideoDataset(self.video_ann_file, self.video_pipeline,
self.data_prefix)
video_dataset = VideoDataset(
self.video_ann_file,
self.video_pipeline,
data_prefix=self.data_prefix,
start_index=3)
assert len(video_dataset) == 2
assert type(video_dataset[0]) == dict
assert video_dataset.start_index == 3
def test_repeat_dataset(self):
rawframe_dataset = RawframeDataset(self.frame_ann_file,
......
......@@ -60,11 +60,15 @@ class TestLoading(object):
cls.flow_filename_tmpl = '{}_{:05d}.jpg'
video_total_frames = len(mmcv.VideoReader(cls.video_path))
cls.video_results = dict(
filename=cls.video_path, label=1, total_frames=video_total_frames)
filename=cls.video_path,
label=1,
total_frames=video_total_frames,
start_index=0)
cls.frame_results = dict(
frame_dir=cls.img_dir,
total_frames=cls.total_frames,
filename_tmpl=cls.filename_tmpl,
start_index=1,
modality='RGB',
offset=0,
label=1)
......@@ -92,6 +96,7 @@ class TestLoading(object):
video_id='test_imgs',
total_frames=cls.total_frames,
filename_tmpl=cls.filename_tmpl,
start_index=1,
out_props=[[['test_imgs',
ExampleSSNInstance(1, 4, 10, 1, 1, 1)], 0],
[['test_imgs',
......@@ -103,6 +108,12 @@ class TestLoading(object):
'total_frames'
]
with pytest.warns(UserWarning):
# start_index has been deprecated
config = dict(
clip_len=3, frame_interval=1, num_clips=5, start_index=1)
SampleFrames(**config)
# Sample Frame with no temporal_jitter
# clip_len=3, frame_interval=1, num_clips=5
video_result = copy.deepcopy(self.video_results)
......@@ -116,6 +127,8 @@ class TestLoading(object):
assert len(sample_frames_results['frame_inds']) == 15
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 15
assert np.max(sample_frames_results['frame_inds']) <= 5
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with no temporal_jitter
# clip_len=5, frame_interval=1, num_clips=5,
......@@ -150,6 +163,8 @@ class TestLoading(object):
frame_inds = sample_frames_results['frame_inds'].reshape([5, 5])
for i in range(5):
assert check_monotonous(frame_inds[i])
assert np.max(sample_frames_results['frame_inds']) <= 5
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with temporal_jitter
# clip_len=4, frame_interval=2, num_clips=5
......@@ -164,6 +179,8 @@ class TestLoading(object):
assert len(sample_frames_results['frame_inds']) == 20
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 20
assert np.max(sample_frames_results['frame_inds']) <= 5
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with no temporal_jitter in test mode
# clip_len=4, frame_interval=1, num_clips=6
......@@ -182,6 +199,8 @@ class TestLoading(object):
assert len(sample_frames_results['frame_inds']) == 24
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 24
assert np.max(sample_frames_results['frame_inds']) <= 5
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with no temporal_jitter in test mode
# clip_len=3, frame_interval=1, num_clips=6
......@@ -200,6 +219,8 @@ class TestLoading(object):
assert len(sample_frames_results['frame_inds']) == 18
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 18
assert np.max(sample_frames_results['frame_inds']) <= 5
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with no temporal_jitter to get clip_offsets
# clip_len=1, frame_interval=1, num_clips=8
......@@ -223,7 +244,7 @@ class TestLoading(object):
np.array([1, 2, 2, 3, 4, 5, 5, 6]))
# Sample Frame with no temporal_jitter to get clip_offsets
# clip_len=1, frame_interval=1, num_clips=8, start_index=0
# clip_len=1, frame_interval=1, num_clips=8
video_result = copy.deepcopy(self.video_results)
frame_result = copy.deepcopy(self.frame_results)
frame_result['total_frames'] = 6
......@@ -231,18 +252,18 @@ class TestLoading(object):
clip_len=1,
frame_interval=1,
num_clips=8,
start_index=0,
temporal_jitter=False,
test_mode=True)
sample_frames = SampleFrames(**config)
sample_frames_results = sample_frames(video_result)
assert sample_frames_results['start_index'] == 0
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert len(sample_frames_results['frame_inds']) == 8
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 8
assert_array_equal(sample_frames_results['frame_inds'],
np.array([0, 1, 1, 2, 3, 4, 4, 5]))
np.array([1, 2, 2, 3, 4, 5, 5, 6]))
# Sample Frame with no temporal_jitter to get clip_offsets zero
# clip_len=6, frame_interval=1, num_clips=1
......@@ -257,6 +278,7 @@ class TestLoading(object):
test_mode=True)
sample_frames = SampleFrames(**config)
sample_frames_results = sample_frames(video_result)
assert sample_frames_results['start_index'] == 0
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert len(sample_frames_results['frame_inds']) == 6
......@@ -278,11 +300,14 @@ class TestLoading(object):
test_mode=False)
sample_frames = SampleFrames(**config)
sample_frames_results = sample_frames(video_result)
assert sample_frames_results['start_index'] == 0
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert len(sample_frames_results['frame_inds']) == 240
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 240
assert np.max(sample_frames_results['frame_inds']) <= 30
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame with no temporal_jitter to get clip_offsets
# clip_len=1, frame_interval=1, num_clips=8
......@@ -299,6 +324,7 @@ class TestLoading(object):
sample_frames_results = sample_frames(video_result)
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert sample_frames_results['start_index'] == 0
assert len(sample_frames_results['frame_inds']) == 8
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 8
......@@ -318,11 +344,14 @@ class TestLoading(object):
test_mode=False)
sample_frames = SampleFrames(**config)
sample_frames_results = sample_frames(video_result)
assert sample_frames_results['start_index'] == 0
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert len(sample_frames_results['frame_inds']) == 24
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 24
assert np.max(sample_frames_results['frame_inds']) <= 10
assert np.min(sample_frames_results['frame_inds']) >= 1
# Sample Frame using twice sample
# clip_len=12, frame_interval=1, num_clips=2
......@@ -338,11 +367,14 @@ class TestLoading(object):
test_mode=True)
sample_frames = SampleFrames(**config)
sample_frames_results = sample_frames(video_result)
assert sample_frames_results['start_index'] == 0
assert self.check_keys_contain(sample_frames_results.keys(),
target_keys)
assert len(sample_frames_results['frame_inds']) == 48
sample_frames_results = sample_frames(frame_result)
assert len(sample_frames_results['frame_inds']) == 48
assert np.max(sample_frames_results['frame_inds']) <= 40
assert np.min(sample_frames_results['frame_inds']) >= 1
def test_dense_sample_frames(self):
target_keys = [
......@@ -362,6 +394,7 @@ class TestLoading(object):
test_mode=True)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 240
......@@ -376,6 +409,7 @@ class TestLoading(object):
clip_len=4, frame_interval=1, num_clips=6, temporal_jitter=False)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 24
......@@ -395,6 +429,7 @@ class TestLoading(object):
test_mode=True)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 240
......@@ -413,6 +448,7 @@ class TestLoading(object):
temporal_jitter=False)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 24
......@@ -431,6 +467,7 @@ class TestLoading(object):
temporal_jitter=False)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 24
......@@ -452,6 +489,7 @@ class TestLoading(object):
test_mode=True)
dense_sample_frames = DenseSampleFrames(**config)
dense_sample_frames_results = dense_sample_frames(video_result)
assert dense_sample_frames_results['start_index'] == 0
assert self.check_keys_contain(dense_sample_frames_results.keys(),
target_keys)
assert len(dense_sample_frames_results['frame_inds']) == 120
......@@ -461,7 +499,7 @@ class TestLoading(object):
def test_sample_proposal_frames(self):
target_keys = [
'frame_inds', 'clip_len', 'frame_interval', 'num_clips',
'total_frames'
'total_frames', 'start_index'
]
# test error cases
......@@ -475,7 +513,7 @@ class TestLoading(object):
aug_ratio=0.5,
temporal_jitter=False)
sample_frames = SampleProposalFrames(**config)
sample_frames_results = sample_frames(proposal_result)
sample_frames(proposal_result)
# test normal cases
# Sample Frame with no temporal_jitter
......@@ -839,7 +877,7 @@ class TestLoading(object):
def test_rawframe_decode(self):
target_keys = ['frame_inds', 'imgs', 'original_shape', 'modality']
# test frame selector with 2 dim input when start_index = 0
# test frame selector with 2 dim input
inputs = copy.deepcopy(self.frame_results)
inputs['frame_inds'] = np.arange(0, self.total_frames, 2)[:,
np.newaxis]
......@@ -887,7 +925,7 @@ class TestLoading(object):
320, 3)
assert results['original_shape'] == (240, 320)
# test frame selector with 1 dim input when start_index = 0
# test frame selector with 1 dim input
inputs = copy.deepcopy(self.frame_results)
inputs['frame_inds'] = np.arange(0, self.total_frames, 2)
# since the test images start with index 1, we plus 1 to frame_inds
......@@ -911,7 +949,6 @@ class TestLoading(object):
assert results['original_shape'] == (240, 320)
# test frame selector with 1 dim input for flow images
# when start_index = 0
inputs = copy.deepcopy(self.flow_frame_results)
inputs['frame_inds'] = np.arange(0, self.total_frames, 2)
# since the test images start with index 1, we plus 1 to frame_inds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册