提交 703306ce 编写于 作者: L liuyuhui

fix Collective multi dataset_name

上级 70b6bbe8
......@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context):
context["model"] = {}
if len(context["env"]["phase"]) > 1:
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
warnings.warn(
"Cluster Train Only Support One Phase.",
category=UserWarning,
......@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase):
context["model"][model_dict["name"]]["compiled_program"] = None
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "QueueDataset":
raise ValueError(
"Collective don't support QueueDataset training, please use DataLoader."
)
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
context)
context["dataset"][phase[
"dataset_name"]] = dataset_class.create_dataset(
phase["dataset_name"], context)
context["status"] = "startup_pass"
def _build_strategy(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册