提交 9f2af209 编写于 作者: M Megvii Engine Team

feat(mgb): add enflame comp node

GitOrigin-RevId: 478c8538aa890dddf4e6ca95d1c4bb8a8b49ed8e
上级 15d3b3b9
......@@ -31,6 +31,7 @@
#endif
#if MEGDNN_WITH_CUDA
#include "src/cuda/handle.h"
#endif
......
......@@ -18,6 +18,7 @@
#endif
#if MEGDNN_WITH_ROCM
#include "src/rocm/megcore/computing_context.hpp"
#endif
......
......@@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
}
dev_type = DeviceType::MULTITHREAD;
ptr += 11;
} else {
}
else {
if (ptr[1] != 'p' || ptr[2] != 'u') {
err();
}
......@@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
//! num_steam store the nr_thread
std::swap(num_dev, num_stream);
}
return {dev_type, num_dev, {num_stream}};
}
......
......@@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
{
auto type = cn_impl->env().property().type;
mgb_throw_if(
type != CompNode::DeviceType::CPU &&
type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS &&
type != CompNode::DeviceType::CAMBRICON,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS, CAMBRICON"
mgb_throw_if(type != CompNode::DeviceType::CPU
&& type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS
,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS"
);
}
......
......@@ -36,6 +36,7 @@
#endif
using namespace mgb;
/* =================== MegDNNHandle =================== */
......@@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
}
#endif
#if MGB_ATLAS
void mgb::_on_atlas_error(const char* expr, int err, const char* file,
......@@ -421,6 +423,7 @@ void CompNodeEnv::fini() {
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream));
}
#endif
#if MGB_ROCM
if (m_property.type == DeviceType::ROCM) {
m_rocm_env.activate();
......@@ -440,6 +443,7 @@ void CompNodeEnv::fini() {
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream));
}
#endif
}
#if MGB_ENABLE_COMP_NODE_ASYNC_INIT
......
......@@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() {
#endif
}
AtlasError::AtlasError(const std::string &msg):
SystemError(msg)
{
......
......@@ -82,7 +82,7 @@ class CompNode {
CAMBRICON = 3,
ROCM = 8,
ATLAS = 9,
MULTITHREAD,
MULTITHREAD = 11,
MAX_DEVICE_ID,
};
static constexpr size_t NR_DEVICE_TYPE =
......
......@@ -63,6 +63,7 @@
#endif //MGB_ENABLE_LOGGING
#endif //MGB_CUDA
#if MGB_ATLAS
#include "megcore_atlas.h"
#include <atomic>
......@@ -205,6 +206,7 @@ namespace mgb {
#endif
#if MGB_ROCM
[[noreturn]] void _on_hip_error(const char* expr, hipError_t err,
const char* file, const char* func, int line);
......@@ -369,6 +371,7 @@ public:
const ContinuationCtx<cudaStream_t>& cont);
#endif
#if MGB_ATLAS
struct AtlasEnv {
int device = -1;
......
......@@ -139,6 +139,11 @@ public:
CudaError(const std::string& msg);
};
class EnFlameError final : public SystemError {
public:
EnFlameError(const std::string& msg);
};
class AtlasError final: public SystemError {
public:
AtlasError(const std::string& msg);
......
......@@ -166,6 +166,7 @@ TEST(TestCompNode, Load) {
ASSERT_NE(atlas0, atlas1);
#endif
}
TEST(TestCompNode, FreeAfterFinalize) {
......@@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) {
#endif
#endif // MGB_CAMBRICON
#if MGB_ATLAS
TEST(TestCompNodeAtlas, D2DCopy) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册