未验证 提交 069d8b09 编写于 作者: L lilong12 提交者: GitHub

update (#72)

上级 ce760c9f
......@@ -18,6 +18,7 @@ from __future__ import print_function
import errno
import json
import os
import math
import shutil
import subprocess
import sys
......@@ -142,7 +143,7 @@ class Entry(object):
self.log_period = 200
self.input_info = [{'name': 'image',
'shape': [-1, 3, 224, 224],
'shape': [-1, 3, 112, 112],
'dtype': 'float32'},
{'name': 'label',
'shape':[-1, 1],
......@@ -957,9 +958,8 @@ class Entry(object):
self.load_checkpoint(executor=exe, main_program=origin_prog)
if self.train_reader is None:
train_reader = paddle.batch(reader.arc_train(
self.dataset_dir, self.num_classes),
batch_size=self.train_batch_size)
train_reader = reader.arc_train(
self.dataset_dir, self.num_classes)
else:
train_reader = self.train_reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册