提交 d712ac0d 编写于 作者: L limingqi107

add count of graphs using the parameter

上级 0aa9f900
......@@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(true);
return new_parameter;
......@@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
return new_parameter;
}
......
......@@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
if (!input_node->isa<Parameter>()) {
continue;
}
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
parameter->DecreaseUsedGraphCount();
// Only the parameter has no graph used, then clear the output address.
if (parameter->used_graph_count() != 0) {
continue;
}
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
}
}
// clear input value node output address.
......
......@@ -282,7 +282,7 @@ class ANode : public AnfNode {
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
......@@ -300,6 +300,10 @@ class Parameter : public ANode {
ValuePtr default_param() const { return default_param_; }
ParamInfoPtr param_info() const;
void IncreaseUsedGraphCount() { used_graph_count_++; }
void DecreaseUsedGraphCount() { used_graph_count_--; }
int used_graph_count() const { return used_graph_count_; }
bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {
return false;
......@@ -315,6 +319,8 @@ class Parameter : public ANode {
std::string name_;
bool has_default_;
ValuePtr default_param_;
// The count of graphs using the parameter.
int used_graph_count_;
};
using ParameterPtr = std::shared_ptr<Parameter>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册