diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index 2687dcdd0f3ce5ac39a17ec4cdf8fb4cfe031513..362592e6de64a4bbfecb6868726b4a733edf4e14 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -38,11 +38,9 @@ class StartupBase(object): if dirname is None or dirname == "": return print("going to load ", dirname) - if is_fleet: - context["fleet"].load_persistables(context["exe"], dirname) - else: - fluid.io.load_persistables( - context["exe"], dirname, main_program=main_program) + fluid.io.load_persistables( + context["exe"], dirname, main_program=main_program) + print("load from {} success".format(dirname)) class SingleStartup(StartupBase): @@ -81,7 +79,6 @@ class PSStartup(StartupBase): "startup_program"] with fluid.program_guard(train_prog, startup_prog): context["exe"].run(startup_prog) - self.load(context, True) context["status"] = "train_pass" @@ -99,7 +96,7 @@ class CollectiveStartup(StartupBase): "startup_program"] with fluid.program_guard(train_prog, startup_prog): context["exe"].run(startup_prog) - self.load(context, True) + self.load(context, main_program=train_prog) context["status"] = "train_pass"