提交 5e8aa333 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor winograd output transpose

GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级 c6eb2e8d
......@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale) {
......@@ -284,7 +284,7 @@ void StrategyHelper<
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
......@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type* mid_buf1 = transform_mid_buf;
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start);
OutputVisitor<layout, format> output_visitor(OC);
size_t oc = oc_start + oc_index;
......
......@@ -6,8 +6,7 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
......@@ -44,8 +43,8 @@ public:
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
size_t IC, size_t ic, size_t unit_idx,
size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f);
......@@ -54,7 +53,7 @@ public:
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_index,
size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale
......
......@@ -45,7 +45,6 @@ public:
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack();
for (auto&& algo : matmul_algos) {
if (algo->algoset() ==
//! TODO: threre should filter MK matmul
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
continue;
}
......
......@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
};
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
......
......@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 {
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype);
//! AT * m * A
size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t ocb = (oc - oc_start) / 8;
size_t oc = oc_start + oc_index;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
UNROLL_CALL_NOWRAPPER(4, cb);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(m) \
v##m##0 = t##m##0 + t##m##1 + t##m##2; \
v##m##1 = t##m##1 - t##m##2 + t##m##3;
UNROLL_CALL_NOWRAPPER(2, cb);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc);
Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc);
#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(2, 2, cb);
UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(2, 2, cb);
UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb
}
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
......@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 {
ow * 8); \
} \
} while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save);
}
UNROLL_CALL_RAW_D2(2, 2, out_save);
}
};
#undef CONCAT
......@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
}
}
void winograd_nchw88_2x3_8x8_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) {
void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float,
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype,
dst_dtype);
}
}
#undef cb
}
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/unroll_macro.h"
......@@ -19,10 +20,10 @@
#include <x86intrin.h>
#ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include "midout.h"
......@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 {
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch);
size_t IW8 = IW * 8; //! For nchw88 mode
size_t IW8 = IW * 8; //! For nchw88 mode
size_t iw8_start = iw_start * 8; //! For nchw88 mode
size_t icb = ic / 8;
if (!(inner && ic + 8 < IC)) {
......@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
const float* fptr = filter +
(ocb * ICB + icb) * 3 * 3 * 8 * 8 +
ic_inner * 8;
......@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 {
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype);
//! AT * m * A
size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t ocb = (oc - oc_start) / 8;
size_t oc = oc_start + oc_index;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6,
v5subv6;
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
......@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 {
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
UNROLL_CALL_NOWRAPPER(8, cb);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(m) \
......@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 {
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
UNROLL_CALL_NOWRAPPER(6, cb);
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc);
Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc);
#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(6, 6, cb);
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(6, 6, cb);
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
......@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 {
ow * 8); \
} \
} while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save);
}
UNROLL_CALL_RAW_D2(6, 6, out_save);
}
};
#undef CONCAT
......@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha;
......@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
}
}
void winograd_nchw88_6x3_8x8_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) {
void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float,
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
src_dtype, dst_dtype);
}
}
#undef cb
}
} // namespace winograd
} // namespace arm_common
} // namespace x86
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册