提交 72f77bde 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4363 export and load model for serving

Merge pull request !4363 from hexia/export_and_load_model_for_serving
......@@ -90,14 +90,17 @@ class IrExportBuilder {
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
std::string suffix = "0");
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
......@@ -105,8 +108,10 @@ class IrExportBuilder {
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
size_t AllocateIndex() { return ++node_index_; }
void ResetIndex() { node_index_ = 0; }
size_t GetNodeIndex() { return ++node_index_; }
void ResetNodeIndex() { node_index_ = 0; }
size_t GetTupleIndex() { return ++shape_index_; }
void ResetTupleIndex() { shape_index_ = 0; }
private:
onnx::ModelProto model_;
......@@ -114,6 +119,7 @@ class IrExportBuilder {
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_{0};
size_t shape_index_{0};
};
using IrExporterPtr = std::shared_ptr<IrExporter>;
......@@ -146,7 +152,7 @@ void IrExportBuilder::BuildModelInfo() {
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
onnx::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
ResetIndex();
ResetNodeIndex();
todo_.clear();
todo_.push_back(func_graph);
while (!todo_.empty()) {
......@@ -177,7 +183,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
if (!param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default";
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
continue;
}
......@@ -232,13 +238,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
if (dims.size() == 0) {
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
} else {
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
type_proto->set_denotation(std::to_string(tup_shape->shape().size()));
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
} else if (type->isa<Number>() || type->isa<String>()) {
type_proto->set_denotation(type->type_name());
} else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
}
......@@ -248,9 +261,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("tensor");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
attr_proto->set_ref_attr_name("tensor:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name("value0");
auto data = value->cast<tensor::TensorPtr>();
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
......@@ -284,6 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, onnx::Ten
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
bool is_only_return = true;
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
......@@ -291,9 +306,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
}
auto cnode = node->cast<CNodePtr>();
if (cnode == func_graph->get_return()) {
if (is_only_return) {
MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!";
}
BuildOutput(cnode, graph_proto);
} else {
BuildCNode(cnode, graph_proto);
is_only_return = false;
}
}
}
......@@ -303,24 +322,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
// Using make_tuple to set multi-output
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
auto tuple_node = arg->cast<CNodePtr>();
for (size_t i = 1; i < tuple_node->size(); i++) {
auto input_node = arg->cast<CNodePtr>()->input(i);
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
auto output_name = GetUniqueNodeName(tuple_node->input(i));
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(tuple_node->input(i), output_proto);
}
} else {
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(arg, output_proto);
}
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->set_output(0, output_name);
SetValueInfoProto(arg, output_proto);
}
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
......@@ -343,45 +349,44 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::NodeProto *const node_proto, std::string suffix) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_ref_attr_name("shape");
if (suffix.compare("0") != 0) {
attr_proto->set_name("shape" + suffix);
} else {
attr_proto->set_name("shape");
}
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetTensorProto(type, shape, tensor_proto);
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
// Get shape of cnode
// 1. prim ArgMaxWithValue need to get shape from tuple element
// 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape
if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
}
onnx::AttributeProto *const attr_proto, std::string *const seq_string) {
if (type->isa<Tuple>() && seq_string != nullptr) {
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
for (size_t i = 0; i < elements.size(); i++) {
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
}
*seq_string += "],";
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) {
string shape_name = "shape" + std::to_string(GetTupleIndex());
*seq_string += shape_name + ",";
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
SetTensorProto(type, shape, tensor_proto);
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) {
*seq_string += type->type_name() + ",";
} else {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString();
return;
}
SetShapeToNodeProto(type, shape, node_proto);
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
}
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
// Get shape of cnode
// 1. need to get shape from tuple element
// 2. save shape in TensorProto
// 3. save tuple string in ref_attr_name
MS_EXCEPTION_IF_NULL(node);
auto type = node->Type();
auto shape = node->Shape();
ResetTupleIndex();
std::string seq_string = "shape:";
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
}
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
auto inputs_size = node->size();
if (inputs_size < 1) {
......@@ -443,15 +448,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
std::string node_name = "";
if (node->isa<Parameter>()) {
node_name = GetNodeName(node);
} else if (node->isa<CNode>() || node->isa<ValueNode>()) {
} else if (node->isa<CNode>()) {
auto iter = node_index_map_.find(node);
if (iter != node_index_map_.end()) {
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
} else {
auto node_idx = AllocateIndex();
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
}
} else if (node->isa<ValueNode>()) {
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
} else {
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
}
......@@ -485,17 +494,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("type");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
if (value->isa<Int>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
} else if (value->isa<TensorType>()) {
tensor_proto->set_name("tensor");
attr_proto->set_ref_attr_name("type:tensor0");
tensor_proto->set_name("tensor0");
auto elem_type = value->cast<TensorTypePtr>()->element();
if (elem_type->isa<Int>()) {
auto int_value = elem_type->cast<IntPtr>();
......@@ -519,10 +532,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
SetScalarToAttributeProto(value, attr_proto);
} else if (value->isa<Number>() || value->isa<TensorType>()) {
SetTypeToAttributeProto(value, attr_proto);
} else if (value->isa<ValueSequeue>()) {
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto);
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) {
ResetTupleIndex();
std::string seq_string = "scalar:";
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "Attr string: " << seq_string;
} else if (value->isa<tensor::Tensor>()) {
SetTensorToAttributeProto(value, attr_proto);
} else if (value->isa<None>()) {
attr_proto->set_ref_attr_name("none");
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
}
......@@ -532,16 +553,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetScalarToProto(value, tensor_proto);
attr_proto->set_ref_attr_name("scalar:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, "value0");
}
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto,
const std::string &value_name) {
if (value == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
}
tensor_proto->set_name(value_name);
if (value->isa<StringImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
tensor_proto->add_string_data(GetValue<std::string>(value));
......@@ -560,44 +583,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto
} else if (value->isa<Int64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
} else if (value->isa<FloatImm>()) {
} else if (value->isa<UInt8Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8);
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value());
} else if (value->isa<UInt16Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16);
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value());
} else if (value->isa<UInt32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32);
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value());
} else if (value->isa<UInt64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64);
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value());
} else if (value->isa<FP32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
tensor_proto->add_float_data(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
tensor_proto->add_double_data(GetValue<double>(value));
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
}
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
onnx::AttributeProto *const attr_proto) {
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string) {
string value_name = "value" + std::to_string(GetTupleIndex());
if (seq_string != nullptr) {
*seq_string += value_name + ",";
}
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, value_name);
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
if (value->isa<ValueTuple>()) {
if (value->isa<ValueTuple>() && seq_string != nullptr) {
*seq_string += "Tuple[";
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
if (tuple_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
return;
}
auto type_id = tuple_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : tuple_value->value()) {
SetScalarToProto(item, tensor_proto);
if (item->isa<ValueTuple>()) {
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
}
}
} else if (value->isa<ValueList>()) {
*seq_string += "],";
} else if (value->isa<ValueList>() && seq_string != nullptr) {
*seq_string += "List[";
const ValueListPtr &list_value = value->cast<ValueListPtr>();
if (list_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
return;
}
auto type_id = list_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : list_value->value()) {
SetScalarToProto(item, tensor_proto);
if (item->isa<ValueList>()) {
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
}
}
*seq_string += "],";
}
}
......
......@@ -57,7 +57,7 @@ int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string file
bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
int fd = open(onnx_file.get(), O_RDONLY);
int fd = open(modelFile.c_str(), O_RDONLY);
if (fd < 0) {
MS_LOG(EXCEPTION) << "failed to open file";
}
......
......@@ -18,8 +18,12 @@
#include <functional>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "frontend/operator/ops.h"
......@@ -55,6 +59,97 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
};
template <typename T, typename P>
std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
std::stack<std::string> rules;
std::stack<P> value;
int count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<P> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<T>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
}
}
}
return {};
}
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("scalar:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<ValueTuple>(str, kv);
if (!result) {
return {};
}
return result;
}
std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("shape:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
if (!result) {
return {};
}
return result;
}
#if 0
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
......@@ -67,9 +162,16 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
ParserScalarAttrValue(prim, attr_name, attr_value_vec); \
} \
}
#endif
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
......@@ -110,6 +212,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
// tensor_info->MallocData();
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
......@@ -167,45 +270,35 @@ bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const
return true;
}
bool MSANFModelParser::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: {
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_string_string(attr_tensor);
}
case onnx::TensorProto_DataType_INT32: {
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_int32_int32(attr_tensor);
}
case onnx::TensorProto_DataType_INT64: {
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_int64_int64(attr_tensor);
}
case onnx::TensorProto_DataType_UINT64: {
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_uint64_uint64(attr_tensor);
}
case onnx::TensorProto_DataType_FLOAT: {
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_float_float(attr_tensor);
}
case onnx::TensorProto_DataType_DOUBLE: {
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_double_double(attr_tensor);
}
case onnx::TensorProto_DataType_BOOL: {
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
auto value = prim->GetAttr(attr_name);
break;
return ParseAttrInScalar_int32_bool(attr_tensor);
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return false;
return {};
}
return true;
return {};
}
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
......@@ -223,21 +316,48 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_TYPE: {
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
break;
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
break;
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
case FORM_PARSE_TENSOR: {
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
}
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
prim->AddAttr(attr_name, iter->second);
} else {
auto res = ParserScalarAttrValue(ref_attr_name, kv);
prim->AddAttr(attr_name, res);
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
return true;
}
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
......@@ -247,6 +367,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
// tensor_info->MallocData();
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
......@@ -324,22 +445,58 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
return true;
}
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_SCALAR: {
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
const onnx::AttributeProto &attr_proto) {
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
auto attr_name = attr_tensor.name();
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
ValueNodePtr new_value_node;
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
new_value_node = NewValueNode(iter->second);
new_value_node->set_abstract(iter->second->ToAbstract());
} else {
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
new_value_node = NewValueNode(value_ptr);
new_value_node->set_abstract(value_ptr->ToAbstract());
}
default:
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
return false;
anfnode_build_map_[value_node_name] = new_value_node;
}
return true;
}
bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
......@@ -349,24 +506,26 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_pr
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
ShapeVector shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.t();
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape_vec.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
return abstract;
std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
const onnx::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
ShapeVector shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(j));
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
}
return kv;
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
......@@ -383,21 +542,13 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type);
AbstractBasePtr abstract = nullptr;
AbstractBasePtr abstract_first = nullptr;
AbstractBasePtr abstract_second = nullptr;
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.name() == kCNodeShapeAttr) {
abstract = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape1Attr) {
abstract_first = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape2Attr) {
abstract_second = GetAbstractForCNode(attr_proto);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
......@@ -419,24 +570,17 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
}
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (node_type == "LayerNorm") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
elem.push_back(abstract_second);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (node_type == "ArgMaxWithValue") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (nullptr == abstract) {
if (0 == kv.size()) {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (1 == kv.size()) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
cnode_ptr->set_abstract(abstract);
}
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
......@@ -471,19 +615,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
} else {
const onnx::ValueInfoProto &output_node = importProto.output(0);
const onnx::TypeProto &output_typeproto = output_node.type();
int output_type = output_typeproto.tensor_type().elem_type();
ShapeVector output_shape;
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
}
tensor::TensorPtr tensor_return =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[output_type], output_shape);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_abstract(tensor_return->ToAbstract());
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}
......
......@@ -52,18 +52,17 @@ class MSANFModelParser {
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
const onnx::AttributeProto &attr_proto);
std::string producer_name_;
int model_version_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册