diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index 4fc6657aff1ae8c98f0cca8f1ddc6492e4346df4..ca274fd6f341d4a6e33d55a2a0f71d02a2133a79 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -11,18 +11,18 @@ #include "./comp_node.h" +#include "megbrain/common.h" #include "megbrain/comp_node_env.h" #include "megbrain/system.h" #include "megbrain/utils/arith_helper.h" #include "megbrain/utils/thread.h" -#include "megbrain/utils/timer.h" #include "megbrain/utils/thread_pool.h" -#include "megbrain/common.h" +#include "megbrain/utils/timer.h" +#include #include #include #include -#include #include #ifndef __APPLE__ @@ -44,8 +44,6 @@ struct TaskElem { }; } // anonymous namespace -using CpuCompNodeImpl = CpuCompNode::CompNodeImpl; - void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) { dispatch(std::move(task)); } @@ -110,7 +108,15 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { * \brief use to check the all ther recording tasks are its self CompNode * related task, void hook other CompNode related task to the recorder. */ - void check_the_same_comp_node(const CompNode& comp_node) const; + void check_the_same_comp_node(const CompNode& comp_node) const { + if (mgb_unlikely(comp_node.valid())) { + mgb_assert(m_record_compnode == comp_node, + "CompNode %s can't hook in CompNode %s when recording\n", + comp_node.locator().to_string().c_str(), + m_record_compnode.locator().to_string().c_str()); + } + } + public: SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool, const CompNode& comp_node) @@ -127,13 +133,13 @@ public: } } - void enter_fake_exec(const CompNode& comp_node) override { + void enter_fake_exec(const CompNode& comp_node) override { check_the_same_comp_node(comp_node); mgb_assert(!m_stopped && !m_fake_exec); m_fake_exec = true; } - void exit_fake_exec(const CompNode& comp_node) override { + void exit_fake_exec(const CompNode& comp_node) override { check_the_same_comp_node(comp_node); mgb_assert(!m_stopped && m_fake_exec); mgb_assert(m_tasks.empty()); @@ -165,9 +171,9 @@ public: m_thread_pool->add_task(i); } m_thread_pool->deactive(); - }else{ + } else { for (auto&& task : m_tasks) { - for(size_t i=0; i m_worker_queue; +//! ==================== CompNodeBaseImpl ====================== +class CpuCompNode::CompNodeBaseImpl : public CpuDispatchableBase { +protected: Locator m_locator, m_locator_logical; - std::unique_ptr m_thread_pool; - - //! ptr to default cpu, only used by check_global_finalized - static CpuCompNodeImpl *sm_default_cpu_comp_node_ptr; - - //! return whether global finalized, and print warning in such case - inline bool check_global_finalized(const char* reason); - - static void static_free_device(ImplBase* self, void* ptr) { - static_cast(self)->free_device(ptr); - } - - static void static_free_host(ImplBase* self, void* ptr) { - static_cast(self)->free_host(ptr); - } - public: - CompNodeImpl(const Locator& locator, const Locator& locator_logical, - const std::shared_ptr& worker_queue); - ~CompNodeImpl() { - if (sm_cur_recorder) { - sm_cur_recorder->stop(); - } - if (m_worker_queue) { - // synchronize before fini - m_worker_queue->wait_all_task_finish(); - } - m_env.fini(); - if (m_worker_queue) { - // wait for new kernels dispatched in fini() (like free_device()) - m_worker_queue->wait_all_task_finish(); - } - if (this == sm_default_cpu_comp_node_ptr) { - // This should only happen in global library .fini. We clear - // sm_default_cpu_comp_node_ptr so check_global_finalized() can - // work correctly - sm_default_cpu_comp_node_ptr = nullptr; - } - } +public: + CompNodeBaseImpl(const Locator& locator, const Locator& locator_logical, + free_func_t fd, free_func_t fh) + : CpuDispatchableBase(fd, fh), + m_locator(locator), + m_locator_logical(locator_logical) {} - ThreadPool* get_thread_pool() const { return m_thread_pool.get(); } + virtual ~CompNodeBaseImpl() {} - void* mgb_aligned_alloc(size_t size) { - auto alignment = get_mem_addr_alignment(); + void* mgb_aligned_alloc(size_t size) { + auto alignment = get_mem_addr_alignment(); #ifdef WIN32 - return _aligned_malloc(size, alignment); + return _aligned_malloc(size, alignment); #elif defined(__ANDROID__) || defined(ANDROID) - return memalign(alignment, size); + return memalign(alignment, size); #else - void *ptr = nullptr; - auto err = posix_memalign(&ptr, alignment, size); - mgb_assert(!err, "failed to malloc %zubytes with align %zu", - size, alignment); - return ptr; + void* ptr = nullptr; + auto err = posix_memalign(&ptr, alignment, size); + mgb_assert(!err, "failed to malloc %zubytes with align %zu", size, + alignment); + return ptr; #endif - } + } - static void mgb_aligned_free(void* ptr) { + static void mgb_aligned_free(void* ptr) { #ifdef WIN32 - _aligned_free(ptr); + _aligned_free(ptr); #else - ::free(ptr); + ::free(ptr); #endif - } - - void* alloc_device(size_t size) override { - if (sm_cur_recorder) { - sm_cur_recorder->on_alloc(this); - } - return mgb_aligned_alloc(size); - } - - void free_device(void *ptr) { - if (sm_cur_recorder || check_global_finalized("free_device()")) { - mgb_aligned_free(ptr); - if (sm_cur_recorder) { - sm_cur_recorder->on_free(this); - } - return; - } else { - auto do_free = [ptr]() { - mgb_aligned_free(ptr); - }; - m_env.cpu_env().dispatch(do_free); - } - } - - void *alloc_host(size_t size) override { - if (m_worker_queue) { - m_worker_queue->check_exception(); - } - return mgb_aligned_alloc(size); - } - - void free_host(void *ptr) { - if (check_global_finalized("free_host()")) { - mgb_aligned_free(ptr); - return; - } - if (m_worker_queue) { - m_worker_queue->check_exception(); - } - return mgb_aligned_free(ptr); - } - - void copy_to_host(void *host_ptr, - const void *device_ptr, size_t size) override { - if (m_worker_queue) { - m_worker_queue->check_exception(); - } - // use lambda capture to avoid memory allocation in std::bind - auto do_copy = [host_ptr, device_ptr, size]() { - std::memcpy(host_ptr, device_ptr, size); - }; - m_env.cpu_env().dispatch(do_copy); - } - - void copy_to_device(void *device_ptr, - const void *host_ptr, size_t size) override { - if (m_worker_queue) { - m_worker_queue->check_exception(); - } - // use lambda capture to avoid memory allocation in std::bind - auto do_copy = [device_ptr, host_ptr, size]() { - std::memcpy(device_ptr, host_ptr, size); - }; - m_env.cpu_env().dispatch(do_copy); - } + } - void peer_copy_to( - Impl *dest_impl, void *dest, - const void *src, size_t size) override { - if (!dest_impl->same_type()) { - if (dest_impl->env().property().type == DeviceType::ATLAS) { -#if MGB_ATLAS - dest_impl->copy_to_device(dest, src, size); - return; -#else - mgb_throw(MegBrainError, - "Atlas comp_node used but " - "MGB_ATLAS not enabled"); -#endif - } else if (dest_impl->env().property().type == - DeviceType::CAMBRICON) { -#if MGB_CAMBRICON - dest_impl->copy_to_device(dest, src, size); - return; -#else - mgb_throw(MegBrainError, - "Cambricon comp_node used but " - "MGB_CAMBRICON not enabled"); -#endif + void* alloc_device(size_t size) override { return mgb_aligned_alloc(size); } - } else { - mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT, - "currently only peer copy from default cpu comp " - "nodes " - "is implemented"); - } - } - dest_impl->copy_to_device(dest, src, size); - } + void* alloc_host(size_t size) override { return mgb_aligned_alloc(size); } - size_t get_mem_addr_alignment() override { - return m_env.property().mem_alignment; - } + void copy_to_host(void* host_ptr, const void* device_ptr, + size_t size) override { + // use lambda capture to avoid memory allocation in std::bind + auto do_copy = [host_ptr, device_ptr, size]() { + std::memcpy(host_ptr, device_ptr, size); + }; + m_env.cpu_env().dispatch(do_copy); + } - std::unique_ptr create_event(size_t flags) override; + void copy_to_device(void* device_ptr, const void* host_ptr, + size_t size) override { + // use lambda capture to avoid memory allocation in std::bind + auto do_copy = [device_ptr, host_ptr, size]() { + std::memcpy(device_ptr, host_ptr, size); + }; + m_env.cpu_env().dispatch(do_copy); + } - void sync() override { - if (sm_cur_recorder) { - sm_cur_recorder->on_sync(this); - } else if (m_worker_queue) { - m_worker_queue->wait_all_task_finish(); - } - if (m_thread_pool) { - m_thread_pool->deactive(); - } - } + void peer_copy_to(Impl* dest_impl, void* dest, const void* src, + size_t size) override { + dest_impl->copy_to_device(dest, src, size); + } - void dispatch(Task &&task) override { - m_env.cpu_env().dispatch(std::move(task)); - } + size_t get_mem_addr_alignment() override { + return m_env.property().mem_alignment; + } - MemNode mem_node() override { - // TODO: numa nodes - return get_host_cpu_mem_node(); - } + void dispatch(Task&& task) override { + m_env.cpu_env().dispatch(std::move(task)); + } - std::pair get_mem_status_bytes() override { - return sys::get_ram_status_bytes(); - } + MemNode mem_node() override { + // TODO: numa nodes + return get_host_cpu_mem_node(); + } - Locator locator() override { - return m_locator; - } + std::pair get_mem_status_bytes() override { + return sys::get_ram_status_bytes(); + } - Locator locator_logical() override { - return m_locator_logical; - } + Locator locator() override { return m_locator; } - std::unique_ptr create_seq_recorder( - cg::ComputingGraph*) override { - return std::make_unique(&sm_cur_recorder, - m_thread_pool.get(), this); - } + Locator locator_logical() override { return m_locator_logical; } - //! current sequence recorder of this thread -#if !defined(IOS) && MGB_HAVE_THREAD - static SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; } -#else - SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; } -#endif + void add_callback(Task&& task) override { + CpuDispatchableBase::add_callback(std::move(task)); + } - void add_callback(Task &&task) override { - if (!check_global_finalized("add_callback()")) { - CpuDispatchableBase::add_callback(std::move(task)); - } else { - task(); - } - } + virtual SeqRecorderImpl* cur_recorder() const = 0; }; -MGB_DYN_TYPE_OBJ_FINAL_IMPL(CpuCompNodeImpl); -CpuCompNodeImpl* CpuCompNodeImpl::sm_default_cpu_comp_node_ptr; -#if !defined(IOS) && MGB_HAVE_THREAD -thread_local CpuCompNode::SeqRecorderImpl* CpuCompNodeImpl::sm_cur_recorder = - nullptr; -#endif - -void CpuCompNode::SeqRecorderImpl::check_the_same_comp_node( - const CompNode& comp_node) const { - if (mgb_unlikely(comp_node.valid())) { - mgb_assert(m_record_compnode == comp_node, - "CompNode %s can't hook in CompNode %s when recording\n", - comp_node.locator().to_string().c_str(), - m_record_compnode.locator().to_string().c_str()); - } -} //! implementation of CPUDispatcher that is passed to megdnn via megcore -class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher { +class CpuCompNode::WorkerQueue::DispatcherImpl final : public CPUDispatcher { std::atomic_size_t m_nr_task{0}; std::shared_ptr m_queue; - CpuCompNode::CompNodeImpl* const m_comp_node; + //! DispatcherImpl only used by CompNodeRecorderImpl, but we still use + //! CompNodeBaseImpl* because of incomplete type error + CompNodeBaseImpl* const m_comp_node; public: DispatcherImpl(const std::shared_ptr& queue, - CpuCompNode::CompNodeImpl* comp_node) + CompNodeBaseImpl* comp_node) : m_queue{queue}, m_comp_node{comp_node} {} void dispatch(Task&& task) override { @@ -559,10 +405,12 @@ public: class InplaceCPUDispatcher final : public CPUDispatcher { std::atomic_size_t m_nr_task{0}; ThreadPool* m_thread_pool = nullptr; - CpuCompNode::CompNodeImpl* const m_comp_node; + //! InplaceCPUDispatcher may used by both type of compnodes, so + //! m_comp_node's type should be base class. + CompNodeBaseImpl* const m_comp_node; public: - InplaceCPUDispatcher(CpuCompNode::CompNodeImpl* comp_node, + InplaceCPUDispatcher(CompNodeBaseImpl* comp_node, ThreadPool* thread_pool = nullptr) : m_thread_pool(thread_pool), m_comp_node(comp_node) {} @@ -585,9 +433,9 @@ public: } else if (m_thread_pool) { m_nr_task.fetch_add(1, std::memory_order_relaxed); m_thread_pool->add_task({task, parallelism}); - }else{ + } else { m_nr_task.fetch_add(1, std::memory_order_relaxed); - for(size_t i=0; iget_thread_pool()->set_affinity(affinity_cb); } else if (m_thread_pool) { m_thread_pool->set_affinity(affinity_cb); - }else{ + } else { affinity_cb(0); } } }; -CpuCompNode::CompNodeImpl::CompNodeImpl( - const Locator& locator, const Locator& locator_logical, - const std::shared_ptr& worker_queue) - : CpuDispatchableBase(static_free_device, static_free_host), - m_worker_queue{worker_queue}, - m_locator(locator), - m_locator_logical(locator_logical) { - auto cn = make_comp_node_from_impl(this); - if (locator.type == DeviceType::MULTITHREAD) { - m_thread_pool = std::unique_ptr( - new ThreadPool(static_cast(locator.nr_threads))); - mgb_assert(m_thread_pool, "ThradPool create failed"); +//! ==================== CompNodeNoRecorderImpl ====================== +/** + * \note: CompNodeNoRecorderImpl will use most implements in base including: + * alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to, + * add_callback ... + */ +class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + //! ptr to default cpu, only used by check_global_finalized + static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr; + + static void static_free_device(ImplBase* self, void* ptr) { + static_cast(self)->free_device(ptr); } - if (locator.type == DeviceType::CPU) { - if (locator.device == Locator::DEVICE_CPU_DEFAULT) { - sm_default_cpu_comp_node_ptr = this; - m_env.init_cpu({std::make_shared(this)}, cn); - } else { - m_env.init_cpu({std::make_shared( - m_worker_queue, this)}, - cn); - } - } else if (locator.type == DeviceType::MULTITHREAD) { - if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { - m_env.init_cpu({std::make_shared( - this, m_thread_pool.get())}, - cn); - } else { - m_worker_queue->attach_thread_pool(m_thread_pool.get()); - m_env.init_cpu({std::make_shared( - m_worker_queue, this)}, - cn); + static void static_free_host(ImplBase* self, void* ptr) { + static_cast(self)->free_host(ptr); + } + using CpuEventImpl = CpuDispatchableBase::EventImpl; + + CompNodeNoRecorderImpl(const Locator& locator, + const Locator& locator_logical) + : CompNodeBaseImpl(locator, locator_logical, static_free_device, + static_free_host) { + mgb_assert( + locator.type == DeviceType::CPU && + locator.device == Locator::DEVICE_CPU_DEFAULT, + "CompNodeNoRecorder is only constructed On DEVICE_CPU_DEFAULT"); + auto cn = make_comp_node_from_impl(this); + m_env.init_cpu({std::make_shared(this)}, cn); + sm_default_cpu_comp_node_ptr = this; + } + + ~CompNodeNoRecorderImpl() { + m_env.fini(); + sm_default_cpu_comp_node_ptr = nullptr; + } + + //! return whether global finalized, and print warning in such case + bool check_global_finalized(const char* reason) { + MGB_MARK_USED_VAR(reason); + if (!sm_default_cpu_comp_node_ptr) { + static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT; + if (!warn_printed.test_and_set()) { + mgb_log_debug( + "cpu comp node method called after global finalize: " + "reason=%s", + reason); + } + return true; } + return false; } -} -class CpuCompNodeImpl::CompSeqRecEventImpl final - : public CpuDispatchableBase::EventImpl { - void do_record() override { - auto impl = static_cast(m_comp_node_impl); - if (auto rec = impl->cur_recorder()) { - auto callback = [this]() { - incr_nr_req(); - on_finish(); - }; - rec->dispatch_allow_after_sync(callback, m_comp_node_impl); + void free_device(void* ptr) { + if (check_global_finalized("free_device()")) { + CompNodeBaseImpl::mgb_aligned_free(ptr); + return; } else { - EventImpl::do_record(); + auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); }; + m_env.cpu_env().dispatch(do_free); } } - void do_device_wait_by(Impl*) override { - mgb_throw(MegBrainError, - "device_wait() should not be called on events created during " - "comp node seq recording"); + void free_host(void* ptr) { + check_global_finalized("free_host()"); + return CompNodeBaseImpl::mgb_aligned_free(ptr); } -public: - using EventImpl::EventImpl; + std::unique_ptr create_event(size_t flags) override { + return std::make_unique(this, flags); + } + + void sync() override {} + + std::unique_ptr create_seq_recorder( + cg::ComputingGraph*) override { + mgb_assert(false, "default_cpu has no ability to record"); + return nullptr; + } + + SeqRecorderImpl* cur_recorder() const override { return nullptr; } }; +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl); +CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr = + nullptr; + +//! ==================== CompNodeRecorderImpl ====================== +class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + std::unique_ptr m_thread_pool; + std::shared_ptr m_worker_queue; + + //! used during comp node seq rec + class CompSeqRecEventImpl final : public CpuDispatchableBase::EventImpl { + void do_record() override { + auto impl = static_cast(m_comp_node_impl); + if (auto rec = impl->cur_recorder()) { + auto callback = [this]() { + incr_nr_req(); + on_finish(); + }; + rec->dispatch_allow_after_sync(callback, m_comp_node_impl); + } else { + EventImpl::do_record(); + } + } + + void do_device_wait_by(Impl*) override { + mgb_throw(MegBrainError, + "device_wait() should not be called on events created " + "during " + "comp node seq recording"); + } -class CpuCompNodeImpl::CpuEventImpl final - : public CpuDispatchableBase::EventImpl { + public: + using EventImpl::EventImpl; + }; + + class CpuEventImpl final : public CpuDispatchableBase::EventImpl { #if MGB_HAVE_THREAD - void host_wait_cv() override { - CpuDispatchableBase::EventImpl::host_wait_cv(); - auto thread_pool = static_cast(m_comp_node_impl) - ->get_thread_pool(); - if (thread_pool) { - thread_pool->deactive(); + void host_wait_cv() override { + CpuDispatchableBase::EventImpl::host_wait_cv(); + auto thread_pool = + static_cast(m_comp_node_impl) + ->get_thread_pool(); + if (thread_pool) { + thread_pool->deactive(); + } } - } #endif + public: + using EventImpl::EventImpl; + }; + +//! TODO: because the x-code bug, see +//! https://github.com/tensorflow/tensorflow/issues/18356 +//! thread local is no support on IOS, +//! When update x-xode, this code should be deleted +#if !defined(IOS) && MGB_HAVE_THREAD + static thread_local SeqRecorderImpl* sm_cur_recorder; +#else + SeqRecorderImpl* sm_cur_recorder = nullptr; +#endif + public: - using EventImpl::EventImpl; -}; + static void static_free_device(ImplBase* self, void* ptr) { + static_cast(self)->free_device(ptr); + } -std::unique_ptr CpuCompNodeImpl::create_event(size_t flags) { - if (m_worker_queue) { - m_worker_queue->check_exception(); + static void static_free_host(ImplBase* self, void* ptr) { + static_cast(self)->free_host(ptr); } - if (sm_cur_recorder) { - return std::make_unique(this, flags); - } else { - return std::make_unique(this, flags); + + CompNodeRecorderImpl(const Locator& locator, const Locator& locator_logical, + const std::shared_ptr& worker_queue) + : CompNodeBaseImpl(locator, locator_logical, static_free_device, + static_free_host), + m_worker_queue(worker_queue) { + auto cn = make_comp_node_from_impl(this); + if (locator.type == DeviceType::MULTITHREAD) { + m_thread_pool = std::unique_ptr( + new ThreadPool(static_cast(locator.nr_threads))); + mgb_assert(m_thread_pool, "ThradPool create failed"); + } + if (locator.type == DeviceType::CPU) { + if (locator.device == Locator::DEVICE_CPU_DEFAULT) { + m_env.init_cpu({std::make_shared(this)}, + cn); + } else { + m_env.init_cpu({std::make_shared( + m_worker_queue, this)}, + cn); + } + } else if (locator.type == DeviceType::MULTITHREAD) { + if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { + m_env.init_cpu({std::make_shared( + this, m_thread_pool.get())}, + cn); + } else { + m_worker_queue->attach_thread_pool(m_thread_pool.get()); + m_env.init_cpu({std::make_shared( + m_worker_queue, this)}, + cn); + } + } } -} + + ~CompNodeRecorderImpl() { + if (sm_cur_recorder) { + sm_cur_recorder->stop(); + } + if (m_worker_queue) { + // synchronize before fini + m_worker_queue->wait_all_task_finish(); + } + m_env.fini(); + if (m_worker_queue) { + // wait for new kernels dispatched in fini() (like free_device()) + m_worker_queue->wait_all_task_finish(); + } + } + + ThreadPool* get_thread_pool() const { return m_thread_pool.get(); } + + //! return whether global finalized, and print warning in such case + bool check_global_finalized(const char* reason) { + MGB_MARK_USED_VAR(reason); + if (!sm_pool) { + static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT; + if (!warn_printed.test_and_set()) { + mgb_log_debug( + "cpu comp node method called after global finalize: " + "reason=%s", + reason); + } + return true; + } + return false; + } + + void* alloc_device(size_t size) override { + if (sm_cur_recorder) { + sm_cur_recorder->on_alloc(this); + } + return CompNodeBaseImpl::alloc_device(size); + } + + void free_device(void* ptr) { + if (sm_cur_recorder || check_global_finalized("free_device()")) { + CompNodeBaseImpl::mgb_aligned_free(ptr); + if (sm_cur_recorder) { + sm_cur_recorder->on_free(this); + } + return; + } else { + auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); }; + m_env.cpu_env().dispatch(do_free); + } + } + + void* alloc_host(size_t size) override { + if (m_worker_queue) { + m_worker_queue->check_exception(); + } + return CompNodeBaseImpl::alloc_host(size); + } + + void free_host(void* ptr) { + if (check_global_finalized("free_host()")) { + CompNodeBaseImpl::mgb_aligned_free(ptr); + return; + } + if (m_worker_queue) { + m_worker_queue->check_exception(); + } + CompNodeBaseImpl::mgb_aligned_free(ptr); + } + + void copy_to_host(void* host_ptr, const void* device_ptr, + size_t size) override { + if (m_worker_queue) { + m_worker_queue->check_exception(); + } + CompNodeBaseImpl::copy_to_host(host_ptr, device_ptr, size); + } + + void copy_to_device(void* device_ptr, const void* host_ptr, + size_t size) override { + if (m_worker_queue) { + m_worker_queue->check_exception(); + } + CompNodeBaseImpl::copy_to_device(device_ptr, host_ptr, size); + } + + void peer_copy_to(Impl* dest_impl, void* dest, const void* src, + size_t size) override { + //! copy to default_cpu + if (dest_impl->same_type()) { + CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size); + return; + } + + if (!dest_impl->same_type()) { + if (dest_impl->env().property().type == DeviceType::ATLAS) { +#if MGB_ATLAS + dest_impl->copy_to_device(dest, src, size); + return; +#else + mgb_throw(MegBrainError, + "Atlas comp_node used but " + "MGB_ATLAS not enabled"); +#endif + } else if (dest_impl->env().property().type == + DeviceType::CAMBRICON) { +#if MGB_CAMBRICON + dest_impl->copy_to_device(dest, src, size); + return; +#else + mgb_throw(MegBrainError, + "Cambricon comp_node used but " + "MGB_CAMBRICON not enabled"); +#endif + } + else { + mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT, + "currently only peer copy from default cpu comp " + "nodes " + "is implemented"); + } + } + dest_impl->copy_to_device(dest, src, size); + } + + std::unique_ptr create_event(size_t flags) override { + if (m_worker_queue) { + m_worker_queue->check_exception(); + } + if (sm_cur_recorder) { + return std::make_unique(this, flags); + } else { + return std::make_unique(this, flags); + } + } + + void sync() override { + if (sm_cur_recorder) { + sm_cur_recorder->on_sync(this); + } else if (m_worker_queue) { + m_worker_queue->wait_all_task_finish(); + } + if (m_thread_pool) { + m_thread_pool->deactive(); + } + } + + std::unique_ptr create_seq_recorder( + cg::ComputingGraph*) override { + return std::make_unique(&sm_cur_recorder, + m_thread_pool.get(), this); + } + + SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; } + + void add_callback(Task&& task) override { + if (!check_global_finalized("add_callback()")) { + CompNodeBaseImpl::add_callback(std::move(task)); + } else { + task(); + } + } +}; +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl); +#if !defined(IOS) && MGB_HAVE_THREAD +thread_local CpuCompNode::SeqRecorderImpl* + CompNodeRecorderImpl::sm_cur_recorder = nullptr; +#endif /* ======================== CpuCompNode ======================== */ struct CpuCompNode::Pool { static constexpr int MAX_NR_COMP_NODE = 1024; - struct CpuCompNodeImplDeleter { - void operator ()(CpuCompNodeImpl *p) { - p->~CpuCompNodeImpl(); - } + struct CompNodeRecorderImplDeleter { + void operator()(CompNodeRecorderImpl* p) { p->~CompNodeRecorderImpl(); } }; std::recursive_mutex mtx; // use global memory pool to ensuare object memory accessible even after // global finalize - std::aligned_storage_t - impl_storage[MAX_NR_COMP_NODE]; + std::aligned_storage_t + impl_storage[MAX_NR_COMP_NODE]; size_t nr_used_impl_storage = 0; - std::unordered_map, - CompNode::LocatorPairHashKey::Hash> locator2impl; + std::unordered_map< + CompNode::LocatorPairHashKey, + std::unique_ptr, + CompNode::LocatorPairHashKey::Hash> + locator2impl; ThinHashMap, std::weak_ptr> physical2queue; - std::unordered_map, - CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; + std::unordered_map< + CompNode::LocatorPairHashKey, + std::unique_ptr, + CompNode::LocatorPairHashKey::Hash> + locator2impl_multi_thread; ThinHashMap, std::weak_ptr> physical2queue_multithead; }; CpuCompNode::Pool* CpuCompNode::sm_pool; Spinlock CpuCompNode::sm_pool_mtx; -void CpuCompNode::foreach(thin_function callback) { +void CpuCompNode::foreach (thin_function callback) { if (!sm_pool) return; - for (size_t i = 0; ; ++ i) { + for (size_t i = 0;; ++i) { CompNode cur; { MGB_LOCK_GUARD(sm_pool->mtx); if (i >= sm_pool->nr_used_impl_storage) return; cur = make_comp_node_from_impl( - reinterpret_cast( - &sm_pool->impl_storage[i])); + reinterpret_cast( + &sm_pool->impl_storage[i])); } callback(cur); } @@ -781,7 +903,7 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, // use static storage so object can be safely accessed even after // global finalize static std::aligned_storage_t storage; - sm_pool = new(&storage) Pool; + sm_pool = new (&storage) Pool; } } mgb_assert(locator.device >= 0 || @@ -800,23 +922,22 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, locator_logical.type == CompNode::DeviceType::MULTITHREAD); } if (locator.type == DeviceType::CPU) { - auto &&pqueue_weak = - sm_pool->physical2queue[{locator.device, locator.stream}]; + auto&& pqueue_weak = + sm_pool->physical2queue[{locator.device, locator.stream}]; auto pqueue = pqueue_weak.lock(); if (!pqueue) { pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->locator2impl[{locator, - locator_logical}]; + auto&& pimpl = sm_pool->locator2impl[{locator, locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu comp nodes; max %d allowed", Pool::MAX_NR_COMP_NODE); pimpl.reset(new ( &sm_pool->impl_storage[sm_pool->nr_used_impl_storage++]) - CpuCompNodeImpl{locator, locator_logical, - pqueue}); + CompNodeRecorderImpl{locator, locator_logical, + pqueue}); } log_comp_node_created(locator, locator_logical); return pimpl.get(); @@ -829,16 +950,16 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->locator2impl_multi_thread[{ - locator, locator_logical}]; + auto&& pimpl = + sm_pool->locator2impl_multi_thread[{locator, locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu multithread comp nodes; max %d allowed", Pool::MAX_NR_COMP_NODE); pimpl.reset(new ( &sm_pool->impl_storage[sm_pool->nr_used_impl_storage++]) - CpuCompNodeImpl{locator, locator_logical, - pqueue}); + CompNodeRecorderImpl{locator, locator_logical, + pqueue}); } log_comp_node_created(locator, locator_logical); return pimpl.get(); @@ -850,25 +971,12 @@ void CpuCompNode::sync_all() { return; MGB_LOCK_GUARD(sm_pool->mtx); - for (auto &&i: sm_pool->locator2impl) + for (auto&& i : sm_pool->locator2impl) i.second->sync(); for (auto&& i : sm_pool->locator2impl_multi_thread) i.second->sync(); } -bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) { - MGB_MARK_USED_VAR(reason); - if (this != sm_default_cpu_comp_node_ptr && !sm_pool) { - static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT; - if (!warn_printed.test_and_set()) { - mgb_log_debug("cpu comp node method called after global finalize: " - "reason=%s", reason); - } - return true; - } - return false; -} - /* ======================== CompNode methods ======================== */ // CompNode get by default_cpu() is different from the CompNode which is // produced by CompNode::load("cpu:default") @@ -878,9 +986,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) { // CpuCompNode::Pool CompNode CompNode::default_cpu() { static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; - static auto empty_queue = - std::make_shared(locator); - static CpuCompNodeImpl impl{locator, locator, empty_queue}; + static CompNodeNoRecorderImpl impl{locator, locator}; return &impl; } @@ -890,22 +996,20 @@ bool CompNode::enable_affinity_for_cpu(bool flag) { return old; } - /* ======================== EventImpl ======================== */ - double CpuCompNode::CpuDispatchableBase::EventImpl::do_elapsed_time_until( - EventImplHelper &end) { - auto &&f1 = static_cast(end).m_prev_finish_time; + EventImplHelper& end) { + auto&& f1 = static_cast(end).m_prev_finish_time; return m_prev_finish_time.time_until_secs(f1); } #if MGB_HAVE_THREAD void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( - Impl *cn_impl) { + Impl* cn_impl) { { auto locator = m_comp_node_impl->locator(); if (locator.device == Locator::DEVICE_CPU_DEFAULT && - !static_cast(m_comp_node_impl) + !static_cast(m_comp_node_impl) ->cur_recorder()) { auto v0 = m_record_nr_req.load(std::memory_order_relaxed), v1 = m_record_nr_finish.load(std::memory_order_relaxed); @@ -934,14 +1038,14 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( mgb_throw(MegBrainError, "Atlas comp_node used but MGB_ATLAS not enabled"); #endif - } else if (cn_impl->env().property().type == CompNode::DeviceType::CAMBRICON) { + } else if (cn_impl->env().property().type == + CompNode::DeviceType::CAMBRICON) { #if MGB_CAMBRICON return m_comp_node_impl->sync(); #else mgb_throw(MegBrainError, "Cambricon comp_node used but MGB_CAMBRICON not enabled"); #endif - } auto version = m_record_nr_req.load(std::memory_order_relaxed); @@ -991,14 +1095,15 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { } void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { - for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20; i < it; ++i) { + for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20; + i < it; ++i) { if (finished()) { return; } } m_dev_wait_nr_waiter.fetch_add(1, std::memory_order_release); - for (; ; ) { + for (;;) { std::unique_lock lock{m_dev_wait_mtx}; if (finished()) { break; @@ -1011,23 +1116,23 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept { auto check_all_finished = [this]() { return do_finished() && - !m_dev_wait_nr_waiter.load(std::memory_order_acquire); + !m_dev_wait_nr_waiter.load(std::memory_order_acquire); }; if (!check_all_finished()) { - mgb_log_debug("event %p has unfinished callbacks when destructed; " - "waiting ...", this); + mgb_log_debug( + "event %p has unfinished callbacks when destructed; " + "waiting ...", + this); while (!check_all_finished()) { std::this_thread::yield(); } } } -#else // MGB_HAVE_THREAD +#else // MGB_HAVE_THREAD -void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { -} +void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {} -void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) { -} +void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) {} void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() { if (m_create_flags & Flags::NEED_TIMER) { @@ -1035,8 +1140,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() { } } -void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() { -} +void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() {} bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { return true; @@ -1046,5 +1150,4 @@ CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept = default; #endif // MGB_HAVE_THREAD - // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/impl/comp_node/cpu/comp_node.h b/src/core/impl/comp_node/cpu/comp_node.h index 57b60a0b4eec685e4ca7d198964cdbb07318fdea..0db224c9bebfe78a720eae173ea6c789b673fd5e 100644 --- a/src/core/impl/comp_node/cpu/comp_node.h +++ b/src/core/impl/comp_node/cpu/comp_node.h @@ -54,7 +54,9 @@ namespace mgb { void add_callback(Task&& task) override; }; - class CompNodeImpl; + class CompNodeBaseImpl; + class CompNodeNoRecorderImpl; + class CompNodeRecorderImpl; static void foreach(thin_function callback); static void finalize(); diff --git a/src/core/test/comp_node_helper.cpp b/src/core/test/comp_node_helper.cpp index 997dfdd3d78e60bf7a56b36a1f276e78530df90a..66403d2e911cf1749ef2ebdc6aaa5f9113b54732 100644 --- a/src/core/test/comp_node_helper.cpp +++ b/src/core/test/comp_node_helper.cpp @@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) { MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter; } ASSERT_EQ(executed.size(), 2u); + + //! test default_cpu with record2 + { + HostTensorND hz; + graph = ComputingGraph::make(); + x = opr::Host2DeviceCopy::make(*graph, host_x); + y = opr::Host2DeviceCopy::make(*graph, host_y); + z = opr::ConvBias::make(x, y, param); + z = opr::GetVarShape::make(z); + graph->options().comp_node_seq_record_level = 2; + graph->options().var_sanity_check_first_run = false; + auto func = graph->compile({make_callback_copy(z, hz, true)}); + ComputingGraph::assert_destroy(graph); + func->execute(); + ASSERT_TRUE(hz.comp_node() == cn); + ASSERT_EQ(hz.ptr()[0], 3); + ASSERT_EQ(hz.ptr()[1], 6); + ASSERT_EQ(hz.ptr()[2], 8); + ASSERT_EQ(hz.ptr()[3], 6); + } } void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {