提交 89bb65fd 编写于 作者: M Megvii Engine Team 提交者: wanchenxi

fix(dnn): fix softplus bwd kernel

GitOrigin-RevId: 1f01ab5592f29ead271d02f7de15cc1c8a65df44
上级 73d03189
......@@ -85,6 +85,15 @@ __device__ __host__ inline float gelu_grad(float x, float dy) {
return dy * (normcdf_v + x * phi);
}
//! grad of softplus
__device__ __host__ inline float softplus_grad(float x, float dy) {
float logg = -dy * expf(-fabs(x)) / (1.f + expf(-fabs(x)));
float grad0 = x > 0.f ? logg : -logg;
float relux = x < 0.f ? 0.f : x;
float grad1 = relux > 0.f ? dy : 0.f;
return grad0 + grad1;
}
__device__ __host__ inline bool feq(float a, float b) {
return fabsf(a - b) < 1e-6;
}
......@@ -287,7 +296,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f));
DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f));
DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x));
DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x)));
DEF_KERN_FLOAT(SOFTPLUS_GRAD, softplus_grad(x, y));
DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y));
DEF_KERN_FLOAT(
HSIGMOID_GRAD,
......
......@@ -397,7 +397,7 @@ def origin_softplus(inp: mge.tensor) -> mge.tensor:
def test_subgraph_elemwise_mode():
def _test_allclose(func, ori_func):
targets = np.array(2)
inp = np.random.randn(2, 256, 10, 16).astype("float32")
inp = np.random.uniform(size=(2, 16, 10, 16)).astype(np.float32)
ori_inp = mge.tensor(inp)
mge_inp = mge.tensor(inp)
......
......@@ -559,21 +559,12 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
}
case Mode::RELU6:
RET(EL2(RELU6_GRAD, i0, og));
case Mode::SOFTPLUS: {
auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0)));
auto logg = og * abse / (1 + abse);
auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg));
RET(EL2(ADD, absg, EL2(SWITCH_GT0, EL1(RELU, i0), og)));
}
case Mode::SOFTPLUS:
RET(EL2(SOFTPLUS_GRAD, i0, og));
case Mode::HSIGMOID:
RET(EL2(HSIGMOID_GRAD, i0, og));
case Mode::LOGSIGMOID: {
og = EL1(NEGATE, og);
auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0)));
auto logg = og * abse / (1 + abse);
auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg));
RET(EL2(SUB, absg, EL2(SWITCH_GT0, EL1(RELU, EL1(NEGATE, i0)), og)));
}
case Mode::LOGSIGMOID:
RET(EL2(SOFTPLUS_GRAD, -i0, og));
case Mode::SQRT:
RET(og / EL1(SQRT, i0) / 2);
case Mode::SQUARE:
......
......@@ -77,6 +77,14 @@ float do_fuse_add_h_swish(float x, float y) {
return z * fmaxf(fminf(z + 3.f, 6.f), 0.f) / 6.f;
}
float do_softplus_grad(float x, float y) {
float logg = -y * expf(-fabs(x)) / (1.f + expf(-fabs(x)));
float grad0 = x > 0.f ? logg : -logg;
float relux = x < 0.f ? 0.f : x;
float grad1 = relux > 0.f ? y : 0.f;
return grad0 + grad1;
}
template <typename T>
T do_shl(T, T); // undefined
template <typename T>
......
......@@ -61,7 +61,7 @@ DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
DEF_TRAIT(ASINH_GRAD, y / std::sqrt(x * x + 1))
DEF_TRAIT(ACOSH_GRAD, y / std::sqrt(x * x - 1))
DEF_TRAIT(ATANH_GRAD, y / (1 - x * x))
DEF_TRAIT(SOFTPLUS_GRAD, y* std::exp(x) / (1.f + std::exp(x)))
DEF_TRAIT(SOFTPLUS_GRAD, do_softplus_grad(x, y))
DEF_TRAIT(RELU6_GRAD, x <= 0.f ? 0.f : (x >= 6.f ? 0.f : y))
DEF_TRAIT(HSIGMOID_GRAD, x <= -3.f ? 0.f : (x >= 3.f ? 0.f : (y / 6.f)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册