提交 31035391 编写于 作者: M Megvii Engine Team

feat(mge/tools): set network_visualize's log_path as optional flag

GitOrigin-RevId: a74bdc08ba86d431a1a0cc9d1fc665d897ecd16f
上级 eeeddbbc
...@@ -40,30 +40,31 @@ def visualize( ...@@ -40,30 +40,31 @@ def visualize(
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
""" """
try: if log_path:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue try:
from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import ( from tensorboard.compat.proto.node_def_pb2 import NodeDef
AllocatorMemoryUsed, from tensorboard.compat.proto.step_stats_pb2 import (
DeviceStepStats, AllocatorMemoryUsed,
NodeExecStats, DeviceStepStats,
StepStats, NodeExecStats,
) StepStats,
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto )
from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboardX import SummaryWriter from tensorboard.compat.proto.versions_pb2 import VersionDef
except ImportError: from tensorboardX import SummaryWriter
logger.error( except ImportError:
"TensorBoard and TensorboardX are required for visualize.", exc_info=True logger.error(
) "TensorBoard and TensorboardX are required for visualize.",
return exc_info=True,
)
return
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR) old_level = set_mgb_log_level(logging.ERROR)
graph = Network.load(model_path) graph = Network.load(model_path)
writer = SummaryWriter(log_path)
def process_name(name): def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8") return name.replace(".", "/").encode(encoding="utf-8")
...@@ -84,21 +85,27 @@ def visualize( ...@@ -84,21 +85,27 @@ def visualize(
node_oup = node.outputs[0] node_oup = node.outputs[0]
inp_list = [process_name(var.owner.name) for var in node.inputs] inp_list = [process_name(var.owner.name) for var in node.inputs]
attr = { if log_path:
"_output_shapes": AttrValue( attr = {
list=AttrValue.ListValue( "_output_shapes": AttrValue(
shape=[ list=AttrValue.ListValue(
TensorShapeProto( shape=[
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] TensorShapeProto(
) dim=[
] TensorShapeProto.Dim(size=d) for d in node_oup.shape
) ]
), )
} ]
)
),
}
if hasattr(node, "calc_flops"): if hasattr(node, "calc_flops"):
flops_num = node.calc_flops() flops_num = node.calc_flops()
# add op flops attr # add op flops attr
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) if log_path:
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_num).encode(encoding="utf-8")
)
flops_list.append( flops_list.append(
dict( dict(
name=node.name, name=node.name,
...@@ -114,9 +121,10 @@ def visualize( ...@@ -114,9 +121,10 @@ def visualize(
# TODO: consider other quantize dtypes # TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4 param_bytes = 1 if is_quantize(node_oup.dtype) else 4
# add tensor size attr # add tensor size attr
attr["size"] = AttrValue( if log_path:
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") attr["size"] = AttrValue(
) s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
params_list.append( params_list.append(
dict( dict(
name=node.name, name=node.name,
...@@ -132,25 +140,33 @@ def visualize( ...@@ -132,25 +140,33 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue continue
node_list.append( if log_path:
NodeDef( node_list.append(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, NodeDef(
name=process_name(node.name),
op=node.type,
input=inp_list,
attr=attr,
)
) )
)
total_flops, total_params = 0, 0 total_flops, total_params = None, None
if log_params: if log_params:
total_params = print_params_stats(params_list, bar_length_max) total_params = print_params_stats(params_list, bar_length_max)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max) total_flops = print_flops_stats(flops_list, bar_length_max)
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
device = "/device:CPU:0" device = "/device:CPU:0"
stepstats = RunMetadata( stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
) )
writer._get_file_writer().add_graph((graph_def, stepstats)) writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))
# summary
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level) _imperative_rt_logger.set_log_level(old_level)
...@@ -164,7 +180,7 @@ def main(): ...@@ -164,7 +180,7 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("model_path", help="dumped model path.") parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("log_path", help="tensorboard log path.") parser.add_argument("--log_path", help="tensorboard log path.")
parser.add_argument( parser.add_argument(
"--bar_length_max", "--bar_length_max",
type=int, type=int,
...@@ -179,7 +195,20 @@ def main(): ...@@ -179,7 +195,20 @@ def main():
parser.add_argument( parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.", "--log_flops", action="store_true", help="whether print and record op flops.",
) )
visualize(**vars(parser.parse_args())) parser.add_argument(
"--all",
action="store_true",
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
)
args = parser.parse_args()
if args.all:
args.log_params = True
args.log_flops = True
if not args.log_path:
args.log_path = "./log"
kwargs = vars(args)
kwargs.pop("all")
visualize(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册