提交 c9d06030 编写于 作者: M Megvii Engine Team

feat(dnn/common): add named tensor shape

GitOrigin-RevId: 918928b8ba673cc1c0fa97aa3bb3d4459203fd8f
上级 ff0e6be7
/**
* \file dnn/include/megdnn/named_tensor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/internal/defs.h"
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include <array>
#include <string>
#include "megdnn/thin/small_vector.h"
#include "megdnn/internal/visibility_prologue.h"
namespace megdnn {
class Dimension {
public:
enum class Name : char {
N = 'N', // Batch size
C = 'C', // input channel
H = 'H', // input height
W = 'W', // input width
K = 'K', // output channel
R = 'R', // filter height
S = 'S', // filter width
P = 'P', // output height
Q = 'Q', // output width
};
static constexpr uint32_t UNDETERMINED_EXTENT =
std::numeric_limits<uint32_t>::max();
static const Name NAME_ALL[];
static const int NR_NAMES;
Dimension() = default;
Dimension(const std::string& expr);
Dimension(Name name, uint32_t stride, uint32_t extent = UNDETERMINED_EXTENT)
: m_name{name}, m_stride{stride}, m_extent{extent} {}
Dimension(const Dimension& rhs) { operator=(rhs); }
Dimension& operator=(const Dimension& rhs);
bool operator==(const Dimension& rhs) const;
bool operator<(const Dimension& rhs) const;
Dimension operator*(const Dimension& rhs) const;
Dimension operator/(const Dimension& rhs) const;
std::string to_string() const;
Name name() const { return m_name; }
uint32_t extent() const { return m_extent; }
uint32_t stride() const { return m_stride; }
private:
Name m_name;
uint32_t m_stride;
uint32_t m_extent;
};
struct NamedTensorShape {
using Format = param::ConvBias::Format;
static constexpr size_t MAX_NDIM = MEGDNN_MAX_NDIM;
std::array<Dimension, MAX_NDIM> dims;
size_t ndim = 0;
NamedTensorShape() = default;
NamedTensorShape(const NamedTensorShape& rhs) = default;
NamedTensorShape(const SmallVector<Dimension>& init_shape);
NamedTensorShape(std::initializer_list<Dimension> init_shape);
std::string to_string() const;
bool eq_shape(const NamedTensorShape& rhs) const;
Dimension& operator[](size_t i) { return dims[i]; }
Dimension operator[](size_t i) const { return dims[i]; }
NamedTensorShape static make_named_tensor_shape(Format format);
};
} // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/common/named_tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/named_tensor.h"
#include "src/common/utils.h"
using namespace megdnn;
/* ===================== Dimension ============================ */
const Dimension::Name Dimension::NAME_ALL[] = {
Dimension::Name::N, Dimension::Name::C, Dimension::Name::H,
Dimension::Name::W, Dimension::Name::K, Dimension::Name::R,
Dimension::Name::S, Dimension::Name::P, Dimension::Name::Q,
};
const int Dimension::NR_NAMES = sizeof(Dimension::NAME_ALL);
Dimension::Dimension(const std::string& expr) {
auto errmsg = [&]() {
return ssprintf("Invalid dimension(%s)", expr.c_str());
};
const char* data = expr.data();
bool has_stride = false;
bool has_extent = false;
bool init_name = false;
while (*data) {
if (data[0] >= 'A' && data[0] <= 'Z') {
megdnn_throw_if(init_name, megdnn_error, errmsg().c_str());
for (auto e : NAME_ALL) {
if (data[0] == static_cast<char>(e)) {
init_name = true;
m_name = e;
break;
}
}
megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
++data;
} else if (data[0] == '/' && data[1] == '/') {
megdnn_throw_if(!init_name || has_stride || has_extent,
megdnn_error, errmsg().c_str());
has_stride = true;
data += 2;
} else if (data[0] == '%') {
megdnn_throw_if(!init_name || has_extent, megdnn_error,
errmsg().c_str());
has_extent = true;
++data;
} else if (data[0] >= '0' && data[0] <= '9') {
megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
uint32_t num = 0;
while (data[0] >= '0' && data[0] <= '9') {
num = num * 10 + (data[0] - '0');
++data;
}
if (has_extent)
m_extent = num;
else if (has_stride)
m_stride = num;
} else {
megdnn_throw(errmsg().c_str());
}
}
megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
if (!has_extent) {
m_extent = UNDETERMINED_EXTENT;
}
if (!has_stride) {
m_stride = 1;
}
}
Dimension& Dimension::operator=(const Dimension& rhs) {
m_name = rhs.m_name;
m_stride = rhs.m_stride;
m_extent = rhs.m_extent;
return *this;
}
bool Dimension::operator==(const Dimension& rhs) const {
return m_name == rhs.m_name && m_stride == rhs.m_stride &&
m_extent == rhs.m_extent;
}
bool Dimension::operator<(const Dimension& rhs) const {
if (m_name != rhs.m_name) {
return static_cast<char>(m_name) < static_cast<char>(rhs.m_name);
}
return m_stride > rhs.m_stride;
}
Dimension Dimension::operator*(const Dimension& rhs) const {
megdnn_assert(m_name == rhs.m_name,
"Multiply operation cannot be applied on dimensions with "
"different name(lhs:%c, rhs:%c)",
static_cast<char>(m_name), static_cast<char>(rhs.m_name));
megdnn_assert(
m_stride == rhs.m_stride * rhs.m_extent,
"Multiply operation cannot be applied on operands(lhs:%s, rhs:%s)",
to_string().c_str(), rhs.to_string().c_str());
if (m_extent == UNDETERMINED_EXTENT)
return Dimension(m_name, rhs.m_stride);
return Dimension(m_name, rhs.m_stride, m_extent * rhs.m_extent);
}
Dimension Dimension::operator/(const Dimension& rhs) const {
megdnn_assert(m_name == rhs.m_name,
"Divide operation cannot be applied on dimensions with "
"different name(lhs:%c, rhs:%c)",
static_cast<char>(m_name), static_cast<char>(rhs.m_name));
if (operator==(rhs))
return Dimension(m_name, 1, 1);
megdnn_assert(
!(*this < rhs),
"Divisor must be smaller than dividend(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
if (m_stride == rhs.m_stride) {
if (m_extent == UNDETERMINED_EXTENT) {
megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT,
"Divide operation cannot be applied on "
"operands(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
return Dimension(m_name, rhs.m_extent * m_stride);
} else {
megdnn_assert(m_extent % rhs.m_extent == 0,
"Divide operation cannot be applied on "
"operands(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
return Dimension(m_name, rhs.m_extent * m_stride,
m_extent / rhs.m_extent);
}
} else {
if (m_extent == UNDETERMINED_EXTENT) {
megdnn_assert(rhs.m_extent == UNDETERMINED_EXTENT &&
rhs.m_stride % m_stride == 0,
"Divide operation cannot be applied on "
"operands(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
return Dimension(m_name, m_stride, rhs.m_stride / m_stride);
} else {
megdnn_assert(m_extent * m_stride == rhs.m_extent * rhs.m_stride &&
rhs.m_stride % m_stride == 0,
"Divide operation cannot be applied on "
"operands(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
return Dimension(m_name, m_stride, m_extent / rhs.m_extent);
}
}
}
std::string Dimension::to_string() const {
if (m_extent == UNDETERMINED_EXTENT) {
if (m_stride == 1)
return ssprintf("%c", static_cast<char>(m_name));
else
return ssprintf("%c//%u", static_cast<char>(m_name), m_stride);
} else {
if (m_stride == 1)
return ssprintf("%c%%%u", static_cast<char>(m_name), m_extent);
else
return ssprintf("%c//%u%%%u", static_cast<char>(m_name), m_stride,
m_extent);
}
}
/* ===================== NamedTensorShape ===================== */
NamedTensorShape::NamedTensorShape(const SmallVector<Dimension>& init_shape) {
megdnn_assert(init_shape.size() <= MAX_NDIM,
"Illegal to construct a NamedTensorShape with "
"more than MAX_NDIM(%zu) axes; got(%zu)",
MAX_NDIM, init_shape.size());
ndim = init_shape.size();
memcpy(this->dims.data(), init_shape.data(), sizeof(Dimension) * ndim);
}
NamedTensorShape::NamedTensorShape(std::initializer_list<Dimension> init_shape)
: NamedTensorShape(SmallVector<Dimension>{init_shape}) {}
bool NamedTensorShape::eq_shape(const NamedTensorShape& rhs) const {
MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code");
if (ndim == rhs.ndim) {
size_t eq = 0;
switch (ndim) {
case 7:
eq += dims[6] == rhs.dims[6];
MEGDNN_FALLTHRU
case 6:
eq += dims[5] == rhs.dims[5];
MEGDNN_FALLTHRU
case 5:
eq += dims[4] == rhs.dims[4];
MEGDNN_FALLTHRU
case 4:
eq += dims[3] == rhs.dims[3];
MEGDNN_FALLTHRU
case 3:
eq += dims[2] == rhs.dims[2];
MEGDNN_FALLTHRU
case 2:
eq += dims[1] == rhs.dims[1];
MEGDNN_FALLTHRU
case 1:
eq += dims[0] == rhs.dims[0];
}
return eq == ndim;
}
return false;
}
std::string NamedTensorShape::to_string() const {
std::string rst("{");
for (size_t i = 0; i < ndim; i++) {
if (i)
rst.append(",");
rst.append(dims[i].to_string());
}
rst.append("}");
return rst;
}
NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) {
switch (format) {
case Format::NCHW:
return {{"N"}, {"C"}, {"H"}, {"W"}};
case Format::NHWC:
return {{"N"}, {"H"}, {"W"}, {"C"}};
case Format::NCHW4:
return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}};
case Format::NCHW8:
return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}};
case Format::NCHW32:
return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}};
case Format::NCHW64:
return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}};
case Format::NCHW44:
return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}, {"N%4"}};
case Format::NCHW88:
return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}};
case Format::NCHW44_DOT:
return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}};
default:
megdnn_throw(
ssprintf("Format unimplement(%d)", static_cast<int>(format))
.c_str());
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/common/named_tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/named_tensor.h"
// clang-format off
#include "test/common/utils.h"
#include "test/common/fix_gtest_on_platforms_without_exception.inl"
// clang-format on
using namespace megdnn;
using megdnn::test::MegDNNError;
TEST(NAMED_TENSOR, NAMED_TENSOR_SHAPE_BASIC) {
ASSERT_EQ(Dimension::NR_NAMES, 9);
Dimension dim0 = {"C"}, dim1 = {"C//32"}, dim2 = {"C//4"}, dim3 = {"C%32"},
dim4 = {"C%4"}, dim5 = {"C//4%8"};
ASSERT_TRUE(dim0 == dim1 * dim3);
ASSERT_TRUE(dim2 == dim1 * dim5);
ASSERT_THROW(dim0 * dim0, MegDNNError);
ASSERT_TRUE(dim1 == dim0 / dim3);
ASSERT_TRUE(dim3 == dim0 / dim1);
ASSERT_TRUE(dim4 == dim3 / dim5);
ASSERT_TRUE(dim5 == dim3 / dim4);
ASSERT_TRUE(dim5 == dim2 / dim1);
ASSERT_THROW(dim5 / dim1, MegDNNError);
ASSERT_TRUE(dim1 < dim4);
ASSERT_TRUE(dim5 < dim4);
ASSERT_FALSE(dim4 < dim5);
ASSERT_TRUE(dim1 < dim2);
ASSERT_FALSE(dim2 < dim1);
auto shape0 = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW);
SmallVector<Dimension> dims = {{"N"}, {"C"}, {"H"}, {"W"}};
NamedTensorShape shape1(dims);
NamedTensorShape shape2{{"N"}, {"C"}, {"H"}, {"W"}};
ASSERT_TRUE(shape0.eq_shape(shape1));
ASSERT_TRUE(shape0.eq_shape(shape2));
ASSERT_TRUE(shape1.eq_shape(shape2));
auto shape3 = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW4);
ASSERT_FALSE(shape0.eq_shape(shape3));
auto shape4 = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW44_DOT);
std::sort(shape4.dims.begin(), shape4.dims.begin() + shape4.ndim);
NamedTensorShape shape5{{"N"}, {"C//32"}, {"H"},
{"W"}, {"C//8%4"}, {"C%8"}};
std::sort(shape5.dims.begin(), shape5.dims.begin() + shape5.ndim);
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册