提交 5e496435 编写于 作者: Q qingqing01

Clean code

上级 2ad718f7
...@@ -97,7 +97,7 @@ class PadTarget(object): ...@@ -97,7 +97,7 @@ class PadTarget(object):
return samples return samples
class MyBatchSampler(fluid.io.BatchSampler): class BatchSampler(fluid.io.BatchSampler):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
...@@ -178,8 +178,6 @@ class OCRDataset(paddle.io.Dataset): ...@@ -178,8 +178,6 @@ class OCRDataset(paddle.io.Dataset):
h, w = int(h), int(w) h, w = int(h), int(w)
labels = [int(c) for c in labels.split(',')] labels = [int(c) for c in labels.split(',')]
self._sample_infos.append(SampleInfo(i, h, w, im_name, labels)) self._sample_infos.append(SampleInfo(i, h, w, im_name, labels))
#self._sample_infos = sorted(self._sample_infos,
# key=lambda x: x.w)
def __getitem__(self, idx): def __getitem__(self, idx):
info = self._sample_infos[idx] info = self._sample_infos[idx]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -76,7 +76,7 @@ def main(FLAGS): ...@@ -76,7 +76,7 @@ def main(FLAGS):
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.BatchSampler(
test_dataset, test_dataset,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
...@@ -125,7 +125,7 @@ def beam_search(FLAGS): ...@@ -125,7 +125,7 @@ def beam_search(FLAGS):
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.BatchSampler(
test_dataset, test_dataset,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -100,7 +100,7 @@ def main(FLAGS): ...@@ -100,7 +100,7 @@ def main(FLAGS):
train_dataset = data.train() train_dataset = data.train()
train_collate_fn = BatchCompose( train_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
train_sampler = data.MyBatchSampler( train_sampler = data.BatchSampler(
train_dataset, batch_size=FLAGS.batch_size, shuffle=True) train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
train_loader = fluid.io.DataLoader( train_loader = fluid.io.DataLoader(
train_dataset, train_dataset,
...@@ -112,7 +112,7 @@ def main(FLAGS): ...@@ -112,7 +112,7 @@ def main(FLAGS):
test_dataset = data.test() test_dataset = data.test()
test_collate_fn = BatchCompose( test_collate_fn = BatchCompose(
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
test_sampler = data.MyBatchSampler( test_sampler = data.BatchSampler(
test_dataset, test_dataset,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
......
"""Contains common utility functions.""" """Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册