提交 f943b817 编写于 作者: H Haitao

Add stride slice op support


Former-commit-id: 21179e30e4299f3c5295bfc912984291daa43ce0
上级 45e3998d
......@@ -23,6 +23,7 @@ obj-y+=batchnorm.o
obj-y+=scale.o
obj-y+=logistic.o
obj-y+=detection_postprocess.o
obj-y+=stridedslice.o
obj-y+=fused/
obj-y+=init.o
......
......@@ -48,7 +48,7 @@ extern void RegisterReLuNodeExec(void);
extern void RegisterResizeNodeExec(void);
extern void RegisterLogisticNodeExec(void);
extern void RegisterDetectionPostProcessNodeExec(void);
extern void RegisterStridedSliceNodeExec(void);
#ifdef CONFIG_ARCH_BLAS
extern void RegisterConvBlasNodeExec(void);
extern void RegisterDeconvBlasNodeExec(void);
......@@ -87,7 +87,7 @@ void RegisterCommonOps(void)
RegisterResizeNodeExec();
RegisterLogisticNodeExec();
RegisterDetectionPostProcessNodeExec();
RegisterStridedSliceNodeExec();
#ifdef CONFIG_ARCH_BLAS
RegisterConvBlasNodeExec();
RegisterDeconvBlasNodeExec();
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2019, Open AI Lab
* Author: chunyinglv@openailab.com
*/
#include <iostream>
#include <functional>
#include <cstring>
#include <algorithm>
#include "logger.hpp"
#include "node_ops.hpp"
#include "tensor_mem.hpp"
#include "graph.hpp"
#include "operator/stridedslice.hpp"
namespace TEngine {
namespace StridedSliceImpl {
struct StridedSliceOps : public NodeOps
{
bool Run(Node* node) override
{
const Tensor * input_tensor=node->GetInputTensor(0);
Tensor * output_tensor=node->GetOutputTensor(0);
float * input=(float *)get_tensor_mem(input_tensor);
float * output=(float *)get_tensor_mem(output_tensor);
StridedSlice* slice_op = dynamic_cast<StridedSlice*>(node->GetOp());
StridedSliceParam* param = slice_op->GetParam();
const TShape& shape=input_tensor->GetShape();
const std::vector<int>& in_dim = shape.GetDim();
const TShape& shape1=output_tensor->GetShape();
const std::vector<int>& out_dim = shape1.GetDim();
int out_w = out_dim[3];
int out_hw = out_dim[2] * out_w;
int out_chw = out_dim[1] * out_hw;
int in_w = in_dim[3];
int in_hw = in_dim[2] * in_w;
int in_chw = in_dim[1] * in_hw;
for(int n=0;n<out_dim[0];n++)
{
for(int c=0;c<out_dim[1];c++)
{
for(int h=0;h<out_dim[2];h++)
{
for(int w=0;w<out_dim[3];w++)
{
output[n*out_chw + c*out_hw + h*out_w + w ] =
input[(param->begin[0]+ n*param->stride[0])*in_chw +
(param->begin[1]+ c*param->stride[1])*in_hw +
(param->begin[2]+ h*param->stride[2])*in_w +
(param->begin[3]+ w*param->stride[3])];
}
}
}
}
return true;
}
};
NodeOps* SelectFunc(const CPUInfo* cpu_info, Node* node)
{
Tensor* input = node->GetInputTensor(0);
const int data_type = input->GetDataType();
const ExecAttr* exec_attr = any_cast<const ExecAttr*>(node->GetAttr(ATTR_EXEC_ATTR));
if(data_type != TENGINE_DT_FP32 || exec_attr->graph_layout != TENGINE_LAYOUT_NCHW)
return nullptr;
StridedSliceOps* ops = new StridedSliceOps();
return ops;
}
} // namespace StridedSliceImpl
using namespace StridedSliceImpl;
void RegisterStridedSliceNodeExec(void)
{
NodeOpsRegistryManager::RegisterOPImplementor("common", "StridedSlice", StridedSliceImpl::SelectFunc, 1000);
}
} // namespace TEngine
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2019, Open AI Lab
* Author: chunyinglv@openailab.com
*/
#ifndef __STRIDEDSLICE_HPP__
#define __STRIDEDSLICE_HPP__
#include "operator.hpp"
#include "stridedslice_param.hpp"
namespace TEngine {
class StridedSlice : public OperatorWithParam<StridedSlice, StridedSliceParam>
{
public:
StridedSlice()
{
name_ = "StridedSlice";
}
StridedSlice(const StridedSlice& src) = default;
virtual ~StridedSlice() {}
bool InferShape(const std::vector<TEngine::TShape>& ishape, std::vector<TEngine::TShape>& oshape,
int layout) override;
};
} // namespace TEngine
#endif
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2019, Open AI Lab
* Author: chunyinglv@openailab.com
*/
#ifndef __STRIDEDSLICE_PARAM_HPP__
#define __STRIDEDSLICE_PARAM_HPP__
#include "parameter.hpp"
namespace TEngine {
struct StridedSliceParam : public NamedParam
{
int begin[4];
int end[4];
int stride[4];
};
} // namespace TEngine
#endif
......@@ -44,3 +44,4 @@ obj-y+=squeeze.o
obj-y+=swap_axis.o
obj-y+=tanh.o
obj-y+=gru.o
obj-y+=stridedslice.o
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2019, Open AI Lab
* Author: chunyinglv@openailab.com
*/
#include "operator/stridedslice.hpp"
namespace TEngine {
// bool StridedSlice::InferShape(const std::vector<TEngine::TShape>& ishape, std::vector<TEngine::TShape>& oshape, int layout)
// {
// const TShape& input = ishape[0];
// const std::vector<int>& in_dim = input.GetDim();
// std::vector<int> o_dim=input.GetDim();
// o_dim[0]= in_dim[0];//- param_.begin[0] + param_.end[0];
// o_dim[1]= in_dim[1];//- param_.begin[1] + param_.end[1];
// o_dim[2]= in_dim[2];//- param_.begin[2] + param_.end[2];
// o_dim[3]= in_dim[3];//- param_.begin[3] + param_.end[3];
// TShape shape;
// shape.SetDim(o_dim);
// shape.SetDataLayout(input.GetDataLayout());
// oshape[0] = shape;
// return true;
// }
bool StridedSlice::InferShape(const std::vector<TEngine::TShape>& ishape, std::vector<TEngine::TShape>& oshape, int layout)
{
const TShape& input = ishape[0];
const std::vector<int>& in_dim = input.GetDim();
std::vector<int> o_dim(4);
o_dim[0]= (in_dim[0]- param_.begin[0] + param_.end[0])/param_.stride[0];
o_dim[1]= (in_dim[1]- param_.begin[1] + param_.end[1])/param_.stride[1];
o_dim[2]= (in_dim[2]- param_.begin[2] + param_.end[2])/param_.stride[2];
o_dim[3]= (in_dim[3]- param_.begin[3] + param_.end[3])/param_.stride[3];
TShape shape;
shape.SetDim(o_dim);
shape.SetDataLayout(input.GetDataLayout());
oshape[0] = shape;
return true;
}
} // namespace TEngine
......@@ -52,6 +52,7 @@
#include "operator/lstm_param.hpp"
#include "operator/rnn_param.hpp"
#include "operator/gru_param.hpp"
#include "operator/stridedslice_param.hpp"
#include "operator_manager.hpp"
#include "type_name.hpp"
......@@ -3204,6 +3205,37 @@ static bool LoadGRU(TFNode* tf_node, TFGraph& tf_graph, StaticGraph* graph)
return true;
}
static bool LoadStridedSlice(TFNode* tf_node, TFGraph& tf_graph, StaticGraph* graph)
{
StaticNode* node = tf_node->static_node;
for(unsigned int i = 0; i < tf_node->inputs.size(); i++)
{
AddNodeInputTensor(node, tf_node->inputs[i]->static_tensor);
}
StridedSliceParam param = any_cast<StridedSliceParam>(OpManager::GetOpDefParam("StridedSlice"));
int* begins = ( int* )LoadConstParam(tf_node->inputs[1]);
int* ends = ( int* )LoadConstParam(tf_node->inputs[2]);
int* strides = ( int* )LoadConstParam(tf_node->inputs[3]);
// tengine NCHW layout
param.begin[nhwc_axis_swap[0]]=begins[0];
param.end[nhwc_axis_swap[0]] =ends[0];
param.stride[nhwc_axis_swap[0]]=strides[0];
param.begin[nhwc_axis_swap[1]]=begins[1];
param.end[nhwc_axis_swap[1]] =ends[1];
param.stride[nhwc_axis_swap[1]]=strides[1];
param.begin[nhwc_axis_swap[2]]=begins[2];
param.end[nhwc_axis_swap[2]] =ends[2];
param.stride[nhwc_axis_swap[2]]=strides[2];
param.begin[nhwc_axis_swap[3]]=begins[3];
param.end[nhwc_axis_swap[3]] =ends[3];
param.stride[nhwc_axis_swap[3]]=strides[3];
StaticOp* op = CreateStaticOp(graph, "StridedSlice");
SetOperatorParam(op, param);
SetNodeOp(node, op);
return true;
}
} // namespace tf_serializer
......@@ -3245,6 +3277,7 @@ bool TFSerializerRegisterOpLoader(void)
p_tf->RegisterOpLoadMethod("LSTM", op_load_t(LoadLSTM));
p_tf->RegisterOpLoadMethod("RNN", op_load_t(LoadRNN));
p_tf->RegisterOpLoadMethod("GRU", op_load_t(LoadGRU));
p_tf->RegisterOpLoadMethod("StridedSlice", op_load_t(LoadStridedSlice));
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册