提交 cded8ef1 编写于 作者: M Megvii Engine Team

feat(imperative): add rng opdef

GitOrigin-RevId: b62dcea7f5e7d9cacb5697dacf3f05eea23190cf
上级 40bab1ed
......@@ -50,11 +50,11 @@ def normal(
"""
if size is None:
size = (1,)
seed = _random_seed_generator().__next__()
op = GaussianRNG(seed=seed, mean=mean, std=std)
op = GaussianRNG(mean, std)
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
return output
......@@ -92,10 +92,10 @@ def uniform(
if size is None:
size = (1,)
seed = _random_seed_generator().__next__()
op = UniformRNG(seed=seed)
op = UniformRNG()
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
return low + (high - low) * output
......@@ -16,6 +16,7 @@
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/rng.h"
#include <Python.h>
#include <unordered_map>
......@@ -489,4 +490,22 @@ void init_ops(py::module m) {
_init_py_backward_graph(m);
_init_py_op_base(m);
INIT_ALL_OP(m)
m.def("new_rng_handle", &RNGMixin::new_handle);
// FIXME: RNG op might execute after handle released due to async dispatch,
// which would cause memory leak or use-after-free
m.def("delete_rng_handle", &RNGMixin::delete_handle);
m.def("set_rng_seed", &set_rng_seed);
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<RNGMixin::Handle>());
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<float ,float>())
.def(py::init<float ,float, mgb::CompNode>())
.def(py::init<float ,float, RNGMixin::Handle>());
}
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from megengine import tensor
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle
from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.core.tensor.core import apply
def test_gaussian_rng():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = GaussianRNG(1.0, 3.0)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = GaussianRNG(-1.0, 2.0, cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = GaussianRNG(3.0, 1.0, h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
assert str(output.device) == str(cn)
def test_uniform_rng():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = UniformRNG()
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = UniformRNG(cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = UniformRNG(h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)
/**
* \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
//#include "megbrain/common.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
using Handle = THandle;
template <typename... Args>
Handle new_handle(Args&&... args) {
return static_cast<HandleFactory*>(this)->do_new_handle(
std::forward<Args>(args)...);
}
size_t delete_handle(Handle handle) {
size_t removed = 0;
if (!is_finalized()) {
MGB_LOCK_GUARD(m_mtx);
removed = m_handle2op.erase(handle);
}
static_cast<HandleFactory*>(this)->do_delete_handle(handle);
return removed;
}
template <typename DnnOp>
auto get_dnn_op(Handle handle, CompNode cn) {
mgb_assert(!is_finalized());
DnnOpWithMutex* dnn_op_with_mtx;
{
MGB_LOCK_GUARD(m_mtx);
dnn_op_with_mtx = &m_handle2op[handle];
}
auto dnn_handle =
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
DnnOp* dnn_op;
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
bool initialized = false;
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) !=
nullptr) {
mgb_assert(dnn_op->handle() == dnn_handle);
initialized = true;
} else {
auto new_op = dnn_handle->create_operator<DnnOp>();
dnn_op = new_op.get();
dnn_op_with_mtx->op = std::move(new_op);
}
return std::make_tuple(initialized, dnn_op, std::move(lock));
}
protected:
using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
DnnOpManagerT() = default;
private:
struct DnnOpWithMutex {
std::mutex mtx;
std::unique_ptr<megdnn::OperatorBase> op;
};
std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx);
m_handle2op.clear();
return {};
}
std::unordered_map<Handle, DnnOpWithMutex> m_handle2op;
std::mutex m_mtx;
};
class RNGDnnOpManager final
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> {
public:
size_t delete_handle(Handle handle) {
size_t ret = 0;
{
MGB_LOCK_GUARD(sm_mtx);
auto iter = sm_partial2full.find(handle);
if (iter != sm_partial2full.end()) {
for (auto&& h : iter->second) {
ret += DnnOpManagerBase::delete_handle(h.second);
}
sm_partial2full.erase(iter);
}
}
ret += DnnOpManagerBase::delete_handle(handle);
return ret;
}
Handle do_new_handle(CompNode comp_node, uint64_t seed) {
auto handle = m_handle_pool.alloc(comp_node, seed);
return reinterpret_cast<Handle>(handle);
}
void do_delete_handle(Handle handle) {
m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
}
static uint64_t get_seed(Handle handle) {
return reinterpret_cast<HandleData*>(handle)->seed;
}
static CompNode get_comp_node(Handle handle) {
return reinterpret_cast<HandleData*>(handle)->comp_node;
}
static Handle get_full_handle(Handle handle, CompNode comp_node) {
if (get_comp_node(handle).valid()) {
return handle;
}
MGB_LOCK_GUARD(sm_mtx);
auto&& full = sm_partial2full[handle][comp_node];
if (!full) {
full = inst().new_handle(comp_node, get_seed(handle));
}
return full;
}
static Handle get_default_handle(CompNode comp_node) {
static Handle glob_partial_handle =
inst().new_handle(CompNode{}, glob_default_seed);
if (!comp_node.valid()) {
return glob_partial_handle;
}
return get_full_handle(glob_partial_handle, comp_node);
}
static RNGDnnOpManager& inst() {
static RNGDnnOpManager mgr;
return mgr;
}
static void set_glob_default_seed(uint64_t seed) {
glob_default_seed = seed;
}
private:
struct HandleData {
CompNode comp_node;
uint64_t seed;
HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
};
MemPool<HandleData> m_handle_pool;
static std::mutex sm_mtx;
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>>
sm_partial2full;
static uint64_t glob_default_seed;
};
uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
std::unordered_map<RNGDnnOpManager::Handle,
CompNode::UnorderedMap<RNGDnnOpManager::Handle>>
RNGDnnOpManager::sm_partial2full;
template <typename Op>
struct OpMeth;
template <>
struct OpMeth<UniformRNG> {
using DnnOp = megdnn::UniformRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::UniformRNG;
static Param make_param(const UniformRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle())};
}
};
template <>
struct OpMeth<GaussianRNG> {
using DnnOp = megdnn::GaussianRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::GaussianRNG;
static Param make_param(const GaussianRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std};
}
};
template <typename Op>
void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0];
auto cn = dest->comp_node();
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn);
{
auto handle_cn = RNGDnnOpManager::get_comp_node(handle);
mgb_assert(cn == handle_cn,
"inconsistent comp_node: handle: %s, output: %s",
cn.to_string().c_str(), handle_cn.to_string().c_str());
}
// retrieve dnn_op from glob cache
auto dnn_op_thread_safe = RNGDnnOpManager::inst()
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn);
auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %zu, dnn_op: %zu",
handle_seed, dnn_op->param().seed);
}
dnn_op->param() = OpMeth<Op>::make_param(rng);
// allocate workspace
size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout());
auto workspace = Blob::make(cn, wk_size);
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk);
}
template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
LogicalTensorDesc dest;
dest.comp_node = op.cast_final_safe<Op>().comp_node();
if (!dest.comp_node.valid())
dest.comp_node = inputs[0]->comp_node();
auto hv = inputs[0]->get_value().proxy_to_default_cpu();
TensorShape tshape;
cg::copy_tensor_value_to_shape(tshape, hv);
dest.layout = TensorLayout(tshape, dtype::Float32());
return {dest};
}
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto desc = infer_output_attrs<Op>(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec<Op>(def, inputs, outputs);
return outputs;
}
template<typename Op>
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually",
nr_inp);
auto&& rng = def.cast_final_safe<Op>();
auto param = OpMeth<Op>::make_param(rng);
return OpMeth<Op>::OpNode::make(
inputs[0], param, {rng.comp_node()}).node()->owner_opr();
}
template<typename T>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& xxx_rng_def = def.cast_final_safe<T>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
xxx_rng_def.dyn_typeinfo()->name,
nr_inp);
auto&& tshp = inputs[0];
TensorLayout out_layout = tshp.layout;
out_layout.dtype = dtype::Float32();
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, tshp.comp_node}}, true};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of %s expects ndim=1; got ndim=%lu actually",
xxx_rng_def.dyn_typeinfo()->name,
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
}
return {{{out_layout, tshp.comp_node}}, true};
}
} // anonymous namespace
RNGMixin::RNGMixin(CompNode cn):
m_handle(RNGDnnOpManager::get_default_handle(cn)) {}
uint64_t RNGMixin::seed() const {
return RNGDnnOpManager::get_seed(m_handle);
}
CompNode RNGMixin::comp_node() const {
return RNGDnnOpManager::get_comp_node(m_handle);
}
RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}
size_t RNGMixin::delete_handle(Handle handle) {
return RNGDnnOpManager::inst().delete_handle(handle);
}
void set_rng_seed(uint64_t seed) {
RNGDnnOpManager::set_glob_default_seed(seed);
}
#define REG_RNG_OP(NAME)\
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);
REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG)
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
class RNGMixin {
public:
using Handle = size_t;
static Handle new_handle(
CompNode comp_node={}, uint64_t seed=0);
static size_t delete_handle(Handle handle);
Handle handle() const {
return m_handle;
}
uint64_t seed() const;
CompNode comp_node() const;
protected:
RNGMixin(Handle handle): m_handle(handle) {}
RNGMixin(CompNode comp_node);
private:
Handle m_handle;
};
class GaussianRNG : public OpDefImplBase<GaussianRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float mean = 1.0f, std = 0.0;
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {}
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}):
GaussianRNG(comp_node_) { mean = mean_; std = std_; }
GaussianRNG(float mean_, float std_, Handle handle):
RNGMixin(handle), mean(mean_), std(std_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(dyn_typeinfo());
append(seed());
append(mean);
append(std);
return xxhash.digest();
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const GaussianRNG&>(rhs_);
return rhs.seed() == seed()
&& rhs.mean == mean
&& rhs.std == std;
}
};
class UniformRNG : public OpDefImplBase<UniformRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {}
UniformRNG(Handle handle): RNGMixin(handle) {}
size_t hash() const override {
return hash_pair_combine(
mgb::hash(seed()),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const UniformRNG&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.seed() == seed();
}
};
void set_rng_seed(uint64_t seed);
} // namespace mgb::imperative
/**
* \file imperative/src/test/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./helper.h"
#include "megbrain/imperative/ops/rng.h"
using namespace mgb;
using namespace imperative;
template<typename Op, typename ...Args>
void check_rng_basic(Args&& ...args) {
for (auto&& tshape: {
TensorShape{2, 3, 4, 5},
{3, 4, 5, 6},
{2333}})
for (auto&& cn: {
CompNode::load("cpu0"),
CompNode::load("xpu0")})
{
auto op = Op::make(std::forward<Args>(args)..., cn);
DeviceTensorND tshape_dev;
cg::copy_shape_to_tensor_value(tshape_dev, tshape);
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)});
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape));
ASSERT_TRUE(cn == outputs[0]->comp_node());
}
}
TEST(TestImperative, UniformRNGBasic) {
check_rng_basic<UniformRNG>();
}
TEST(TestImperative, GaussianRNGBasic) {
check_rng_basic<GaussianRNG>(2.f, 3.f);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册