提交 8c47c1f1 编写于 作者: M Megvii Engine Team

perf(syncbn): reimplement with subgraph

GitOrigin-RevId: 13e7e3d3c0d0e9cd8939ad5ddf62bc91a5dabde0
上级 53da5c79
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device from .._wrap import as_device
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
...@@ -219,3 +220,49 @@ def _normalize_axis( ...@@ -219,3 +220,49 @@ def _normalize_axis(
) )
return axis return axis
raise raise
def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile
binary_ops = {
"+": builtin.Elemwise(mode="add"),
"-": builtin.Elemwise(mode="sub"),
"*": builtin.Elemwise(mode="mul"),
"/": builtin.Elemwise(mode="true_div"),
"//": builtin.Elemwise(mode="floor_div"),
"**": builtin.Elemwise(mode="pow"),
"√": builtin.Elemwise(mode="expm1"),
"max": builtin.Elemwise(mode="max"),
"additive": builtin.Elemwise(mode="add"),
}
unary_ops = {
"-": builtin.Elemwise(mode="negate"),
}
def decorator(func):
builder = _SubgraphBuilder(name)
def apply_expr(op, *args):
if isinstance(op, str):
if len(args) == 2:
op = binary_ops[op]
elif len(args) == 1:
op = unary_ops[op]
return builder.apply(op, args, 1)[0]
def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device)
inputs = [builder.input() for _ in range(nr_inputs)]
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None:
return builder.get()
else:
return builder.compile(gopt_level)
return decorator
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
...@@ -20,10 +22,13 @@ from ..core.tensor.utils import ( ...@@ -20,10 +22,13 @@ from ..core.tensor.utils import (
astype, astype,
cast_tensors, cast_tensors,
convert_single_value, convert_single_value,
make_shape_tuple,
setscalar, setscalar,
subgraph,
) )
from ..device import get_default_device from ..device import get_default_device
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
...@@ -1153,6 +1158,111 @@ def batch_norm( ...@@ -1153,6 +1158,111 @@ def batch_norm(
return inp return inp
@lru_cache(maxsize=None)
def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
# fmt: off
@subgraph("SyncBnStage0", dtype, device, 1)
def syncbn_stage0(inputs, f, c):
input = inputs[0]
reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
input_shape = f(GetVarShape(), input)
input_elems = f(Reduce(mode="product", axis=0), input_shape)
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
reduce_size = f("//", input_elems, reduce_elems)
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3)
def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7]
channel_mean = f("/", channel_x1s, reduce_size)
channel_var =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("-", f("*", reduce_size, reduce_size))),
f("/", channel_x2s, reduce_size))
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3)
def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6]
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar,), (True,)
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3)
def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
running_mean = f("*", running_mean, momentum)
running_mean =\
f("+", running_mean,
f("*", f("-", c(1), momentum),
channel_mean))
channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size),
f("-", reduce_size, c(1)))),
f("/", channel_x2s,
f("-", reduce_size, c(1))))
running_var = f("*", running_var, momentum)
running_var =\
f("+", running_var,
f("*", f("-", c(1), momentum),
channel_variance_unbiased))
return (running_mean, running_var), (True, True)
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3)
def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,)
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3)
def syncbn_split_stats(inputs, f, c):
stats = inputs[0]
c_1 = c(1, dtype="int32")
channel_x1s_end = c(channels+1, dtype="int32")
def _subtensor(src, axis, begin, end):
items = (axis, (begin is not None), (end is not None), False, False),
args = ()
if begin is not None:
args += begin,
if end is not None:
args += end,
return f(builtin.Subtensor(items=items), src, *args)
reduce_size = _subtensor(stats, 1, None, c_1)
channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
reduce_size = f(builtin.Reshape(), reduce_size, c_1)
return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
# fmt: on
return (
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
)
def sync_batch_norm( def sync_batch_norm(
inp: Tensor, inp: Tensor,
running_mean: Tensor, running_mean: Tensor,
...@@ -1193,52 +1303,55 @@ def sync_batch_norm( ...@@ -1193,52 +1303,55 @@ def sync_batch_norm(
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode eps_mode
) )
_channels = inp.shape[1] # TODO: cudnnBn fastpath
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim _ndim = inp.ndim
_device = inp.device _device = inp.device
_dtype = inp.dtype _dtype = inp.dtype
_param_shape = (1, _channels) + (1,) * (_ndim - 2)
_reduce_axis = [0] + [i for i in range(2, _ndim)]
if training: def _make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
(result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result
elif x.ndim == 1:
(result,) = apply(builtin.Reshape(), x, reduce_shape)
return result
return x
(
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp)
def _sum_on_channel(inp): eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
return inp.sum(axis=_reduce_axis, keepdims=True)
reduce_size = inp.shape[0] weight = _make_full_if_none(weight, 1)
for i in range(2, _ndim): bias = _make_full_if_none(bias, 0)
reduce_size = reduce_size * inp.shape[i]
channel_x1s = _sum_on_channel(inp)
channel_x2s = _sum_on_channel(inp ** 2)
if training:
if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = broadcast_to( (stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s)
Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
)
stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
stat = all_reduce_sum(stat, group) stat = all_reduce_sum(stat, group)
reduce_size = stat[:, :1].reshape(1) reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat)
channel_x1s = stat[:, 1 : 1 + _channels]
channel_x2s = stat[:, 1 + _channels :]
channel_mean = channel_x1s / reduce_size outvar, channel_mean, *_ = apply(
channel_variance = ( syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias
channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
) )
else: else:
assert running_var is not None and running_mean is not None assert running_var is not None and running_mean is not None
channel_variance = running_var.reshape(*_param_shape) channel_mean = running_mean
channel_mean = running_mean.reshape(*_param_shape) channel_var = running_var
outvar, *_ = apply(
invsqrt_channel_variance = ( syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias
maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps )
) ** -0.5
if weight is not None:
weight = weight.reshape(*_param_shape)
if bias is not None:
bias = bias.reshape(*_param_shape)
# outvar = output * weight + bias # outvar = output * weight + bias
# where output = inp * invsqrt_channel_variance + ( # where output = inp * invsqrt_channel_variance + (
...@@ -1246,28 +1359,18 @@ def sync_batch_norm( ...@@ -1246,28 +1359,18 @@ def sync_batch_norm(
# ) # )
# Manually expand output for gopt # Manually expand output for gopt
if weight is not None:
inv_var_wt = invsqrt_channel_variance * weight
neg_channel_mean = -channel_mean
if bias is not None:
outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
else:
outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
else:
outvar = inp * invsqrt_channel_variance + (
-channel_mean * invsqrt_channel_variance
)
if bias is not None:
outvar = outvar + bias
if training and running_var is not None and running_mean is not None: if training and running_var is not None and running_mean is not None:
running_mean *= momentum momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
running_mean += (1 - momentum) * channel_mean running_mean[...], running_var[...] = apply(
channel_variance_unbiased = channel_x1s ** 2 / ( syncbn_stage2,
-reduce_size * (reduce_size - 1) running_mean,
) + channel_x2s / (reduce_size - 1) running_var,
running_var *= momentum momentum,
running_var += (1 - momentum) * channel_variance_unbiased reduce_size,
channel_x1s,
channel_x2s,
channel_mean,
)
return outvar return outvar
......
...@@ -66,7 +66,7 @@ def is_tracing(): ...@@ -66,7 +66,7 @@ def is_tracing():
@contextlib.contextmanager @contextlib.contextmanager
def exclude_from_trace(): def exclude_from_trace():
global skip_tracing global skip_tracing
if skip_tracing: if skip_tracing or (active_trace is None):
yield yield
return return
try: try:
......
...@@ -58,6 +58,9 @@ void init_common(py::module m) { ...@@ -58,6 +58,9 @@ void init_common(py::module m) {
.def_property_readonly("logical_name", [](const CompNode& cn) { .def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical(); return cn.to_string_logical();
}) })
.def_property_readonly("physical_name", [](const CompNode& cn) {
return cn.to_string();
})
.def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) { .def_property_readonly("get_mem_status_bytes", [](const CompNode& cn) {
return cn.get_mem_status_bytes(); return cn.get_mem_status_bytes();
}) })
...@@ -70,6 +73,7 @@ void init_common(py::module m) { ...@@ -70,6 +73,7 @@ void init_common(py::module m) {
cn.to_string_physical().c_str(), cn.to_string_physical().c_str(),
cn.to_string_logical().c_str()); cn.to_string_logical().c_str());
}) })
.def("__hash__", [](CompNode cn){ return mgb::hash(cn); })
.def_static("_sync_all", &CompNode::sync_all) .def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self) .def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count, .def_static("_get_device_count", &CompNode::get_device_count,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
...@@ -477,4 +478,50 @@ void init_ops(py::module m) { ...@@ -477,4 +478,50 @@ void init_ops(py::module m) {
m.def("set_global_rng_seed", &rng::set_global_rng_seed); m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed); m.def("get_global_rng_seed", &rng::get_global_rng_seed);
m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name}{}
std::string name;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1;
};
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>())
.def("input", [](PySubgraphBuilder& self){
auto var = self.next_var++;
self.graph.inputs.push_back(var);
return var;
})
.def("apply", [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, Subgraph::vars_t inputs, size_t nr_outputs){
Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++);
}
self.graph.exprs.push_back({op, inputs, outputs});
return outputs;
})
.def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){
auto var = self.next_var++;
mgb::HostTensorND hvalue(cn);
npy::np2tensor(value.cast<py::array>().ptr(), npy::Meth::copy_into(&hvalue), dtype);
self.graph.constants.push_back({var, Tensor::make(hvalue)});
return var;
})
.def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){
self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true);
})
.def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad){
mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad;
})
.def("get", [](PySubgraphBuilder& self){
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
})
.def("compile", [](PySubgraphBuilder& self, int gopt_level){
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask);
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level);
});
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册