/** * \file dnn/src/common/api_cache.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include #include #include #include #include #include #include "megdnn/thin/function.h" #include "./utils.h" namespace megdnn { // https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ class RWSpin { public: class Lock { private: RWSpin* m_spin; void (RWSpin::*m_lock)(void); void (RWSpin::*m_unlock)(void); public: Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock) : m_spin{spin}, m_lock{lock}, m_unlock{unlock} {} void lock() { (m_spin->*m_lock)(); } void unlock() { (m_spin->*m_unlock)(); } }; private: std::atomic m_atomic{0}; static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF; static constexpr uint32_t sm_writer_mask = 0x80000000; void _reader_lock() { uint32_t expected = m_atomic; do { expected &= sm_reader_mask; } while (!m_atomic.compare_exchange_strong(expected, expected + 1)); } void _reader_unlock() { m_atomic--; } void _writer_lock() { uint32_t expected = m_atomic; do { expected &= sm_reader_mask; } while (!m_atomic.compare_exchange_strong(expected, expected | sm_writer_mask)); while (m_atomic.load() != sm_writer_mask) ; } void _writer_unlock() { // assert m_atomic == sm_writer_mask m_atomic = 0; } public: Lock reader() { return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock}; } Lock writer() { return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock}; } }; template class FunctionCache; template class FunctionCache { public: using key_t = std::string; using value_t = TRet; using key_mapper_t = thin_function; using value_mapper_t = thin_function; using storage_t = std::unordered_map; storage_t storage; key_mapper_t key_mapper; value_mapper_t value_mapper; RWSpin spin; public: TRet operator()(TArgs... args) { key_t key = key_mapper(args...); auto reader_lock = spin.reader(); auto writer_lock = spin.writer(); { MEGDNN_LOCK_GUARD(reader_lock); auto iter = storage.find(key); if (iter != storage.end()) { return iter->second; } } // RWSpin doesn't support upgrade { MEGDNN_LOCK_GUARD(writer_lock); if (storage.count(key) != 0) { return storage[key]; } value_t ret = value_mapper(std::forward(args)...); storage[key] = ret; return ret; } } }; // FIFO class StringSerializer { private: std::string m_buffer; size_t m_cursor = 0; public: template T read_plain() { static_assert(std::is_trivially_copyable::value, "invalid type"); T ret; std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; } template void read_plain(T* dest) { static_assert(std::is_trivially_copyable::value, "invalid type"); std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); } template void write_plain(const T& value) { static_assert(std::is_trivially_copyable::value, "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); } std::string take() { return std::move(m_buffer); } void reset(std::string new_buf) { m_cursor = 0; m_buffer = std::move(new_buf); } }; struct Empty {}; // in: seq[1, 2, ..., m] // out: seq[N+1, N+2, ... N+m] template inline std::index_sequence inc_index_sequence( std::index_sequence) { return {}; } template class ParamBundle { private: // out: Min, Min+1, ..., Max template using make_index_range = decltype( inc_index_sequence(std::make_index_sequence())); // store params in a tuple using storage_t = std::tuple...>; storage_t m_storage; // deconstruct tuple and call functor template auto call_helper(TFunctor&& functor, std::index_sequence) { return functor(std::get(m_storage).value...); } template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { return serialize_helper(ser, std::get(m_storage).serialize(ser, prev), std::index_sequence()); } template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { return deserialize_helper( ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); } template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { std::get(m_storage).value = std::forward(arg); set_values_helper(std::index_sequence(), std::forward(args)...); } template void set_values_helper(std::index_sequence) { static_assert(sizeof...(Indices) == 0, "redundant indices"); } public: template auto call_by(TFunctor&& functor) { return call_helper(std::forward(functor), std::make_index_sequence()); } // recursively store params into ser template void serialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); serialize_helper(ser, Empty{}, make_index_range()); } // recursively load params from ser template void deserialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); deserialize_helper(ser, Empty{}, make_index_range()); } // recursively set params into m_storage template void set_values(TArgs&&... args) { set_values_helper(make_index_range(), std::forward(args)...); } }; template class Param { public: T value; Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(value); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { ser.read_plain(&value); return Empty{}; } }; template , typename TInputs = std::tuple<>, typename TOutputs = std::tuple<>> class FunctionCacheBuilder { private: // decl value with type of tuple-of-args static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } template static auto declfunction_helper(std::index_sequence) -> thin_function().value)( decltype(std::get(declargs()).value)...)> { return {}; } // decl value with type of original function static auto declfunction() { return declfunction_helper( std::make_index_sequence::value + std::tuple_size::value>()); } template static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))>...> { return {}; } // decl value with type of bundle-of-args static auto declbundle() { return declbundle_helper( std::make_index_sequence::value + std::tuple_size::value>()); } // type of original function using function_t = decltype(declfunction()); // type of bundle-of-args using bundle_t = decltype(declbundle()); public: // declare new return type, cannot be override template auto ret() { static_assert(std::is_same>::value, "return value redefinition"); return FunctionCacheBuilder{}; } // declare new input template auto input() { static_assert(std::tuple_size::value == 0, "input arg cannot be declared after output"); using TNewInputs = decltype(std::tuple_cat(std::declval(), std::declval>())); return FunctionCacheBuilder{}; } // declare new output template auto output() { using TNewOutputs = decltype( std::tuple_cat(std::declval(), std::declval>())); return FunctionCacheBuilder{}; } // summary template function_t build(TFunctor&& func) { constexpr size_t n_inputs = std::tuple_size::value; constexpr size_t n_outputs = std::tuple_size::value; auto cache = std::make_shared>(); // bundle -> ser(in args) cache->key_mapper = [](bundle_t bundle) { StringSerializer ser; bundle.template serialize_params<0, n_inputs>(ser); return ser.take(); }; // bundle -> ser(out args) cache->value_mapper = [func](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); ret.serialize(ser, Empty{}); bundle.template serialize_params( ser); return ser.take(); }; return [=](auto&&... args) mutable { bundle_t bundle; TRet ret; StringSerializer ser; static_assert( sizeof...(args) == std::tuple_size::value + std::tuple_size::value, "args count mismatch"); bundle.template set_values<0, sizeof...(args)>( std::forward(args)...); ser.reset((*cache)(bundle)); ret.deserialize(ser, Empty{}); bundle.template deserialize_params( ser); return ret.value; }; } }; template class RefParam { public: T* value; Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(*value); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { *value = ser.read_plain(); return Empty{}; } }; // like RefParam but return *value while ser and deser. Working with ArrayParam template class RefArraySizeParam { public: T* value; T serialize(StringSerializer& ser, Empty) { ser.write_plain(*value); return *value; } T deserialize(StringSerializer& ser, Empty) { ser.read_plain(value); return *value; } }; // accept array length from previous param. Working with RefArraySizeParam template class ArrayParam { public: decltype(std::declval().value)* value; Empty serialize(StringSerializer& ser, TSize size) { TItem param; for (TSize i = 0; i < size; ++i) { param.value = value[i]; param.serialize(ser, Empty{}); } return Empty{}; } Empty deserialize(StringSerializer& ser, TSize size) { TItem param; for (TSize i = 0; i < size; ++i) { param.deserialize(ser, Empty{}); value[i] = param.value; } return Empty{}; } }; } // namespace megdnn