提交 d275a823 编写于 作者: M Megvii Engine Team

chore(scripts): clarify and fix default value of bit combined enum

GitOrigin-RevId: 3716bf9bb566a23c6916df611dae563934e824cf
上级 7c715bd4
......@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
name = p + e
e = self._enums[(p, e)]
self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1)
attribute = "(bit_flags)" if e.combined else ""
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
for idx, member in enumerate(e.members):
self._write_doc(member)
if e.combined:
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1)
def _write_doc(self, doc):
......@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
return
self._write_doc(e.name)
self._used_enum.add(key)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name,
scramble_enum_member_name(str(e.members[e.default])))
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(str(e.members[e.default]))
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v):
while v in self._cur_const_val:
......@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
return
self._used_enum.add((e.src_class, e.src_name))
enum_name = e.src_class + e.src_name
self._write(
"%s:%s = %s;", e.name_field, enum_name,
scramble_enum_member_name(str(e.src_enum.members[e.get_default()])))
s = e.src_enum
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault):
if not isinstance(cppdefault, str):
......
......@@ -73,11 +73,21 @@ class member_defs:
"""define an enum; the result would contain both an enum class def and its
corresponding data field
:param default: index of default member value
:param default:
for normal enum class: index of default member value
for bit combined class: tuple of index of default member value
For example, following representations of the default value for bit
combined class are all equivalent:
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)
:attr name_field: name of the data field of this enum in the param
struct
:attr member_alias: list of (member, alias) pairs
:attr member_alias:
for normal enum class: list of (member, alias) pairs
for bit combined class: list of (tuple of members, alias) paris
"""
__slots__ = ['name', 'name_field', 'members', 'default',
'member_alias', 'combined']
......@@ -90,17 +100,11 @@ class member_defs:
name = member_defs.Doc.make(name)
assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members))
if isinstance(default, str):
if default not in name_field:
raise ValueError(
"Default value '{}' does not exist.".format(default))
default = name_field.index(default)
assert isinstance(default, int)
self.name = name
self.combined = combined
self.name_field = self.get_name_field(name.id, name_field)
self.members = members
self.default = default
self.default = self.normalize_enum_value(default)
self.all_enums[(param_name, name.id)] = self
......@@ -114,6 +118,43 @@ class member_defs:
assert isinstance(name_field, str)
return name_field
def normalize_enum_value(self, value):
def normalize(v):
if isinstance(v, str):
if v not in self.members:
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int)
return v
if self.combined:
if isinstance(value, int):
value = self.decompose_combined_enum(value)
assert isinstance(value, tuple)
value = tuple(normalize(i) for i in value)
return value
else:
return normalize(value)
@staticmethod
def decompose_combined_enum(v):
"""Integer => tuple of the indexes of the enum members"""
assert isinstance(v, int)
idx = 0
members = []
while v > 0:
if v & 1:
members.append(idx)
idx += 1
v >>= 1
return tuple(members)
def compose_combined_enum(self, v):
"""tuple of members => Integer"""
assert self.combined and isinstance(v, tuple)
norm_v = self.normalize_enum_value(v)
return sum(1 << i for i in norm_v)
class Field(Base):
"""define a normal data field"""
__slots__ = ['name', 'dtype', 'default']
......@@ -146,6 +187,10 @@ class member_defs:
src_name = name
self.src_name = src_name
self.default = default
# TODO: remove this assertion if needed; adding mock param_defs in
# current testing framework is too complicated, and currently we
# only allow aliasing of normal enum
assert not self.src_enum.combined
@property
def src_enum(self):
......@@ -157,7 +202,7 @@ class member_defs:
set"""
if self.default is None:
return self.src_enum.default
return self.default
return self.src_enum.normalize_enum_value(self.default)
class ParamDef:
......@@ -198,7 +243,7 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias))
return self
def add_bit_combination_enum(self, name, *members, default=0,
def add_bit_combination_enum(self, name, *members, default=tuple(),
name_field=None, member_alias=[]):
self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True))
......@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
' for idx, v in enumerate(pdata):\n'
' if isinstance(v, _EnumBase):\n'
' pdata[idx] = _enum_member2num[id(v)]\n'
' elif isinstance(v, _BitCombinedEnumBase):\n'
' pdata[idx] = v._value_\n'
' return tag + self._packer.pack(*pdata)\n'
'\n'
)
self._write(
'class _EnumBase(enum.Enum):\n'
# it's hard to mix custom implemention into enum, just do copy-paste instead
classbody = (
' @classmethod\n'
' def __normalize(cls, val):\n'
' if isinstance(val, str):\n'
......@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n'
'\n'
)
self._write(
'class _EnumBase(enum.Enum):\n' + classbody
)
self._write(
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody
)
if not self._imperative:
self._write(
'def _as_dtype_num(dtype):\n'
......@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
def _on_member_enum(self, e):
qualname = '{}.{}'.format(self._cur_param_name, e.name)
self._write('class %s(_EnumBase):', e.name, indent=1)
if e.combined:
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1)
else:
self._write('class %s(_EnumBase):', e.name, indent=1)
self._write_doc(e.name)
for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
if e.combined:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, 1<<idx))
self._write('%s = 1 << %d', emem, idx)
self._write_doc(emem)
else:
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))
for emem, emem_alis in e.member_alias:
self._write('%s = %s', emem_alis, emem)
for emem, emem_alias in e.member_alias:
if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
else:
self._write('%s = %s', emem_alias, emem)
self._unindent()
self._write('')
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = "'{}'".format(e.members[e.default])
self._cur_fields.append(self.FieldDef(
name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I',
default="'{}'".format(e.members[e.default]),
default=default,
type=qualname,
doc=None))
......@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
self._write('%s = %s.%s', e.name, e.src_class, e.src_name)
s = e.src_enum
qualname = '{}.{}'.format(e.src_class, e.src_name)
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = "'{}'".format(s.members[e.get_default()])
self._cur_fields.append(self.FieldDef(
name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I',
default="'{}'".format(s.members[e.get_default()]),
default=default,
type=qualname,
doc=None))
......@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
v += ','
self._write(v)
for mem, alias in e.member_alias:
self._write('%s = %s,', alias, mem)
if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else:
self._write('%s = %s,', alias, mem)
self._write('};', indent=-1)
self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(e.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name, e.members[e.default]),
e.name_field)
if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, e.members[e.default])
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_enum_alias(self, e):
s = e.src_enum
......@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(s.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name,
s.members[e.get_default()]),
e.name_field)
if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, s.members[e.get_default()])
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_field(self, f):
self._non_static_members.append(f)
......
......@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
return
# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
if e.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])
wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
......@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
self._write("def {} : {};".format(td_class, enum_def))
# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
s = e.src_enum
if s.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])
wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
......
......@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
}
};
template<typename T, typename SFINAE=void>
struct EnumTrait;
template <typename T>
struct EnumTrait {
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
static constexpr bool is_bit_combined = false;
static constexpr std::underlying_type_t<T> max = 0;
};
template <typename T>
......@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
return ret;
}
}
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1);
return obj;
}
static int py_init(PyObject* self, PyObject* args, PyObject*) {
int input = 1;
if (PyArg_ParseTuple(args, "|i", &input)){
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value =
static_cast<T>(input);
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
if (!PyTuple_Size(args)) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
return obj;
}
else {
PyObject* input;
if (!PyArg_ParseTuple(args, "|O", &input)) {
return nullptr;
}
T value;
try {
value = pyobj_convert_generic<T>::from(input);
} CATCH_ALL(nullptr);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
}
return 0;
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
......@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
} else if(PyLong_Check(obj)) {
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
"out of range, cannot convert %zu to %s",
static_cast<uint32_t>(value), Wrapper::name);
return static_cast<T>(value);
}
// try as string
// TODO: type checkcd
......
......@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() {
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n",
&ctx);
os << tgfmt(
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr "
"bool is_bit_combined = true;};\n",
&ctx);
os << tgfmt(R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr bool is_bit_combined = true;
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)", &ctx, attr->getEnumMembers().size());
}
auto str2type = [&](auto&& i) -> std::string {
......@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) {
// others should always use singleton
os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
......
......@@ -6,7 +6,7 @@ decl_opr('Convolution',
'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
desc='batched convolution on channeled 2D images')
decl_opr('Convolution',
......@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
body=[
'a, b = all_inputs',
'all_inputs = [b, a]'
......@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
Doc('bias', 'bias'),
],
params=[('param', 'ConvBiasV1'),
('execution_policy', 'ExecutionPolicy')],
('execution_policy', 'ExecutionPolicyV0')],
desc=('activation(convolution(src, filter) + bias) with specified '
'dtype'),
has_out_dtype=True)
......
......@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields(
'when profile or heuristic algo selection it require the algos'
'must be reproducible'),
Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')).
'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'),
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'),
]).
add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册