提交 d68d4d1d 编写于 作者: M Megvii Engine Team

perf(atlas): use async d2d

GitOrigin-RevId: 55914631cb63bc1057b2f2a124d8168b3e1e29cb
上级 d8ac6c70
......@@ -41,26 +41,22 @@ AtlasComputingContext::~AtlasComputingContext() {
void AtlasComputingContext::memcpy(void* dst, const void* src,
size_t size_in_bytes,
megcoreMemcpyKind_t kind) {
aclrtMemcpyKind atlas_kind;
switch (kind) {
case megcoreMemcpyDeviceToHost:
atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST;
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_HOST));
break;
case megcoreMemcpyHostToDevice:
atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE;
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
break;
case megcoreMemcpyDeviceToDevice:
atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE, m_ctx.stream));
break;
default:
megdnn_throw("bad atlas memcpy kind");
}
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
atlas_kind, m_ctx.stream));
#else
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind));
#endif
}
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
......@@ -69,11 +65,7 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
}
void AtlasComputingContext::synchronize() {
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtSynchronizeStream(m_ctx.stream));
#else
return;
#endif
}
// vim: syntax=cpp.doxygen
......@@ -230,10 +230,10 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
auto&& src_env = m_env.atlas_env();
activate();
if (dst_env.device == src_env.device) {
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#if 1
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#else
MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE));
......
......@@ -361,6 +361,7 @@ void AtlasRuntimeOpr::scn_do_execute() {
i, output(i)->cname());
aclmdlAddDatasetBuffer(model_outputs, output_db);
}
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));
for (size_t i = 0; i < nr_inputs; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册