common.cuh 1.2 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/resize/common.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

namespace megdnn {
namespace cuda {
namespace resize {

__device__ inline void get_origin_coord(float scale, int size, int idx,
18 19
                                        float& alpha, int& origin_idx,
                                        bool cubic = false) {
20 21 22
    alpha = (idx + 0.5f) / scale - 0.5f;
    origin_idx = static_cast<int>(floor(alpha));
    alpha -= origin_idx;
23 24 25 26 27 28 29 30
    if (!cubic) {
        if (origin_idx < 0) {
            origin_idx = 0;
            alpha = 0;
        } else if (origin_idx + 1 >= size) {
            origin_idx = size - 2;
            alpha = 1;
        }
31 32 33
    }
}

34 35 36 37
__device__ inline int get_nearest_src(float scale, int size, int idx) {
    return min(static_cast<int>(idx / scale), size - 1);
}

38 39 40 41 42
}  // namespace resize
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen