未验证 提交 2e56337e 编写于 作者: B Bo Zhou 提交者: GitHub

support paddle 1.8.2 (#317)

上级 524ba6f6
......@@ -93,7 +93,7 @@ def main():
act_dim = env.action_space.n
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
if args.algo == 'DDQN':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
......
......@@ -75,7 +75,7 @@ class DDQN(Algorithm):
greedy_action = layers.argmax(next_action_value, axis=-1)
# calculate the target q value with target network
batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int')
batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int32')
range_tmp = layers.range(
start=0, end=batch_size, step=1, dtype='int64') * self.act_dim
a_indices = range_tmp + greedy_action
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册