提交 d0410d61 编写于 作者: X xiefangqi

md delete set_dataset_size interface

上级 c45f79d3
......@@ -3052,14 +3052,6 @@ class MindDataset(MappableDataset):
self.dataset_size = num_rows
return self.dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
if self.shuffle_option is None:
return True
......@@ -3503,13 +3495,6 @@ class GeneratorDataset(MappableDataset):
self.dataset_size = num_rows
return self.dataset_size
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
if value >= 0:
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
......@@ -3696,14 +3681,6 @@ class TFRecordDataset(SourceDataset):
self.dataset_size = self.num_samples
return self.dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
return self.shuffle_files
......
......@@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset.set_dataset_size(len(sampler))
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers,
operations=transform_img)
......
......@@ -156,7 +156,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_s
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset.set_dataset_size(len(sampler))
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
......
......@@ -81,7 +81,6 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
image_trans = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
......
......@@ -173,7 +173,6 @@ def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
yield train_eval_gen.__next__()
ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"])
ds.set_dataset_size(numbers_of_batch)
ds = ds.repeat(epochs)
return ds
......
......@@ -165,7 +165,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
yield train_eval_gen.__next__()
ds = de.GeneratorDataset(_iter_h5_data(), ["ids", "weights", "labels"])
ds.set_dataset_size(numbers_of_batch)
ds = ds.repeat(epochs)
return ds
......
......@@ -161,7 +161,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
ds = de.GeneratorDataset(_iter_h5_data(),
["ids", "weights", "labels"])
ds.set_dataset_size(numbers_of_batch)
ds = ds.repeat(epochs)
return ds
......
......@@ -23,8 +23,7 @@ from mindspore import log as logger
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = epoch_size
......@@ -40,10 +39,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
......
......@@ -94,11 +94,9 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shu
"""
# create iter dataset
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)
# wrapped with GeneratorDataset
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
dataset.set_dataset_size(dataset_len)
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
channelswap_op = C.HWC2CHW()
......
......@@ -262,9 +262,6 @@ def test_concat_12():
data1 = ds.GeneratorDataset(generator, ["col1"])
data2 = ds.GeneratorDataset(generator_10, ["col1"])
data1.set_dataset_size(3)
data2.set_dataset_size(7)
data3 = data1 + data2
res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
......@@ -288,9 +285,6 @@ def test_concat_13():
data1 = ds.GeneratorDataset(generator, ["col1"])
data2 = ds.GeneratorDataset(generator_20, ["col1"])
data1.set_dataset_size(3)
data2.set_dataset_size(10)
data1 = data1.batch(3)
data2 = data2.batch(5)
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
......@@ -161,18 +159,6 @@ def test_imagefolder():
assert data.num_classes() == 4
def test_generator():
def generator():
for i in range(64):
yield (np.array([i]),)
data1 = ds.GeneratorDataset(generator, ["data"])
data1.set_dataset_size(10)
assert data1.get_dataset_size() == 10
data1.output_shapes()
assert data1.get_dataset_size() == 10
if __name__ == '__main__':
# test_compare_v1_and_2()
# test_imagefolder()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册