提交 fb49a283 编写于 作者: M Megvii Engine Team

refactor(mgb/dnn): refactor enum used in serializing

GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355
上级 d69b5903
......@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[cname]
def scramble_enum_member_name(name):
s = name.find('<<')
if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
if name in ("MIN", "MAX"):
return name + "_"
o_name = name.split(' ')[0].split('=')[0]
if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_")
return name
class FlatBuffersWriter(IndentWriterBase):
......@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(str(e.members[e.default]))
default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v):
......@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault):
......
......@@ -121,10 +121,12 @@ class member_defs:
def normalize_enum_value(self, value):
def normalize(v):
if isinstance(v, str):
if v not in self.members:
for idx, m in enumerate(self.members):
m = str(m).split(' ')[0].split('=')[0]
if v == m :
return idx
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int)
return v
if self.combined:
......@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
self._write_doc(e.name)
for idx, emem in enumerate(e.members):
for emem in e.members:
if e.combined:
self._write('%s = 1 << %d', emem, idx)
self._write('%s', emem)
self._write_doc(emem)
else:
self._write('%s = "%s"', emem, emem)
v = str(emem).split(' ')[0].split('=')[0]
n = int(str(emem).split('=')[1])
self._write('%s = "%s"', v, v)
self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))
qualname, v, n))
for emem, emem_alias in e.member_alias:
em_a = emem_alias.split(' ')[0].split('=')[0]
if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
self._write('%s = %s', em_a, e.compose_combined_enum(emem))
else:
self._write('%s = %s', emem_alias, emem)
em = str(emem).split(' ')[0].split('=')[0]
self._write('%s = %s', em_a, em)
self._unindent()
self._write('')
......@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = "'{}'".format(e.members[e.default])
default = "'{}'".format(str(e.members[e.default]).split(' ')[0].split('=')[0])
self._cur_fields.append(self.FieldDef(
name=e.name_field,
......@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = "'{}'".format(s.members[e.get_default()])
default = "'{}'".format(str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._cur_fields.append(self.FieldDef(
name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field),
......@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def _on_member_enum(self, e):
self._write_doc(e.name)
self._write('enum class %s: uint32_t {', e.name, indent=1)
for idx, i in enumerate(e.members):
for i in e.members:
self._write_doc(i)
v = '{} = {}'.format(i, idx)
if e.combined:
v = '{} = 1 << {}'.format(i, idx)
v = str(i)
if i is not e.members[-1] or e.member_alias:
v += ','
self._write(v)
......@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else:
self._write('%s = %s,', alias, mem)
self._write('%s = %s,', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1)
self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
......@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, e.members[e.default])
value = str(e.members[e.default])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_enum_alias(self, e):
......@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, s.members[e.get_default()])
value = str(s.members[e.get_default()])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_field(self, f):
......@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def _on_member_enum(self, e):
self._write_doc(e.name)
self._write('struct %s {', e.name, indent=1)
for idx, val in enumerate(e.members):
for val in e.members:
self._write_doc(val)
self._write('static const uint32_t %s = %d;', val, idx)
v = str(val)
self._write('static const uint32_t %s;', v)
for mem, alias in e.member_alias:
self._write('static const uint32_t %s = %s;', alias, mem)
self._write('static const uint32_t %s = %s;', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1)
def _on_member_enum_alias(self, e):
......@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members = e.src_enum.members
else:
members = e.members
for idx, i in enumerate(members):
for i in members:
v = str(i)
v = v.split(' ')[0].split('=')[0]
self._write('case %s::%s::%s: return "%s";',
self._param_name, e.name, i, i, indent=0)
self._param_name, e.name, v, v, indent=0)
self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));',
self._param_name, e.name, indent=0)
self._write('}', indent=-1)
......
......@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
def format(v):
return '\"{}\"'.format(str(v))
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
enum_def += ','.join(format(i) for i in e.members)
if e.combined:
......@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])
default_val = "{}::{}::{}".format(
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])
wrapped = self._wrapped_with_default_value(td_class, default_val)
......@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])
default_val = "{}::{}::{}".format(fullname, e.name, str(
s.members[e.get_default()]).split(' ')[0].split('=')[0])
wrapped = self._wrapped_with_default_value(td_class, default_val)
......
此差异已折叠。
......@@ -241,14 +241,17 @@ private:
if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
body += formatv(" switch ({0}){{\n", "$_self." + it.name);
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(
" props_.emplace_back(\"{0}\", \"{1}\");\n",
it.name, enumMember
);
size_t d1 = enumMember.find(' ');
size_t d2 = enumMember.find('=');
size_t d = d1 <= d2 ? d1 : d2;
body += formatv(" case {0}::{1}::{2}:\n",
getCppClassName(),
enumAttr->getEnumName(),
enumMember.substr(0, d));
body +=
formatv(" props_.emplace_back(\"{0}\", "
"\"{1}\");\n",
it.name, enumMember.substr(0, d));
body += " break;\n";
}
body += " default: break;\n";
......
......@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
size_t d1 = v.find(' ');
size_t d2 = v.find('=');
size_t d = d1 <= d2 ? d1 : d2;
case_body.push_back(
formatv("case {0}::{1}: return \"{1}\";", ename,
v.substr(0, d)));
});
os << formatv(R"(
template <>
......
......@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(), i
);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
os << formatv("\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(),
i.substr(0, d));
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i
));
className, attr->getEnumName(), i.substr(0, d)));
}
if (attr->getEnumCombinedFlag()) {
//! define operator |
......
......@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&ctx);
auto quote = [&](auto&& i) -> std::string {
return formatv("\"{0}\"", i);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return formatv("\"{0}\"", i.substr(0, d));
};
os << tgfmt(R"(
template<> const char*
......@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
auto mem2value = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
i.substr(0, d));
};
os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
......@@ -192,12 +199,15 @@ os << tgfmt(R"(
auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) {
size_t d1 = members[idx].find(' ');
size_t d2 = members[idx].find('=');
size_t d = d1 <= d2 ? d1 : d2;
os << tgfmt(R"({
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx);
})", &ctx, members[idx].substr(0, d), idx);
}
}
......
......@@ -136,12 +136,13 @@ class HeaderGen:
mode_list = [i.strip() for i in fin]
for i in mode_list:
i = i.split(' ')[0].split('=')[0]
if i in self._elemwise_modes:
content = '_cb({})'.format(i)
else:
content = ''
self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
......
......@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('ExecutionPolicy', version=0, is_legacy=True).
add_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, '
Doc('HEURISTIC = 0', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE = 1', 'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'),
Doc('PROFILE',
Doc('PROFILE = 2',
'run possible algorithms on real device to find the best'),
Doc('PROFILE_REPRODUCIBLE',
Doc('PROFILE_REPRODUCIBLE = 3',
'the fastest of profile result that is also reproducible'),
Doc('PROFILE_HEURISTIC',
Doc('PROFILE_HEURISTIC = 4',
'use profile result and heuristic to choose the fastest algorithm')).
add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'),
......@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
add_bit_combination_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE',
Doc('HEURISTIC = 1 << 0', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE = 1 << 1',
'run possible algorithms on real device to find the best'),
Doc('REPRODUCIBLE',
Doc('REPRODUCIBLE = 1 << 2',
'when profile or heuristic algo selection it require the algos'
'must be reproducible'),
Doc('OPTIMIZED',
Doc('OPTIMIZED = 1 << 3',
'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
......@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CollectiveComm', 'collective communication between multiple computing '
'nodes on localhost')
.add_enum(Doc('Mode', 'mode of collective communication'),
Doc('REDUCE_SUM', 'reduce by sum to output computing node'),
Doc('BROADCAST', 'copy input value to each output computing node'),
Doc('ALL_GATHER', 'each output comp node gets the concatenated '
Doc('REDUCE_SUM = 0', 'reduce by sum to output computing node'),
Doc('BROADCAST = 1', 'copy input value to each output computing node'),
Doc('ALL_GATHER = 2', 'each output comp node gets the concatenated '
'value of all inputs'),
Doc('REDUCE_SCATTER_SUM',
Doc('REDUCE_SCATTER_SUM = 3',
'reduce inputs by sum and each output gets one part of it'),
Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'),
Doc('ALL_REDUCE_SUM = 4', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX = 5', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN = 6', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD = 7', 'every output gets the prod of all inputs'),
Doc('GATHER = 8', 'concat inputs to one node'),
Doc('SCATTER = 9', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL = 10', 'scatter inputs and gather them on each computing node'),
name_field='mode'))
(pdef('FakeSerializedDType',
......@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)')
.add_enum(Doc('Mode', 'how to compare predicate var with branch keys'),
Doc('CASE',
Doc('CASE = 0',
'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'),
Doc('CASE_FALLBACK', 'like :attr:`CASE`, but add an extra output '
Doc('CASE_FALLBACK = 1', 'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'),
Doc('PIECEWISE', 'One more outputs would be produced than the '
Doc('PIECEWISE = 2', 'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as '
r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
......@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CondExecPredLogical',
'compute a logical function over a set of PPVs')
.add_enum('Mode', Doc('OR', 'logical or'),
Doc('AND', 'logical and'),
Doc('XOR', 'exclusive-or'),
Doc('NOR', 'not or(inputs)'),
Doc('NAND', 'not and(inputs)'),
Doc('XNOR', 'not xor(inputs)'))
.add_enum('Mode', Doc('OR = 0', 'logical or'),
Doc('AND = 1', 'logical and'),
Doc('XOR = 2', 'exclusive-or'),
Doc('NOR = 3', 'not or(inputs)'),
Doc('NAND = 4', 'not and(inputs)'),
Doc('XNOR = 5', 'not xor(inputs)'))
)
(pdef('CondExecMark',
'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr')
.add_enum(Doc('GradMode', 'mode for computing the gradient'),
Doc('SUM', 'normal gradient mode: sum all the activated components'),
Doc('SUM_COND_OUT', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
Doc('SUM = 0', 'normal gradient mode: sum all the activated components'),
Doc('SUM_COND_OUT = 1', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'),
name_field='grad_mode')
......@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically
generated gradient oprs."""),
Doc('SHAPE_VALUE', 'enable both shape and value inference'),
Doc('SHAPE_ONLY',
Doc('SHAPE_VALUE = 0', 'enable both shape and value inference'),
Doc('SHAPE_ONLY = 1',
'only enable shape inference (disable value inference)'),
Doc('NONE', 'disable both shape and value inference'),
Doc('NONE = 2', 'disable both shape and value inference'),
name_field='static_infer')
)
......@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'),
1)
.add_enum('Mode',
Doc('EXACT_ONE', 'copy the var whose mask is activated to the output'
Doc('EXACT_ONE = 0', 'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'),
Doc('EXACT_ONE_SAME_SHAPE', 'like :attr:`EXACT_ONE` with the '
Doc('EXACT_ONE_SAME_SHAPE = 1', 'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape '
'inference can be easier'),
Doc('SUM', 'sum all the active branches into output var; require '
Doc('SUM = 2', 'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably '
'unknown).'),
Doc('SUM_COND_OUT', 'like :attr:`SUM` but also add an ExecutionMask'
Doc('SUM_COND_OUT = 3', 'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if '
' no branch is taken')
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册