提交 5aa52d38 编写于 作者: M Megvii Engine Team

feat(dnn/rocm): add adaptive pooling opr

GitOrigin-RevId: e844b3e7706af483bbcf9d5cd20f4c009a7f7e27
上级 83cf4ee6
/**
* \file dnn/src/rocm/adaptive_pooling/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/rocm/adaptive_pooling/opr_impl.h"
namespace megdnn {
namespace rocm {
void AdaptivePoolingForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
auto opr = handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, workspace);
}
size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto opr = handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src, dst);
return opr->get_workspace_in_bytes(src, dst);
}
void AdaptivePoolingBackwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in dst,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
auto opr = handle()->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, diff, grad, workspace);
}
size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) {
auto opr = handle()->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src, dst);
return opr->get_workspace_in_bytes(src, dst, diff, grad);
}
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/adaptive_pooling/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace rocm {
class AdaptivePoolingForwardImpl final : public AdaptivePoolingForward {
public:
using AdaptivePoolingForward::AdaptivePoolingForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override;
};
class AdaptivePoolingBackwardImpl final : public AdaptivePoolingBackward {
public:
using AdaptivePoolingBackward::AdaptivePoolingBackward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) override;
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -22,6 +22,7 @@
#include "src/rocm/elemwise/opr_impl.h"
#include "src/rocm/eye/opr_impl.h"
#include "src/rocm/pooling/opr_impl.h"
#include "src/rocm/adaptive_pooling/opr_impl.h"
#include "src/rocm/reduce/opr_impl.h"
#include "src/rocm/type_cvt/opr_impl.h"
#include "src/rocm/topk/opr_impl.h"
......@@ -160,6 +161,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK);
......
/**
* \file dnn/test/rocm/adaptive_pooling.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "test/rocm/fixture.h"
#include "megdnn/tensor_iter.h"
#include "test/common/adaptive_pooling.h"
#include "test/common/checker.h"
#include "src/common/utils.h"
#include "test/rocm/utils.h"
#include "test/rocm/benchmarker.h"
namespace megdnn {
namespace test {
TEST_F(ROCM, ADAPTIVE_POOLING_FORWARD) {
auto args = adaptive_pooling::get_args();
using Format = param::AdaptivePooling::Format;
DType dtype = dtype::Float32();
for (auto&& arg : args) {
auto param = arg.param;
auto src = arg.ishape;
auto dst = arg.oshape;
param.format = Format::NCHW;
Checker<AdaptivePooling> checker(handle_rocm());
checker.set_epsilon(1e-2);
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
TensorShapeArray{src, dst, {}});
}
}
TEST_F(ROCM, ADAPTIVE_POOLING_BACKWARD) {
auto args = adaptive_pooling::get_args();
for (auto&& arg : args) {
Checker<AdaptivePoolingBackward> checker(handle_rocm());
TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
TensorLayout olayout = TensorLayout(arg.oshape, dtype::Float32());
auto constraint = [this,
arg](CheckerHelper::TensorValueArray& tensors_orig) {
megdnn_assert(tensors_orig.size() == 4);
auto opr = handle_rocm()->create_operator<AdaptivePoolingForward>();
opr->param() = arg.param;
auto tensors_rocm_storage = CheckerHelper::alloc_tensors(
handle_rocm(),
{tensors_orig[0].layout, tensors_orig[1].layout}, 0);
auto&& tensors_rocm = *tensors_rocm_storage;
auto span = tensors_rocm[0].layout.span();
auto dst = static_cast<dt_byte*>(tensors_rocm[0].raw_ptr) +
span.low_byte;
auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr) +
span.low_byte;
megdnn_memcpy_H2D(handle_rocm(), dst, src, span.dist_byte());
auto workspace_size = opr->get_workspace_in_bytes(
tensors_rocm[0].layout, tensors_rocm[1].layout);
auto workspace_rocm = megdnn_malloc(handle_rocm(), workspace_size);
Workspace workspace{static_cast<dt_byte*>(workspace_rocm),
workspace_size};
opr->exec(tensors_rocm[0], tensors_rocm[1], workspace);
megdnn_free(handle_rocm(), workspace_rocm);
span = tensors_rocm[1].layout.span();
dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr) +
span.low_byte;
src = static_cast<const dt_byte*>(tensors_rocm[1].raw_ptr) +
span.low_byte;
megdnn_memcpy_D2H(handle_rocm(), dst, src, span.dist_byte());
};
DType dtype = dtype::Float32();
checker.set_tensors_constraint(constraint)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_param(arg.param)
.exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册