提交 67575d58 编写于 作者: M Megvii Engine Team

feat(mge/opr): add interpolate bilinear mode

GitOrigin-RevId: f7023a3fd381f36e64702893576702d59c5be2c6
上级 0558b212
......@@ -197,7 +197,7 @@ public:
protected:
//! get origin coord
std::pair<float, int> get_origin_coord(float scale, int size, int idx);
std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false);
//! get nearest index in src
int get_nearest_src(float scale, int size, int idx);
......
......@@ -11,6 +11,7 @@
*/
#include "megdnn/handle.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
......@@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
if (param().format == Param::Format::NCHW) {
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
auto imode = param().imode;
megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR ||
imode == param::Resize::InterpolationMode::NEAREST);
using IMode = param::Resize::InterpolationMode;
megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST ||
imode == IMode::INTER_CUBIC);
} else if (param().format == Param::Format::NHWC) {
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
} else if (param().format == Param::Format::NCHW4) {
......@@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff,
}
std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
int idx) {
int idx, bool cubic) {
//! copy from resize_cv.cpp
float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx;
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
if (!cubic) {
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}
}
return {alpha, origin_idx};
}
......
/**
* \file dnn/src/common/resize.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/arch.h"
#if MEGDNN_CC_HOST && !defined(__host__)
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif
namespace megdnn {
namespace resize {
MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic(
float x, float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
} // namespace resize
} // namespace megdnn
/* vim: set ft=cpp: */
......@@ -71,7 +71,10 @@ struct RoundingConverter<uint8_t> {
__host__ __device__ __forceinline__ uint8_t operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
using std::max;
using std::min;
#endif
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places
return static_cast<uint8_t>(round(x));
}
};
......
......@@ -11,6 +11,7 @@
#pragma once
#include "src/common/cv/enums.h"
#include "src/common/resize.cuh"
#include "megdnn/basic_types.h"
......@@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) {
coeffs[1] = x;
}
__host__ __device__ inline void interpolate_cubic_coefs(float x,
float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
__device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) {
const float s45 = 0.70710678118654752440084436210485;
const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45},
......@@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs<INTER_LINEAR>(float x, float* coeffs) {
}
template <>
__device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) {
interpolate_cubic_coefs(x, coeffs);
megdnn::resize::interpolate_cubic(x, coeffs);
}
template <>
__device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x,
......
......@@ -12,6 +12,10 @@
#include "src/cuda/resize/common.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/cv/kernel_common.cuh"
using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
namespace megdnn {
namespace cuda {
......@@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
}
}
}
__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N,
int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
ih0--;
iw0--;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
for (int c = 0; c < C; ++c) {
constexpr int ksize = 4;
for (int kh = 0; kh < ksize; kh++) {
int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1);
atomicAdd(dst + ih * IW + iw,
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]);
}
}
hidden += OH * OW;
dst += IH * IW;
}
}
}
void backward_data_proxy(InterpolationMode imode, const float* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream) {
......@@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff,
stream));
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
if(imode == InterpolationMode::INTER_LINEAR) {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
}
else if (imode == InterpolationMode::INTER_NEAREST) {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
switch (imode) {
case InterpolationMode::INTER_LINEAR: {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_NEAREST: {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_CUBIC: {
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
default: {
megdnn_throw("unsupported interpolation mode");
break;
}
}
}
after_kernel_launch();
......
......@@ -15,16 +15,19 @@ namespace cuda {
namespace resize {
__device__ inline void get_origin_coord(float scale, int size, int idx,
float& alpha, int& origin_idx) {
float& alpha, int& origin_idx,
bool cubic = false) {
alpha = (idx + 0.5f) / scale - 0.5f;
origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx;
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
if (!cubic) {
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}
}
}
......
......@@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
C, IH, IW, OH, OW, stream);
return;
}
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR ||
param().imode == Param::InterpolationMode::NEAREST,
"unsupported interpolation mode for NCHW format");
megdnn_assert(
param().imode == Param::InterpolationMode::LINEAR ||
param().imode == Param::InterpolationMode::NEAREST ||
param().imode == Param::InterpolationMode::INTER_CUBIC,
"unsupported interpolation mode for NCHW format");
if (src.layout.dtype == dtype::Float32{}) {
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
......
......@@ -8,15 +8,20 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"
#include "src/common/rounding_converter.cuh"
#include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/cv/kernel_common.cuh"
#include "src/common/resize.cuh"
using namespace megdnn;
using namespace cuda;
using namespace resize;
using namespace megdnn::cuda::resize;
using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
namespace {
......@@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
dst[oh * OW + ow] = output_converter(
sptr[ih * S_IH + iw * S_IW]);
dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]);
sptr += S_IC;
dst += OH * OW;
......@@ -90,6 +94,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, int S_IN,
int S_IC, int S_IH, int S_IW, float scale_h,
float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
dst += blockIdx.z * C * OH * OW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
ih0--;
iw0--;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
for (int c = 0; c < C; ++c) {
float ret = 0;
constexpr int ksize = 4;
for (int kh = 0; kh < ksize; kh++) {
int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1);
ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] *
w_coeff[kw];
}
}
dst[oh * OW + ow] = output_converter(ret);
sptr += S_IC;
dst += OH * OW;
}
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h,
......@@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
OW, scale_h, scale_w);
} else {
if (imode == InterpolationMode::INTER_LINEAR) {
kern_general_linear<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
} else if (imode == InterpolationMode::INTER_NEAREST) {
kern_general_nearest<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
switch (imode) {
case InterpolationMode::INTER_LINEAR:
kern_general_linear<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scale_h, scale_w);
break;
case InterpolationMode::INTER_NEAREST:
kern_general_nearest<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scale_h, scale_w);
break;
case InterpolationMode::INTER_CUBIC:
kern_general_cubic<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scale_h, scale_w);
break;
default:
megdnn_throw("unsupported interpolation mode");
break;
}
}
N -= curr_batch_size;
......@@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h,
float scale_w) {
int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C,
#pragma unroll
for (int c1 = 0; c1 < 4; ++c1) {
dst[o_coor + c1] = output_converter(
sptr[i_coor00 + c1] * (1.0f - alphaw) * (1.0f - alphah) +
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) +
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah +
sptr[i_coor11 + c1] * alphaw * alphah);
sptr[i_coor00 + c1] * (1.0f - alphaw) *
(1.0f - alphah) +
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) +
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah +
sptr[i_coor11 + c1] * alphaw * alphah);
}
dst += OH * OW * 4;
sptr += IH * IW * 4;
......@@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
after_kernel_launch();
}
#define INST(ctype) \
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \
int, int, int, int, int, int, int, \
cudaStream_t);
#define INST(ctype) \
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \
int, int, int, int, int, int, int, int, int, \
int, cudaStream_t);
INST(float)
INST(uint8_t)
INST(int8_t)
#undef INST
#define INST(ctype) \
#define INST(ctype) \
template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \
int, int, int, cudaStream_t)
int, int, int, cudaStream_t)
INST(int8_t);
#undef INST
......
......@@ -59,12 +59,14 @@
* ---------------------------------------------------------------------------
*/
#include "src/cuda/cv/kernel_common.cuh"
#include "src/common/resize.cuh"
#include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.cuh"
using namespace megdnn;
using namespace cuda;
using namespace megcv;
using megdnn::resize::interpolate_cubic;
namespace {
......@@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale,
fr -= sr[tid];
float coef[4];
interpolate_cubic_coefs(fr, coef);
interpolate_cubic(fr, coef);
#pragma unroll
for (int j = 0, index = 0; j < 4; j++, index += size) {
dst[tid + index] = coef[j];
......@@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) {
fr -= sr[tid];
float coef[4];
interpolate_cubic_coefs(fr, coef);
interpolate_cubic(fr, coef);
#pragma unroll
for (int j = 0, index = 0; j < 4; j++, index += size) {
dst[tid + index] = (short)(coef[j] * ONE);
......@@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector(
int sc = floor(fc);
fc -= sc;
float coef_col[4];
interpolate_cubic_coefs(fc, coef_col);
interpolate_cubic(fc, coef_col);
for (int i = 0; i < ELEMENTS_PER_THREADS; i++) {
if (dr >= dst_rows)
......@@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector(
int sr = floor(fr);
fr -= sr;
float coef_row[4];
interpolate_cubic_coefs(fr, coef_row);
interpolate_cubic(fr, coef_row);
float dst_data[CH] = {0};
#pragma unroll
for (int offset_r = 0; offset_r < 4; ++offset_r) {
......@@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector(
short icoef_col[4] = {0};
float coef_col[4];
interpolate_cubic_coefs(fc, coef_col);
interpolate_cubic(fc, coef_col);
#pragma unroll
for (int i = 0; i < 4; i++) {
icoef_col[i] = (short)(coef_col[i] * ONE);
......@@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector(
fr -= sr;
short icoef_row[4];
float coef_row[4];
interpolate_cubic_coefs(fr, coef_row);
interpolate_cubic(fr, coef_row);
#pragma unroll
for (int i = 0; i < 4; i++) {
icoef_row[i] = (short)(coef_row[i] * ONE);
......
......@@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW4 ||
(param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST)) {
param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) {
naive::ResizeImpl::exec(src, dst, workspace);
return;
}
......
......@@ -9,18 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/resize/opr_impl.h"
#include "midout.h"
#include "src/common/cv/enums.h"
#include "src/common/resize.cuh"
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/naive/handle.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/resize_cv.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_resize_layout)
MIDOUT_DECL(megdnn_naive_resize_layout_nearest)
MIDOUT_DECL(megdnn_naive_resize_nchw)
using namespace megdnn;
using namespace naive;
using namespace resize;
template <typename ctype>
ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
......@@ -90,20 +93,84 @@ INST(dt_quint8);
#undef INST
template <typename ctype>
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) {
void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param,
InterpolationMode imode) {
megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rounding::RoundingConverter<ctype> output_converter;
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
switch (imode) {
case InterpolationMode::NEAREST: {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] =
sptr[c * S_IC + ih * S_IH + iw * S_IW];
}
break;
}
case InterpolationMode::INTER_LINEAR: {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] *
alphaw * (1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] *
alphaw * alphah);
}
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
float alphah = coord_h.first;
float alphaw = coord_w.first;
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW];
int ih0 = coord_h.second - 1;
int iw0 = coord_w.second - 1;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
rep(c, static_cast<int>(C)) {
constexpr int ksize = 4;
float ret = 0;
rep(kh, ksize) {
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
ret += sptr[c * S_IC + h * S_IH + w * S_IW] *
h_coeff[kh] * w_coeff[kw];
}
}
dptr[c * OH * OW + oh * OW + ow] =
output_converter(ret);
}
break;
}
default:
megdnn_throw("unsupported mode in ResizeBackwardImpl");
break;
}
}
sptr += S_IN;
......@@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
MIDOUT_END();
return;
}
megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
rounding::RoundingConverter<ctype> output_converter;
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * alphaw *
(1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * alphaw *
alphah);
}
}
sptr += S_IN;
dptr += C * OH * OW;
}
}
template <typename ctype>
......@@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST) {
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \
midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \
} \
MIDOUT_END(); \
return; \
if (param().format == param::Resize::Format::NCHW) {
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \
} \
MIDOUT_END(); \
return; \
}
switch (src.layout.dtype.enumv()) {
......@@ -319,12 +350,10 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
return;
}
#undef cb
#undef cb
}
if ((param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
if (((src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) ||
(param().imode == param::Resize::InterpolationMode::LINEAR)) {
#define cb(dt, ct, _midout_iv) \
......@@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
if(param().imode == InterpolationMode::INTER_LINEAR) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
(1.0f - alphaw) * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] +=
(1.0f - alphaw) * alphah * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] +=
alphaw * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
switch (param().imode) {
case InterpolationMode::INTER_LINEAR: {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
(1.0f - alphaw) * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] +=
(1.0f - alphaw) * alphah * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] +=
alphaw * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
}
break;
}
} else if (param().imode == InterpolationMode::NEAREST) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow];
case InterpolationMode::NEAREST: {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] +=
hptr[c * OH * OW + oh * OW + ow];
}
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second - 1;
int iw0 = coord_w.second - 1;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
rep(c, static_cast<int>(C)) {
constexpr int ksize = 4;
rep(kh, ksize) {
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
sptr[c * IH * IW + h * IW + w] +=
hptr[c * OH * OW + oh * OW + ow] *
h_coeff[kh] * w_coeff[kw];
}
}
}
break;
}
default: {
megdnn_throw("unsupported mode in ResizeBackwardImpl");
break;
}
}
else megdnn_throw("unsupported mode in ResizeBackwardImpl");
}
sptr += C * IH * IW;
hptr += C * OH * OW;
......
......@@ -47,7 +47,7 @@ private:
void kern_naive(const KernParam<ctype>& kern_param);
template <typename ctype>
void kern_nchw_nearest(const KernParam<ctype>& kern_param);
void kern_nchw(const KernParam<ctype>& kern_param, InterpolationMode imode);
template <typename ctype>
void kern_naive_nhwc(const KernParam<ctype>& kern_param);
......
......@@ -68,6 +68,7 @@
#include "src/common/cv/helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/common/resize.cuh"
MIDOUT_DECL(megdnn_naive_resizecv_imode)
MIDOUT_DECL(megdnn_naive_resizecv_dtype)
......@@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype)
using namespace megdnn;
using namespace naive;
using namespace megcv;
using namespace megdnn::resize;
namespace {
......@@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat<T>& src, Mat<T>& dst,
const DecimateAlpha* ytab, int ytab_size,
const int* yofs);
static inline void interpolate_cubic(float x, float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
static inline void interpolate_lanczos4(float x, float* coeffs) {
static const double s45 = 0.70710678118654752440084436210485;
static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45},
......
......@@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) {
TEST_F(CUDA, RESIZE_FORWARD) {
using namespace resize;
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) {
std::vector<TestArg> args = get_args(imode);
Checker<Resize> checker(handle_cuda());
......@@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) {
}
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) {
param::Resize param;
param.format = param::Resize::Format::NCHW;
......@@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
}
TEST_F(CUDA, RESIZE_BACKWARD) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda());
param::Resize param;
......
......@@ -574,19 +574,25 @@ def interpolate(
raise ValueError("under linear mode, size can only be single value")
dsize = size
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]:
if not align_corners:
# fastpath for interpolate
op = builtin.Resize(
imode="linear" if mode == "bilinear" else "nearest", format="NCHW"
)
mode_map = {
"linear": "linear",
"bilinear": "linear",
"nearest": "nearest",
"bicubic": "cubic",
}
op = builtin.Resize(imode=mode_map[mode], format="NCHW")
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
oh, ow = dsize[0], dsize[1]
ih, iw = inp.shape[2], inp.shape[3]
if align_corners:
(ret,) = apply(op, inp, shape)
else:
assert mode in [
"linear",
"bilinear",
], "align_corners only support linear or bilinear mode"
oh, ow = dsize[0], dsize[1]
ih, iw = inp.shape[2], inp.shape[3]
hscale = (ih - 1.0) / (oh - 1.0)
wscale = 1.0 * iw / ow
if mode != "linear":
......@@ -607,34 +613,11 @@ def interpolate(
axis=0,
).reshape(1, 3, 3)
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
else:
hscale = 1.0 * ih / oh
wscale = 1.0 * iw / ow
row0 = concat(
[wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5],
axis=0,
).reshape(1, 3)
row1 = concat(
[Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5],
axis=0,
).reshape(1, 3)
weight = concat(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0,
).reshape(1, 3, 3)
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32")
if mode in ["linear", "bilinear"]:
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" mode
assert mode == "bicubic"
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",)
ret = transpose(ret, (0, 3, 1, 2))
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册