提交 5054efed 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

fix compiled_program restore (#192)

上级 cb4b3852
......@@ -212,6 +212,8 @@ class Agent(AgentBase):
if program is None:
program = self.learn_program
if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.load_params(
......
......@@ -40,7 +40,10 @@ def compile(program, loss=None):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
return fluid.compiler.CompiledProgram(program).with_data_parallel(
loss_name=loss_name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
compiled_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=loss_name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
compiled_program._init_program = program
return compiled_program
......@@ -116,6 +116,22 @@ class AgentBaseTest(unittest.TestCase):
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
def test_compiled_restore(self):
agent = TestAgent(self.algorithm)
agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs)
save_path1 = './model.ckpt'
agent.save(save_path1)
agent.restore(save_path1)
# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent.learn_program = parl.compile(another_agent.learn_program)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册