提交 c8697a70 编写于 作者: M Megvii Engine Team

feat(imperative/src): python wrapper for cambricon and atlas runtime opr

GitOrigin-RevId: bd969d1339463645d559cf1d0d016713d04191d9
上级 8cfed4a1
......@@ -529,7 +529,11 @@ class InputNode(OpNode):
@property
def device(self):
return self.outputs[0].device
var = self.outputs[0]
if isinstance(var, VarNode):
return var.device
else:
return var.comp_node
@property
def dtype(self):
......
......@@ -36,6 +36,10 @@ def _str2device_type(type_str: str, allow_unspec: bool = True):
return DeviceType.CPU
elif type_str == "GPU" or type_str == "CUDA":
return DeviceType.CUDA
elif type_str == "CAMBRICON":
return DeviceType.CAMBRICON
elif type_str == "ATLAS":
return DeviceType.ATLAS
else:
assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu"
return DeviceType.UNSPEC
......@@ -65,6 +69,24 @@ def is_cuda_available() -> bool:
return CompNode._get_device_count(t, False) > 0
def is_cambricon_available() -> bool:
"""
Returns whether cambricon device is available on this system.
"""
t = _str2device_type("cambricon")
return CompNode._get_device_count(t, False) > 0
def is_atlas_available() -> bool:
"""
Returns whether atlas device is available on this system.
"""
t = _str2device_type("atlas")
return CompNode._get_device_count(t, False) > 0
def set_default_device(device: str = "xpux"):
r"""
Sets default computing node.
......
......@@ -20,3 +20,30 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None):
op = builtin.TensorRTRuntime(data, len(data))
# return sequence of outputs
return apply(op, *inputs)
def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable):
r"""
Load a serialized Cambricon model as a runtime operator in MegEngine.
:param inputs: list of input tensors.
:param data: the serialized Cambricon model.
:param symbol: name of the function in Cambricon model.
:param tensor_dim_mutable: whether the input tensors' shapes are mutable
in ``cnrtModel_t``.
"""
op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable)
return apply(op, *inputs)
def atlas_runtime_opr(inputs, data):
r"""
Load a serialized Atlas model as a runtime operator in MegEngine.
:param inputs: list of input tensors.
:param data: the serialized Atlas model.
"""
op = builtin.AtlasRuntime(data, len(data))
return apply(op, *inputs)
......@@ -786,7 +786,11 @@ class trace:
)
output_names = output_names or self._output_names
dumped_device = as_device("xpux")
def dumped_device(info):
device_name = info.device.logical_name
if device_name[:3] in ("cpu", "gpu", "xpu"):
return as_device("xpux")
return info.device
h2v = {}
graph = G.Graph()
......@@ -794,19 +798,21 @@ class trace:
# apply graph_opt_level in dump
if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level
for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=dumped_device,
device=dumped_device(info),
shape=info.shape or (1,),
name=arg_names[i] if arg_names else None,
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=k,
)
for op, ihandles, ohandles in self._seq:
......@@ -833,7 +839,7 @@ class trace:
h2v[h] = graph.make_const(
info.bound_data.numpy(),
dtype=info.dtype,
device=dumped_device,
device=dumped_device(info),
name=info.name,
)
ivars.append(h2v[h])
......
......@@ -9,7 +9,11 @@
# pylint: disable=redefined-builtin
import numpy as np
from ..functional.external import tensorrt_runtime_opr
from ..functional.external import (
atlas_runtime_opr,
cambricon_runtime_opr,
tensorrt_runtime_opr,
)
from .module import Module
......@@ -33,3 +37,52 @@ class TensorrtRuntimeSubgraph(Module):
def forward(self, *inputs):
return tensorrt_runtime_opr(inputs, data=self._data)
class CambriconRuntimeSubgraph(Module):
r"""Load a serialized CambriconRuntime subgraph.
See :func:`~.cambricon_runtime_opr` for more details.
"""
def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
self._data = data
self.symbol = symbol
self.tensor_dim_mutable = tensor_dim_mutable
@property
def data(self):
return self._data
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, *inputs):
outputs = cambricon_runtime_opr(
inputs, self._data, self.symbol, self.tensor_dim_mutable
)
return outputs
class AtlasRuntimeSubgraph(Module):
r"""Load a serialized AtlasRuntime subgraph.
See :func:`~.atlas_runtime_opr` for more details.
"""
def __init__(self, data, **kwargs):
super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
self._data = data
@property
def data(self):
return self._data
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, *inputs):
return atlas_runtime_opr(inputs, data=self._data)
......@@ -427,8 +427,9 @@ class GraphInference:
list(self._inp_dict.keys()), list(inputs.keys())
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._inp_dict[key].set_value(
Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
)
self._func.execute()
self._func.wait()
......
......@@ -171,6 +171,8 @@ void init_common(py::module m) {
.value("UNSPEC", CompNode::DeviceType::UNSPEC)
.value("CUDA", CompNode::DeviceType::CUDA)
.value("CPU", CompNode::DeviceType::CPU)
.value("CAMBRICON", CompNode::DeviceType::CAMBRICON)
.value("ATLAS", CompNode::DeviceType::ATLAS)
.value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD)
.value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID);
......
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#if MGB_ATLAS
#include "megbrain/opr/atlas_runtime_op.h"
namespace mgb::imperative {
namespace {
namespace atlas_runtime {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const AtlasRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::AtlasRuntimeOpr::make(op.buf.c_str(), op.buf_size,
symbol_var_inputs, config);
}
OP_TRAIT_REG(AtlasRuntime, AtlasRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace atlas_runtime
} // namespace
} // namespace mgb::imperative
#endif
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#if MGB_CAMBRICON
#include "megbrain/cambricon/cambricon_runtime_opr.h"
namespace mgb::imperative {
namespace {
namespace cambricon_runtime {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const CambriconRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::CambriconRuntimeOpr::make(op.buf.c_str(), op.buf_size,
op.symbol, symbol_var_inputs,
op.tensor_dim_mutable, config);
}
OP_TRAIT_REG(CambriconRuntime, CambriconRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace cambricon_runtime
} // namespace
} // namespace mgb::imperative
#endif
\ No newline at end of file
......@@ -266,6 +266,22 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
);
}
def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}
def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size,
MgbStringAttr:$symbol,
MgbBoolAttr:$tensor_dim_mutable
);
}
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册