提交 31035391 编写于 作者: M Megvii Engine Team

feat(mge/tools): set network_visualize's log_path as optional flag

GitOrigin-RevId: a74bdc08ba86d431a1a0cc9d1fc665d897ecd16f
上级 eeeddbbc
......@@ -40,30 +40,31 @@ def visualize(
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True
)
return
if log_path:
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.",
exc_info=True,
)
return
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)
graph = Network.load(model_path)
writer = SummaryWriter(log_path)
def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
......@@ -84,21 +85,27 @@ def visualize(
node_oup = node.outputs[0]
inp_list = [process_name(var.owner.name) for var in node.inputs]
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape]
)
]
)
),
}
if log_path:
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=d) for d in node_oup.shape
]
)
]
)
),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
# add op flops attr
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8"))
if log_path:
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_num).encode(encoding="utf-8")
)
flops_list.append(
dict(
name=node.name,
......@@ -114,9 +121,10 @@ def visualize(
# TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4
# add tensor size attr
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
params_list.append(
dict(
name=node.name,
......@@ -132,25 +140,33 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
node_list.append(
NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
if log_path:
node_list.append(
NodeDef(
name=process_name(node.name),
op=node.type,
input=inp_list,
attr=attr,
)
)
)
total_flops, total_params = 0, 0
total_flops, total_params = None, None
if log_params:
total_params = print_params_stats(params_list, bar_length_max)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer._get_file_writer().add_graph((graph_def, stepstats))
device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))
# summary
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
......@@ -164,7 +180,7 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("log_path", help="tensorboard log path.")
parser.add_argument("--log_path", help="tensorboard log path.")
parser.add_argument(
"--bar_length_max",
type=int,
......@@ -179,7 +195,20 @@ def main():
parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.",
)
visualize(**vars(parser.parse_args()))
parser.add_argument(
"--all",
action="store_true",
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
)
args = parser.parse_args()
if args.all:
args.log_params = True
args.log_flops = True
if not args.log_path:
args.log_path = "./log"
kwargs = vars(args)
kwargs.pop("all")
visualize(**kwargs)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册