未验证 提交 189ac02b 编写于 作者: 1 123malin 提交者: GitHub

test=develop, add distributed tools (#22623) (#22637)

上级 77428e8f
......@@ -15,6 +15,7 @@
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.contrib.utils import HDFSClient
import os
import time
def check_all_trainers_ready(ready_path, epoch):
......
......@@ -23,15 +23,19 @@ import sys
import time
import paddle.fluid as fluid
from paddle.fluid.log_helper import get_logger
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet as fleet_pslib
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler
from . import hdfs
from .hdfs import *
from . import utils
__all__ = ["FleetUtil"]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
fleet = fleet_pslib
class FleetUtil(object):
"""
......@@ -46,6 +50,16 @@ class FleetUtil(object):
"""
def __init__(self, mode="pslib"):
global fleet
if mode == "pslib":
fleet = fleet_pslib
elif mode == "transpiler":
fleet = fleet_transpiler
else:
raise ValueError(
"Please choose one mode from [\"pslib\", \"transpiler\"]")
def rank0_print(self, s):
"""
Worker of rank 0 print some log.
......@@ -1535,3 +1549,69 @@ class FleetUtil(object):
(print_prefix, auc, bucket_error, mae, rmse,
actual_ctr, predicted_ctr, copc, mean_predict_qvalue,
total_ins_num))
def program_type_trans(self, prog_dir, prog_fn, is_text):
return utils.program_type_trans(prog_dir, prog_fn, is_text)
def draw_from_program_file(self, model_filename, is_text, output_dir,
output_filename):
"""draw program from file"""
program = utils.load_program(model_filename, is_text)
utils.graphviz(program.global_block(), output_dir, output_filename)
def draw_from_program(self, program, output_dir, output_name):
"""draw Program"""
utils.graphviz(program.global_block(), output_dir, output_name)
def check_two_programs(self, config):
train_prog = utils.load_program(config.train_prog_path,
config.is_text_train_program)
pruned_prog = utils.load_program(config.pruned_prog_path,
config.is_text_pruned_program)
if config.draw:
pruned_dir = os.path.dirname(config.pruned_prog_path)
self.draw_from_program(pruned_prog, pruned_dir,
config.draw_out_name)
res = utils.check_pruned_program_vars(train_prog, pruned_prog)
if res:
_logger.info("check_programs succeed.")
else:
_logger.info(
"check_programs failed. pruned program and train program not match!"
)
return res
def check_vars_and_dump(self, config):
_logger.info("start check_vars_and_dump.")
results = utils.check_saved_vars_try_dump(
config.dump_model_dir, config.dump_program_filename,
config.is_text_dump_program, config.feed_config,
config.fetch_config, config.batch_size, config.save_params_filename)
_logger.info("check_vars_and_dump succeed.")
return results
def parse_program_proto(self, prog_path, is_text, output_dir):
"""
Parse program.proto into a more readable format.
This function will generate three files:
output_dir/vars_all.log,
output_dir/vars_persistable.log,
output_dir/ops.log.
Args:
prog_path(str): proto file path to be parsed.
is_text(bool): proto file is human-readale format or not(binary).
output_dir(str): output dir.
Examples:
.. code-block:: python
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
fleet_util = FleetUtil()
program_path = "./program.pbtxt"
is_text = True
output_dir = "/tmp/"
fleet_util.parse_program_proto(program_path, is_text, output_dir)
"""
program = utils.load_program(prog_path, is_text)
utils.parse_program(program, output_dir)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, absolute_import
import os
import sys
import logging
import subprocess
import numpy as np
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from google.protobuf import text_format
from paddle.fluid import debugger
from paddle.fluid.framework import Program
from paddle.fluid.proto import framework_pb2
__all__ = [
"load_program", "save_program", "program_type_trans",
"check_saved_vars_try_dump", "parse_program", "check_pruned_program_vars",
"graphviz"
]
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
persistable_vars_out_fn = "vars_persistable.log"
all_vars_out_fn = "vars_all.log"
ops_out_fn = "ops.log"
feed_fetch_type_list = [
core.VarDesc.VarType.FEED_MINIBATCH, core.VarDesc.VarType.FETCH_LIST
]
not_expected_op_types = ["lookup_table"]
def load_program(model_filename, is_text=False):
if is_text:
return load_program_text(model_filename)
return load_program_binary(model_filename)
def load_program_binary(model_filename):
"""load program from binary string file"""
with open(model_filename, "rb") as f:
program_desc_str = f.read()
return Program.parse_from_string(program_desc_str)
def load_program_text(model_filename):
"""load program from human-readable text file"""
with open(model_filename, "r") as f:
program_desc_text = f.read()
prog_desc = framework_pb2.ProgramDesc()
text_format.Merge(program_desc_text, prog_desc)
return Program.parse_from_string(prog_desc.SerializeToString())
def save_program(program, model_filename='__model__', is_text=False):
if is_text:
with open(model_filename, "w") as f:
f.write(str(program))
else:
with open(model_filename, "wb") as f:
f.write(program.desc.serialize_to_string())
def check_pruned_program_vars(train_prog, pruned_prog):
is_match = True
pruned_vars = [(v.name, v) for v in pruned_prog.list_vars()
if fluid.io.is_persistable(v)]
pruned_vars = OrderedDict(pruned_vars)
pruned_vars_name = [name for name in pruned_vars]
logger.info("persistable vars in pruned program: {}".format(
pruned_vars_name))
for var_name in pruned_vars:
var = pruned_vars[var_name]
# feed and fetch op is added in pruned program when pruning, not need to be found in train program
if var.type in feed_fetch_type_list:
break
try:
train_prog_var = train_prog.global_block().var(var_name)
except ValueError as e:
logger.error(
"not find variable '%s' in train program. please check pruning."
% var_name)
logger.error(e)
continue
if var.shape != train_prog_var.shape or var.dtype != train_prog_var.dtype:
logger.error(
"variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}".
format(var_name, var.shape, var.dtype, train_prog_var.shape,
train_prog_var.dtype))
is_match = False
return is_match
def graphviz(block, output_dir="", filename='debug'):
dot_path = os.path.join(output_dir, filename + '.dot')
pdf_path = os.path.join(output_dir, filename + '.pdf')
debugger.draw_block_graphviz(block, path=dot_path)
cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path]
p = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
p.wait()
def program_type_trans(prog_dir, prog_fn, is_text):
prog = load_program(os.path.join(prog_dir, prog_fn), is_text)
prog_out_fn = prog_fn + ".bin" if is_text else prog_fn + ".pbtxt"
save_program(prog, os.path.join(prog_dir, prog_out_fn), 1 - is_text)
return prog_out_fn
def append_save_op(block, var, path):
block.append_op(
type='save', inputs={'X': [var]}, outputs={},
attrs={'file_path': path})
def append_load_op(block, var, path):
block.append_op(
type='load',
inputs={},
outputs={'Out': [var]},
attrs={'file_path': path})
def save_var(np_array, var_name, shape_list, dtype, save_path):
program = fluid.Program()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.program_guard(program):
d0_data = fluid.layers.data(var_name, shape=shape_list, dtype=dtype)
append_save_op(program.global_block(), d0_data, save_path)
exe.run(feed={var_name: np_array}, fetch_list=[])
def load_var(var_name, shape_list, dtype, save_path):
program = fluid.Program()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.program_guard(program):
d0_data = fluid.layers.data(var_name, shape=shape_list, dtype=dtype)
append_load_op(program.global_block(), d0_data, save_path)
outs = exe.run(feed={}, fetch_list=[d0_data])
return outs
def reader(batch_size, fn, dim):
data = []
if isinstance(dim, list) or isinstance(dim, tuple):
shape = list(dim)
_temp = 1
for x in dim:
_temp = _temp * x
dim = _temp
else:
shape = [dim]
shape = [batch_size] + shape
dim = dim * batch_size
for line in open(fn, 'r'):
fields = line.strip().split(' ')
fields = [float(d) for d in fields]
while len(fields) >= dim:
tmp = fields[:dim]
fields = fields[dim:]
data.append(np.array(tmp).reshape(shape))
return data
def feed_gen(batch_size, feeded_vars_dims, feeded_vars_filelist):
batch_feed = []
for i, fn in enumerate(feeded_vars_filelist):
batch_feed.append(reader(batch_size, fn, feeded_vars_dims[i]))
return batch_feed
def try_load_model_vars(dump_dir, dump_prog_fn, is_text_dump_program,
batch_size, feed_config, fetch_config, save_filename,
saved_params):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
if is_text_dump_program:
dump_prog_fn = program_type_trans(dump_dir, dump_prog_fn,
is_text_dump_program)
inference_program, feed_target_names, fetch_targets = \
fluid.io.load_inference_model(dump_dir, exe, model_filename=dump_prog_fn,
params_filename=save_filename)
# check program vars and saved vars shape
orig_para_shape = {
each_var.name: tuple(each_var.desc.shape())
for each_var in saved_params
}
for each_var in saved_params:
var_temp = fluid.global_scope().find_var(each_var.name)
assert var_temp != None, "can't not find var: " + each_var.name
new_shape = (np.array(var_temp.get_tensor())).shape
assert each_var.name in orig_para_shape, each_var.name + "MUST in var list"
orig_shape = orig_para_shape.get(each_var.name)
if new_shape != orig_shape:
raise RuntimeError(
"Shape not matching: the Program requires a parameter with a shape of ({}), "
"while the loaded parameter (namely [ {} ]) has a shape of ({}).".
format(orig_shape, each_var.name, new_shape))
# check feed/fetch vars in program and config
fetch_targets_names = [v.name for v in fetch_targets]
if not feed_target_names:
logger.warning("no feed targets in program.")
if not fetch_targets_names:
logger.warning("no fetch targets in program.")
fetch_list = fetch_targets
feed_name_list = feed_target_names
if feed_config.feeded_vars_names is not None and feed_target_names != feed_config.feeded_vars_names:
logger.warning(
"feed vars in program and config are diff: feed in program: {}. feed in config {}.".
format(feed_target_names, feed_config.feeded_vars_names))
feed_name_list = feed_config.feeded_vars_names
# remove feed op in inference_program. new feed op will be added in exe.run
global_block = inference_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed": # only remove feed op here
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
if fetch_config.fetch_vars_names is not None and fetch_targets_names != fetch_config.fetch_vars_names:
logger.warning(
"fetch vars in program and config are diff: fetch in program: {}. fetch in config {}.".
format(fetch_targets_names, fetch_config.fetch_vars_names))
fetch_list = [
inference_program.global_block().var(i)
for i in fetch_config.fetch_vars_names
]
# remove fetch op in inference_program. new fetch op will be added in exe.run
global_block = inference_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "fetch": # only remove fetch op here
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
# if fetch_list have lod tensor
return_numpy = all([v.lod_level == 0 for v in fetch_list])
# try dump fetch_targets
feed_tensors = []
assert len(feed_config.feeded_vars_names) == len(
feed_config.feeded_vars_dims) == len(feed_config.feeded_vars_types)
# check program vars and feed tensor shape in config
for i in range(len(feed_config.feeded_vars_names)):
var = inference_program.global_block().var(
feed_config.feeded_vars_names[i])
if not isinstance(feed_config.feeded_vars_dims[i], (list, tuple)):
tensor_shape = (feed_config.feeded_vars_dims[i], )
else:
tensor_shape = tuple(feed_config.feeded_vars_dims[i])
feed_config.feeded_vars_dims[i] = tensor_shape
var_shape = var.shape[1:]
if tensor_shape != var_shape:
raise RuntimeError(
"feed variable '{}' shape not match. infer program shape: {}. feed tensor shape: {}".
format(feed_config.feeded_vars_names[i], var_shape,
tensor_shape))
if not feed_config.feeded_vars_filelist:
logger.info("generate random feed vars.")
for i in range(len(feed_config.feeded_vars_names)):
var = inference_program.global_block().var(
feed_config.feeded_vars_names[i])
# create fake feed tensor. if lod_level > 1, should create_lod_tensor()
if var.lod_level == 0:
feed_tensors.append(
np.array(
np.random.random(
tuple([batch_size] + list(
feed_config.feeded_vars_dims[i]))),
dtype=feed_config.feeded_vars_types[i]))
elif var.lod_level == 1:
t = np.array(
np.random.random(
tuple([batch_size] + list(
feed_config.feeded_vars_dims[i]))),
dtype=feed_config.feeded_vars_types[i])
feed_tensors.append(
fluid.create_lod_tensor(t, [[1] * batch_size], place))
else:
raise RuntimeError(
"vars with lod_level >= 2 is not supported now in this infer program check tool."
)
results = exe.run(inference_program,
feed={
name: feed_tensors[i]
for i, name in enumerate(feed_name_list)
},
fetch_list=fetch_list,
return_numpy=return_numpy)
else:
logger.info("load feed vars from files: {}.".format(
feed_config.feeded_vars_filelist))
feed_vars = [
inference_program.global_block().var(
feed_config.feeded_vars_names[i])
for i in range(len(feed_config.feeded_vars_names))
]
feeder = fluid.DataFeeder(feed_list=feed_vars, place=place)
batch_feed = feed_gen(batch_size, feed_config.feeded_vars_dims,
feed_config.feeded_vars_filelist)
slots = [batch_feed]
results = exe.run(inference_program,
feed=feeder.feed(slots),
fetch_list=fetch_list,
return_numpy=return_numpy)
for i, v in enumerate(fetch_list):
logger.info("fetch_targets name: %s" % v.name)
logger.info("fetch_targets: {}".format(results[i]))
return results
def check_not_expected_ops(prog):
op_types_set = set()
for op in prog.global_block().ops:
if op.type in not_expected_op_types and op.type not in op_types_set:
logger.warning(
"find op type '{}' in program, please check if your program is pruned correctly !".
format(op.type))
op_types_set.add(op.type)
def check_saved_vars_try_dump(dump_dir,
dump_prog_fn,
is_text_dump_program,
feed_config,
fetch_config,
batch_size=1,
save_filename=None):
dump_prog = load_program(
os.path.join(dump_dir, dump_prog_fn), is_text_dump_program)
saved_params = [
v for v in dump_prog.list_vars() if fluid.io.is_persistable(v)
]
logger.info("persistable vars in dump program: {}".format(
[v.name for v in saved_params]))
check_not_expected_ops(dump_prog)
return try_load_model_vars(dump_dir, dump_prog_fn, is_text_dump_program,
batch_size, feed_config, fetch_config,
save_filename, saved_params)
def parse_program(program, output_dir):
# persistable vars
output = {}
persistable_vars = [
v for v in program.list_vars() if fluid.io.is_persistable(v)
]
output["persistable_vars"] = [{
'name': str(v.name),
'shape': str(v.shape),
'lod_level': int(v.lod_level),
'dtype': str(v.dtype),
'type': str(v.type)
} for v in persistable_vars]
with open(os.path.join(output_dir, persistable_vars_out_fn), 'w') as f:
f.write("persistable vars:\n")
for var in output["persistable_vars"]:
f.write(str(var))
f.write("\n")
# all vars
all_vars = [v for v in program.list_vars()]
output["all_vars"] = [{
'name': str(v.name),
'shape': str(v.shape),
'lod_level': int(v.lod_level),
'dtype': str(v.dtype)
} if v.type not in feed_fetch_type_list else {
'name': str(v.name),
'type': str(v.type)
} for v in all_vars]
with open(os.path.join(output_dir, all_vars_out_fn), 'w') as f:
f.write("all vars:\n")
for var in output["all_vars"]:
f.write(str(var))
f.write("\n")
# ops
ops = program.global_block().ops
output["ops"] = [{
'type': op.type,
'input_arg_names': str(op.input_arg_names),
'output_arg_names': str(op.output_arg_names)
} for op in ops]
with open(os.path.join(output_dir, ops_out_fn), 'w') as f:
f.write("ops:\n")
for op in output["ops"]:
f.write(str(op))
f.write("\n")
......@@ -13,14 +13,43 @@
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
import tarfile
import tempfile
import os
import sys
from paddle.dataset.common import download, DATA_HOME
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.utils.fleet_barrier_util import check_all_trainers_ready
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
import paddle.fluid.incubate.fleet.utils.utils as utils
class TestFleetUtils(unittest.TestCase):
proto_data_url = "https://fleet.bj.bcebos.com/fleet_util_data.tgz"
proto_data_md5 = "59b7f12fd9dc24b64ae8e4629523a92a"
module_name = "fleet_util_data"
pruned_dir = os.path.join("fleet_util_data", "pruned_model")
train_dir = os.path.join("fleet_util_data", "train_program")
def download_files(self):
path = download(self.proto_data_url, self.module_name,
self.proto_data_md5)
print('data is downloaded at ' + path)
tar = tarfile.open(path)
unzip_folder = tempfile.mkdtemp()
tar.extractall(unzip_folder)
return unzip_folder
def test_fleet_util_init(self):
fleet_util_pslib = FleetUtil()
fleet_util_transpiler = FleetUtil(mode="transpiler")
self.assertRaises(Exception, FleetUtil, "other")
def test_fleet_barrier(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
......@@ -30,6 +59,165 @@ class TestFleetUtils(unittest.TestCase):
fleet.init(role)
check_all_trainers_ready("/ready_path/", 0)
def test_program_type_trans(self):
data_dir = self.download_files()
program_dir = os.path.join(data_dir, self.pruned_dir)
text_program = "pruned_main_program.pbtxt"
binary_program = "pruned_main_program.bin"
fleet_util = FleetUtil()
text_to_binary = fleet_util.program_type_trans(program_dir,
text_program, True)
binary_to_text = fleet_util.program_type_trans(program_dir,
binary_program, False)
self.assertTrue(
os.path.exists(os.path.join(program_dir, text_to_binary)))
self.assertTrue(
os.path.exists(os.path.join(program_dir, binary_to_text)))
def test_parse_program_proto(self):
data_dir = self.download_files()
parse_program_file_path = os.path.join(
data_dir,
os.path.join(self.pruned_dir, "pruned_main_program.pbtxt"))
is_text_parse_program = True
parse_output_dir = os.path.join(data_dir, self.pruned_dir)
fleet_util = FleetUtil()
fleet_util.parse_program_proto(parse_program_file_path,
is_text_parse_program, parse_output_dir)
ops_log = os.path.join(parse_output_dir, "ops.log")
vars_log = os.path.join(parse_output_dir, "vars_all.log")
vars_persistable = os.path.join(parse_output_dir,
"vars_persistable.log")
self.assertTrue(os.path.exists(ops_log))
self.assertTrue(os.path.exists(vars_log))
self.assertTrue(os.path.exists(vars_persistable))
def test_check_vars_and_dump(self):
data_dir = self.download_files()
class config:
pass
feed_config = config()
feed_config.feeded_vars_names = ['concat_1.tmp_0', 'concat_2.tmp_0']
feed_config.feeded_vars_dims = [682, 1199]
feed_config.feeded_vars_types = [np.float32, np.float32]
feed_config.feeded_vars_filelist = [
os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_1")),
os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_2"))
]
fetch_config = config()
fetch_config.fetch_vars_names = ['similarity_norm.tmp_0']
conf = config()
conf.batch_size = 1
conf.feed_config = feed_config
conf.fetch_config = fetch_config
conf.dump_model_dir = os.path.join(data_dir, self.pruned_dir)
conf.dump_program_filename = "pruned_main_program.pbtxt"
conf.is_text_dump_program = True
conf.save_params_filename = None
fleet_util = FleetUtil()
# test saved var's shape
conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match"
self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf)
# test program.proto without feed_op and fetch_op
conf.dump_program_filename = "pruned_main_program.no_feed_fetch"
results = fleet_util.check_vars_and_dump(conf)
self.assertTrue(len(results) == 1)
np.testing.assert_array_almost_equal(
results[0], np.array(
[[3.0590223e-07]], dtype=np.float32))
# test feed_var's shape
conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match"
self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf)
# test correct case with feed_vars_filelist
conf.dump_program_filename = "pruned_main_program.pbtxt"
results = fleet_util.check_vars_and_dump(conf)
self.assertTrue(len(results) == 1)
np.testing.assert_array_almost_equal(
results[0], np.array(
[[3.0590223e-07]], dtype=np.float32))
# test correct case without feed_vars_filelist
conf.feed_config.feeded_vars_filelist = None
# test feed var with lod_level >= 2
conf.dump_program_filename = "pruned_main_program.feed_lod2"
self.assertRaises(Exception, fleet_util.check_vars_and_dump, conf)
conf.dump_program_filename = "pruned_main_program.pbtxt"
results = fleet_util.check_vars_and_dump(conf)
self.assertTrue(len(results) == 1)
def test_check_two_programs(self):
data_dir = self.download_files()
class config:
pass
conf = config()
conf.train_prog_path = os.path.join(
data_dir, os.path.join(self.train_dir, "join_main_program.pbtxt"))
conf.is_text_train_program = True
# test not match
conf.pruned_prog_path = os.path.join(
data_dir,
os.path.join(self.pruned_dir,
"pruned_main_program.save_var_shape_not_match"))
conf.is_text_pruned_program = True
conf.draw = False
fleet_util = FleetUtil()
res = fleet_util.check_two_programs(conf)
self.assertFalse(res)
# test match
conf.pruned_prog_path = os.path.join(
data_dir,
os.path.join(self.pruned_dir, "pruned_main_program.pbtxt"))
if sys.platform == 'win32' or sys.platform == 'sys.platform':
conf.draw = False
else:
conf.draw = True
conf.draw_out_name = "pruned_check"
res = fleet_util.check_two_programs(conf)
self.assertTrue(res)
def test_draw_program(self):
if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass
else:
data_dir = self.download_files()
program_path = os.path.join(
data_dir,
os.path.join(self.train_dir, "join_main_program.pbtxt"))
is_text = True
program = utils.load_program(program_path, is_text)
output_dir = os.path.join(data_dir, self.train_dir)
output_filename_1 = "draw_prog_1"
output_filename_2 = "draw_prog_2"
fleet_util = FleetUtil()
fleet_util.draw_from_program_file(program_path, is_text, output_dir,
output_filename_1)
fleet_util.draw_from_program(program, output_dir, output_filename_2)
self.assertTrue(
os.path.exists(
os.path.join(output_dir, output_filename_1 + ".dot")))
self.assertTrue(
os.path.exists(
os.path.join(output_dir, output_filename_1 + ".pdf")))
self.assertTrue(
os.path.exists(
os.path.join(output_dir, output_filename_2 + ".dot")))
self.assertTrue(
os.path.exists(
os.path.join(output_dir, output_filename_2 + ".pdf")))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册