未验证 提交 2e5d9cb0 编写于 作者: H Hongsheng Zeng 提交者: GitHub

add unittest of get_weights set_weights with create_parameter (#262)

上级 4a449e3c
......@@ -690,6 +690,43 @@ class ModelBaseTest(unittest.TestCase):
self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
def test_get_weights_set_weights_with_create_parameter(self):
model1 = TestModel2()
model2 = TestModel2()
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[100], dtype='float32')
model1_output = model1.predict(obs)
model2_output = model2.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 100)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model1_output, model2_output])
self.assertNotEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
# pass parameters of self.model to model2
params = model1.get_weights()
model2.set_weights(params)
random_obs = np.random.random(size=(N, 100)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model1_output, model2_output])
self.assertEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册