提交 c0e2e4c7 编写于 作者: M Megvii Engine Team

fix(param_pack): impl param pack concat in imperative_rt

GitOrigin-RevId: 91edd9c0bf9d3020ba0ef2f87ca63c6901f8b1ce
上级 eac8f841
/**
* \file imperative/src/impl/async_releaser.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/system.h"
#include "./event_pool.h"
namespace mgb {
namespace imperative {
class AsyncReleaser : public CompNodeDepedentObject {
struct WaiterParam {
CompNode cn;
CompNode::Event* event;
BlobPtr blob;
HostTensorStorage::RawStorage storage;
};
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
AsyncReleaser* m_par_releaser;
public:
// disable busy wait by set max_spin=0 to save CPU cycle
Waiter(AsyncReleaser* releaser)
: AsyncQueueSC<WaiterParam, Waiter>(0),
m_par_releaser(releaser) {}
void process_one_task(WaiterParam& param) {
if (param.event->finished()) {
param.blob.reset();
param.storage.reset();
EventPool::without_timer().free(param.event);
return;
}
using namespace std::literals;
std::this_thread::sleep_for(1us);
add_task(std::move(param));
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("releaser");
}
};
Waiter m_waiter{this};
protected:
std::shared_ptr<void> on_comp_node_finalize() override {
m_waiter.wait_task_queue_empty();
return {};
}
public:
static AsyncReleaser* inst() {
static AsyncReleaser releaser;
return &releaser;
}
~AsyncReleaser() {
m_waiter.wait_task_queue_empty();
}
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }
void add(const HostTensorND& hv) {
add(hv.comp_node(), {}, hv.storage().raw_storage());
}
void add(CompNode cn, BlobPtr blob,
HostTensorStorage::RawStorage storage = {}) {
auto event = EventPool::without_timer().alloc(cn);
event->record();
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
}
};
}
}
......@@ -12,6 +12,9 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/tensor_manip.h"
#include "../async_releaser.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
......@@ -173,6 +176,7 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) {
// memory forward
ret.push_back(
inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
}
......@@ -197,8 +201,52 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
return opr;
}
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(comp_node == input->comp_node(), "inputs for param_pack_concat must in same comp_node");
mgb_assert(dtype == input->dtype(), "inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto output = Tensor::make(dest_layout, comp_node);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*)*nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {(dt_byte*)srcs_raw_ptr, [comp_node](dt_byte* ptr){
comp_node.free_host(ptr);
}};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(src_shapes, inputs.back()->shape(), TensorShape{});
}
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr;
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
caller.op->exec({srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), output->dev_tensor().as_megdnn(),
caller.create_workspace({{ws_size}, dtype::Byte()}));
AsyncReleaser::inst()->add(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
return { output };
}
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback();
} // param_pack
......
......@@ -11,7 +11,10 @@
#include "megbrain/imperative.h"
#include "megbrain/imperative/blob_manager.h"
#include "./event_pool.h"
#include "./async_releaser.h"
#include <mutex>
namespace mgb {
......@@ -19,70 +22,6 @@ namespace imperative {
namespace {
class AsyncReleaser : public CompNodeDepedentObject {
struct WaiterParam {
CompNode cn;
CompNode::Event* event;
BlobPtr blob;
HostTensorStorage::RawStorage storage;
};
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
AsyncReleaser* m_par_releaser;
public:
// disable busy wait by set max_spin=0 to save CPU cycle
Waiter(AsyncReleaser* releaser)
: AsyncQueueSC<WaiterParam, Waiter>(0),
m_par_releaser(releaser) {}
void process_one_task(WaiterParam& param) {
if (param.event->finished()) {
param.blob.reset();
param.storage.reset();
EventPool::without_timer().free(param.event);
return;
}
using namespace std::literals;
std::this_thread::sleep_for(1us);
add_task(std::move(param));
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("releaser");
}
};
Waiter m_waiter{this};
protected:
std::shared_ptr<void> on_comp_node_finalize() override {
m_waiter.wait_task_queue_empty();
return {};
}
public:
static AsyncReleaser* inst() {
static AsyncReleaser releaser;
return &releaser;
}
~AsyncReleaser() {
m_waiter.wait_task_queue_empty();
}
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }
void add(const HostTensorND& hv) {
add(hv.comp_node(), {}, hv.storage().raw_storage());
}
void add(CompNode cn, BlobPtr blob,
HostTensorStorage::RawStorage storage = {}) {
auto event = EventPool::without_timer().alloc(cn);
event->record();
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
}
};
class CompNodeSyncManager : public CompNodeDepedentObject {
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event;
std::mutex m_mtx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册