提交 bb9b78b4 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add Double & Dueling DQN (#163)

* add Double & Dueling DQN

* yapf......................

* update

* Update train.py
上级 4d763f36
......@@ -17,20 +17,23 @@ import paddle.fluid as fluid
import parl
from parl import layers
from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim):
def __init__(self, algorithm, act_dim, start_lr, total_step):
super(AtariAgent, self).__init__(algorithm)
assert isinstance(act_dim, int)
self.act_dim = act_dim
self.exploration = 1.1
self.global_step = 0
self.update_target_steps = 10000 // 4
self.lr_scheduler = LinearDecayScheduler(start_lr, total_step)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
......@@ -53,8 +56,11 @@ class AtariAgent(parl.Agent):
name='next_obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal,
lr)
def sample(self, obs):
sample = np.random.random()
......@@ -89,6 +95,8 @@ class AtariAgent(parl.Agent):
self.alg.sync_target()
self.global_step += 1
lr = self.lr_scheduler.step(step_num=obs.shape[0])
act = np.expand_dims(act, -1)
reward = np.clip(reward, -1, 1)
feed = {
......@@ -96,7 +104,8 @@ class AtariAgent(parl.Agent):
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal
'terminal': terminal,
'lr': lr
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
......
......@@ -18,7 +18,7 @@ from parl import layers
class AtariModel(parl.Model):
def __init__(self, act_dim):
def __init__(self, act_dim, algo='DQN'):
self.act_dim = act_dim
self.conv1 = layers.conv2d(
......@@ -29,7 +29,15 @@ class AtariModel(parl.Model):
num_filters=64, filter_size=4, stride=1, padding=1, act='relu')
self.conv4 = layers.conv2d(
num_filters=64, filter_size=3, stride=1, padding=1, act='relu')
self.fc1 = layers.fc(size=act_dim)
self.algo = algo
if algo == 'Dueling':
self.fc1_adv = layers.fc(size=512, act='relu')
self.fc2_adv = layers.fc(size=act_dim)
self.fc1_val = layers.fc(size=512, act='relu')
self.fc2_val = layers.fc(size=1)
else:
self.fc1 = layers.fc(size=act_dim)
def value(self, obs):
obs = obs / 255.0
......@@ -44,5 +52,11 @@ class AtariModel(parl.Model):
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv4(out)
out = layers.flatten(out, axis=1)
out = self.fc1(out)
return out
if self.algo == 'Dueling':
As = self.fc2_adv(self.fc1_adv(out))
V = self.fc2_val(self.fc1_val(out))
Q = As + (V - layers.reduce_mean(As, dim=1, keep_dim=True))
else:
Q = self.fc1(out)
return Q
......@@ -20,10 +20,9 @@ import os
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from collections import deque
from datetime import datetime
from replay_memory import ReplayMemory, Experience
from parl.utils import logger
from parl.utils import tensorboard, logger
from tqdm import tqdm
from utils import get_player
......@@ -34,7 +33,7 @@ CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 1e-3 * 0.5
LEARNING_RATE = 3e-4
def run_train_episode(env, agent, rpm):
......@@ -67,7 +66,7 @@ def run_train_episode(env, agent, rpm):
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps
return total_reward, steps, np.mean(all_cost)
def run_evaluate_episode(env, agent):
......@@ -93,27 +92,38 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
act_dim = env.action_space.n
model = AtariModel(act_dim)
algorithm = parl.algorithms.DQN(
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = AtariAgent(algorithm, act_dim=act_dim)
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
model = AtariModel(act_dim, args.algo)
if args.algo == 'Double':
algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA)
elif args.algo in ['DQN', 'Dueling']:
algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA)
agent = AtariAgent(
algorithm,
act_dim=act_dim,
start_lr=LEARNING_RATE,
total_step=args.train_total_steps)
with tqdm(
total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
total_reward, steps = run_train_episode(env, agent, rpm)
total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps)
# train
test_flag = 0
pbar = tqdm(total=args.train_total_steps)
recent_100_reward = []
total_steps = 0
max_reward = None
while total_steps < args.train_total_steps:
# start epoch
total_reward, steps = run_train_episode(env, agent, rpm)
total_reward, steps, loss = run_train_episode(env, agent, rpm)
total_steps += steps
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
tensorboard.add_scalar('dqn/score', total_reward, total_steps)
tensorboard.add_scalar('dqn/loss', loss,
total_steps) # mean of total loss
tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
......@@ -127,6 +137,8 @@ def main():
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, np.mean(eval_rewards)))
eval_test = np.mean(eval_rewards)
tensorboard.add_scalar('dqn/eval', eval_test, total_steps)
pbar.close()
......@@ -137,10 +149,16 @@ if __name__ == '__main__':
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument(
'--algo',
default='DQN',
help=
'DQN/DDQN/Dueling, represent DQN, double DQN, and dueling DQN respectively',
)
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e8),
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
......@@ -149,5 +167,4 @@ if __name__ == '__main__':
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
......@@ -15,6 +15,7 @@
from parl.algorithms.fluid.a3c import *
from parl.algorithms.fluid.ddpg import *
from parl.algorithms.fluid.dqn import *
from parl.algorithms.fluid.ddqn import *
from parl.algorithms.fluid.policy_gradient import *
from parl.algorithms.fluid.ppo import *
from parl.algorithms.fluid.impala.impala import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import copy
import numpy as np
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
class DDQN(Algorithm):
def __init__(
self,
model,
act_dim=None,
gamma=None,
):
""" Double DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function.
gamma (float): discounted factor for reward computation.
"""
self.model = model
self.target_model = copy.deepcopy(model)
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
def predict(self, obs):
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
pred_value = self.model.value(obs)
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
# choose acc. to behavior network
next_action_value = self.model.value(next_obs)
greedy_action = layers.argmax(next_action_value, axis=-1)
# calculate the target q value with target network
batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int')
range_tmp = layers.range(
start=0, end=batch_size, step=1, dtype='int64') * self.act_dim
a_indices = range_tmp + greedy_action
a_indices = layers.cast(a_indices, dtype='int32')
next_pred_value = self.target_model.value(next_obs)
next_pred_value = layers.reshape(
next_pred_value, shape=[
-1,
])
max_v = layers.gather(next_pred_value, a_indices)
max_v = layers.reshape(
max_v, shape=[
-1,
])
max_v.stop_gradient = True
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * max_v
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate, epsilon=1e-3)
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id=None):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
......@@ -25,12 +25,7 @@ __all__ = ['DQN']
class DQN(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
gamma=None,
lr=None):
def __init__(self, model, hyperparas=None, act_dim=None, gamma=None):
""" DQN algorithm
Args:
......@@ -50,14 +45,11 @@ class DQN(Algorithm):
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
self.lr = hyperparas['lr']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
......@@ -73,10 +65,12 @@ class DQN(Algorithm):
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal):
return self.learn(obs, action, reward, next_obs, terminal)
def define_learn(self, obs, action, reward, next_obs, terminal,
learning_rate):
return self.learn(obs, action, reward, next_obs, terminal,
learning_rate)
def learn(self, obs, action, reward, next_obs, terminal):
def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
""" update value model self.model with DQN algorithm
"""
......@@ -93,7 +87,8 @@ class DQN(Algorithm):
layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3)
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate, epsilon=1e-3)
optimizer.minimize(cost)
return cost
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册