api_cache.h 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/**
 * \file dnn/src/common/api_cache.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#pragma once

15
#include <atomic>
16
#include <cstring>
17
#include <memory>
18
#include <mutex>
19
#include <tuple>
20
#include <unordered_map>
21 22 23

#include "megdnn/thin/function.h"

24 25
#include "./utils.h"

26
namespace megdnn {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/
class RWSpin {
public:
    class Lock {
    private:
        RWSpin* m_spin;
        void (RWSpin::*m_lock)(void);
        void (RWSpin::*m_unlock)(void);

    public:
        Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock)
                : m_spin{spin}, m_lock{lock}, m_unlock{unlock} {}
        void lock() { (m_spin->*m_lock)(); }
        void unlock() { (m_spin->*m_unlock)(); }
    };

private:
    std::atomic<uint32_t> m_atomic{0};

    static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF;
    static constexpr uint32_t sm_writer_mask = 0x80000000;

    void _reader_lock() {
        uint32_t expected = m_atomic;
        do {
            expected &= sm_reader_mask;
        } while (!m_atomic.compare_exchange_strong(expected, expected + 1));
    }
    void _reader_unlock() { m_atomic--; }
    void _writer_lock() {
        uint32_t expected = m_atomic;
        do {
            expected &= sm_reader_mask;
        } while (!m_atomic.compare_exchange_strong(expected,
                                                   expected | sm_writer_mask));
        while (m_atomic.load() != sm_writer_mask)
            ;
    }
    void _writer_unlock() {
        // assert m_atomic == sm_writer_mask
        m_atomic = 0;
    }

public:
    Lock reader() {
        return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock};
    }
    Lock writer() {
        return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock};
    }
};

template <typename TSignature>
class FunctionCache;

template <typename TRet, typename... TArgs>
class FunctionCache<TRet(TArgs...)> {
85 86
public:
    using key_t = std::string;
87
    using value_t = TRet;
88 89 90
    using key_mapper_t = thin_function<key_t(TArgs...)>;
    using value_mapper_t = thin_function<value_t(TArgs...)>;
    using storage_t = std::unordered_map<key_t, value_t>;
91

92 93 94
    storage_t storage;
    key_mapper_t key_mapper;
    value_mapper_t value_mapper;
95

96 97 98 99
    RWSpin spin;

public:
    TRet operator()(TArgs... args) {
100
        key_t key = key_mapper(args...);
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        auto reader_lock = spin.reader();
        auto writer_lock = spin.writer();
        {
            MEGDNN_LOCK_GUARD(reader_lock);
            auto iter = storage.find(key);
            if (iter != storage.end()) {
                return iter->second;
            }
        }
        // RWSpin doesn't support upgrade
        {
            MEGDNN_LOCK_GUARD(writer_lock);
            if (storage.count(key) != 0) {
                return storage[key];
            }
            value_t ret = value_mapper(std::forward<TArgs>(args)...);
            storage[key] = ret;
            return ret;
119 120 121 122 123 124 125 126 127
        }
    }
};

// FIFO
class StringSerializer {
private:
    std::string m_buffer;
    size_t m_cursor = 0;
128

129 130 131
public:
    template <typename T>
    T read_plain() {
132 133
        static_assert(std::is_trivially_copyable<T>::value, "invalid type");
        T ret;
134
        std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
135
        m_cursor += sizeof(T);
136
        return ret;
137 138
    }
    template <typename T>
139 140 141 142 143 144 145
    void read_plain(T* dest) {
        static_assert(std::is_trivially_copyable<T>::value, "invalid type");
        std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T));
        m_cursor += sizeof(T);
    }
    template <typename T>
    void write_plain(const T& value) {
146 147 148
        static_assert(std::is_trivially_copyable<T>::value,
                      "type should be trivially copyable");
        m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
149
    }
150 151
    std::string take() { return std::move(m_buffer); }
    void reset(std::string new_buf) {
152
        m_cursor = 0;
153
        m_buffer = std::move(new_buf);
154 155 156 157 158
    }
};

struct Empty {};

159 160 161
// in: seq[1, 2, ..., m]
// out: seq[N+1, N+2, ... N+m]
template <std::size_t N, std::size_t... Seq>
162
inline std::index_sequence<N + Seq...> inc_index_sequence(
163 164 165 166
        std::index_sequence<Seq...>) {
    return {};
}

167 168 169
template <typename... TParams>
class ParamBundle {
private:
170
    // out: Min, Min+1, ..., Max
171
    template <std::size_t Min, std::size_t Max>
172 173
    using make_index_range = decltype(
            inc_index_sequence<Min>(std::make_index_sequence<Max - Min>()));
174

175
    // store params in a tuple
176 177 178
    using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
    storage_t m_storage;

179
    // deconstruct tuple and call functor
180
    template <typename TFunctor, size_t... Indices>
181
    auto call_helper(TFunctor&& functor, std::index_sequence<Indices...>) {
182 183
        return functor(std::get<Indices>(m_storage).value...);
    }
184

185
    template <size_t Index, size_t... Indices, typename TPrev>
186 187 188 189 190
    auto serialize_helper(StringSerializer& ser, TPrev&& prev,
                          std::index_sequence<Index, Indices...>) {
        return serialize_helper(ser,
                                std::get<Index>(m_storage).serialize(ser, prev),
                                std::index_sequence<Indices...>());
191
    }
192

193
    template <typename TPrev>
194 195
    auto serialize_helper(StringSerializer& ser, TPrev&& prev,
                          std::index_sequence<>) {}
196

197
    template <size_t Index, size_t... Indices, typename TPrev>
198 199 200 201 202
    auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
                            std::index_sequence<Index, Indices...>) {
        return deserialize_helper(
                ser, std::get<Index>(m_storage).deserialize(ser, prev),
                std::index_sequence<Indices...>());
203
    }
204

205
    template <typename TPrev>
206 207
    auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
                            std::index_sequence<>) {}
208

209
    template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
210 211
    void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg,
                           TArgs&&... args) {
212
        std::get<Index>(m_storage).value = std::forward<TArg>(arg);
213 214
        set_values_helper(std::index_sequence<Indices...>(),
                          std::forward<TArgs>(args)...);
215
    }
216

217 218 219 220 221 222 223 224
    template <size_t... Indices>
    void set_values_helper(std::index_sequence<Indices...>) {
        static_assert(sizeof...(Indices) == 0, "redundant indices");
    }

public:
    template <typename TFunctor>
    auto call_by(TFunctor&& functor) {
225 226
        return call_helper(std::forward<TFunctor>(functor),
                           std::make_index_sequence<sizeof...(TParams)>());
227
    }
228 229

    // recursively store params into ser
230 231 232
    template <size_t NBegin, size_t NEnd>
    void serialize_params(StringSerializer& ser) {
        static_assert(NEnd >= NBegin, "invalid range");
233
        serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
234
    }
235 236

    // recursively load params from ser
237 238 239
    template <size_t NBegin, size_t NEnd>
    void deserialize_params(StringSerializer& ser) {
        static_assert(NEnd >= NBegin, "invalid range");
240
        deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
241
    }
242 243

    // recursively set params into m_storage
244 245
    template <size_t NBegin, size_t NEnd, typename... TArgs>
    void set_values(TArgs&&... args) {
246 247
        set_values_helper(make_index_range<NBegin, NEnd>(),
                          std::forward<TArgs>(args)...);
248 249 250 251
    }
};

template <typename T>
252
class Param {
253 254
public:
    T value;
255

256 257 258 259
    Empty serialize(StringSerializer& ser, Empty) {
        ser.write_plain(value);
        return Empty{};
    }
260

261
    Empty deserialize(StringSerializer& ser, Empty) {
262
        ser.read_plain(&value);
263 264 265 266
        return Empty{};
    }
};

267 268
template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
          typename TOutputs = std::tuple<>>
269 270
class FunctionCacheBuilder {
private:
271
    // decl value with type of tuple-of-args
272 273 274 275 276
    static auto declargs()
            -> decltype(std::tuple_cat(std::declval<TInputs>(),
                                       std::declval<TOutputs>())) {
        return {};
    }
277

278
    template <size_t... Indices>
279 280 281 282 283
    static auto declfunction_helper(std::index_sequence<Indices...>)
            -> thin_function<decltype(std::declval<TRet>().value)(
                    decltype(std::get<Indices>(declargs()).value)...)> {
        return {};
    }
284 285

    // decl value with type of original function
286
    static auto declfunction() {
287 288 289
        return declfunction_helper(
                std::make_index_sequence<std::tuple_size<TInputs>::value +
                                         std::tuple_size<TOutputs>::value>());
290
    }
291

292
    template <size_t... Indices>
293
    static auto declbundle_helper(std::index_sequence<Indices...>)
294 295
            -> ParamBundle<std::remove_reference_t<
                    decltype(std::get<Indices>(declargs()))>...> {
296 297
        return {};
    }
298 299

    // decl value with type of bundle-of-args
300
    static auto declbundle() {
301 302 303
        return declbundle_helper(
                std::make_index_sequence<std::tuple_size<TInputs>::value +
                                         std::tuple_size<TOutputs>::value>());
304
    }
305 306

    // type of original function
307
    using function_t = decltype(declfunction());
308
    // type of bundle-of-args
309
    using bundle_t = decltype(declbundle());
310

311
public:
312
    // declare new return type, cannot be override
313 314
    template <typename TNewRet>
    auto ret() {
315 316
        static_assert(std::is_same<TRet, Param<Empty>>::value,
                      "return value redefinition");
317 318
        return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
    }
319
    // declare new input
320 321
    template <typename TNewInput>
    auto input() {
322 323 324 325 326
        static_assert(std::tuple_size<TOutputs>::value == 0,
                      "input arg cannot be declared after output");
        using TNewInputs =
                decltype(std::tuple_cat(std::declval<TInputs>(),
                                        std::declval<std::tuple<TNewInput>>()));
327 328
        return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
    }
329
    // declare new output
330 331
    template <typename TNewOutput>
    auto output() {
332 333
        using TNewOutputs = decltype(
                std::tuple_cat(std::declval<TOutputs>(),
334
                               std::declval<std::tuple<TNewOutput>>()));
335 336
        return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
    }
337
    // summary
338
    template <typename TFunctor>
339 340 341
    function_t build(TFunctor&& func) {
        constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
        constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
342 343 344
        auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>();
        // bundle -> ser(in args)
        cache->key_mapper = [](bundle_t bundle) {
345
            StringSerializer ser;
346
            bundle.template serialize_params<0, n_inputs>(ser);
347 348
            return ser.take();
        };
349
        // bundle -> ser(out args)
350
        cache->value_mapper = [func](bundle_t bundle) {
351 352 353 354
            StringSerializer ser;
            TRet ret;
            ret.value = bundle.call_by(func);
            ret.serialize(ser, Empty{});
355 356
            bundle.template serialize_params<n_inputs, n_inputs + n_outputs>(
                    ser);
357 358 359 360 361 362
            return ser.take();
        };
        return [=](auto&&... args) mutable {
            bundle_t bundle;
            TRet ret;
            StringSerializer ser;
363 364 365 366 367 368
            static_assert(
                    sizeof...(args) == std::tuple_size<TInputs>::value +
                                               std::tuple_size<TOutputs>::value,
                    "args count mismatch");
            bundle.template set_values<0, sizeof...(args)>(
                    std::forward<decltype(args)>(args)...);
369
            ser.reset((*cache)(bundle));
370
            ret.deserialize(ser, Empty{});
371 372
            bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>(
                    ser);
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
            return ret.value;
        };
    }
};

template <typename T>
class RefParam {
public:
    T* value;
    Empty serialize(StringSerializer& ser, Empty) {
        ser.write_plain(*value);
        return Empty{};
    }
    Empty deserialize(StringSerializer& ser, Empty) {
        *value = ser.read_plain<T>();
        return Empty{};
    }
};

392
// like RefParam but return *value while ser and deser. Working with ArrayParam
393 394 395 396 397 398 399 400 401
template <typename T>
class RefArraySizeParam {
public:
    T* value;
    T serialize(StringSerializer& ser, Empty) {
        ser.write_plain(*value);
        return *value;
    }
    T deserialize(StringSerializer& ser, Empty) {
402 403
        ser.read_plain(value);
        return *value;
404 405 406
    }
};

407
// accept array length from previous param. Working with RefArraySizeParam
408 409 410
template <typename TSize, typename TItem>
class ArrayParam {
public:
411
    decltype(std::declval<TItem>().value)* value;
412
    Empty serialize(StringSerializer& ser, TSize size) {
413
        TItem param;
414
        for (TSize i = 0; i < size; ++i) {
415 416
            param.value = value[i];
            param.serialize(ser, Empty{});
417 418 419 420
        }
        return Empty{};
    }
    Empty deserialize(StringSerializer& ser, TSize size) {
421
        TItem param;
422
        for (TSize i = 0; i < size; ++i) {
423 424
            param.deserialize(ser, Empty{});
            value[i] = param.value;
425 426 427 428 429
        }
        return Empty{};
    }
};

430
}  // namespace megdnn