From cded8ef1a6a363f8f2da47053e946a1fa521fd6b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 16:23:36 +0800 Subject: [PATCH] feat(imperative): add rng opdef GitOrigin-RevId: b62dcea7f5e7d9cacb5697dacf3f05eea23190cf --- .../python/megengine/random/distribution.py | 16 +- imperative/python/src/ops.cpp | 19 + imperative/python/test/unit/test_rng.py | 76 ++++ imperative/src/impl/ops/rng.cpp | 350 ++++++++++++++++++ .../src/include/megbrain/imperative/ops/rng.h | 95 +++++ imperative/src/test/rng.cpp | 45 +++ 6 files changed, 593 insertions(+), 8 deletions(-) create mode 100644 imperative/python/test/unit/test_rng.py create mode 100644 imperative/src/impl/ops/rng.cpp create mode 100644 imperative/src/include/megbrain/imperative/ops/rng.h create mode 100644 imperative/src/test/rng.cpp diff --git a/imperative/python/megengine/random/distribution.py b/imperative/python/megengine/random/distribution.py index 268f69462..199778ebb 100644 --- a/imperative/python/megengine/random/distribution.py +++ b/imperative/python/megengine/random/distribution.py @@ -50,11 +50,11 @@ def normal( """ if size is None: size = (1,) - seed = _random_seed_generator().__next__() - op = GaussianRNG(seed=seed, mean=mean, std=std) + op = GaussianRNG(mean, std) _ref = Tensor([], dtype="int32") - size = utils.astensor1d(size, _ref, dtype="int32") - (output,) = apply(op, size) + shape = utils.astensor1d(size, _ref, dtype="int32") + shape = Tensor(shape, dtype="int32") + (output,) = apply(op, shape) return output @@ -92,10 +92,10 @@ def uniform( if size is None: size = (1,) - seed = _random_seed_generator().__next__() - op = UniformRNG(seed=seed) + op = UniformRNG() _ref = Tensor([], dtype="int32") - size = utils.astensor1d(size, _ref, dtype="int32") - (output,) = apply(op, size) + shape = utils.astensor1d(size, _ref, dtype="int32") + shape = Tensor(shape, dtype="int32") + (output,) = apply(op, shape) return low + (high - low) * output diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 1be1253d0..5dc3be143 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -16,6 +16,7 @@ #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/rng.h" #include #include @@ -489,4 +490,22 @@ void init_ops(py::module m) { _init_py_backward_graph(m); _init_py_op_base(m); INIT_ALL_OP(m) + + m.def("new_rng_handle", &RNGMixin::new_handle); + // FIXME: RNG op might execute after handle released due to async dispatch, + // which would cause memory leak or use-after-free + m.def("delete_rng_handle", &RNGMixin::delete_handle); + m.def("set_rng_seed", &set_rng_seed); + + py::class_, OpDef>(m, "UniformRNG") + .def(py::init<>()) + .def(py::init()) + .def(py::init()); + + py::class_, OpDef>(m, "GaussianRNG") + .def(py::init<>()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()); } diff --git a/imperative/python/test/unit/test_rng.py b/imperative/python/test/unit/test_rng.py new file mode 100644 index 000000000..d5d8e1643 --- /dev/null +++ b/imperative/python/test/unit/test_rng.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +from megengine import tensor +from megengine.core._imperative_rt import CompNode +from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle +from megengine.core.ops.builtin import GaussianRNG, UniformRNG +from megengine.core.tensor.core import apply + + +def test_gaussian_rng(): + shape = ( + 8, + 9, + 11, + 12, + ) + shape = tensor(shape, dtype="int32") + op = GaussianRNG(1.0, 3.0) + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 + assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu1") + op = GaussianRNG(-1.0, 2.0, cn) + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 + assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 + assert str(output.device) == str(cn) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + op = GaussianRNG(3.0, 1.0, h) + (output,) = apply(op, shape) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 + assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 + assert str(output.device) == str(cn) + + +def test_uniform_rng(): + shape = ( + 8, + 9, + 11, + 12, + ) + shape = tensor(shape, dtype="int32") + op = UniformRNG() + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu1") + op = UniformRNG(cn) + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 + assert str(output.device) == str(cn) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + op = UniformRNG(h) + (output,) = apply(op, shape) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 + assert str(output.device) == str(cn) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp new file mode 100644 index 000000000..21b373786 --- /dev/null +++ b/imperative/src/impl/ops/rng.cpp @@ -0,0 +1,350 @@ +/** + * \file imperative/src/impl/ops/rng.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/rng.h" +#include +#include "megbrain/comp_node_env.h" +#include "megbrain/graph/helper.h" +#include "megbrain/opr/rand.h" +//#include "megbrain/common.h" + +#include "../op_trait.h" + +namespace mgb { +namespace imperative { + +namespace { + +template +class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { +public: + using Handle = THandle; + + template + Handle new_handle(Args&&... args) { + return static_cast(this)->do_new_handle( + std::forward(args)...); + } + + size_t delete_handle(Handle handle) { + size_t removed = 0; + if (!is_finalized()) { + MGB_LOCK_GUARD(m_mtx); + removed = m_handle2op.erase(handle); + } + static_cast(this)->do_delete_handle(handle); + return removed; + } + + template + auto get_dnn_op(Handle handle, CompNode cn) { + mgb_assert(!is_finalized()); + DnnOpWithMutex* dnn_op_with_mtx; + { + MGB_LOCK_GUARD(m_mtx); + dnn_op_with_mtx = &m_handle2op[handle]; + } + auto dnn_handle = + MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); + DnnOp* dnn_op; + std::unique_lock lock(dnn_op_with_mtx->mtx); + bool initialized = false; + if ((dnn_op = dynamic_cast(dnn_op_with_mtx->op.get())) != + nullptr) { + mgb_assert(dnn_op->handle() == dnn_handle); + initialized = true; + } else { + auto new_op = dnn_handle->create_operator(); + dnn_op = new_op.get(); + dnn_op_with_mtx->op = std::move(new_op); + } + return std::make_tuple(initialized, dnn_op, std::move(lock)); + } + +protected: + using DnnOpManagerBase = DnnOpManagerT; + DnnOpManagerT() = default; + +private: + struct DnnOpWithMutex { + std::mutex mtx; + std::unique_ptr op; + }; + + std::shared_ptr on_comp_node_finalize() override { + MGB_LOCK_GUARD(m_mtx); + m_handle2op.clear(); + return {}; + } + + std::unordered_map m_handle2op; + std::mutex m_mtx; +}; + +class RNGDnnOpManager final + : public DnnOpManagerT { +public: + size_t delete_handle(Handle handle) { + size_t ret = 0; + { + MGB_LOCK_GUARD(sm_mtx); + auto iter = sm_partial2full.find(handle); + if (iter != sm_partial2full.end()) { + for (auto&& h : iter->second) { + ret += DnnOpManagerBase::delete_handle(h.second); + } + sm_partial2full.erase(iter); + } + } + ret += DnnOpManagerBase::delete_handle(handle); + return ret; + } + + Handle do_new_handle(CompNode comp_node, uint64_t seed) { + auto handle = m_handle_pool.alloc(comp_node, seed); + return reinterpret_cast(handle); + } + + void do_delete_handle(Handle handle) { + m_handle_pool.free(reinterpret_cast(handle)); + } + + static uint64_t get_seed(Handle handle) { + return reinterpret_cast(handle)->seed; + } + + static CompNode get_comp_node(Handle handle) { + return reinterpret_cast(handle)->comp_node; + } + + static Handle get_full_handle(Handle handle, CompNode comp_node) { + if (get_comp_node(handle).valid()) { + return handle; + } + MGB_LOCK_GUARD(sm_mtx); + auto&& full = sm_partial2full[handle][comp_node]; + if (!full) { + full = inst().new_handle(comp_node, get_seed(handle)); + } + return full; + } + + static Handle get_default_handle(CompNode comp_node) { + static Handle glob_partial_handle = + inst().new_handle(CompNode{}, glob_default_seed); + if (!comp_node.valid()) { + return glob_partial_handle; + } + return get_full_handle(glob_partial_handle, comp_node); + } + + static RNGDnnOpManager& inst() { + static RNGDnnOpManager mgr; + return mgr; + } + + static void set_glob_default_seed(uint64_t seed) { + glob_default_seed = seed; + } + +private: + struct HandleData { + CompNode comp_node; + uint64_t seed; + HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {} + }; + + MemPool m_handle_pool; + + static std::mutex sm_mtx; + static std::unordered_map> + sm_partial2full; + static uint64_t glob_default_seed; +}; + +uint64_t RNGDnnOpManager::glob_default_seed = 0; +std::mutex RNGDnnOpManager::sm_mtx; +std::unordered_map> + RNGDnnOpManager::sm_partial2full; + +template +struct OpMeth; + +template <> +struct OpMeth { + using DnnOp = megdnn::UniformRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::UniformRNG; + static Param make_param(const UniformRNG& rng) { + return {RNGDnnOpManager::get_seed(rng.handle())}; + } +}; + +template <> +struct OpMeth { + using DnnOp = megdnn::GaussianRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::GaussianRNG; + static Param make_param(const GaussianRNG& rng) { + return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; + } +}; + +template +void exec(const OpDef& op, const SmallVector& inputs, + const SmallVector& outputs) { + auto&& rng = op.cast_final_safe(); + auto dest = outputs[0]; + + auto cn = dest->comp_node(); + auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); + { + auto handle_cn = RNGDnnOpManager::get_comp_node(handle); + mgb_assert(cn == handle_cn, + "inconsistent comp_node: handle: %s, output: %s", + cn.to_string().c_str(), handle_cn.to_string().c_str()); + } + + // retrieve dnn_op from glob cache + auto dnn_op_thread_safe = RNGDnnOpManager::inst() + .get_dnn_op::DnnOp>(handle, cn); + auto initialized = std::get<0>(dnn_op_thread_safe); + auto dnn_op = std::get<1>(dnn_op_thread_safe); + if (initialized) { + auto handle_seed = RNGDnnOpManager::get_seed(handle); + mgb_assert(dnn_op->param().seed == handle_seed, + "inconsistent rng seed: handle: %zu, dnn_op: %zu", + handle_seed, dnn_op->param().seed); + } + dnn_op->param() = OpMeth::make_param(rng); + + // allocate workspace + size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout()); + auto workspace = Blob::make(cn, wk_size); + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); + + dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk); +} + +template +SmallVector infer_output_attrs( + const OpDef& op, const SmallVector& inputs) { + LogicalTensorDesc dest; + dest.comp_node = op.cast_final_safe().comp_node(); + if (!dest.comp_node.valid()) + dest.comp_node = inputs[0]->comp_node(); + + auto hv = inputs[0]->get_value().proxy_to_default_cpu(); + TensorShape tshape; + cg::copy_tensor_value_to_shape(tshape, hv); + dest.layout = TensorLayout(tshape, dtype::Float32()); + return {dest}; +} + +template +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs) { + auto desc = infer_output_attrs(def, inputs); + SmallVector outputs; + for (auto&& i : desc) { + outputs.push_back(Tensor::make(i.layout, i.comp_node)); + } + exec(def, inputs, outputs); + return outputs; +} + +template +cg::OperatorNodeBase* apply_on_var_node( + const OpDef& def, const VarNodeArray& inputs) { + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", + nr_inp); + auto&& rng = def.cast_final_safe(); + auto param = OpMeth::make_param(rng); + return OpMeth::OpNode::make( + inputs[0], param, {rng.comp_node()}).node()->owner_opr(); +} + +template +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& xxx_rng_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", + xxx_rng_def.dyn_typeinfo()->name, + nr_inp); + + auto&& tshp = inputs[0]; + + TensorLayout out_layout = tshp.layout; + out_layout.dtype = dtype::Float32(); + if (tshp.layout.ndim == 0 || tshp.value.empty()) { + out_layout.ndim = 0; + return {{{out_layout, tshp.comp_node}}, true}; + } + mgb_assert( + tshp.layout.ndim == 1, + "target shape of %s expects ndim=1; got ndim=%lu actually", + xxx_rng_def.dyn_typeinfo()->name, + tshp.layout.ndim); + + size_t target_ndim = tshp.layout.shape[0]; + out_layout.ndim = target_ndim; + auto* ptr = tshp.value.ptr(); + for (size_t i = 0; i < target_ndim; ++i) { + out_layout.shape[i] = ptr[i]; + } + + return {{{out_layout, tshp.comp_node}}, true}; +} + +} // anonymous namespace + +RNGMixin::RNGMixin(CompNode cn): + m_handle(RNGDnnOpManager::get_default_handle(cn)) {} + +uint64_t RNGMixin::seed() const { + return RNGDnnOpManager::get_seed(m_handle); +} + +CompNode RNGMixin::comp_node() const { + return RNGDnnOpManager::get_comp_node(m_handle); +} + +RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { + return RNGDnnOpManager::inst().new_handle(comp_node, seed); +} + +size_t RNGMixin::delete_handle(Handle handle) { + return RNGDnnOpManager::inst().delete_handle(handle); +} + +void set_rng_seed(uint64_t seed) { + RNGDnnOpManager::set_glob_default_seed(seed); +} +#define REG_RNG_OP(NAME)\ +namespace { \ +OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ + .apply_on_var_node(apply_on_var_node) \ + .apply_on_physical_tensor(apply_on_physical_tensor) \ + .infer_output_attrs_fallible(infer_output_attrs_fallible) \ + .fallback(); \ +} \ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); + +REG_RNG_OP(UniformRNG) +REG_RNG_OP(GaussianRNG) + +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/ops/rng.h b/imperative/src/include/megbrain/imperative/ops/rng.h new file mode 100644 index 000000000..eb3ed7b41 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/rng.h @@ -0,0 +1,95 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/rng.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/op_def.h" + +namespace mgb::imperative { + +class RNGMixin { +public: + using Handle = size_t; + + static Handle new_handle( + CompNode comp_node={}, uint64_t seed=0); + + static size_t delete_handle(Handle handle); + + Handle handle() const { + return m_handle; + } + + uint64_t seed() const; + + CompNode comp_node() const; +protected: + RNGMixin(Handle handle): m_handle(handle) {} + RNGMixin(CompNode comp_node); +private: + Handle m_handle; +}; + +class GaussianRNG : public OpDefImplBase, + public RNGMixin { + MGB_DYN_TYPE_OBJ_FINAL_DECL; +public: + float mean = 1.0f, std = 0.0; + GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {} + GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}): + GaussianRNG(comp_node_) { mean = mean_; std = std_; } + GaussianRNG(float mean_, float std_, Handle handle): + RNGMixin(handle), mean(mean_), std(std_) {} + size_t hash() const override { + XXHash xxhash{}; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(dyn_typeinfo()); + append(seed()); + append(mean); + append(std); + return xxhash.digest(); + } + + + bool is_same_st(const Hashable& rhs_) const override { + auto&& rhs = static_cast(rhs_); + return rhs.seed() == seed() + && rhs.mean == mean + && rhs.std == std; + } +}; + +class UniformRNG : public OpDefImplBase, + public RNGMixin { + MGB_DYN_TYPE_OBJ_FINAL_DECL; +public: + UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {} + UniformRNG(Handle handle): RNGMixin(handle) {} + + size_t hash() const override { + return hash_pair_combine( + mgb::hash(seed()), + reinterpret_cast(dyn_typeinfo())); + } + + bool is_same_st(const Hashable& rhs_) const override { + auto&& rhs = static_cast(rhs_); + return rhs.dyn_typeinfo() == dyn_typeinfo() + && rhs.seed() == seed(); + } + +}; + +void set_rng_seed(uint64_t seed); +} // namespace mgb::imperative diff --git a/imperative/src/test/rng.cpp b/imperative/src/test/rng.cpp new file mode 100644 index 000000000..cf3786e54 --- /dev/null +++ b/imperative/src/test/rng.cpp @@ -0,0 +1,45 @@ +/** + * \file imperative/src/test/rng.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./helper.h" +#include "megbrain/imperative/ops/rng.h" + +using namespace mgb; +using namespace imperative; + +template +void check_rng_basic(Args&& ...args) { + for (auto&& tshape: { + TensorShape{2, 3, 4, 5}, + {3, 4, 5, 6}, + {2333}}) + for (auto&& cn: { + CompNode::load("cpu0"), + CompNode::load("xpu0")}) + { + auto op = Op::make(std::forward(args)..., cn); + DeviceTensorND tshape_dev; + cg::copy_shape_to_tensor_value(tshape_dev, tshape); + auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); + ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); + ASSERT_TRUE(cn == outputs[0]->comp_node()); + } +} + +TEST(TestImperative, UniformRNGBasic) { + check_rng_basic(); +} + +TEST(TestImperative, GaussianRNGBasic) { + check_rng_basic(2.f, 3.f); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab