提交 2df84754 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix node display bug in tensorboard

GitOrigin-RevId: c997d6cccbfbdeaf2d24d6115650b1fee4bc0763
上级 13481fd2
......@@ -7,8 +7,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import logging
import re
import numpy as np
......@@ -71,7 +71,10 @@ def visualize(
graph = Network.load(model_path)
def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
# nodes that start with point or contain float const will lead to display bug
if not re.match(r"^[+-]?\d*\.\d*", name):
name = name.replace(".", "/")
return name.encode(encoding="utf-8")
summary = [["item", "value"]]
node_list = []
......@@ -128,10 +131,6 @@ def visualize(
param_stats["name"] = node.name
params_list.append(param_stats)
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
if log_path:
node_list.append(
NodeDef(
......
......@@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray):
param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8
return {
"dtype": param.dtype,
"shape": shape,
"mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()),
......@@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20):
header = [
"name",
"dtype",
"shape",
"mean",
"std",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册