提交 a90c937d 编写于 作者: M Megvii Engine Team

feat(interpreter): add command buffer for inplace

GitOrigin-RevId: 020d1e88d4d58a63d55a2dec9bb167edc94ae6eb
上级 09af925f
......@@ -33,7 +33,7 @@ def _run_wrapped(
class launcher:
"""Decorator for launching multiple processes in single-machine multi-gpu training.
:param func: the function you want to launch in distributed mode.
:param n_gpus: how many devices each node.
:param world_size: how many devices totally.
......
......@@ -32,7 +32,7 @@ namespace views = ranges::views;
namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
interpreter::Interpreter::Channel* interpreter_for_py;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing,
*cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
......@@ -673,7 +673,9 @@ py::object make_empty_tensorwrapper() {
}
void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel();
imperative::Tensor::static_initialize();
static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel();
interpreter_for_py = sl_interpreter_for_py.get();
auto* tensor_type = TensorWrapper::wrap_t::type()
.def<&TensorWrapper::numpy>("numpy")
......@@ -724,6 +726,8 @@ void init_tensor(py::module m) {
[](int level) { interpreter_for_py->config_async_level(level); });
m.def("get_async_level",
[]() { return interpreter_for_py->get_async_level(); });
m.def("set_buffer_length",
[](int length) { interpreter_for_py->set_buffer_length(length); });
m.def("sync",
[]() {
interpreter_for_py->sync();
......
......@@ -34,7 +34,7 @@ struct ObjectPtr : B {
namespace mgb::imperative::python {
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
extern interpreter::Interpreter::Channel* interpreter_for_py;
class SharedHandle {
using Handle = interpreter::Interpreter::Handle;
......
......@@ -111,6 +111,11 @@ void BlobManagerImpl::defrag(const CompNode& cn) {
MGB_TRY{cn.free_device(cn.alloc_device(tot_sz));}
MGB_CATCH(MemAllocError&, {})
// sort blobs by created time, may be helpful for reduce memory fragment
std::sort(blob_data_arrary.begin(), blob_data_arrary.end(), [](auto& lhs, auto& rhs){
return lhs.blob->id() < rhs.blob->id();
});
// allocate for each storage
for (auto i : blob_data_arrary) {
DeviceTensorStorage d_storage = DeviceTensorStorage(cn);
......
......@@ -22,10 +22,10 @@ class FunctionHooker;
template <typename TRet, typename... TArgs>
class FunctionHooker<TRet(TArgs...)> {
public:
using FunctionType = thin_function<TRet(TArgs&&...)>;
using FunctionType = thin_function<TRet(TArgs...)>;
//Type of hooks. Hook should accept a real function as argument
//and invoke it on an appropriate time
using HookType = thin_function<TRet(FunctionType, TArgs&&...)>;
using HookType = thin_function<TRet(FunctionType, TArgs...)>;
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {
m_backup = {nullptr, [](FunctionType*){}};
}
......@@ -43,7 +43,7 @@ public:
m_backup = decltype(m_backup)(backup, restorer);
}
//Replace with hooked version
*m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet {
*m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet {
return hook(func, std::forward<TArgs>(args)...);
};
//Convinent for chain call
......@@ -58,7 +58,7 @@ private:
//Helps to deduce template args
template <typename TRet, typename... TArgs>
FunctionHooker(thin_function<TRet(TArgs...)>* f)
->FunctionHooker<TRet(TArgs...)>;
-> FunctionHooker<TRet(TArgs...)>;
template<typename TSignature>
auto make_shared_hook(thin_function<TSignature>* fptr){
......
......@@ -11,20 +11,20 @@
#include "./interpreter_impl.h"
#include "megbrain/common.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"
using namespace mgb;
using namespace imperative;
using namespace interpreter;
using namespace interpreter::intl;
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
return std::make_unique<ChannelImpl>();
}
Interpreter& Interpreter::inst() {
Tensor::_static_init();
static InterpreterImpl inst_;
return inst_;
}
......@@ -35,7 +35,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu();
m_valid_handle.insert(info);
m_worker.add_task(Put{info, value, no_cache});
m_buffer.enqueue(Put{info, value, no_cache});
return info;
}
......@@ -50,14 +50,14 @@ void* ChannelImpl::put(const DeviceTensorND& data) {
void ChannelImpl::del(void* handle) {
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)});
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)});
}
void ChannelImpl::swap_in(void* handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(SwapIn{reinterpret_cast<TensorInfo*>(handle)});
m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)});
}
}
......@@ -65,7 +65,7 @@ void ChannelImpl::swap_out(void* handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(SwapOut{reinterpret_cast<TensorInfo*>(handle)});
m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)});
}
}
......@@ -73,7 +73,7 @@ void ChannelImpl::drop(void* handle) {
if (m_enable_evict & DROP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(Drop{reinterpret_cast<TensorInfo*>(handle)});
m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)});
}
}
......@@ -88,14 +88,16 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
{
MGB_LOCK_GUARD(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
}
lock.unlock();
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
......@@ -127,7 +129,7 @@ SmallVector<void*> ChannelImpl::apply_op(
}
}
}
m_worker.add_task(std::move(cmd));
m_buffer.enqueue(std::move(cmd));
if (!(validated && validated_bkp) && m_async_level == 1) {
sync();
} else if (m_async_level == 0) {
......@@ -150,7 +152,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
if (!info->value_fetched) {
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
m_waitee = info;
m_worker.add_task(GetValue{info});
m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return info->value_fetched;
......@@ -171,6 +173,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
m_waitee = info;
m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return bool(info->ptr);
......@@ -206,6 +209,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
m_waitee = info;
m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return bool(info->ptr);
......@@ -215,6 +219,9 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
}
void ChannelImpl::sync() {
if (!m_buffer.empty()) {
m_buffer.enqueue(Flush{});
}
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe();
......@@ -350,6 +357,10 @@ void ChannelImpl::set_drop_flag(bool flag) {
}
}
void ChannelImpl::set_buffer_length(int length) {
m_buffer.set_capacity(length);
}
void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
if (!info->ptr && info->evict_type != NONE) {
if (info->evict_type == SWAP) {
......@@ -401,6 +412,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
} else if constexpr (std::is_same_v<T, ApplyOp>) {
SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) {
if (m_enable_evict && i->evict_type != NONE) {
if (!i->ptr) {
......@@ -408,9 +420,20 @@ void ChannelImpl::process_one_task(Command& cmd) {
}
}
mgb_assert(i->ptr, "Invalid input tensor ptr!");
// refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs.push_back(i->ptr);
}
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
for (auto* del : cmd.dels) {
// refcnt --, owners: [tensor_inputs]
// if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
free(del);
}
// Here std::move is REQUIRED for removing duplicated references.
auto tensor_outputs = OpDef::apply_on_physical_tensor(
*cmd.op, std::move(tensor_inputs));
mgb_assert(tensor_outputs.size() == cmd.outputs.size());
for (size_t i = 0; i < tensor_outputs.size(); ++i) {
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
......@@ -436,8 +459,12 @@ void ChannelImpl::process_one_task(Command& cmd) {
do_swap_out(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) {
do_drop(cmd.dest);
} else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src);
} else {
static_assert(!std::is_same_v<T, T>);
static_assert(std::is_same_v<T, Flush> ||
std::is_same_v<T, Nop>);
}
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
......@@ -454,7 +481,6 @@ void ChannelImpl::process_one_task(Command& cmd) {
}, cmd);
}
void ChannelImpl::check_worker_exc_unsafe() {
if (m_worker_exc) {
std::exception_ptr exc;
......@@ -462,3 +488,120 @@ void ChannelImpl::check_worker_exc_unsafe() {
std::rethrow_exception(exc);
}
}
void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd);
mgb_log_debug("%s Enqueued", command_repr.c_str());
m_commands.push_back(std::move(cmd));
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
}
void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter);
mgb_log_debug("%s Flushed", command_repr.c_str());
m_owner->m_worker.add_task(std::move(*iter));
}
m_commands.erase(m_commands.begin(), pos);
}
auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
return std::visit([this](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
auto* op_type = cmd.op->dyn_typeinfo();
if (op_type == RemoteRecv::typeinfo() ||
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo() ||
op_type == BackwardGraph::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end();
} else if constexpr (std::is_same_v<T, Flush>) {
if (cmd.dest == nullptr) {
return m_commands.end();
}
auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()});
if (produce_iter != m_commands.end()) {
return produce_iter + 1;
}
}
if (m_commands.size() > m_capacity) {
return m_commands.begin() + (m_commands.size() - m_capacity);
}
return m_commands.begin();
}, cmd);
}
/**
* 1. Find ApplyOp(dest) in buffered commands
* 2. Check if there are other usages between ApplyOp and Del, return false if not
* 3. Fuse Del into ApplyOp, return true
*/
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
auto* dest = cmd.dest;
// TODO: eliminate Puts
auto begin = m_commands.begin(), end = m_commands.end();
auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
}
return false;
});
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
return false;
}
mgb_log_debug("%s Fused", cmd.to_string().c_str());
std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
return true;
}
auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
-> Handle {
auto found = range[1];
for (auto iter = range[0]; iter != range[1]; ++iter) {
std::visit([&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
dest) > 0) {
found = iter;
}
} else if constexpr (std::is_same_v<T, GetValue>) {
if (cmd.dest == dest) {
found = iter;
}
} else if constexpr (std::is_same_v<T, SwapIn> ||
std::is_same_v<T, SwapOut> ||
std::is_same_v<T, Drop>) {
//TODO: ignore swap-like commands, just remove them from buffer
if (cmd.dest == dest) {
found = iter;
}
}
}, *iter);
};
return found;
}
auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
-> Handle {
return std::find_if(range[0], range[1], [dest](auto& cmd) {
return std::visit([dest](const auto& cmd){
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
} else if constexpr (std::is_same_v<T, Put>) {
return cmd.dest == dest;
}
return false;
}, cmd);
});
}
......@@ -9,13 +9,15 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <variant>
#include <deque>
#include <future>
#include <list>
#include <unordered_set>
#include <variant>
#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"
namespace mgb::imperative::interpreter::intl {
using Handle = Interpreter::Handle;
......@@ -58,39 +60,99 @@ struct Put {
TensorInfo* dest;
HostTensorND value;
bool no_cache = false;
std::string to_string() const { return ssprintf("Command: Put %p", dest); }
};
struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;
std::string to_string() const {
std::string builder{"Command: ApplyOp {"};
builder += "inputs [";
for (auto* input : inputs) {
builder += ssprintf("%p, ", input);
}
builder += "], outputs [";
for (auto* output : outputs) {
builder += ssprintf("%p, ", output);
}
builder += "], dels [";
for (auto* del : dels) {
builder += ssprintf("%p, ", del);
}
builder += "]";
return builder;
}
};
struct Del {
TensorInfo* dest;
std::string to_string() const { return ssprintf("Command: Del %p", dest); }
};
struct GetValue {
TensorInfo* dest;
};
std::string to_string() const {
return ssprintf("Command: GetValue %p", dest);
}
};
struct SwapIn {
TensorInfo* dest;
std::string to_string() const {
return ssprintf("Command: SwapIn %p", dest);
}
};
struct SwapOut {
TensorInfo* dest;
std::string to_string() const {
return ssprintf("Command: SwapOut %p", dest);
}
};
struct Drop {
TensorInfo* dest;
std::string to_string() const {
return ssprintf("Command: Drop %p", dest);
}
};
struct Move {
TensorInfo* src;
TensorInfo* dest;
std::string to_string() const {
return ssprintf("Command: Move %s to %s",
src->desc.layout.to_string().c_str(),
dest->desc.layout.to_string().c_str());
}
};
struct Flush {
TensorInfo* dest = nullptr;
std::string to_string() const {
return ssprintf("Command: Flush %p", dest);
}
};
struct Nop {
std::string to_string() const { return "Command: Nop"; }
};
using Command = std::variant<Put,
ApplyOp,
Del,
GetValue,
SwapIn,
SwapOut,
Drop>;
Drop,
Move,
Flush,
Nop>;
struct ChannelImpl : Interpreter::Channel {
ChannelImpl() : m_worker(this) {}
ChannelImpl() : m_worker(this), m_buffer(this) {}
~ChannelImpl() override;
Handle put(const HostTensorND& value, bool no_cache) override;
......@@ -116,6 +178,7 @@ struct ChannelImpl : Interpreter::Channel {
void close() override;
void set_swap_flag(bool) override;
void set_drop_flag(bool) override;
void set_buffer_length(int) override;
void config_async_level(int level) override;
int get_async_level() override;
......@@ -174,7 +237,56 @@ private:
std::mutex mtx;
std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
}m_st;
/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {
int capacity = 3;
if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) {
capacity = atoi(capacity_str);
}
set_capacity(capacity);
}
void enqueue(Command cmd);
bool empty() const {
return m_commands.empty();
}
void set_capacity(int capacity) {
mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length");
m_capacity = capacity;
}
private:
ChannelImpl* m_owner;
size_t m_capacity;
std::deque<Command> m_commands;
using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;
// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used, returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
......
......@@ -32,8 +32,8 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node(
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return def.trait()->apply_on_physical_tensor(def, inputs);
SmallVector<TensorPtr> inputs) {
return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
}
VarNodeArray OpDef::apply_on_var_node(
......
......@@ -17,17 +17,17 @@ namespace mgb {
namespace imperative {
namespace detail {
template<typename Signature>
template <typename Signature>
struct OpMeth;
template<typename RType, typename ...Args>
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
template <typename RType, typename... Args>
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator ()(args...);
return this->Base::operator()(std::forward<Args>(args)...);
}
};
template<typename T>
......@@ -56,7 +56,7 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
return opr->usable_output();
}
};
} // detail
} // namespace detail
using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
......
......@@ -56,17 +56,15 @@ protected:
return {};
}
AsyncReleaser() {
EventPool::without_timer();
}
public:
static AsyncReleaser* inst() {
static AsyncReleaser releaser;
return &releaser;
}
~AsyncReleaser() { m_waiter.wait_task_queue_empty(); }
~AsyncReleaser() {
m_waiter.wait_task_queue_empty();
}
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }
......@@ -85,8 +83,6 @@ public:
class CompNodeSyncManager : public CompNodeDepedentObject {
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event;
std::mutex m_mtx;
private:
static CompNodeSyncManager mgr;
public:
std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx);
......@@ -94,8 +90,9 @@ public:
return {};
}
static CompNodeSyncManager* inst() {
return &mgr;
static CompNodeSyncManager& inst() {
static CompNodeSyncManager sl_inst;
return sl_inst;
}
CompNode::Event* get_or_create_event(Blob* blob) {
......@@ -113,7 +110,6 @@ public:
m_blob2event.erase(blob);
}
};
CompNodeSyncManager CompNodeSyncManager::mgr;
// Cache for small blobs
// 1. A blob has to be seen twice (within a window) to be eligible for cache
......@@ -236,9 +232,12 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject {
MGB_LOCK_GUARD(mtx);
return cn2cache[hv.comp_node()].lookup(hv);
}
};
MultiCNConstTensorCache const_tensor_cache;
static MultiCNConstTensorCache& inst() {
static MultiCNConstTensorCache sl_inst;
return sl_inst;
}
};
} // namespace
......@@ -246,20 +245,26 @@ void EventDeleter::operator()(CompNode::Event* event) {
EventPool::without_timer().free(event);
}
namespace {
std::atomic_uint64_t next_blob_id = 0;
}
Blob::Blob(const DeviceTensorStorage& s):
m_comp_node{s.comp_node()}, m_storage{s.raw_storage()},
m_size{s.size()} {
m_id = next_blob_id++;
BlobManager::inst()->register_blob(this);
}
Blob::Blob(CompNode cn, size_t sz):
m_comp_node{cn}, m_storage{}, m_size{sz} {
m_id = next_blob_id++;
BlobManager::inst()->register_blob(this);
}
Blob::~Blob() {
BlobManager::inst()->unregister_blob(this);
CompNodeSyncManager::inst()->remove(this);
CompNodeSyncManager::inst().remove(this);
}
const Blob::RawStorage& Blob::storage() {
......@@ -302,7 +307,7 @@ Tensor::Tensor(const BlobPtr blob, const size_t offset, const TensorLayout& layo
: m_layout{layout}, m_blob{blob}, m_offset{offset} {}
TensorPtr Tensor::make(const HostTensorND& hv) {
auto&& blob = const_tensor_cache.lookup(hv);
auto&& blob = MultiCNConstTensorCache::inst().lookup(hv);
if (blob) {
return make(std::forward<decltype(blob)>(blob), hv.layout(), hv);
}
......@@ -366,13 +371,17 @@ void Tensor::add_release_callback(CompNode cn) {
}
CompNode::Event* Tensor::get_or_create_event() {
auto e = CompNodeSyncManager::inst()->get_or_create_event(m_blob.get());
auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get());
e->record();
return e;
}
void Tensor::_static_init() {
void Tensor::static_initialize() {
EventPool::with_timer();
EventPool::without_timer();
AsyncReleaser::inst();
CompNodeSyncManager::inst();
MultiCNConstTensorCache::inst();
}
} // namespace imperative
......
......@@ -117,7 +117,7 @@ void Profiler::start(uint32_t flags) {
auto hook_apply_on_var_node =
make_shared_hook(&trait.apply_on_var_node);
hook_apply_on_physical_tensor->apply_hook([this, flags]
(auto&& apply, const OpDef& def, const SmallVector<TensorPtr>& inputs) {
(auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) {
auto shape2vector = [](const TensorShape& shape) {
std::vector<size_t> vector_shape;
for (size_t i = 0; i < shape.ndim; i++) {
......
......@@ -11,6 +11,7 @@
#include "./proxy_graph.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb {
namespace imperative {
......@@ -70,11 +71,34 @@ void exec(const OpDef& def,
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
SmallVector<TensorPtr> inputs) {
auto output_descs = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs(output_descs.size(), {});
for (size_t i = 0; i < outputs.size(); i++) {
auto& output = outputs[i];
auto& output_desc = output_descs[i];
if (def.same_type<Elemwise>()) {
for (size_t j = 0; j < inputs.size(); j++) {
// TODO: reindex inputs to support inplace exprs like 'y = x op x'.
auto& input = inputs[j];
// Because we pass inputs by value, if input and input->blob() are all unique,
// their ownerships are on the stack, thus we can reuse them safely.
// @see: interpreter::intl::ChannelImpl::process_one_task
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() &&
input->layout().dtype == output_desc.layout.dtype &&
input->layout().eq_layout(output_desc.layout) &&
input->comp_node() == output_desc.comp_node) {
static std::atomic_llong inplace_count = 0;
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld",
output_desc.layout.to_string().c_str(), ++inplace_count);
output = Tensor::make(input->blob(), input->layout(), input->offset());
break;
}
}
}
if (!output) {
output = Tensor::make(output_desc.layout, output_desc.comp_node);
}
}
exec(def, inputs, outputs);
return outputs;
......
......@@ -44,6 +44,7 @@ struct Interpreter {
virtual void close() = 0;
virtual void set_swap_flag(bool) = 0;
virtual void set_drop_flag(bool) = 0;
virtual void set_buffer_length(int) = 0;
virtual void config_async_level(int level) = 0;
virtual int get_async_level() = 0;
......
......@@ -38,7 +38,7 @@ public:
static SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs);
SmallVector<TensorPtr> inputs);
static cg::VarNodeArray apply_on_var_node(
const OpDef& def,
......
......@@ -46,11 +46,16 @@ public:
size_t size() const {
return m_size;
}
size_t id() const {
return m_id;
}
private:
friend class BlobManagerImpl;
CompNode m_comp_node;
mutable RawStorage m_storage;
size_t m_size = 0;
size_t m_id;
};
struct EventDeleter {
......@@ -134,8 +139,7 @@ public:
// Make sure all static objects required to destruct a tensor has completed
// construction. All static storage duration object that holds tensors must
// call this method before their constructors completes.
static void _static_init();
static void static_initialize();
private:
TensorLayout m_layout;
......
......@@ -19,7 +19,7 @@ namespace proxy_graph_detail {
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs);
SmallVector<TensorPtr> inputs);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册