提交 022dbea8 编写于 作者: M Megvii Engine Team

feat(opr): add masked_fill op

GitOrigin-RevId: 47cd068b9e2220448c2c2735001ce6416f06252f
上级 9a6ba334
......@@ -141,6 +141,7 @@ union OperatorParam {
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
param.Fill = 93,
}
table Operator {
......
......@@ -1392,6 +1392,31 @@ protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};
class MaskedFill : public OperatorBase {
DEF_OPR_PARAM(Fill);
DEF_OPR_IMPL(MaskedFill, OperatorBase, 2, 1);
public:
virtual void exec(
_megdnn_tensor_in origin, _megdnn_tensor_in index,
_megdnn_tensor_out dst) = 0;
void exec(
_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst,
_megdnn_workspace /*workspace*/) {
exec(origin, index, dst);
}
virtual size_t get_workspace_in_bytes(
const TensorLayout& origin, const TensorLayout& index,
const TensorLayout& dest) = 0;
void deduce_layout(
const TensorLayout& origin, const TensorLayout& index, TensorLayout& dest);
protected:
void check_exec(
const TensorLayout& origin, const TensorLayout& index,
const TensorLayout& dest);
};
/*!
* \brief standard padding operator
* Inputs must have the same dtype, and the output tensor shape must greater or equal
......
......@@ -218,7 +218,8 @@ private:
cb(RegionRestrictedConvolutionBackwardData) \
cb(RegionRestrictedConvolutionBackwardFilter) \
cb(GroupNormForward) \
cb(GroupNormBackward)
cb(GroupNormBackward) \
cb(MaskedFill)
// clang-format on
/*!
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void MaskedFill::deduce_layout(
const TensorLayout& origin, const TensorLayout& /*index*/, TensorLayout& dest) {
dest = TensorLayout(origin, origin.dtype);
}
void MaskedFill::check_exec(
const TensorLayout& origin, const TensorLayout& index,
const TensorLayout& dest) {
megdnn_assert_contiguous(index);
megdnn_assert_contiguous(dest);
megdnn_assert(index.dtype == dtype::Bool());
megdnn_assert(origin.ndim >= index.ndim);
bool correct_index_shape = true;
for (size_t i = 0; i < index.ndim; i++) {
correct_index_shape = correct_index_shape && origin.shape[i] == index.shape[i];
}
megdnn_assert(correct_index_shape, "unsupported index shape");
bool supported_dtype = false;
#define cb(Dtype) supported_dtype = supported_dtype || (origin.dtype == Dtype());
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(megdnn::dtype::Bool)
#undef cb
megdnn_assert(supported_dtype, "unsupported dtype");
}
} // namespace megdnn
\ No newline at end of file
......@@ -144,6 +144,7 @@ DEF(RegionRestrictedConvolutionBackwardData, 5, true, false);
DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
DEF(GroupNormForward, 6, true, true);
DEF(GroupNormBackward, 8, true, true);
DEF(MaskedFill, 3, false, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -44,6 +44,7 @@
#include "src/cuda/lrn/opr_impl.h"
#include "src/cuda/lsq/opr_impl.h"
#include "src/cuda/mask_conv/opr_impl.h"
#include "src/cuda/masked_fill/opr_impl.h"
#include "src/cuda/matrix_inverse/opr_impl.h"
#include "src/cuda/matrix_mul/opr_impl.h"
#include "src/cuda/max_tensor_diff/opr_impl.h"
......@@ -178,6 +179,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskPropagate);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskedFill);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardFilter);
......
#include "./kern.cuh"
namespace megdnn {
namespace cuda {
#define cb(_dtype) \
INST_RUN_ELEMWISE( \
MaskedFillScalarKernOp<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 1);
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
} // namespace cuda
} // namespace megdnn
#pragma once
#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"
#if MEGDNN_CC_HOST
#include "megdnn/oprs.h"
#endif
namespace megdnn {
namespace cuda {
template <typename ctype>
struct MaskedFillScalarKernOp {
using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>;
typedef typename VectTypeTrait::vect_type vect_type;
ctype* output;
bool* mask;
ctype value;
uint32_t mask_stride;
__device__ __forceinline__ void operator()(uint32_t idx, ctype orig) {
output[idx] = mask[idx / mask_stride]
? value
: orig; //! mask[idx] * orig + mask[idx]* *value;
}
__device__ __forceinline__ void operator()(uint32_t idx, vect_type orig) {
ctype a = mask[(idx) / mask_stride] ? value : orig.x;
ctype b = mask[(idx + 1) / mask_stride] ? value : orig.y;
ctype g = mask[(idx + 2) / mask_stride] ? value : orig.z;
ctype r = mask[(idx + 3) / mask_stride] ? value : orig.w;
*(vect_type*)(&output[idx]) = VectTypeTrait::make_vector(a, b, g, r);
}
#if MEGDNN_CC_HOST
MaskedFillScalarKernOp(
const TensorND& output, const TensorND& mask, ctype value,
uint32_t mask_stride)
: output{output.ptr<ctype>()},
mask{mask.ptr<bool>()},
value{value},
mask_stride{mask_stride} {}
#endif
};
} // namespace cuda
} // namespace megdnn
#include "./opr_impl.h"
#include "./kern.cuh"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
void MaskedFillImpl::exec(
_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dest) {
check_exec(origin.layout, index.layout, dest.layout);
megdnn_assert(index.layout.is_contiguous());
uint32_t mask_stride = TensorLayout(origin.layout, origin.layout.dtype)
.stride[index.layout.ndim - 1];
ElemwiseOpParamN<1> ele_param;
ele_param[0] = origin;
ele_param.init_from_given_tensor();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (origin.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
auto value = static_cast<T>(param().value); \
run_elemwise<MaskedFillScalarKernOp<T>, T, 1>( \
ele_param, stream, {dest, index, value, mask_stride}); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class MaskedFillImpl : public MaskedFill {
public:
using MaskedFill::MaskedFill;
void exec(_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst)
override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
......@@ -48,6 +48,7 @@
#include "src/naive/lstm/opr_impl.h"
#include "src/naive/lstm_cell/opr_impl.h"
#include "src/naive/mask_conv/opr_impl.h"
#include "src/naive/masked_fill/opr_impl.h"
#include "src/naive/matrix_inverse/opr_impl.h"
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/naive/max_tensor_diff/opr_impl.h"
......
#include "src/naive/masked_fill/opr_impl.h"
#include <cmath>
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise_helper.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace {
using namespace megdnn;
template <typename T>
void forward_impl(const ElemwiseOpParamN<3> src, const T value) {
auto inp = tensor_iter_valonly<T>(src[0]).begin();
auto out = tensor_iter_valonly<T>(src[1]).begin();
auto mask = tensor_iter_valonly<bool>(src[2]).begin();
size_t total = src[0].layout.total_nr_elems();
for (size_t i = 0; i < total; ++i) {
*out = *mask ? value : *inp;
++inp;
++out;
++mask;
}
}
} // namespace
namespace megdnn {
namespace naive {
void MaskedFillImpl::exec(
_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dest) {
check_exec(origin.layout, index.layout, dest.layout);
megdnn_assert(origin.layout.is_contiguous() && index.layout.is_contiguous());
ElemwiseOpParamN<3> src;
src[0] = origin;
src[1] = dest;
src[2] = index;
if (src[2].layout.ndim < src[0].layout.ndim) {
for (size_t n = src[2].layout.ndim; n < src[0].layout.ndim; n++)
src[2].layout.add_axis_cont_inplace(n);
}
src[2].layout = src[2].layout.broadcast(origin.layout);
#define cb(DType) \
if (origin.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
auto value = static_cast<T>(param().value); \
forward_impl<T>(src, value); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
} // namespace naive
} // namespace megdnn
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class MaskedFillImpl : public MaskedFill {
public:
using MaskedFill::MaskedFill;
void exec(_megdnn_tensor_in origin, _megdnn_tensor_in index, _megdnn_tensor_out dst)
override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
#include "test/cuda/fixture.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, MASKEDFILL) {
using Param = MaskedFill::Param;
Param param;
param.value = 1.0;
Checker<MaskedFill> checker(handle_cuda());
checker.set_epsilon(1e-2);
auto run = [&](DType d) {
for (size_t A : {2, 3})
for (size_t B : {6, 9}) {
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, dtype::Bool())
.set_dtype(2, d)
.execs({{A, B, 2, 1}, {A, B}, {A, B, 2, 1}});
}
for (size_t A : {2, 3})
for (size_t B : {6, 9}) {
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, dtype::Bool())
.set_dtype(2, d)
.execs({{A, B, 2, 1}, {A, B, 2, 1}, {A, B, 2, 1}});
}
};
run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
run(dtype::Uint8());
}
} // namespace test
} // namespace megdnn
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, MASKEDFILL) {
Checker<MaskedFill> checker(handle(), true);
MaskedFill::Param param;
param.value = 0.2;
checker.set_param(param).exect(
Testcase{
TensorValue(
{2, 3, 2, 1}, dtype::Float32(),
{3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587,
0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input
TensorValue({2}, dtype::Bool(), {false, true}), // hx
{}},
Testcase{
{},
{},
TensorValue(
{2, 3, 2, 1}, dtype::Float32(),
{3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, 0.2, 0.2,
0.2, 0.2, 0.2, 0.2}), // output
});
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 3, 1, 2}, dtype::Float32(),
{-2.4348, -1.7948, 0.5223, 0.0932, -0.2955,
-0.0492}), // input
TensorValue({1, 3}, dtype::Bool(), {false, true, false}), // hx
{},
},
Testcase{
{},
{},
TensorValue(
{1, 3, 1, 2}, dtype::Float32(),
{-2.4348, -1.7948, 0.2, 0.2, -0.2955, -0.0492}),
});
}
} // namespace test
} // namespace megdnn
......@@ -1290,6 +1290,29 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl);
bool is_val_scalar = false;
float value;
if (PyLong_Check(val.ptr())) {
is_val_scalar = true;
value = static_cast<float>(PyLong_AsDouble(val.ptr()));
}
if (PyFloat_Check(val.ptr())) {
is_val_scalar = true;
value = static_cast<float>(PyFloat_AsDouble(val.ptr()));
}
if (TensorWrapper::try_cast(idx_hdl.ptr()) && is_bool_dtype(idx_hdl.ptr()) &&
is_val_scalar && enable_fastpath(inp_hdl)) {
std::vector<PyObject*> q(3);
std::shared_ptr<OpDef> Op = MaskedFill::make(value);
py::object maskedfill = py::cast(Op);
q[0] = maskedfill.ptr();
q[1] = inp_hdl.ptr();
q[2] = idx_hdl.ptr();
py::tuple result =
py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
py::object res = result[0];
return res;
}
if (!TensorWrapper::try_cast(val.ptr())) {
val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"));
}
......
......@@ -292,4 +292,62 @@ OP_TRAIT_REG(Split, Split, opr::Split)
} // namespace split
namespace masked_fill {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op_def = def.cast_final_safe<MaskedFill>();
OperatorNodeConfig config{op_def.make_name()};
mgb_assert(inputs.size() == 2);
return opr::MaskedFill::make(inputs[0], inputs[1], op_def.param(), config)
.node()
->owner_opr();
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
return {{{{input_descs[0].layout, input_descs[0].layout.dtype},
input_descs[0].comp_node}},
input_descs[0].layout.ndim != 0};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<MaskedFill>();
auto&& inp = inputs[0];
auto&& mask = inputs[1];
TensorLayout outlayout(inp->layout(), inp->layout().dtype);
auto output = Tensor::make(outlayout, inp->comp_node());
DnnOprCaller<megdnn::MaskedFill> dnn_opr{inp->comp_node(), op.param()};
dnn_opr.exec_with_ws(inp, mask, output);
return {output};
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::MaskedFill>();
return MaskedFill::make(node->param());
}
OP_TRAIT_REG(MaskedFill, MaskedFill, mgb::opr::MaskedFill)
.get_input_layout_constraint(get_input_layout_constraint)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.apply_on_var_node(apply_on_var_node)
.make_from_op_node(make_from_op_node)
.fallback();
} // namespace masked_fill
} // namespace mgb::imperative
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
7d6df1c8e50a22ef2c36b7ea89daa9c5 ../../src/core/include/megbrain/ir/ops.td
f30ae9494b4bf3363cd74d9396acaf49 generated/opdef.h.inl
cb27f486b28a099221f38c6fcaa06a44 generated/opdef.cpp.inl
adb758acd1147f213db7f0cb1b708773 generated/opdef.py.inl
30ad8e75a5994edf9ec46387c6285312 generated/opdef.cpy.inl
4bd0317fd84b5065c8d88a7ca6241908 ../../src/core/include/megbrain/ir/ops.td
cb32cb1ef6b2ef4a7defaeb02ecd36e3 generated/opdef.h.inl
1c0230f60ddf3459de2aa4e16c1e2957 generated/opdef.cpp.inl
f6cbfd25f0d61e7b94c687733f5ae9b9 generated/opdef.py.inl
3a023199c39ea5611975b902a882bbba generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -4788,6 +4788,43 @@ OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime)
.props(MagicMindRuntime_props_impl)
.make_name(MagicMindRuntime_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskedFill);
namespace {
size_t MaskedFill_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MaskedFill>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.value));
return val;
}
bool MaskedFill_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<MaskedFill>(),
&&b_ = rhs_.cast_final_safe<MaskedFill>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.value != b_.value) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> MaskedFill_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MaskedFill>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("value", std::to_string(op_.value));
return props_;
}
std::string MaskedFill_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MaskedFill>();
static_cast<void>(op_);
return "MaskedFill";
}
} // anonymous namespace
OP_TRAIT_REG(MaskedFill, MaskedFill)
.hash(MaskedFill_hash_impl)
.is_same_st(MaskedFill_is_same_st_impl)
.props(MaskedFill_props_impl)
.make_name(MaskedFill_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
namespace {
......
......@@ -14037,6 +14037,115 @@ void _init_py_MagicMindRuntime(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MagicMindRuntime::typeinfo(), &py_type).second);
}
PyOpDefBegin(MaskedFill) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(MaskedFill)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"value", serialization<decltype(opdef.value)>::dump(opdef.value)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(MaskedFill)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("value");
if (iter != state.end()) {
opdef.value = serialization<decltype(opdef.value)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(MaskedFill)
int PyOp(MaskedFill)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"value", "scope", NULL};
PyObject *value = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &value, &scope))
return -1;
if (value) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MaskedFill)*>(self)->inst().value =
py::cast<decltype(MaskedFill::value)>(py::handle(value));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(MaskedFill)::py_getsetters[] = {
{const_cast<char*>("value"), py_get_generic(MaskedFill, value), py_set_generic(MaskedFill, value), const_cast<char*>("value"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(MaskedFill)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(MaskedFill)::getstate, METH_NOARGS, "MaskedFill getstate"},
{const_cast<char*>("__setstate__"), PyOp(MaskedFill)::setstate, METH_VARARGS, "MaskedFill setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(MaskedFill)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(MaskedFill)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(MaskedFill)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(MaskedFill)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, value: float = ...) -> None\n"
};
void _init_py_MaskedFill(py::module m) {
using py_op = PyOp(MaskedFill);
auto& py_type = PyOpType(MaskedFill);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.MaskedFill";
py_type.tp_basicsize = sizeof(PyOp(MaskedFill));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "MaskedFill";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(MaskedFill), &PyOp(MaskedFill)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("MaskedFill", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MaskedFill::typeinfo(), &py_type).second);
}
PyOpDefBegin(MatrixInverse) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -22157,6 +22266,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_LayerNorm(m); \
_init_py_Linspace(m); \
_init_py_MagicMindRuntime(m); \
_init_py_MaskedFill(m); \
_init_py_MatrixInverse(m); \
_init_py_MatrixMul(m); \
_init_py_MeshGrid(m); \
......
......@@ -1288,6 +1288,19 @@ public:
MagicMindRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
};
class MaskedFill : public OpDefImplBase<MaskedFill> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float value = 0;
MaskedFill() = default;
MaskedFill(float value_, std::string scope_ = {}): value(value_) { set_scope(scope_); }
MaskedFill(::megdnn::param::Fill packed_param_0): value(packed_param_0.value) {}
::megdnn::param::Fill param() const {
return {value};
}
};
class MatrixInverse : public OpDefImplBase<MatrixInverse> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -1412,6 +1412,12 @@ MagicMindRuntimeInst
.def_readwrite("buf", &MagicMindRuntime::buf)
.def_readwrite("buf_size", &MagicMindRuntime::buf_size);
py::class_<MaskedFill, std::shared_ptr<MaskedFill>, OpDef> MaskedFillInst(m, "MaskedFill");
MaskedFillInst
.def(py::init<float, std::string>(), py::arg("value") = 0, py::arg("scope") = {})
.def_readwrite("value", &MaskedFill::value);
py::class_<MatrixInverse, std::shared_ptr<MatrixInverse>, OpDef> MatrixInverseInst(m, "MatrixInverse");
MatrixInverseInst
......
......@@ -553,5 +553,6 @@ def MeshGrid: MgbHashableOp<"MeshGrid"> {
def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [ConvolutionParam]>;
def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>;
def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>;
#endif // MGB_OPS
......@@ -1631,4 +1631,26 @@ MEGDNN_OPR_INIT2(PaddingBackward, "padding_backward", 1, false);
// f}}}
/* f{{{ ======================= MaskedFill ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskedFill);
MEGDNN_OPR_INIT2(MaskedFill, "masked_fill");
void MaskedFill::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MaskedFill) {
mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) {
SymbolVar grad = MaskedFill::make(out_grad[0], opr.input(1), {.0});
return grad.node();
} else
return InvalidGrad::make(opr, wrt_idx);
}
#endif
// f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -214,6 +214,7 @@ MGB_SEREG_OPR(RelayoutFormatV1, 1);
MGB_SEREG_OPR(Padding, 1);
MGB_SEREG_OPR(PaddingBackward, 2);
MGB_SEREG_OPR(MaskedFill, 2);
} // namespace opr
} // namespace mgb
......
......@@ -630,6 +630,18 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MaskedFill, intl::MegDNNOprWrapperFwd<megdnn::MaskedFill>) // {
public:
MGE_WIN_DECLSPEC_FUC MaskedFill(
VarNode* src, VarNode* index, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar index, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC void init_output_dtype() override;
};
} // namespace opr
} // namespace mgb
......
......@@ -124,6 +124,7 @@ union OperatorParam {
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
param.Fill = 93,
}
table Operator {
......
......@@ -141,6 +141,7 @@ union OperatorParam {
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
param.Fill = 93,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册