提交 5e8aa333 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor winograd output transpose

GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级 c6eb2e8d
......@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale) {
......@@ -284,7 +284,7 @@ void StrategyHelper<
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
......@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type* mid_buf1 = transform_mid_buf;
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha;
OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start);
OutputVisitor<layout, format> output_visitor(OC);
size_t oc = oc_start + oc_index;
......
......@@ -6,8 +6,7 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
......@@ -44,8 +43,8 @@ public:
input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r,
size_t IC, size_t ic, size_t unit_idx,
size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f);
......@@ -54,7 +53,7 @@ public:
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_index,
size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale
......
......@@ -45,7 +45,6 @@ public:
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack();
for (auto&& algo : matmul_algos) {
if (algo->algoset() ==
//! TODO: threre should filter MK matmul
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
continue;
}
......
......@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
};
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
......
......@@ -186,18 +186,16 @@ struct OutputTransform2X3_NCHW88 {
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype);
//! AT * m * A
size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t ocb = (oc - oc_start) / 8;
size_t oc = oc_start + oc_index;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
......@@ -254,7 +252,6 @@ struct OutputTransform2X3_NCHW88 {
} while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save);
}
}
};
#undef CONCAT
} // namespace
......@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
}
}
void winograd_nchw88_2x3_8x8_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) {
void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float,
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype,
dst_dtype);
}
}
#undef cb
}
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/unroll_macro.h"
......@@ -19,10 +20,10 @@
#include <x86intrin.h>
#ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include "midout.h"
......@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
const float* fptr = filter +
(ocb * ICB + icb) * 3 * 3 * 8 * 8 +
ic_inner * 8;
......@@ -220,18 +221,17 @@ struct OutputTransform6X3_NCHW88 {
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 &&
oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
Op op(src_dtype, dst_dtype);
//! AT * m * A
size_t OCB = (oc_end - oc_start) / 8;
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t ocb = (oc - oc_start) / 8;
size_t oc = oc_start + oc_index;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
......@@ -253,8 +253,7 @@ struct OutputTransform6X3_NCHW88 {
* 0 0.0 0 0 0 1
*/
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6,
v5subv6;
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
......@@ -318,7 +317,6 @@ struct OutputTransform6X3_NCHW88 {
} while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save);
}
}
};
#undef CONCAT
} // namespace
......@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert(IC % 8 == 0);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha;
......@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
}
}
void winograd_nchw88_6x3_8x8_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode,
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) {
void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float,
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2,
float, float, bmode, nonline_mode, output_transform_buf,
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
src_dtype, dst_dtype);
}
}
#undef cb
}
} // namespace winograd
} // namespace arm_common
} // namespace x86
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册