From 9d8fd9252dae7c8d48d50cdc439ca4c273e6d38f Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Mon, 7 Sep 2020 19:00:37 +0800 Subject: [PATCH] run transfermer decoder success --- mindspore/lite/src/ops/gather.cc | 2 +- .../kernel/arm/base/fullconnection_base.cc | 8 +++--- .../kernel/arm/fp32/arithmetic_self.cc | 9 ------- .../runtime/kernel/arm/int8/gather_int8.cc | 26 ++++--------------- .../src/runtime/kernel/arm/int8/gather_int8.h | 3 --- .../lite/tools/anf_exporter/anf_exporter.cc | 13 +++++----- .../anf_importer/import_from_meta_graphT.cc | 2 +- 7 files changed, 18 insertions(+), 45 deletions(-) diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 4d48933d1..25f0e9d86 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -91,7 +91,7 @@ int Gather::InferShape(std::vector inputs_, std::vector out_shape{in_shape}; out_shape.erase(out_shape.begin() + axis); for (int i = 0; i < indices_rank; i++) { - out_shape.insert(out_shape.begin() + axis, indices_shape[i]); + out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]); } output->set_shape(out_shape); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index 62dec5ad5..e7d8eb3ee 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -55,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vectordata_type() != kNumberTypeUInt8) { + if (input_tensor->data_type() != kNumberTypeInt8) { MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type(); return RET_ERROR; } @@ -63,7 +63,7 @@ int RestoreFullconnectWeight(lite::tensor::Tensor *input_tensor) { MS_LOG(ERROR) << "no quant param"; return RET_ERROR; } - const auto* quant_data = static_cast(input_tensor->Data()); + const auto* quant_data = static_cast(input_tensor->Data()); auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); if (dequant_data == nullptr) { MS_LOG(ERROR) << "malloc faile"; @@ -108,7 +108,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorData(); - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (!weight_tensor->GetQuantParams().empty()) { RestoreFullconnectWeight(inputs.at(kWeightIndex)); } auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); @@ -123,7 +123,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector(opParameter->type_)); return nullptr; } - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (!weight_tensor->GetQuantParams().empty()) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc index 7a183f7d0..db76c4e4e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -116,11 +116,6 @@ int RestoreMulWeight(lite::tensor::Tensor *input_tensor) { return RET_OK; } int ArithmeticSelfCPUKernel::Run() { - void *restore_data = nullptr; - if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) { - restore_data = in_tensors_[1]->Data(); - RestoreMulWeight(in_tensors_[1]); - } auto ret = Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; @@ -135,10 +130,6 @@ int ArithmeticSelfCPUKernel::Run() { MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; return ret; } - if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) { - in_tensors_[1]->FreeData(); - in_tensors_[1]->SetData(restore_data); - } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc index 749123770..2b897e38c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc @@ -35,29 +35,11 @@ int GatherInt8CPUKernel::Init() { axis_ = (reinterpret_cast(op_parameter_))->axis_; batchDims_ = (reinterpret_cast(op_parameter_))->batchDims_; auto in_quant_args = in_tensors_.at(0)->GetQuantParams(); - auto ind_quant_args = in_tensors_.at(1)->GetQuantParams(); auto out_quant_args = out_tensors_.at(0)->GetQuantParams(); param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; param_.zp_in_ = in_quant_args.front().zeroPoint; param_.zp_out_ = out_quant_args.front().zeroPoint; - auto indices_ptr = reinterpret_cast(in_tensors_.at(1)->Data()); - if (indices_ != nullptr) { - free(indices_); - indices_ = nullptr; - } - int count = in_tensors_.at(1)->ElementsNum(); - indices_ = reinterpret_cast(malloc(count * sizeof(int))); - if (indices_ == nullptr) { - MS_LOG(ERROR) << "Gather Malloc indices_ error!"; - return RET_ERROR; - } - (void)memset(indices_, 0, count * sizeof(int)); - for (int i = 0; i < count; ++i) { - indices_[i] = - static_cast(round((indices_ptr[i] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale)); - } - if (!InferShapeDone()) { return RET_OK; } @@ -73,6 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) { auto input_ptr = reinterpret_cast(input_tensor->Data()); auto output_ptr = reinterpret_cast(out_tensor->Data()); + auto indices_ptr = reinterpret_cast(out_tensor->Data()); auto in_shape = input_tensor->shape(); int in_rank = in_shape.size(); @@ -80,8 +63,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) { const int limit = in_shape[axis_]; for (int i = 0; i < indices_element_size; ++i) { - if (indices_[i] >= limit) { - MS_LOG(ERROR) << " indice data: " << indices_[i] << " is not in [ 0, " << limit - 1 << " ]"; + if (indices_ptr[i] >= limit) { + MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]"; return RET_ERROR; } } @@ -103,7 +86,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) { int error_code; input_ptr += thread_stride * limit; output_ptr += thread_stride * indices_element_size; - error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_, indices_element_size, param_); + error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_); if (error_code != RET_OK) { return RET_ERROR; @@ -127,6 +110,7 @@ int GatherInt8CPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; return prepare_ret; } + int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherInt8Run, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h index af00f6c08..9633d6d5a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h @@ -30,8 +30,6 @@ class GatherInt8CPUKernel : public LiteKernel { const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~GatherInt8CPUKernel() { - free(indices_); - indices_ = nullptr; } int Init() override; @@ -40,7 +38,6 @@ class GatherInt8CPUKernel : public LiteKernel { int DoGather(int task_id); private: - int *indices_ = nullptr; int thread_count_; int batchDims_; int axis_; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index daa8a341b..aef14f865 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -129,7 +129,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr & for (auto node : graph_input_nodes_) { for (auto input : node->inputIndex) { auto tensor = meta_graphT->allTensors[input].get(); - if (tensor->data.empty()) { + if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) { tensor->nodeType = schema::NodeType_ValueNode; tensor->format = schema::Format_NHWC; if (!IsContain(meta_graphT->inputIndex, input)) { @@ -261,7 +261,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr input_anod return RET_OK; } auto paramTensor = std::make_unique(); - paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->format = schema::Format_NHWC; auto abstractBase = paramNode->abstract(); if (abstractBase == nullptr) { @@ -341,11 +340,10 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrinputs().size() <= 1) { return RET_OK; } - bool is_graph_input = true; + bool is_graph_input = false; for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); if (input_node->isa()) { - is_graph_input = false; auto ret = ConvertInputCNode(input_node, fb_node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertInputCNode failed"; @@ -357,6 +355,9 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrcast()->has_default()) { + is_graph_input = true; + } } else if (input_node->isa()) { auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node); if (ret != RET_OK) { @@ -382,7 +383,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr(cnode->abstract()); for (size_t i = 0; i < tuple->size(); i++) { auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; + msTensor->nodeType = schema::NodeType_CNode; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); if (tuple->size() == 1) { node_id_map_[cnode_name] = meta_graphT->allTensors.size(); @@ -399,7 +400,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrnodeType = schema::NodeType_Parameter; + ms_tensor->nodeType = schema::NodeType_CNode; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); node_id_map_[cnode_name] = meta_graphT->allTensors.size(); meta_graphT->allTensors.emplace_back(ms_tensor); diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 4f72dbb3b..3cf04acea 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -59,8 +59,8 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { std::memcpy(tensor_data, tensor->data.data(), size); param_value->set_tensor_addr(tensor_data); param_value->set_tensor_size(size); + parameter->set_default_param(param_value); } - parameter->set_default_param(param_value); AddNode(i, parameter); } return RET_OK; -- GitLab