未验证 提交 d84a30b2 编写于 作者: C Chen Weihang 提交者: GitHub

[cherry-pick] append scale to static runner and remove loader place (#24649)

* Append scale for static runner outputs (#24627)

* add scale for static runner outputs, test=develop

* fix import relation, test=develop

* remove len limit, test=develop

* remove imperative data loader place limit, test=develop (#24641)
上级 62047d30
......@@ -24,6 +24,7 @@ from .. import core
from .. import framework
from .. import backward
from ..layers import nn
from .base import switch_to_static_graph
from ... import compat as cpt
......@@ -359,8 +360,27 @@ class StaticModelRunner(layers.Layer):
# NOTE: reverse feed vars
self._input_names.reverse()
# Step 4. add scale for outputs
tmp_program = self._build_program_by_desc(program_desc)
self._append_scale_to_output(tmp_program)
return program_desc
@switch_to_static_graph
def _append_scale_to_output(self, program):
# 1. append scale & save var
scale_output_vars = []
with framework.program_guard(program):
for i, out in enumerate(self._output_descs):
var = program.global_block().var(out.name())
var = nn.scale(
var, 1., name="static_model_runner/scale_{}".format(i))
scale_output_vars.append(var)
# 2. update output names & descs
for i, var in enumerate(scale_output_vars):
self._output_names[i] = var.name
self._output_descs[i] = var.desc
@switch_to_static_graph
def _append_backward_desc(self):
assert self._infer_program_desc is not None, "The StaticModelRunner not initialized properly."
......
......@@ -18,7 +18,7 @@ import six
import numpy as np
import threading
import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
......@@ -671,12 +671,12 @@ class DygraphGeneratorLoader(DataLoaderBase):
if not iterable:
logging.warning(
"Please NOTE: dygraph can support iterable mode only. Change to iterable mode."
"Please NOTE: imperative mode can support iterable mode only. Change to iterable mode."
)
self._iterable = True
if not return_list:
logging.warning(
"Please NOTE: dygraph can support return as list only. Change to return as list."
"Please NOTE: imperative mode can support return as list only. Change to return as list."
)
self._return_list = True
......@@ -941,10 +941,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
def set_batch_generator(self, reader, places=None):
self._batch_reader = reader
assert places is not None, "Places cannot be None when DataLoader is iterable"
if places is None:
places = _current_expected_place()
self._places = _convert_places(places)
assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode"
"Number of places must be 1 in imperative mode"
return self
......
......@@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase):
self.epoch_num = 1
self.capacity = 5
def iter_loader_data(self, loader):
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_single_process_loader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
......@@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
self.iter_loader_data(loader)
def test_multi_process_loader(self):
with fluid.dygraph.guard():
......@@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
self.iter_loader_data(loader)
def test_generator_no_places(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(capacity=self.capacity)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size)
self.iter_loader_data(loader)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册