提交 2f6d5f9e 编写于 作者: H HappyAngel 提交者: Xiaoyang LI

speedup fp32 depthwise conv

* update con_dw

* update

* add conv_depthwise_3x3s1.cc and conv_depthwise_3x3s2.cc

* add conv_depthwise_3x3s1_fp32 and conv_depthwise_3x3s2_fp32

* add new conv_dw

* only support conv_dw pad=0, 1

* add conv_dw_s1 conv_dw_s2 fp32

*     //conv2_func _impl2{nullptr};
update conv_dw, add conv_3x3s1 and conv_3x3s2, pad=[0,1]

* fix format, test=develop

* fix formmat, test=develop
上级 a3241ca7
......@@ -78,6 +78,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_fp32.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_3x3s1.cc
conv_depthwise_3x3s2.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc
......
......@@ -25,7 +25,6 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data,
int bs,
......
......@@ -53,6 +53,38 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_depthwise_3x3s1_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3p0_fp32(const float* din,
float* dout,
int num,
......
此差异已折叠。
此差异已折叠。
......@@ -562,9 +562,19 @@ void conv_depthwise_3x3_fp32(const void* din,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
if (pad_w != pad_h) {
LOG(FATAL) << "fp32 depthwise conv3x3 pad_w: " << pad_w
<< ", pad_h: " << pad_h << " must be equal";
return;
}
int stride = param.strides[1];
if (stride == 1) {
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
int pad = pad_w;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1 && pad < 2) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
......@@ -575,10 +585,12 @@ void conv_depthwise_3x3_fp32(const void* din,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
pad,
flag_bias,
flag_relu,
ctx);
} else if (stride == 2) {
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din),
} else if (stride == 2 && pad < 2) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
......@@ -589,10 +601,13 @@ void conv_depthwise_3x3_fp32(const void* din,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
pad,
flag_bias,
flag_relu,
ctx);
} else {
LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride << " unsupported";
LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride
<< " or pad(<2): " << pad << " unsupported";
}
#if 0
if (pad == 1) {
......
......@@ -37,12 +37,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto kh = w_dims[2];
auto cround = ROUNDUP(oc, cblock);
weights_.Resize({cround, 1, kh, kw});
auto w_data = weights_.mutable_data<float>();
auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
// auto w_data = weights_.mutable_data<float>();
// auto w_data_in = param.filter->data<float>();
// lite::arm::math::conv_trans_weights_numc(
// w_data_in, w_data, oc, 1, cblock, kh * kw);
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = true;
flag_trans_weights_ = false;
// flag_trans_weights_ = true;
} else if (kw == 5) {
VLOG(5) << "invoke 5x5 dw conv fp32";
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册