提交 5e8aa333 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor winograd output transpose

GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级 c6eb2e8d
...@@ -235,7 +235,7 @@ void StrategyHelper< ...@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type* input_transform_buf, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile, size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float rescale) { float rescale) {
...@@ -284,7 +284,7 @@ void StrategyHelper< ...@@ -284,7 +284,7 @@ void StrategyHelper<
const output_compute_type* bias, dst_type* output, const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
...@@ -296,7 +296,7 @@ void StrategyHelper< ...@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type* mid_buf1 = transform_mid_buf; output_compute_type* mid_buf1 = transform_mid_buf;
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
OutputGetter<output_compute_type, dst_type> getter(dtype); OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); OutputVisitor<layout, format> output_visitor(OC);
size_t oc = oc_start + oc_index; size_t oc = oc_start + oc_index;
......
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* implied.
*/ */
#pragma once #pragma once
...@@ -44,8 +43,8 @@ public: ...@@ -44,8 +43,8 @@ public:
input_filter_compute_type* input_transform_buf, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, size_t IC, size_t ic, size_t unit_idx,
size_t m, size_t r, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f); float rescale = 1.0f);
...@@ -54,7 +53,7 @@ public: ...@@ -54,7 +53,7 @@ public:
const output_compute_type* bias, dst_type* output, const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_index, size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale float input_filter_scale = 1.0f, // input_scale * filter_scale
......
...@@ -45,7 +45,6 @@ public: ...@@ -45,7 +45,6 @@ public:
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack();
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->algoset() == if (algo->algoset() ==
//! TODO: threre should filter MK matmul
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
continue; continue;
} }
......
...@@ -536,7 +536,6 @@ public: ...@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \ NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \ size_t nr_tiles_in_unit); \
}; };
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
......
...@@ -186,18 +186,16 @@ struct OutputTransform2X3_NCHW88 { ...@@ -186,18 +186,16 @@ struct OutputTransform2X3_NCHW88 {
float* output, float* transform_mid_buf, float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx,
const DType& src_dtype, const DType& dst_dtype) { size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf); MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype); Op op(src_dtype, dst_dtype);
//! AT * m * A //! AT * m * A
size_t OCB = (oc_end - oc_start) / 8; size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { size_t oc = oc_start + oc_index;
size_t ocb = (oc - oc_start) / 8; size_t ocb = oc_index / 8;
#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \ auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \ output_transform_buf + \
...@@ -254,7 +252,6 @@ struct OutputTransform2X3_NCHW88 { ...@@ -254,7 +252,6 @@ struct OutputTransform2X3_NCHW88 {
} while (0); } while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save); UNROLL_CALL_RAW_D2(2, 2, out_save);
} }
}
}; };
#undef CONCAT #undef CONCAT
} // namespace } // namespace
...@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, ...@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
} }
} }
void winograd_nchw88_2x3_8x8_f::output( void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf,
const float* output_transform_buf, const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, float* transform_mid_buf, BiasMode bmode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, NonlineMode nonline_mode, size_t OH,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \ #define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__); __VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS( DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2,
float, bmode, nonline_mode, output_transform_buf, bias, output, float, float, bmode, nonline_mode, output_transform_buf,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype,
dst_dtype);
}
}
#undef cb #undef cb
} }
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
...@@ -19,10 +20,10 @@ ...@@ -19,10 +20,10 @@
#include <x86intrin.h> #include <x86intrin.h>
#ifdef WIN32CMAKE #ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h> #include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h> #include <fmaintrin.h>
#include <smmintrin.h>
#endif #endif
#include "midout.h" #include "midout.h"
...@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { ...@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) { for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
const float* fptr = filter + const float* fptr = filter +
(ocb * ICB + icb) * 3 * 3 * 8 * 8 + (ocb * ICB + icb) * 3 * 3 * 8 * 8 +
ic_inner * 8; ic_inner * 8;
...@@ -220,18 +221,17 @@ struct OutputTransform6X3_NCHW88 { ...@@ -220,18 +221,17 @@ struct OutputTransform6X3_NCHW88 {
float* output, float* transform_mid_buf, float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile, size_t oc_index, size_t unit_idx,
const DType& src_dtype, const DType& dst_dtype) { size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf); MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype); Op op(src_dtype, dst_dtype);
//! AT * m * A //! AT * m * A
size_t OCB = (oc_end - oc_start) / 8; size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { size_t oc = oc_start + oc_index;
size_t ocb = (oc - oc_start) / 8; size_t ocb = oc_index / 8;
#define cb(m, n) \ #define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \ auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \ output_transform_buf + \
...@@ -253,8 +253,7 @@ struct OutputTransform6X3_NCHW88 { ...@@ -253,8 +253,7 @@ struct OutputTransform6X3_NCHW88 {
* 0 0.0 0 0 0 1 * 0 0.0 0 0 0 1
*/ */
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
v5subv6;
#define cb(m) \ #define cb(m) \
v1addv2 = v1##m + v2##m; \ v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \ v1subv2 = v1##m - v2##m; \
...@@ -318,7 +317,6 @@ struct OutputTransform6X3_NCHW88 { ...@@ -318,7 +317,6 @@ struct OutputTransform6X3_NCHW88 {
} while (0); } while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save); UNROLL_CALL_RAW_D2(6, 6, out_save);
} }
}
}; };
#undef CONCAT #undef CONCAT
} // namespace } // namespace
...@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, ...@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert(IC % 8 == 0); megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1 // OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf; float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha; float* patchT = transform_mid_buf + 8 * alpha * alpha;
...@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, ...@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
} }
} }
void winograd_nchw88_6x3_8x8_f::output( void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf,
const float* output_transform_buf, const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, float* transform_mid_buf, BiasMode bmode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, NonlineMode nonline_mode, size_t OH,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \ #define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__); __VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS( DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2,
float, bmode, nonline_mode, output_transform_buf, bias, output, float, float, bmode, nonline_mode, output_transform_buf,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
src_dtype, dst_dtype);
}
}
#undef cb #undef cb
} }
} // namespace winograd } // namespace winograd
} // namespace arm_common } // namespace x86
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册