提交 0aa9f900 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5749 Support multi grad

Merge pull request !5749 from amongo/SupportPyantiveMultiGrad
......@@ -1031,7 +1031,8 @@ PynativeExecutor::PynativeExecutor() {
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) {
// judge graph_context_.empty() to create sperate graphs except for the top
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
}
......@@ -1040,21 +1041,24 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
}
auto g = std::make_shared<FuncGraph>();
if (top_g_ == nullptr) {
if (graph_context_.empty()) {
// a df builder is built for every top function graph
df_builder_ = std::make_shared<FuncGraph>();
df_builder_map_[cell_id] = df_builder_;
top_g_ = curr_g_ = g;
resource_ = std::make_shared<pipeline::Resource>();
resource_->results()[pipeline::kPynativeGraphId] = graph_id_++;
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
first_grad_step_ = true;
top_graph_cells_.insert(cell_id);
Pushp();
} else {
Pushp();
if (df_builder_ == nullptr) {
MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr";
}
curr_g_ = g;
}
Pushp();
if (graph_info_map_.count(g) == 0) {
graph_info_map_[g] = GraphInfo();
}
......@@ -1171,22 +1175,25 @@ void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &pa
}
}
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); }
void PynativeExecutor::Popp() {
if (graph_p_.empty()) {
MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
if (graph_context_.empty()) {
MS_LOG(EXCEPTION) << "Stack graph_context_ is empty";
}
graph_context_.pop();
if (!graph_context_.empty()) {
curr_g_ = graph_context_.top();
}
curr_g_ = graph_p_.top();
graph_p_.pop();
}
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetCellId(cell, args);
if (cell_graph_map_.count(cell_id) != 0) {
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
MS_LOG(DEBUG) << "Endgraph already compiled";
return;
}
cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
......@@ -1246,7 +1253,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1);
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
......@@ -1257,26 +1264,29 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
}
}
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) {
if (graph_context_.size() > 1) {
Popp();
// connect the previous graph to the inside graph
auto graph_prev = graph_context_.top();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], false);
inputs.push_back(input);
}
auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetCellId(cell, args));
auto out_cnode = graph_prev->NewCNode(inputs);
set_pyobj(graph_prev, GetCellId(cell, args));
if (py::isinstance<py::tuple>(out)) {
auto out_list = py::cast<py::tuple>(out);
auto out_size = static_cast<int>(out_list.size());
for (int i = 0; i < out_size; i++) {
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i);
SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i});
}
}
set_obj_node_map(curr_g_, GetId(out), out_cnode);
set_obj_node_map(graph_prev, GetId(out), out_cnode);
} else {
parse::ResolveFuncGraph(newfg, resource_);
resource_->set_func_graph(newfg);
Popp();
}
}
......@@ -1348,14 +1358,36 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size();
std::string cell_id = GetCellId(cell, args);
if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled";
return;
}
size_t forward_args_count = args.size();
if (grad->sens_param()) {
forward_args_count = forward_args_count - 1;
}
py::tuple forward_args(forward_args_count);
for (size_t i = 0; i < forward_args_count; i++) {
forward_args[i] = args[i];
}
std::string forward_cell_id = GetCellId(cell, forward_args);
MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id;
if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) {
MS_LOG(EXCEPTION) << "Cannot find df builder";
}
df_builder_ = df_builder_map_[forward_cell_id];
if (df_builder_ == nullptr) {
MS_LOG(EXCEPTION) << "Got unexpected null df builder";
}
if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) {
MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id;
}
MS_LOG(DEBUG) << "GradNet first compiled";
resource_ = cell_resource_map_[forward_cell_id];
std::vector<AnfNodePtr> new_params;
for (size_t i = 0; i < size; i++) {
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
......@@ -1368,6 +1400,10 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
MS_EXCEPTION_IF_NULL(resource_->func_graph());
if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) {
MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id;
}
top_g_ = cell_graph_map_[forward_cell_id];
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
resource_->set_func_graph(g);
resource_->manager()->KeepRoots({g});
......@@ -1409,6 +1445,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&graph_map_, flag);
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&cell_graph_map_, flag);
MapClear<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_, flag);
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&df_builder_map_, flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
......@@ -1431,7 +1468,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
std::stack<FuncGraphPtr>().swap(graph_context_);
ConfigManager::GetInstance().ResetIterNum();
}
......@@ -1509,7 +1546,6 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase)
}
std::string backend = MsContext::GetInstance()->backend_policy();
MS_LOG(DEBUG) << "Eval run" << backend;
BaseRef value = (*run)(arg_list);
MS_LOG(DEBUG) << "Run end" << value.ToString();
......
......@@ -155,7 +155,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::stack<FuncGraphPtr> graph_p_;
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
// the stack that records the context of graph created, the bottom is the top graph
std::stack<FuncGraphPtr> graph_context_;
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
......
......@@ -21,7 +21,7 @@ from types import FunctionType
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F
......@@ -475,6 +475,7 @@ class _ListAppend(ListAppend_):
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
ListAppend_.__init__(self, name)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import context, nn, Tensor, Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
class _Grad(nn.Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
if self.wrt_params:
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network, self.params)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class GradOfFirstInput(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
class GradOfAllInputs(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
def test_multi_grad():
class ForwardNetMul(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y):
a = x * x
b = y * y
return a * b
class ForwardNetAdd(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y):
a = x + x + x
b = y + y
return a * b
mulnet = ForwardNetMul()
addnet = ForwardNetAdd()
x = Tensor(np.ones([32]), dtype=mstype.float32)
y = Tensor(np.ones([32])*2, dtype=mstype.float32)
sens = Tensor(np.ones([32]), dtype=mstype.float32)
mulnet.set_grad()
addnet.set_grad()
out1 = mulnet(x, y)
out2 = addnet(x, y)
grad_mul = GradOfAllInputs(mulnet)
grad_add = GradOfAllInputs(addnet)
grad_mul(x, y, sens)
grad_add(x, y, sens)
def test_multi_same_grad():
class ForwardNetMul(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y):
a = x * x
b = y * y
return a * b
class ForwardNetAdd(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y):
a = x*3
b = y*2
return a + b
mulnet = ForwardNetMul()
addnet = ForwardNetAdd()
x = Tensor(np.ones([32]), dtype=mstype.float32)
y = Tensor(np.ones([32]), dtype=mstype.float32)
sens = Tensor(np.ones([32]), dtype=mstype.float32)
mulnet.set_grad()
addnet.set_grad()
out1 = mulnet(x, y)
out2 = addnet(x, y)
grad_mul = GradOfAllInputs(mulnet)
grad_add = GradOfFirstInput(mulnet)
grad_mul(x, y, sens)
grad_add(x, y, sens)
def test_net_inner_grad():
class ForwardNetMul(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y):
a = x * x
b = y * y
return a * b
class ForwardNetAdd(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x, y):
a = x + x
b = y + y
res = self.net(a, b)
return res
mulnet = ForwardNetMul()
addnet = ForwardNetAdd(mulnet)
x = Tensor(np.ones([32]), dtype=mstype.float32)
y = Tensor(np.ones([32]), dtype=mstype.float32)
sens = Tensor(np.ones([32]), dtype=mstype.float32)
mulnet.set_grad()
addnet.set_grad()
out1 = mulnet(x, y)
out2 = addnet(x, y)
grad_mul = GradOfAllInputs(addnet)
grad_add = GradOfAllInputs(mulnet)
grad_mul(x, y, sens)
grad_add(x, y, sens)
def test_net_inner_first_run_grad():
class ForwardNetMul(nn.Cell):
def __init__(self):
super().__init__()
self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1')
def construct(self, x, y):
a = x * self.z1
b = y * y
return a * b
class ForwardNetAdd(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
def construct(self, x, y):
a = x + x*self.z3
b = y + y*self.z2
res = self.net(a, b)
return res
mulnet = ForwardNetMul()
addnet = ForwardNetAdd(mulnet)
x = Tensor(np.ones([32]), dtype=mstype.float32)
y = Tensor(np.ones([32]), dtype=mstype.float32)
sens = Tensor(np.ones([32]), dtype=mstype.float32)
mulnet.set_grad()
addnet.set_grad()
out1 = mulnet(x, y)
out2 = addnet(x, y)
grad_mul = GradOfAllInputs(addnet)
grad_add = GradOfFirstInput(mulnet)
grad_mul(x, y, sens)
grad_add(x, y, sens)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册