未验证 提交 3a27f407 编写于 作者: Z Zheyue Tan 提交者: GitHub

Add Prioritized DQN (#326)

- add prioritized dqn
- fix#239
上级 c85204dc
## Prioritized Experience Replay
Reproducing paper [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952).
Prioritized experience replay (PER) develops a framework for prioritizing experience, so as to replay important transitions more frequently. There are two variants of prioritizing the transitions, rank-based and proportional-based. Our implementation is the proportional variant, which has a better performance, as reported in the original paper.
## Reproduced Results
Results have been reproduced with [Double DQN](https://arxiv.org/abs/1509.06461v3) on following three environments:
<p align="center">
<img src="result.png"/>
</p>
## How to use
### Dependencies:
+ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym[atari]==0.17.2
+ atari-py==0.2.6
+ tqdm
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
### Start Training:
Train on BattleZone game:
```bash
python train.py --rom ./rom_files/battle_zone.bin
```
> To train on more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
../DQN_variant/atari.py
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim, update_freq):
super(AtariAgent, self).__init__(algorithm)
assert isinstance(act_dim, int)
self.act_dim = act_dim
self.exploration = 1.0
self.global_step = 0
self.update_target_steps = 10000 // 4
self.update_freq = update_freq
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
sample_weight = layers.data(
name='sample_weight', shape=[1], dtype='float32')
self.cost, self.delta = self.alg.learn(
obs, action, reward, next_obs, terminal, sample_weight)
def sample(self, obs, decay_exploration=True):
sample = np.random.random()
if sample < self.exploration:
act = np.random.randint(self.act_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.act_dim)
else:
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
if decay_exploration:
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal, sample_weight):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
reward = np.clip(reward, -1, 1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward.astype('float32'),
'next_obs': next_obs.astype('float32'),
'terminal': terminal.astype('bool'),
'sample_weight': sample_weight.astype('float32')
}
cost, delta = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost, self.delta])
return cost, delta
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.act_dim = act_dim
self.conv1 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=64, filter_size=4, stride=1, padding=1, act='relu')
self.conv4 = layers.conv2d(
num_filters=64, filter_size=3, stride=1, padding=1, act='relu')
self.fc1 = layers.fc(size=act_dim)
def value(self, obs):
obs = obs / 255.0
out = self.conv1(obs)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv2(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv3(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv4(out)
out = layers.flatten(out, axis=1)
Q = self.fc1(out)
return Q
../DQN_variant/atari_wrapper.py
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import paddle.fluid as fluid
import parl
from parl.core.fluid import layers
class PrioritizedDQN(parl.Algorithm):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" DQN algorithm with prioritized experience replay.
Args:
model (parl.Model): model defining forward network of Q function
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
"""
self.model = model
self.target_model = copy.deepcopy(model)
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, sample_weight):
""" update value model self.model with DQN algorithm
"""
pred_value = self.model.value(obs)
next_pred_value = self.target_model.value(next_obs)
best_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
pred_action_value = layers.reduce_sum(
action_onehot * pred_value, dim=1)
delta = layers.abs(target - pred_action_value)
cost = sample_weight * layers.square_error_cost(
pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3)
optimizer.minimize(cost)
return cost, delta # `delta` is the TD-error
def sync_target(self):
""" sync weights of self.model to self.target_model
"""
self.model.sync_weights_to(self.target_model)
class PrioritizedDoubleDQN(parl.Algorithm):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" Double DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function.
gamma (float): discounted factor for reward computation.
"""
self.model = model
self.target_model = copy.deepcopy(model)
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, sample_weight):
pred_value = self.model.value(obs)
action_onehot = layers.one_hot(action, self.act_dim)
pred_action_value = layers.reduce_sum(
action_onehot * pred_value, dim=1)
# calculate the target q value
next_action_value = self.model.value(next_obs)
greedy_action = layers.argmax(next_action_value, axis=-1)
greedy_action = layers.unsqueeze(greedy_action, axes=[1])
greedy_action_onehot = layers.one_hot(greedy_action, self.act_dim)
next_pred_value = self.target_model.value(next_obs)
max_v = layers.reduce_sum(
greedy_action_onehot * next_pred_value, dim=1)
max_v.stop_gradient = True
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * max_v
delta = layers.abs(target - pred_action_value)
cost = sample_weight * layers.square_error_cost(
pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3)
optimizer.minimize(cost)
return cost, delta
def sync_target(self):
""" sync weights of self.model to self.target_model
"""
self.model.sync_weights_to(self.target_model)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class SumTree(object):
def __init__(self, capacity):
self.capacity = capacity
self.elements = [None for _ in range(capacity)]
self.tree = [0 for _ in range(2 * capacity - 1)]
self._ptr = 0
self._min = 10
def full(self):
return all(self.elements) # no `None` in self.elements
def add(self, item, priority):
self.elements[self._ptr] = item
tree_idx = self._ptr + self.capacity - 1
self.update(tree_idx, priority)
self._ptr = (self._ptr + 1) % self.capacity
self._min = min(self._min, priority)
def update(self, tree_idx, priority):
diff = priority - self.tree[tree_idx]
self.tree[tree_idx] = priority
while tree_idx != 0:
tree_idx = (tree_idx - 1) >> 1
self.tree[tree_idx] += diff
self._min = min(self._min, priority)
def retrieve(self, value):
parent_idx = 0
while True:
left_child_idx = 2 * parent_idx + 1
right_child_idx = left_child_idx + 1
if left_child_idx >= len(self.tree):
leaf_idx = parent_idx
break
else:
if value <= self.tree[left_child_idx]:
parent_idx = left_child_idx
else:
value -= self.tree[left_child_idx]
parent_idx = right_child_idx
elem_idx = leaf_idx - self.capacity + 1
priority = self.tree[leaf_idx]
return self.elements[elem_idx], leaf_idx, priority
def from_list(self, lst):
assert len(lst) == self.capacity
self.elements = list(lst)
for i in range(self.capacity - 1, 2 * self.capacity - 1):
self.update(i, 1.0)
@property
def total_p(self):
return self.tree[0]
class ProportionalPER(object):
"""Proportional Prioritized Experience Replay.
"""
def __init__(self,
alpha,
seg_num,
size=1e6,
eps=0.01,
init_mem=None,
framestack=4):
self.alpha = alpha
self.seg_num = seg_num
self.size = int(size)
self.elements = SumTree(self.size)
if init_mem:
self.elements.from_list(init_mem)
self.framestack = framestack
self._max_priority = 1.0
self.eps = eps
def _get_stacked_item(self, idx):
""" For atari environment, we use a 4-frame-stack as input
"""
obs, act, reward, next_obs, done = self.elements.elements[idx]
stacked_obs = np.zeros((self.framestack, ) + obs.shape)
stacked_obs[-1] = obs
for i in range(self.framestack - 2, -1, -1):
elem_idx = (self.size + idx + i - self.framestack + 1) % self.size
obs, _, _, _, d = self.elements.elements[elem_idx]
if d:
break
stacked_obs[i] = obs
return (stacked_obs, act, reward, next_obs, done)
def store(self, item, delta=None):
assert len(item) == 5 # (s, a, r, s', terminal)
if not delta:
delta = self._max_priority
assert delta >= 0
ps = np.power(delta + self.eps, self.alpha)
self.elements.add(item, ps)
def update(self, indices, priorities):
priorities = np.array(priorities) + self.eps
priorities_alpha = np.power(priorities, self.alpha)
for idx, priority in zip(indices, priorities_alpha):
self.elements.update(idx, priority)
self._max_priority = max(priority, self._max_priority)
def sample_one(self):
assert self.elements.full(), "The replay memory is not full!"
sample_val = np.random.uniform(0, self.elements.total_p)
item, tree_idx, _ = self.elements.retrieve(sample_val)
return item, tree_idx
def sample(self, beta=1):
""" sample a batch of `seg_num` transitions
Args:
beta: float, degree of using importance sampling weights,
0 - no corrections, 1 - full correction
Return:
items: sampled transitions
indices: idxs of sampled items, used to update priorities later
sample_weights: importance sampling weight
"""
assert self.elements.full(), "The replay memory is not full!"
seg_size = self.elements.total_p / self.seg_num
seg_bound = [(seg_size * i, seg_size * (i + 1))
for i in range(self.seg_num)]
items, indices, priorities = [], [], []
for low, high in seg_bound:
sample_val = np.random.uniform(low, high)
_, tree_idx, priority = self.elements.retrieve(sample_val)
elem_idx = tree_idx - self.elements.capacity + 1
item = self._get_stacked_item(elem_idx)
items.append(item)
indices.append(tree_idx)
priorities.append(priority)
batch_probs = self.size * np.array(priorities) / self.elements.total_p
min_prob = self.size * self.elements._min / self.elements.total_p
sample_weights = np.power(batch_probs / min_prob, -beta)
return np.array(items), np.array(indices), sample_weights
../DQN_variant/rom_files
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
from collections import deque
from datetime import datetime
import gym
import numpy as np
import paddle.fluid as fluid
from tqdm import tqdm
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from parl.utils import logger, summary
from per_alg import PrioritizedDoubleDQN, PrioritizedDQN
from proportional_per import ProportionalPER
from utils import get_player
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 0.00025 / 4
def beta_adder(init_beta, step_size=0.0001):
beta = init_beta
step_size = step_size
def adder():
nonlocal beta, step_size
beta += step_size
return min(beta, 1)
return adder
def process_transitions(transitions):
transitions = np.array(transitions)
batch_obs = np.stack(transitions[:, 0].copy())
batch_act = transitions[:, 1].copy()
batch_reward = transitions[:, 2].copy()
batch_next_obs = np.expand_dims(np.stack(transitions[:, 3]), axis=1)
batch_next_obs = np.concatenate([batch_obs, batch_next_obs],
axis=1)[:, 1:, :, :].copy()
batch_terminal = transitions[:, 4].copy()
batch = (batch_obs, batch_act, batch_reward, batch_next_obs,
batch_terminal)
return batch
def run_episode(env, agent, per, mem=None, warmup=False, train=False):
total_reward = 0
all_cost = []
traj = deque(maxlen=CONTEXT_LEN)
obs = env.reset()
for _ in range(CONTEXT_LEN - 1):
traj.append(np.zeros(obs.shape))
steps = 0
if warmup:
decay_exploration = False
else:
decay_exploration = True
while True:
steps += 1
traj.append(obs)
context = np.stack(traj, axis=0)
action = agent.sample(context, decay_exploration=decay_exploration)
next_obs, reward, terminal, _ = env.step(action)
transition = [obs, action, reward, next_obs, terminal]
if warmup:
mem.append(transition)
if train:
per.store(transition)
if steps % UPDATE_FREQ == 0:
beta = get_beta()
transitions, idxs, sample_weights = per.sample(beta=beta)
batch = process_transitions(transitions)
cost, delta = agent.learn(*batch, sample_weights)
all_cost.append(cost)
per.update(idxs, delta)
total_reward += reward
obs = next_obs
if terminal:
break
return total_reward, steps, np.mean(all_cost)
def run_evaluate_episode(env, agent):
obs = env.reset()
total_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
return total_reward
def main():
# Prepare environments
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
test_env = get_player(
args.rom,
image_size=IMAGE_SIZE,
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
# Init Prioritized Replay Memory
per = ProportionalPER(alpha=0.6, seg_num=args.batch_size, size=MEMORY_SIZE)
# Prepare PARL agent
act_dim = env.action_space.n
model = AtariModel(act_dim)
if args.alg == 'ddqn':
algorithm = PrioritizedDoubleDQN(
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE)
elif args.alg == 'dqn':
algorithm = PrioritizedDQN(
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = AtariAgent(algorithm, act_dim=act_dim, update_freq=UPDATE_FREQ)
# Replay memory warmup
total_step = 0
with tqdm(total=MEMORY_SIZE, desc='[Replay Memory Warm Up]') as pbar:
mem = []
while total_step < MEMORY_WARMUP_SIZE:
total_reward, steps, _ = run_episode(
env, agent, per, mem=mem, warmup=True)
total_step += steps
pbar.update(steps)
per.elements.from_list(mem[:int(MEMORY_WARMUP_SIZE)])
env_name = args.rom.split('/')[-1].split('.')[0]
test_flag = 0
total_steps = 0
pbar = tqdm(total=args.train_total_steps)
while total_steps < args.train_total_steps:
# start epoch
total_reward, steps, loss = run_episode(env, agent, per, train=True)
total_steps += steps
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
summary.add_scalar('{}/score'.format(env_name), total_reward,
total_steps)
summary.add_scalar('{}/loss'.format(env_name), loss,
total_steps) # mean of total loss
summary.add_scalar('{}/exploration'.format(env_name),
agent.exploration, total_steps)
pbar.update(steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
pbar.write("testing")
test_rewards = []
for _ in tqdm(range(3), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
test_rewards.append(eval_reward)
eval_reward = np.mean(test_rewards)
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_steps, eval_reward))
summary.add_scalar('{}/eval'.format(env_name), eval_reward,
total_steps)
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=32, help='batch size for training')
parser.add_argument(
'--alg',
type=str,
default="ddqn",
help='dqn or ddqn, training algorithm to use.')
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e7),
help='maximum environmental steps of games')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
assert args.alg in ['dqn','ddqn'], \
'used algorithm should be dqn or ddqn (double dqn)'
get_beta = beta_adder(init_beta=0.5)
main()
../DQN_variant/utils.py
\ No newline at end of file
......@@ -70,26 +70,14 @@ class DDQN(Algorithm):
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
# choose acc. to behavior network
# calculate the target q value
next_action_value = self.model.value(next_obs)
greedy_action = layers.argmax(next_action_value, axis=-1)
# calculate the target q value with target network
batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int64')
range_tmp = layers.range(
start=0, end=batch_size, step=1, dtype='int64') * self.act_dim
a_indices = range_tmp + greedy_action
a_indices = layers.cast(a_indices, dtype='int32')
greedy_action = layers.unsqueeze(greedy_action, axes=[1])
greedy_action_onehot = layers.one_hot(greedy_action, self.act_dim)
next_pred_value = self.target_model.value(next_obs)
next_pred_value = layers.reshape(
next_pred_value, shape=[
-1,
])
max_v = layers.gather(next_pred_value, a_indices)
max_v = layers.reshape(
max_v, shape=[
-1,
])
max_v = layers.reduce_sum(
greedy_action_onehot * next_pred_value, dim=1)
max_v.stop_gradient = True
target = reward + (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册