提交 d063d577 编写于 作者: M Megvii Engine Team

perf(functional): use fma to reduce elemwise but disable subgraph compilation

GitOrigin-RevId: c75a6e1a09b8e727a48e3b5eaabc6926aa046a46
上级 2a063f8e
......@@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
"-": lambda: builtin.Elemwise(mode="negate"),
}
ternary_ops = {
"fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"),
}
quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")}
def decorator(func):
builder = _SubgraphBuilder(name)
def apply_expr(op, *args):
def apply_expr(op, *args, nr_out=None):
if isinstance(op, str):
if len(args) == 2:
op = binary_ops[op]()
elif len(args) == 1:
op = unary_ops[op]()
return builder.apply(op, args, 1)[0]
elif len(args) == 3:
op = ternary_ops[op]()
elif len(args) == 4:
op = quaternary_ops[op]()
results = builder.apply(op, args, 1 if nr_out is None else nr_out)
if nr_out is None:
assert len(results) == 1
return results[0]
else:
assert len(results) == nr_out
return results
def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device)
......
......@@ -784,7 +784,7 @@ class _Hashable:
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3)
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
......@@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp(
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3)
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
......
......@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3)
@subgraph("SyncBnStage1", dtype, device, 7)
def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7]
......@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("fma3", input, inv_var_wt,
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3)
@subgraph("SyncBnStage1Inference", dtype, device, 6)
def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6]
......@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
bias))
return (outvar,), (True,)
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3)
@subgraph("SyncBnStage2", dtype, device, 7)
def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
running_mean = f("*", running_mean, momentum)
running_mean =\
f("+", running_mean,
f("*", f("-", c(1), momentum),
channel_mean))
c1_minus_momentum = f("-", c(1), momentum)
reduce_size_minus_c1 = f("-", reduce_size, c(1))
running_mean = f("fma4",
running_mean, momentum,
c1_minus_momentum, channel_mean,
)
channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size),
f("-", reduce_size, c(1)))),
reduce_size_minus_c1)),
f("/", channel_x2s,
f("-", reduce_size, c(1))))
running_var = f("*", running_var, momentum)
running_var =\
f("+", running_var,
f("*", f("-", c(1), momentum),
channel_variance_unbiased))
reduce_size_minus_c1))
running_var = f("fma4",
running_var, momentum,
c1_minus_momentum, channel_variance_unbiased
)
return (running_mean, running_var), (True, True)
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3)
@subgraph("SyncBnConcatStats", dtype, device, 3)
def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,)
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3)
@subgraph("SyncBnSplitStats", dtype, device, 1)
def syncbn_split_stats(inputs, f, c):
stats = inputs[0]
c_1 = c(1, dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册