未验证 提交 bbcb707b 编写于 作者: H Hongsheng Zeng 提交者: GitHub

torch benchmark policy gradient (#203)

* torch benchmark policy gradient

* refine comments and use native api
上级 9216d941
## PyTorch benchmark Quick Start
Train an agent with PARL to solve the CartPole problem, a classical benchmark in RL.
## How to use
### Dependencies:
+ [parl](https://github.com/PaddlePaddle/PARL)
+ torch
+ gym
### Start Training:
```
# Install dependencies
pip install torch torchvision gym
git clone https://github.com/PaddlePaddle/PARL.git
cd PARL
pip install .
# Train model
cd benchmark/torch/QuickStart
python train.py
```
### Expected Result
<img src="https://github.com/PaddlePaddle/PARL/blob/develop/examples/QuickStart/performance.gif" width = "300" height ="200" alt="result"/>
The agent can get around 200 points in a few minutes.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
import numpy as np
class CartpoleAgent(parl.Agent):
"""Agent of Cartpole env.
Args:
algorithm(parl.Algorithm): algorithm used to solve the problem.
"""
def __init__(self, algorithm):
self.algorithm = algorithm
self.device = torch.device("cuda" if torch.cuda.
is_available() else "cpu")
def sample(self, obs):
"""Sample an action when given an observation
Args:
obs(np.float32): shape of (obs_dim,)
Returns:
action(int)
"""
obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs)
prob = prob.data.numpy()
action = np.random.choice(len(prob), 1, p=prob)[0]
return action
def predict(self, obs):
"""Predict an action when given an observation
Args:
obs(np.float32): shape of (obs_dim,)
Returns:
action(int)
"""
obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs)
_, action = prob.max(-1)
return action.item()
def learn(self, obs, action, reward):
"""Update model with an episode data
Args:
obs(np.float32): shape of (batch_size, obs_dim)
action(np.int64): shape of (batch_size)
reward(np.float32): shape of (batch_size)
Returns:
loss(float)
"""
obs = torch.tensor(obs, device=self.device, dtype=torch.float)
action = torch.tensor(action, device=self.device, dtype=torch.long)
reward = torch.tensor(reward, device=self.device, dtype=torch.float)
loss = self.algorithm.learn(obs, action, reward)
return loss.item()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import parl
class CartpoleModel(parl.Model):
""" Linear network to solve Cartpole problem.
Args:
obs_dim (int): Dimension of observation space.
act_dim (int): Dimension of action space.
"""
def __init__(self, obs_dim, act_dim):
super(CartpoleModel, self).__init__()
hid1_size = act_dim * 10
self.fc1 = nn.Linear(obs_dim, hid1_size)
self.fc2 = nn.Linear(hid1_size, act_dim)
def forward(self, x):
out = torch.tanh(self.fc1(x))
prob = F.softmax(self.fc2(out), dim=-1)
return prob
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import parl
from parl.utils import logger
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
OBS_DIM = 4
ACT_DIM = 2
LEARNING_RATE = 1e-3
def run_episode(env, agent, train_or_test='train'):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
if train_or_test == 'train':
action = agent.sample(obs)
else:
action = agent.predict(obs)
action_list.append(action)
obs, reward, done, _ = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
def calc_reward_to_go(reward_list):
for i in range(len(reward_list) - 2, -1, -1):
reward_list[i] += reward_list[i + 1]
return np.array(reward_list)
def main():
env = gym.make('CartPole-v0')
model = CartpoleModel(obs_dim=OBS_DIM, act_dim=ACT_DIM)
alg = parl.algorithms.PolicyGradient(model, LEARNING_RATE)
agent = CartpoleAgent(alg)
for i in range(1000): # 1000 episodes
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_reward_to_go(reward_list)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
_, _, reward_list = run_episode(env, agent, train_or_test='test')
total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward))
if __name__ == '__main__':
main()
......@@ -16,3 +16,4 @@ from parl.algorithms.torch.ddqn import *
from parl.algorithms.torch.dqn import *
from parl.algorithms.torch.a2c import *
from parl.algorithms.torch.td3 import *
from parl.algorithms.torch.policy_gradient import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.optim as optim
import parl
from torch.distributions import Categorical
__all__ = ['PolicyGradient']
class PolicyGradient(parl.Algorithm):
def __init__(self, model, lr):
"""Policy gradient algorithm
Args:
model (parl.Model): model defining forward network of policy.
lr (float): learning rate.
"""
assert isinstance(lr, float)
self.model = model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
def predict(self, obs):
"""Predict the probability of actions
Args:
obs (torch.tensor): shape of (obs_dim,)
Returns:
prob (torch.tensor): shape of (action_dim,)
"""
prob = self.model(obs)
return prob
def learn(self, obs, action, reward):
"""Update model with policy gradient algorithm
Args:
obs (torch.tensor): shape of (batch_size, obs_dim)
action (torch.tensor): shape of (batch_size, 1)
reward (torch.tensor): shape of (batch_size, 1)
Returns:
loss (torch.tensor): shape of (1)
"""
prob = self.model(obs)
log_prob = Categorical(prob).log_prob(action)
loss = torch.mean(-1 * log_prob * reward)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册