提交 4abc0534 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add pytorch a2c (#167)

* add pytorch a2c

* add set/get_weights test & copyright

* yapf....

* Update model_base_test_torch.py

* update

* Delete banma.py

* Update model_base_test_torch.py

* update

* Update model.py

* update torch tests

* Update model_base_test_torch.py
上级 7c406386
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
config = {
#========== remote config ==========
'master_address': 'localhost:8010',
#========== env config ==========
'env_name': 'BreakoutNoFrameskip-v4',
'env_dim': 84,
#========== actor config ==========
'actor_num': 5,
'env_num': 5,
'sample_batch_steps': 20,
#========== learner config ==========
'max_sample_steps': int(1e7),
'gamma': 0.99,
'lambda': 1.0,
# start learning rate
'start_lr': 0.001,
'entropy_coeff_scheduler': [(0, -0.01)],
'vf_loss_coeff': 0.5,
'get_remote_metrics_interval': 10,
'log_metrics_interval_s': 10,
'entropy_coeff': -0.05,
'learning_rate': 3e-4
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import parl
import torch
import numpy as np
from collections import defaultdict
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
from parl.utils.rl_utils import calc_gae
from atari_model import ActorCritic
from parl.algorithms import A2C
from atari_agent import Agent
@parl.remote_class
class Actor(object):
def __init__(self, config):
# the cluster may not have gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
self.actor_cuda = False
self.config = config
self.envs = []
for _ in range(config['env_num']):
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env)
self.vector_env = VectorEnv(self.envs)
self.obs_batch = self.vector_env.reset()
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = ActorCritic(act_dim)
if self.actor_cuda:
model = model.cuda()
algorithm = A2C(model, config)
self.agent = Agent(algorithm, config)
def sample(self):
''' Interact with the environments lambda times
'''
sample_data = defaultdict(list)
env_sample_data = {}
for env_id in range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list)
for i in range(self.config['sample_batch_steps']):
self.obs_batch = np.stack(self.obs_batch)
self.obs_batch = torch.from_numpy(self.obs_batch).float()
if self.actor_cuda:
self.obs_batch = self.obs_batch.cuda()
action_batch, value_batch = self.agent.sample(self.obs_batch)
next_obs_batch, reward_batch, done_batch, info_batch = self.vector_env.step(
action_batch.cpu().numpy())
for env_id in range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(
self.obs_batch[env_id].cpu().numpy())
env_sample_data[env_id]['actions'].append(
action_batch[env_id].item())
env_sample_data[env_id]['rewards'].append(reward_batch[env_id])
env_sample_data[env_id]['dones'].append(done_batch[env_id])
env_sample_data[env_id]['values'].append(
value_batch[env_id].item())
if done_batch[
env_id] or i == self.config['sample_batch_steps'] - 1:
next_value = 0
if not done_batch[env_id]:
next_obs = np.expand_dims(next_obs_batch[env_id], 0)
next_obs = torch.from_numpy(next_obs).float()
if self.actor_cuda:
next_obs = next_obs.cuda()
next_value = self.agent.value(next_obs).item()
values = env_sample_data[env_id]['values']
rewards = env_sample_data[env_id]['rewards']
advantages = calc_gae(rewards, values, next_value,
self.config['gamma'],
self.config['lambda'])
target_values = advantages + values
sample_data['obs'].extend(env_sample_data[env_id]['obs'])
sample_data['actions'].extend(
env_sample_data[env_id]['actions'])
sample_data['advantages'].extend(advantages)
sample_data['target_values'].extend(target_values)
env_sample_data[env_id] = defaultdict(list)
self.obs_batch = next_obs_batch
for key in sample_data:
sample_data[key] = np.stack(sample_data[key])
return sample_data
def compute_target(self, v_final, r_lst, mask_lst):
G = v_final.reshape(-1)
td_target = list()
for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
G = r + self.config['gamma'] * G * mask
td_target.append(G)
return torch.tensor(td_target[::-1]).float()
def get_metrics(self):
metrics = defaultdict(list)
for env in self.envs:
monitor = get_wrapper_by_cls(env, MonitorEnv)
if monitor is not None:
for episode_rewards, episode_steps in monitor.next_episode_results(
):
metrics['episode_rewards'].append(episode_rewards)
metrics['episode_steps'].append(episode_steps)
return metrics
def set_weights(self, params):
self.agent.set_weights(params)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import parl
# torch use full CPU by default, which will decrease the performance. Use one thread for one actor here.
torch.set_num_threads(1)
class Agent(parl.Agent):
def __init__(self, algorithm, config):
super(Agent, self).__init__(algorithm)
self.obs_shape = config['obs_shape']
def sample(self, obs):
sample_actions, values = self.algorithm.sample(obs)
return sample_actions, values
def predict(self, obs):
predict_actions = self.algorithm.predict(obs)
return predict_actions
def value(self, obs):
values = self.algorithm.value(obs)
return values
def learn(self, obs, actions, advantages, target_values):
total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.algorithm.learn(
obs, actions, advantages, target_values)
return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import parl
class ActorCritic(parl.Model):
def __init__(self, act_dim):
super(ActorCritic, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=4, out_channels=32, kernel_size=8, stride=4, padding=2)
self.conv2 = nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=4,
stride=2,
padding=2)
self.conv3 = nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)
self.fc = nn.Linear(7744, 512)
self.fc_pi = nn.Linear(512, act_dim)
self.fc_v = nn.Linear(512, 1)
def policy(self, x, softmax_dim=1):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
logits = self.fc_pi(x)
prob = F.softmax(logits, dim=softmax_dim)
return prob
def value(self, x):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
values = self.fc_v(x)
return values
def policy_and_value(self, x, softmax_dim=1):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
values = self.fc_v(x)
logits = self.fc_pi(x)
prob = F.softmax(logits, dim=softmax_dim)
return prob, values
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
import gym
import six
import queue
import parl
import time
import threading
import numpy as np
from collections import defaultdict
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils.window_stat import WindowStat
from parl.utils.time_stat import TimeStat
from parl.utils import machine_info
from parl.utils import logger, get_gpu_count, tensorboard
from parl.algorithms import A2C
from atari_model import ActorCritic
from atari_agent import Agent
from actor import Actor
import time
from statistics import mean
class Learner(object):
def __init__(self, config, cuda):
self.cuda = cuda
self.config = config
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = ActorCritic(act_dim)
if self.cuda:
model = model.cuda()
algorithm = A2C(model, config)
self.agent = Agent(algorithm, config)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
else:
os.environ['CPU_NUM'] = str(1)
#========== Learner ==========
self.total_loss_stat = WindowStat(100)
self.pi_loss_stat = WindowStat(100)
self.vf_loss_stat = WindowStat(100)
self.entropy_stat = WindowStat(100)
self.lr = None
self.entropy_coeff = None
self.learn_time_stat = TimeStat(100)
self.start_time = None
#========== Remote Actor ===========
self.remote_count = 0
self.sample_total_steps = 0
self.sample_data_queue = queue.Queue()
self.remote_metrics_queue = queue.Queue()
self.params_queues = []
self.create_actors()
def create_actors(self):
parl.connect(self.config['master_address'])
logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num']))
for i in six.moves.range(self.config['actor_num']):
params_queue = queue.Queue()
self.params_queues.append(params_queue)
self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count))
remote_thread = threading.Thread(
target=self.run_remote_sample, args=(params_queue, ))
remote_thread.setDaemon(True)
remote_thread.start()
logger.info('All remote actors are ready, begin to learn.')
self.start_time = time.time()
def run_remote_sample(self, params_queue):
remote_actor = Actor(self.config)
cnt = 0
while True:
latest_params = params_queue.get()
remote_actor.set_weights(latest_params)
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
cnt += 1
if cnt % self.config['get_remote_metrics_interval'] == 0:
metrics = remote_actor.get_metrics()
if metrics:
self.remote_metrics_queue.put(metrics)
def step(self):
latest_params = self.agent.get_weights()
for params_queue in self.params_queues:
params_queue.put(latest_params)
train_batch = defaultdict(list)
for i in range(self.config['actor_num']):
sample_data = self.sample_data_queue.get()
for key, value in sample_data.items():
train_batch[key].append(value)
self.sample_total_steps += len(sample_data['obs'])
for key, value in train_batch.items():
train_batch[key] = np.concatenate(value)
train_batch[key] = torch.tensor(train_batch[key]).float()
if self.cuda:
train_batch[key] = train_batch[key].cuda()
with self.learn_time_stat:
total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn(
obs=train_batch['obs'],
actions=train_batch['actions'],
advantages=train_batch['advantages'],
target_values=train_batch['target_values'],
)
self.total_loss_stat.add(total_loss.item())
self.pi_loss_stat.add(pi_loss.item())
self.vf_loss_stat.add(vf_loss.item())
self.entropy_stat.add(entropy.item())
self.lr = lr
self.entropy_coeff = entropy_coeff
def log_metrics(self):
""" Log metrics of learner and actors
"""
if self.start_time is None:
return
metrics = []
while True:
try:
metric = self.remote_metrics_queue.get_nowait()
metrics.append(metric)
except queue.Empty:
break
episode_rewards, episode_steps = [], []
for x in metrics:
episode_rewards.extend(x['episode_rewards'])
episode_steps.extend(x['episode_steps'])
max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
max_episode_steps, mean_episode_steps, min_episode_steps =\
None, None, None, None, None, None
if episode_rewards:
mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
max_episode_rewards = np.max(np.array(episode_rewards).flatten())
min_episode_rewards = np.min(np.array(episode_rewards).flatten())
mean_episode_steps = np.mean(np.array(episode_steps).flatten())
max_episode_steps = np.max(np.array(episode_steps).flatten())
min_episode_steps = np.min(np.array(episode_steps).flatten())
metric = {
'Sample steps': self.sample_total_steps,
'max_episode_rewards': max_episode_rewards,
'mean_episode_rewards': mean_episode_rewards,
'min_episode_rewards': min_episode_rewards,
'max_episode_steps': max_episode_steps,
'mean_episode_steps': mean_episode_steps,
'min_episode_steps': min_episode_steps,
'total_loss': self.total_loss_stat.mean,
'pi_loss': self.pi_loss_stat.mean,
'vf_loss': self.vf_loss_stat.mean,
'entropy': self.entropy_stat.mean,
'learn_time_s': self.learn_time_stat.mean,
'elapsed_time_s': int(time.time() - self.start_time),
'lr': self.lr,
'entropy_coeff': self.entropy_coeff,
}
if metric['mean_episode_rewards'] is not None:
tensorboard.add_scalar('train/mean_reward',
metric['mean_episode_rewards'],
self.sample_total_steps)
tensorboard.add_scalar('train/total_loss', metric['total_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/pi_loss', metric['pi_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/vf_loss', metric['vf_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/entropy', metric['entropy'],
self.sample_total_steps)
tensorboard.add_scalar('train/learn_rate', metric['lr'],
self.sample_total_steps)
logger.info(metric)
def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps']
if __name__ == '__main__':
from a2c_config import config
cuda = torch.cuda.is_available()
learner = Learner(config, cuda)
assert config['log_metrics_interval_s'] > 0
while not learner.should_stop():
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
......@@ -22,10 +22,10 @@ import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from parl.core.torch.agent import Agent
import parl
class AtariAgent(Agent):
class AtariAgent(parl.Agent):
"""Base class of the Agent.
Args:
......
......@@ -16,10 +16,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from parl.core.torch.model import Model
import parl
class AtariModel(Model):
class AtariModel(parl.Model):
"""CNN network used in TensorPack examples.
Args:
......
......@@ -14,3 +14,4 @@
from parl.algorithms.torch.ddqn import *
from parl.algorithms.torch.dqn import *
from parl.algorithms.torch.a2c import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from random import random, randint
import parl
from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
__all__ = ['A2C']
class A2C(parl.Algorithm):
def __init__(self, model, config, hyperparas=None):
assert isinstance(config['vf_loss_coeff'], (int, float))
self.model = model
self.vf_loss_coeff = config['vf_loss_coeff']
self.optimizer = optim.Adam(
self.model.parameters(), lr=config['learning_rate'])
self.config = config
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
def learn(self, obs, actions, advantages, target_values):
prob = self.model.policy(obs, softmax_dim=1)
policy_distri = Categorical(prob)
actions_log_probs = policy_distri.log_prob(actions)
# The policy gradient loss
pi_loss = -((actions_log_probs * advantages).sum())
# The value function loss
values = self.model.value(obs).reshape(-1)
delta = values - target_values
vf_loss = 0.5 * torch.mul(delta, delta).sum()
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
policy_entropy = policy_distri.entropy()
entropy = policy_entropy.sum()
lr = self.lr_scheduler.step(step_num=obs.shape[0])
entropy_coeff = self.entropy_coeff_scheduler.step()
total_loss = pi_loss + vf_loss * self.vf_loss_coeff + entropy * entropy_coeff
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
total_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff
def sample(self, obs):
prob, values = self.model.policy_and_value(obs)
sample_actions = Categorical(prob).sample()
return sample_actions, values
def predict(self, obs):
prob = self.model.policy(obs)
_, predict_actions = prob.max(-1)
return predict_actions
def value(self, obs):
values = self.model.value(obs)
return values
......@@ -19,13 +19,13 @@ import copy
import torch
import torch.optim as optim
import torch.nn.functional as F
from parl.core.torch.algorithm import Algorithm
import parl
import numpy as np
__all__ = ['DDQN']
class DDQN(Algorithm):
class DDQN(parl.Algorithm):
def __init__(self, model, gamma=None, lr=None):
""" Double DQN algorithm
......
......@@ -19,13 +19,13 @@ import copy
import torch
import torch.optim as optim
import torch.nn.functional as F
from parl.core.torch.algorithm import Algorithm
import parl
import numpy as np
__all__ = ['DQN']
class DQN(Algorithm):
class DQN(parl.Algorithm):
def __init__(self, model, gamma=None, lr=None):
""" DQN algorithm
......
......@@ -39,7 +39,7 @@ class AlgorithmBase(object):
Args:
model_ids (List/Set): list/set of model_id, will only return weights of models
whiose model_id in the `model_ids`.
whose model_id in the `model_ids`.
Returns:
Dict of weights ({attribute name: numpy array/List/Dict})
......
......@@ -667,13 +667,8 @@ class ModelBaseTest(unittest.TestCase):
params = self.model.get_weights()
try:
with self.assertRaises(AssertionError):
self.model.set_weights(params[1:])
except:
# expected
return
assert False
def test_set_weights_with_wrong_params_shape(self):
pred_program = fluid.Program()
......@@ -691,14 +686,9 @@ class ModelBaseTest(unittest.TestCase):
x = np.random.random(size=(1, 4)).astype('float32')
try:
outputs = self.executor.run(
with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
except:
# expected
return
assert False
if __name__ == '__main__':
......
......@@ -52,7 +52,7 @@ class Agent(AgentBase):
Public Functions:
- ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an estimate Q function given current observation.
- ``learn``: update the parameters of self.alg.
- ``learn``: update the parameters of self.algorithm.
- ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path.
......@@ -60,21 +60,17 @@ class Agent(AgentBase):
- allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``.
"""
def __init__(self, algorithm, device):
def __init__(self, algorithm):
""".
Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`.
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.algorithm`.
device (torch.device): specify which GPU/CPU to be used.
"""
assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm)
self.alg = algorithm
self.device = torc.device('cuda' if torch.cuda.
is_available() else 'cpu')
def learn(self, *args, **kwargs):
"""The training interface for ``Agent``.
......@@ -102,10 +98,10 @@ class Agent(AgentBase):
Args:
save_path(str): where to save the parameters.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
Raises:
ValueError: if model is None and self.alg.model does not exist.
ValueError: if model is None and self.algorithm.model does not exist.
Example:
......@@ -116,7 +112,7 @@ class Agent(AgentBase):
"""
if model is None:
model = self.alg.model
model = self.algorithm.model
dirname = '/'.join(save_path.split('/')[:-1])
if not os.path.exists(dirname):
os.makedirs(dirname)
......@@ -129,10 +125,10 @@ class Agent(AgentBase):
Args:
save_path(str): path where parameters were previously saved.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
Raises:
ValueError: if model is None and self.alg does not exist.
ValueError: if model is None and self.algorithm does not exist.
Example:
......@@ -145,6 +141,6 @@ class Agent(AgentBase):
"""
if model is None:
model = self.alg.model
model = self.algorithm.model
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint)
......@@ -62,7 +62,7 @@ class Algorithm(AlgorithmBase):
assert isinstance(model, Model)
self.model = model
def get_weights(self):
def get_weights(self, model_ids=None):
""" Get weights of self.model.
Returns:
......@@ -71,7 +71,7 @@ class Algorithm(AlgorithmBase):
"""
return self.model.get_weights()
def set_weights(self, params):
def set_weights(self, params, model_ids=None):
""" Set weights from ``get_weights`` to the model.
Args:
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from parl.core.model_base import ModelBase
......@@ -116,7 +117,10 @@ class Model(nn.Module, ModelBase):
Returns: a Python list containing the parameters of current model.
"""
return list(self.parameters())
weights = self.state_dict()
for key in weights.keys():
weights[key] = weights[key].cpu().numpy()
return weights
def set_weights(self, weights):
"""Copy parameters from ``set_weights()`` to the model.
......@@ -124,8 +128,6 @@ class Model(nn.Module, ModelBase):
Args:
weights (list): a Python list containing the parameters.
"""
assert len(weights) == len(list(self.parameters())), \
'size of input weights should be same as weights number of current model'
for var, weight in zip(self.parameters(), weights):
var.data.copy_(weight.data)
for key in weights.keys():
weights[key] = torch.from_numpy(weights[key])
self.load_state_dict(weights)
......@@ -37,7 +37,7 @@ class TestModel(parl.Model):
class TestAlgorithm(parl.Algorithm):
def __init__(self, model):
self.model = model
super(TestAlgorithm, self).__init__(model)
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
def predict(self, obs):
......@@ -54,13 +54,13 @@ class TestAlgorithm(parl.Algorithm):
class TestAgent(parl.Agent):
def __init__(self, algorithm):
self.alg = algorithm
super(TestAgent, self).__init__(algorithm)
def learn(self, obs, label):
cost = self.alg.learn(obs, label)
cost = self.algorithm.learn(obs, label)
def predict(self, obs):
return self.alg.predict(obs)
return self.algorithm.predict(obs)
class AgentBaseTest(unittest.TestCase):
......@@ -95,6 +95,11 @@ class AgentBaseTest(unittest.TestCase):
current_output = agent.predict(obs).detach().cpu().numpy()
np.testing.assert_equal(current_output, previous_output)
def test_weights(self):
agent = TestAgent(self.alg)
weight = agent.get_weights()
agent.set_weights(weight)
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,7 @@ import numpy as np
import unittest
import os
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
......@@ -44,6 +45,7 @@ class ModelBaseTest(unittest.TestCase):
self.model = TestModel()
self.target_model = TestModel()
self.target_model2 = TestModel()
self.target_model3 = TestModel()
gpu_count = get_gpu_count()
device = torch.device('cuda' if gpu_count else 'cpu')
......@@ -282,22 +284,18 @@ class ModelBaseTest(unittest.TestCase):
params = self.model.get_weights()
expected_params = list(self.model.parameters())
self.assertEqual(len(params), len(expected_params))
for param in params:
flag = False
for expected_param in expected_params:
if param.sum().item() - expected_param.sum().item() < 1e-5:
flag = True
break
self.assertTrue(flag)
for i, key in enumerate(params):
self.assertLess(
(params[key].sum().item() - expected_params[i].sum().item()),
1e-5)
def test_set_weights(self):
params = self.model.get_weights()
new_params = [x + 1.0 for x in params]
self.target_model3.set_weights(params)
self.model.set_weights(new_params)
for x, y in list(zip(new_params, self.model.get_weights())):
self.assertEqual(x.sum().item(), y.sum().item())
for i, j in zip(params.values(),
self.target_model3.get_weights().values()):
self.assertLessEqual(abs(i.sum().item() - j.sum().item()), 1e-3)
def test_set_weights_between_different_models(self):
model1 = TestModel()
......@@ -323,20 +321,14 @@ class ModelBaseTest(unittest.TestCase):
def test_set_weights_wrong_params_num(self):
params = self.model.get_weights()
try:
with self.assertRaises(TypeError):
self.model.set_weights(params[1:])
except:
return
assert False
def test_set_weights_wrong_params_shape(self):
params = self.model.get_weights()
params.reverse()
try:
params['fc1.weight'] = params['fc2.bias']
with self.assertRaises(RuntimeError):
self.model.set_weights(params)
except:
return
assert False
if __name__ == '__main__':
......
......@@ -53,20 +53,12 @@ class TestScheduler(unittest.TestCase):
assert value == 0.3
def test_PiecewiseScheduler_with_empty(self):
try:
with self.assertRaises(AssertionError):
scheduler = PiecewiseScheduler([])
except AssertionError:
# expected
return
assert False
def test_PiecewiseScheduler_with_incorrect_steps(self):
try:
scheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)])
except AssertionError:
# expected
return
assert False
with self.assertRaises(AssertionError):
tscheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)])
def test_LinearDecayScheduler(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册