From 757cc3916dd6d3448a53a007fdddd233ed0215a1 Mon Sep 17 00:00:00 2001 From: fuyw Date: Wed, 25 Sep 2019 20:40:41 +0800 Subject: [PATCH] torchdqn (#150) * git commit -m torchdqn * yapf * fix bugs * fix bugs * fix bugs * yapf * remove fstring format * torch_test yapf * yapf * Add torch in unittest.requirements * update torch_unittest * Torch and FLUID conflict problem in __init__.py * Unittest fail for torch when both torch and fluid exists. * cluster_test fail in the unittest, add timeout seconds. * Torch backend for PARL * add sleep time for unit test send_job_test.py * Unit test for send_job_test.py * use multiple try for unit test * Fix compatibility for python2.7. * fix send_job_test.py bugs * check file exist before send_job_test.py * Modify send_job_test.py --- .teamcity/requirements.txt | 1 + benchmark/torch/dqn/agent.py | 104 ++++++ benchmark/torch/dqn/atari.py | 155 ++++++++ benchmark/torch/dqn/atari_wrapper.py | 102 ++++++ benchmark/torch/dqn/model.py | 69 ++++ benchmark/torch/dqn/replay_memory.py | 114 ++++++ benchmark/torch/dqn/rom_files/breakout.bin | Bin 0 -> 2048 bytes benchmark/torch/dqn/rom_files/pong.bin | Bin 0 -> 2048 bytes .../torch/dqn/rom_files/space_invaders.bin | Bin 0 -> 4096 bytes benchmark/torch/dqn/train.py | 189 ++++++++++ benchmark/torch/dqn/utils.py | 37 ++ parl/__init__.py | 5 +- parl/algorithms/__init__.py | 4 +- parl/algorithms/torch/__init__.py | 16 + parl/algorithms/torch/ddqn.py | 75 ++++ parl/algorithms/torch/dqn.py | 72 ++++ ...agent_base_test.py => agent_base_test_.py} | 0 ...model_base_test.py => model_base_test_.py} | 0 ...n_test.py => policy_distribution_test_.py} | 0 parl/core/torch/__init__.py | 17 + parl/core/torch/agent.py | 150 ++++++++ parl/core/torch/algorithm.py | 92 +++++ parl/core/torch/model.py | 131 +++++++ parl/core/torch/tests/agent_base_test.py | 102 ++++++ parl/core/torch/tests/agent_base_test_.py | 102 ++++++ parl/core/torch/tests/model_base_test.py | 345 ++++++++++++++++++ parl/remote/client.py | 12 +- parl/remote/job.py | 10 + parl/remote/tests/cluster_test.py | 2 +- parl/remote/tests/reset_job_test.py | 1 + parl/remote/tests/rom/pong.bin | 0 parl/remote/tests/send_job_test.py | 81 ++++ parl/utils/utils.py | 8 +- 33 files changed, 1989 insertions(+), 7 deletions(-) create mode 100644 benchmark/torch/dqn/agent.py create mode 100644 benchmark/torch/dqn/atari.py create mode 100644 benchmark/torch/dqn/atari_wrapper.py create mode 100644 benchmark/torch/dqn/model.py create mode 100644 benchmark/torch/dqn/replay_memory.py create mode 100755 benchmark/torch/dqn/rom_files/breakout.bin create mode 100755 benchmark/torch/dqn/rom_files/pong.bin create mode 100755 benchmark/torch/dqn/rom_files/space_invaders.bin create mode 100644 benchmark/torch/dqn/train.py create mode 100644 benchmark/torch/dqn/utils.py create mode 100644 parl/algorithms/torch/__init__.py create mode 100644 parl/algorithms/torch/ddqn.py create mode 100644 parl/algorithms/torch/dqn.py rename parl/core/fluid/tests/{agent_base_test.py => agent_base_test_.py} (100%) rename parl/core/fluid/tests/{model_base_test.py => model_base_test_.py} (100%) rename parl/core/fluid/tests/{policy_distribution_test.py => policy_distribution_test_.py} (100%) create mode 100644 parl/core/torch/__init__.py create mode 100644 parl/core/torch/agent.py create mode 100644 parl/core/torch/algorithm.py create mode 100644 parl/core/torch/model.py create mode 100644 parl/core/torch/tests/agent_base_test.py create mode 100644 parl/core/torch/tests/agent_base_test_.py create mode 100644 parl/core/torch/tests/model_base_test.py create mode 100644 parl/remote/tests/rom/pong.bin create mode 100644 parl/remote/tests/send_job_test.py diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt index 0b2d8e3..07ab270 100644 --- a/.teamcity/requirements.txt +++ b/.teamcity/requirements.txt @@ -4,3 +4,4 @@ gym details parameterized timeout_decorator +torch==1.2.0 diff --git a/benchmark/torch/dqn/agent.py b/benchmark/torch/dqn/agent.py new file mode 100644 index 0000000..95f383a --- /dev/null +++ b/benchmark/torch/dqn/agent.py @@ -0,0 +1,104 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym + +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from parl.core.torch.agent import Agent + + +class AtariAgent(Agent): + """Base class of the Agent. + + Args: + algorithm (object): Algorithm used by this agent. + args (argparse.Namespace): Model configurations. + device (torch.device): use cpu or gpu. + """ + + def __init__(self, algorithm, act_dim): + assert isinstance(act_dim, int) + self.act_dim = act_dim + self.exploration = 1 + self.global_step = 0 + self.update_target_steps = 10000 // 4 + + self.alg = algorithm + self.device = torch.device('cuda' if torch.cuda. + is_available() else 'cpu') + + def save(self, filepath): + state = { + 'model': self.alg.model.state_dict(), + 'target_model': self.alg.target_model.state_dict(), + 'optimizer': self.alg.optimizer.state_dict(), + 'scheduler': self.alg.scheduler.state_dict(), + 'exploration': self.exploration, + } + torch.save(state, filepath) + + def restore(self, filepath): + checkpoint = torch.load(filepath) + self.exploration = checkpoint['exploration'] + self.alg.model.load_state_dict(checkpoint['model']) + self.alg.target_model.load_state_dict(checkpoint['target_model']) + self.alg.optimizer.load_state_dict(checkpoint['optimizer']) + self.alg.scheduler.load_state_dict(checkpoint['scheduler']) + + def sample(self, obs): + sample = np.random.random() + if sample < self.exploration: + act = np.random.randint(self.act_dim) + else: + if np.random.random() < 0.01: + act = np.random.randint(self.act_dim) + else: + pred_q = self.predict(obs) + act = pred_q.max(1)[1].item() + self.exploration = max(0.1, self.exploration - 1e-6) + return act + + def predict(self, obs): + obs = np.expand_dims(obs, 0) + obs = torch.tensor(obs, dtype=torch.float, device=self.device) + pred_q = self.alg.predict(obs) + return pred_q + + def learn(self, obs, act, reward, next_obs, terminal): + if self.global_step % self.update_target_steps == 0: + self.alg.sync_target() + self.global_step += 1 + + act = np.expand_dims(act, -1) + terminal = np.expand_dims(terminal, -1) + reward = np.expand_dims(reward, -1) + reward = np.clip(reward, -1, 1) + + obs = torch.tensor(obs, dtype=torch.float, device=self.device) + next_obs = torch.tensor( + next_obs, dtype=torch.float, device=self.device) + act = torch.tensor(act, dtype=torch.long, device=self.device) + reward = torch.tensor(reward, dtype=torch.float, device=self.device) + terminal = torch.tensor( + terminal, dtype=torch.float, device=self.device) + + cost = self.alg.learn(obs, act, reward, next_obs, terminal) + return cost diff --git a/benchmark/torch/dqn/atari.py b/benchmark/torch/dqn/atari.py new file mode 100644 index 0000000..fa72788 --- /dev/null +++ b/benchmark/torch/dqn/atari.py @@ -0,0 +1,155 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py + +import cv2 +import gym +import numpy as np +import os +import threading +from atari_py import ALEInterface +from gym import spaces +from gym.envs.atari.atari_env import ACTION_MEANING + +__all__ = ['AtariPlayer'] + +ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms" +_ALE_LOCK = threading.Lock() + + +class AtariPlayer(gym.Env): + """ + A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. + Info: + score: the accumulated reward in the current game + gameOver: True when the current game is Over + """ + + def __init__(self, + rom_file, + viz=0, + frame_skip=4, + nullop_start=30, + live_lost_as_eoe=True, + max_num_frames=0): + """ + Args: + rom_file: path to the rom + frame_skip: skip every k frames and repeat the action + viz: visualization to be done. + Set to 0 to disable. + Set to a positive number to be the delay between frames to show. + Set to a string to be a directory to store frames. + nullop_start: start with random number of null ops. + live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. + max_num_frames: maximum number of frames per episode. + """ + super(AtariPlayer, self).__init__() + assert os.path.isfile(rom_file), \ + "rom {} not found. Please download at {}".format(rom_file, ROM_URL) + + try: + ALEInterface.setLoggerMode(ALEInterface.Logger.Error) + except AttributeError: + print("You're not using latest ALE") + + # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 + with _ALE_LOCK: + self.ale = ALEInterface() + self.ale.setInt(b"random_seed", np.random.randint(0, 30000)) + self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) + self.ale.setBool(b"showinfo", False) + + self.ale.setInt(b"frame_skip", 1) + self.ale.setBool(b'color_averaging', False) + # manual.pdf suggests otherwise. + self.ale.setFloat(b'repeat_action_probability', 0.0) + + # viz setup + if isinstance(viz, str): + assert os.path.isdir(viz), viz + self.ale.setString(b'record_screen_dir', viz) + viz = 0 + if isinstance(viz, int): + viz = float(viz) + self.viz = viz + if self.viz and isinstance(self.viz, float): + self.windowname = os.path.basename(rom_file) + cv2.startWindowThread() + cv2.namedWindow(self.windowname) + + self.ale.loadROM(rom_file.encode('utf-8')) + self.width, self.height = self.ale.getScreenDims() + self.actions = self.ale.getMinimalActionSet() + + self.live_lost_as_eoe = live_lost_as_eoe + self.frame_skip = frame_skip + self.nullop_start = nullop_start + + self.action_space = spaces.Discrete(len(self.actions)) + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) + self._restart_episode() + + def get_action_meanings(self): + return [ACTION_MEANING[i] for i in self.actions] + + def _grab_raw_image(self): + """ + :returns: the current 3-channel image + """ + m = self.ale.getScreenRGB() + return m.reshape((self.height, self.width, 3)) + + def _current_state(self): + """ + returns: a gray-scale (h, w) uint8 image + """ + ret = self._grab_raw_image() + # avoid missing frame issue: max-pooled over the last screen + ret = np.maximum(ret, self.last_raw_screen) + if self.viz: + if isinstance(self.viz, float): + cv2.imshow(self.windowname, ret) + cv2.waitKey(int(self.viz * 1000)) + ret = ret.astype('float32') + # 0.299,0.587.0.114. same as rgb2y in torch/image + ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) + return ret.astype('uint8') # to save some memory + + def _restart_episode(self): + with _ALE_LOCK: + self.ale.reset_game() + + # random null-ops start + n = np.random.randint(self.nullop_start) + self.last_raw_screen = self._grab_raw_image() + for k in range(n): + if k == n - 1: + self.last_raw_screen = self._grab_raw_image() + self.ale.act(0) + + def reset(self): + if self.ale.game_over(): + self._restart_episode() + return self._current_state() + + def step(self, act): + oldlives = self.ale.lives() + r = 0 + for k in range(self.frame_skip): + if k == self.frame_skip - 1: + self.last_raw_screen = self._grab_raw_image() + r += self.ale.act(self.actions[act]) + newlives = self.ale.lives() + if self.ale.game_over() or \ + (self.live_lost_as_eoe and newlives < oldlives): + break + + isOver = self.ale.game_over() + if self.live_lost_as_eoe: + isOver = isOver or newlives < oldlives + + info = {'ale.lives': newlives} + return self._current_state(), r, isOver, info diff --git a/benchmark/torch/dqn/atari_wrapper.py b/benchmark/torch/dqn/atari_wrapper.py new file mode 100644 index 0000000..f405162 --- /dev/null +++ b/benchmark/torch/dqn/atari_wrapper.py @@ -0,0 +1,102 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py + +import gym +import numpy as np +from collections import deque +from gym import spaces + +_v0, _v1 = gym.__version__.split('.')[:2] +assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__ + + +class MapState(gym.ObservationWrapper): + def __init__(self, env, map_func): + gym.ObservationWrapper.__init__(self, env) + self._func = map_func + + def observation(self, obs): + return self._func(obs) + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Buffer observations and stack across channels (last axis).""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + chan = 1 if len(shp) == 2 else shp[2] + self.observation_space = spaces.Box( + low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8) + + def reset(self): + """Clear buffer and re-fill by duplicating the first observation.""" + ob = self.env.reset() + for _ in range(self.k - 1): + self.frames.append(np.zeros_like(ob)) + self.frames.append(ob) + return self.observation() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self.observation(), reward, done, info + + def observation(self): + assert len(self.frames) == self.k + return np.stack(self.frames, axis=0) + + +class _FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self): + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs + + def step(self, action): + return self.env.step(action) + + +def FireResetEnv(env): + if isinstance(env, gym.Wrapper): + baseenv = env.unwrapped + else: + baseenv = env + if 'FIRE' in baseenv.get_action_meanings(): + return _FireResetEnv(env) + return env + + +class LimitLength(gym.Wrapper): + def __init__(self, env, k): + gym.Wrapper.__init__(self, env) + self.k = k + + def reset(self): + # This assumes that reset() will really reset the env. + # If the underlying env tries to be smart about reset + # (e.g. end-of-life), the assumption doesn't hold. + ob = self.env.reset() + self.cnt = 0 + return ob + + def step(self, action): + ob, r, done, info = self.env.step(action) + self.cnt += 1 + if self.cnt == self.k: + done = True + return ob, r, done, info diff --git a/benchmark/torch/dqn/model.py b/benchmark/torch/dqn/model.py new file mode 100644 index 0000000..8ba80d5 --- /dev/null +++ b/benchmark/torch/dqn/model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from parl.core.torch.model import Model + + +class AtariModel(Model): + """CNN network used in TensorPack examples. + + Args: + input_channel (int): Input channel of states. + act_dim (int): Dimension of action space. + algo (str): which ('DQN', 'Double', 'Dueling') model to use. + """ + + def __init__(self, input_channel, act_dim, algo='DQN'): + super(AtariModel, self).__init__() + self.conv1 = nn.Conv2d( + input_channel, 32, kernel_size=8, stride=4, padding=2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=2) + self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) + + self.algo = algo + if self.algo == 'Dueling': + self.fc1_adv = nn.Linear(7744, 512) + self.fc1_val = nn.Linear(7744, 512) + self.fc2_adv = nn.Linear(512, act_dim) + self.fc2_val = nn.Linear(512, 1) + else: + self.fc1 = nn.Linear(7744, 512) + self.fc2 = nn.Linear(512, act_dim) + + self.reset_params() + + def reset_params(self): + for m in self.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + nn.init.zeros_(m.bias) + + def forward(self, x): + x = x / 255.0 + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + if self.algo == 'Dueling': + As = self.fc2_adv(F.relu(self.fc1_adv(x))) + V = self.fc2_val(F.relu(self.fc1_val(x))) + Q = As + (V - As.mean(dim=1, keepdim=True)) + else: + Q = self.fc2(F.relu(self.fc1(x))) + return Q diff --git a/benchmark/torch/dqn/replay_memory.py b/benchmark/torch/dqn/replay_memory.py new file mode 100644 index 0000000..ea8c656 --- /dev/null +++ b/benchmark/torch/dqn/replay_memory.py @@ -0,0 +1,114 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import copy +from collections import deque, namedtuple + +Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) + + +class ReplayMemory(object): + def __init__(self, max_size, state_shape, context_len): + self.max_size = int(max_size) + self.state_shape = state_shape + self.context_len = int(context_len) + + self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8') + self.action = np.zeros((self.max_size, ), dtype='int32') + self.reward = np.zeros((self.max_size, ), dtype='float32') + self.isOver = np.zeros((self.max_size, ), dtype='bool') + + self._curr_size = 0 + self._curr_pos = 0 + self._context = deque(maxlen=context_len - 1) + + def append(self, exp): + """append a new experience into replay memory + """ + if self._curr_size < self.max_size: + self._assign(self._curr_pos, exp) + self._curr_size += 1 + else: + self._assign(self._curr_pos, exp) + self._curr_pos = (self._curr_pos + 1) % self.max_size + if exp.isOver: + self._context.clear() + else: + self._context.append(exp) + + def recent_state(self): + """ maintain recent state for training""" + lst = list(self._context) + states = [np.zeros(self.state_shape, dtype='uint8')] * \ + (self._context.maxlen - len(lst)) + states.extend([k.state for k in lst]) + return states + + def sample(self, idx): + """ return state, action, reward, isOver, + note that some frames in state may be generated from last episode, + they should be removed from state + """ + state = np.zeros( + (self.context_len + 1, ) + self.state_shape, dtype=np.uint8) + state_idx = np.arange(idx, + idx + self.context_len + 1) % self._curr_size + + # confirm that no frame was generated from last episode + has_last_episode = False + for k in range(self.context_len - 2, -1, -1): + to_check_idx = state_idx[k] + if self.isOver[to_check_idx]: + has_last_episode = True + state_idx = state_idx[k + 1:] + state[k + 1:] = self.state[state_idx] + break + + if not has_last_episode: + state = self.state[state_idx] + + real_idx = (idx + self.context_len - 1) % self._curr_size + action = self.action[real_idx] + reward = self.reward[real_idx] + isOver = self.isOver[real_idx] + return state, reward, action, isOver + + def __len__(self): + return self._curr_size + + def size(self): + return self._curr_size + + def _assign(self, pos, exp): + self.state[pos] = exp.state + self.reward[pos] = exp.reward + self.action[pos] = exp.action + self.isOver[pos] = exp.isOver + + def sample_batch(self, batch_size): + """sample a batch from replay memory for training + """ + batch_idx = np.random.randint( + self._curr_size - self.context_len - 1, size=batch_size) + batch_idx = (self._curr_pos + batch_idx) % self._curr_size + batch_exp = [self.sample(i) for i in batch_idx] + return self._process_batch(batch_exp) + + def _process_batch(self, batch_exp): + state = np.asarray([e[0] for e in batch_exp], dtype='uint8') + reward = np.asarray([e[1] for e in batch_exp], dtype='float32') + action = np.asarray([e[2] for e in batch_exp], dtype='int8') + isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') + return [state, action, reward, isOver] diff --git a/benchmark/torch/dqn/rom_files/breakout.bin b/benchmark/torch/dqn/rom_files/breakout.bin new file mode 100755 index 0000000000000000000000000000000000000000..abab5a8c0a1890461a11b78d4265f1b794327793 GIT binary patch literal 2048 zcmYjSUu+Y}8Q&eR9orZ@0S?>)!V(N+xmLH$wUAR=C8`p1vGQ1#OV>y^s;Wp; z>-ME*kY{B@RI*OEvUsShRmRG?Ldw782(}MPSIpX&V~+v_0!ff@Vkew$lT++d$({SH zudVv+dguEyGvCbj`^|Xx^k#w!qiJ1mAFGpH)As~7Nl)Jw3~QRWrymMN+xrN-?&`R} z>h~k~vv7KRGMV*| zdp6i@N^l_984Py5)fw#U1TPo^!Qk28z(6O6seuauLAogP8G+P9b6#(fOi^%8s-feV zsD=Na0S!LM|E4K?6Gaa8WJKRIc8!|$2VjRr!G<2l=Yy|xbaYh6#X|X8x=pc>TluDx zfthpG3^C_t{j%lwbP>PuCl|@$xX`bI=ctlPSos?pN-MG7Z2A(`Sg07m|MK_N-b^5^ z=AL>r)T-HFC1Di7U&nBa@0S^=Cap?qQdO!*&!u($SDukOI0_K zyX%p9JyMF)BN6a|(Os{n>fPND5K{ve1nP__e|RLfE?386rKqna>bR*qf;}_dCBq{+$GGG&cF7#Jk8!dcBN^qF67KMKr2UGE zyhfiv5lA}WkXu(SAi=J4i9ND=+&d>C0Z@}2%bd8^JNxcoI)@K3bB$*td+t z9aJ?s=2Oj~sBVE~Fnuf&%S4K}MC`-{jLH&+#7WjGVWhv3ki<$R>CA^wx?!^`PxRMT z7L4~+p3%9%y5XSL4XZN24>e&SzNs##iZmYy&xPw+_?c|ER05F0Pr?iQ^NrzOZu#U^ zl#`xL9j8x4JAESX)&=AB8hxC6$ddousU=WHpME z7O5#?q@uh-mb{DD#`l;Dqnfv8%ti~|x85D|u8;kC%bdTT-v!pwfo||`cPRWr)?#`h-xjIkpgS{+#-D|B_FLMlTn`EDAJb;L;aFPRQvGlo?WqdjcFO%N(;tRqr)y1Nf0uC0Avuf?{6)7u&QZ-vc_IB&_V< zf7KvwkbH`_7%~C_*wMa|$sy7uW63Go_>|GUcH3@w2blbi4dCh_sY^O69UT6}CWPBT zEsTo#rXNnG`6uIaj->t00Q0}sFRuZ~t11K7#Td0U)Cw@MmVc<#8({o^!cROSpHI@7 zovzwtH2<{uj$nzco)flXRjpXW^%J56a5>TDpUDjWxO4ETfDD_ah12=~%?I`lwSK3K zKEBYLvdmv-YM?~hN;nptnb!lr$GhEoVQeG+RT)vzytC}Cq zh)QP!3%PDyXsN8xB(wbAHk21+QD)~#>Zh^NrB91^i>iQys@O>pIC-lmnTPhaU&?YsM5?dZS-Zs6y}at@;BS@!&S6herF$CUTc`wX`C QhoX@jFP`w9-pp_O8(sa^HUIzs literal 0 HcmV?d00001 diff --git a/benchmark/torch/dqn/rom_files/pong.bin b/benchmark/torch/dqn/rom_files/pong.bin new file mode 100755 index 0000000000000000000000000000000000000000..14a5bdfc72548613c059938bdf712efdbb5d3806 GIT binary patch literal 2048 zcmZ8hZ)_Up89!eB`hv|lOB%8y^VK@iuw~iR(;60~GwIrDP_Ofc#8taZ^!K6I+8eLz#Nus=+f#U_=KiM`lL zRPDLD_kDiP``q*Ip6A~EJkg1@f)H!U9UDmcl_Y6O9v!rfOR!Yb5n0a3tyG zJayrA^a)$5*Qw8>;W)vB6f0|z=w`Tip@t|Vp$ugd%dDCn+9XMjIm3h`%>%BXfDe;4VQag4T))ek4CksCmQ_@T~<7J-^gekW>+F z-kp^ZtlqtZ*IQ|sWTKPxGr%2O=;lorD?l z;@?u!$`^?}cf;(vcQ3GSLohX?ya0teGxA$&GZMwj%Q|t@a*KUKB47qCLP`1&gd`5K zAV5J>_9S-0unIIZ7)L=l(#@ zZ*J13x5Ow3{)9SzV@YCnF|kpVQ^@5PWV54Fvn?XYRx3o1RABt08s~2-ZDE0aX|qoK z9alh~wyJ>Z2C3{3g$kjNw+o7H5qY%h2#1#wXEP22mspC4QAX*YB_QjBTVX&} zHh@cXB*Rf3l2O9MdlTxo-dws$FkV)@5^cLVi&xMEu=0j5)_LLL4J*WiQH zhm9Jp6#;SVP?gEgTZe^tF=`zk)1}*Oh?Q=&;fl3FY&bJ%iC|(AxKn_`&L~xNRMo=n zRP@vL>q#$O#d)U_j(}CX0up3M_y|6h)ERBeI-Jf+Wa(e+MUAi!K|&4(^KV`iSC168 zO4r)jD}tt!Y->3|EPdSFuDh55<*w-2SMUe=d1F9+ohegqus;)L*`HwUl~lH|6w%>m zM?SibEALQcN<&2Q(*7ZWq>i!2E4}QI#%`srqLFpjkcAy+^ioeTS#}?_9~(4QjIyCY zu^QYm8!2w&gL?@=!Bg1|^Rv=0%t$5(;WouuM;iSsvsUP-$R{gC_Ujlu)!0Mr#Mwwggv)1T^~m*ml$thdF}_hX zJ-8|Pv7gL6v4qDMN@CH`@eojSIXX5E;UYc6liO}v|4tlL*G-EL9i4R@vEfGn>c0Xd z&@SApLnN!OQ(w0D%PwYq^x36cUZT*_0v>psUyx2iD1Xyox6H$3WpC8h6#RmDNuspw zEM#5o0u)@*EeNe$EONlY7IxSvQ_J4o5E$>k@`I~ZIvP@H3{&KR|-HLgMeN{4i zPXx|kaC+3g&NU0xVP}nspq>WKBgx6?8g~|oOw7ijpi?YIkb|RCF0tIrvh1P^fy$#r z6o2HZR9GH&=ovBY6%FOa;S zhhSW>HUDTyDJV_^Tejws<+{OtFxdUSWUoSR_I>Q&r@lQFC^2U;!2Y3Rhx!`TfQJ7?!QYQ4$qI^d>NjE?SMdg!4d}`%6{`i!A*>mw)<^4a|mj%zu@BH$$m*=KlJDm>y>hxP@JhqB`Ih~F~ zip6L&$^|$~zJQnWPXEs9KLrl{)KoAKoOT9%ffr_e7z%~`;ct&y#y>eWKAt-Ei*ZF^ zyF@&pZP@8N!CzrIo$bT_TOR99*tcn7=Y_}89mEkk*JG*j6{`f%>FBjut$wgell_w~ hPEPiJw}0}?naTbxz|V`5XRr(t-Z9rbAQao?{{YB7UaSBB literal 0 HcmV?d00001 diff --git a/benchmark/torch/dqn/rom_files/space_invaders.bin b/benchmark/torch/dqn/rom_files/space_invaders.bin new file mode 100755 index 0000000000000000000000000000000000000000..82afd60dc52dc0a1df97c670b618574d9dfc457f GIT binary patch literal 4096 zcmaJ^eNbE1m49E7o@5CbuBCP_23o3><9y3usI z!#_@%9WOzI52}@46{SynG8Ztw(Ey&@@1MjY56d=#WN7V z2BAGqkS3j-*$2`)_uQ{@&;6a>If5F}?q_`NL{`f8EhYMLP}yg|*1lPM&*OUmDm!5- zRHlz~?0s?Xi~rvcazqmpfig*cnsUdtF&)xxGVPV^Tia{QZ{X5L{%Qnli?I{OOR@PUmlWH;2P;iSbz$#|lu_?Z zP5E#BpLI#b7YGJ}L#=uliZ(}rWOnY{>En9`s`6?sbkB%c@U&3l{cS#GMXARgVvO_& zzZDg@p&V`IL%+R4{s30W!U;5F2JXNdeH(JL9}GGOvvebHPy_P>!0U7OcHnB#ni-T+ z*dT1f?wvgh0f{4EIyY;&7kxX8&aE_^o1XV?BJ)8-%usJo&l{iPF{MA?UX{QR&f#md zw3)6zE$+ji^VZGlMe`&+)Hb``L;eU>a`-DNaQ1buk$;E#*X=Sb-HC%K^H3Wnanu49 z$~2g7fkb_h!BGdEBkJ=Or3I1*0;iOp)dtJ*utW=Rmt=jSF2Fo2NY!ZH63L=(Nx~-> zBCSetv=pZ&d>pn)u~h{NF?)|4I2`Ii1XUt}wsr=4`YUp*i-#i38XLiRBVd`8~U-pWtQ~1HXBlPYHJ5APX&v%?d0|*pX5WndAY6O0fgpmZMRWHcS9z|;tdEo zo23GzgJ>liROhM1dlXFvBW>V~Uz$H_khaIhRt%$urH=&i59d8iiuP3Ug`v@;E$kW^ zRoIU9+}T`DfpmMq`Ak?E{ZP5vK6BO>7+ZPdl=1i9JEfkA=LSAq>0zs*I@>W=lkZw! zD2KjDpVS>BVHY7mpNFjX_Hw7g;jqj=hAzZ&U877WJ{L^xTDQ2GoYDG?Z6q2@TGMCr zGDDm=gQ_U{sUOauD=(O*c&C-?#YG&D^R+^RT+TC7oDoh8zX2TwfK znAbgW&d)8JZstBXojR?(q~KfIkwi^IZfCtL>^tuATv z-dU!8JS@f;FlMw-|MCAcL4i#Z5j$|1QSB97z)K%R=~mL28V#>%JOUmyS?M$Xeq7Y8 zxebeVIS-$ZF`s}@LfM4EIv|2@x|RGEDmir9C)XK+L&qqkk|T9RizJ;3S9B4jay<2w z)QG1kqK;_)j7*9M%S!L3vj!hJwm^mrVPbeQvexa5|YqhCq=QdQpgWmfva`)C2L zp$-zhI-N^v=H_2QBAwRnBPTVRQjQMtu}?%H8l+Lhbn+Q0AkaexO=iL9xVHFC8|ULS zPDego4ecp5Im1p8aNGaw&_6EHJ{ZQ!EkVPQXzgw9Idjy+ZE{<4EBE2#Aa&uw9NfmB znu91A^Mp?tT(}v=&?&8uKpf!%w|WP@lZR_K=u*W~d5F>kOwt(K=wYIPrZLT-STi>q zp^C*1TOn4m_S)da_JkIJt>5L1yf%$9X)p!uVE9~X2UzUhFJYRF#*Ge zK3YVtl^P-jf*k|le%Tdz3ssUM{jqLEMQL-DoM{kze)W~9 zWZEa`7LxO`iNiBE4vbh~>bf;qlR8yx7+=uZCIzVpE(gX(G)dX0-ATIQGo%}x7#85F z8A5QmXxYD_pI$<;I%za#KK}EHa&KT^1#$zoS6~KX>2WYTO*7g*B%daRVIP`NzB4`( ziZ97WslAy)4n9g*lhdOZQf@S}DF@v{$CU@jaot6_1LN4#%9G`^9J)z#9|uwY284`_ zD4wtd6vX>6x*;&GzzIzDHgdufcEX9+a_YZyXZo1gw8M!h-5@8VhjEKK!)4-ieV9~* zm_0$F{xZZ;G=2=%d}X!(%9nX2hWT~#sjRdNH_1(7S1{>8KBbQ$$y=T^)ucNtLHCn4 z_0O>QSab9eZ42ZPWygsEuVeP<^8Qmkt1p>xAody@U9{UE7h4)AtUNJXom0m%Q~o(T zj(})G8uJA-KhpOYa)Lnq1fP*lu_X8tCehD9#E^;!NiN15o=J_2URG+5j4rDqB%)qU zT|xP(dIjaHCdo;7QfDyOi?6wgEf^1}t1$~2au#R8kwYY%=lnZ@d691wvSDfsQ&zidU;$PbHeL37F62H@zGn0=c z0wp-Z1@8F5de7WfJ?Wg3AV)mmyKux@*4T@u!ZVr0QV(v8Kt#1*9&s)6=_SQNa=ns- z9QM@UV6)mWYdg$`JT^Nd1E+CNN!qb)`Ad0-iM&{f;ewW!gRZDESQFmE@r3i$nWY{U zI>uU|Lp}p5*wU{{8+t5QYM5m(G0fVud-?F}?*e(uK+h%x?TO^yoSmiL!}L7(`O2;~ z&Vxu1!j^-lyFO$nkCcj3(rRE{HyPQa`Ns8Qh+mR&BlERqb@k9{AaE!Ukly{>4}Z4* zh1yW*gDCw#hM)fAK}qT8*i&C$xOTu}-?(wRDYbYAAzfXhaG<5O{=k6)Yts9&$KyFL zNFourUdHah10E()UoXqGEjL?g8TmjxN*}h=zPxA8p63d4b1k*G?c29MTDWwnrFOEu zrR8Vej>hZRG)m7Ns;I7SWdes*R}-s;0;Z*M?Mu%+STvcH)xE>A;;sAP|GfM1g9W2O5M!o7KgiZu>b96_W$SwMiiWhch6m@ZNaBdD7PF8TeEBr%L=T593TYjzy$?F<2ieB`{CS_nIFA6m(5-mlXfz#ol|&kB7Zj5 zXFL`voHQEE3fk}1u#Wxmmqv|rpQov_2x9YB$u`mf@_O}ON)JJ&iQwY2ul+uQZH;8w z;qWv~d&2u+x)G`5xrK}U=Z$C2Sq$MEr_22p&xu%*MFiBtMYz;BbO9R|uign6e`Y-V7M&#+o)is`m z58)%Y0#`9Njlu^o0+%rbKA%^+kc2zoJQ^j#Nm_5gBw*Hwn7$Sv5!=!u4s6e>Z+A#s zkC>I_8ee_yk28zQ`PE4bjLFIR`k&TM%9D{v{KvltYVj>kPLfGkj^JZTsKSoPHB6V; z9bARz;I}zNp{&GaEwCOZ&zCuQOR1B$ZF2H<$;odhaq=#glXs)M@p0B8+Jp+xB5f5p yf0f8K;2hp*mmK28O(pJ%vJ%m^r9{|$UkSIbAh{f(&1$jQ5cgJwiLT;b;eP=p literal 0 HcmV?d00001 diff --git a/benchmark/torch/dqn/train.py b/benchmark/torch/dqn/train.py new file mode 100644 index 0000000..9db3b8f --- /dev/null +++ b/benchmark/torch/dqn/train.py @@ -0,0 +1,189 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import cv2 +import gym +import os +import threading +import torch +import parl + +import numpy as np +from tqdm import tqdm +from parl.utils import tensorboard, logger +from parl.algorithms import DQN, DDQN + +from agent import AtariAgent +from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState +from model import AtariModel +from replay_memory import ReplayMemory, Experience +from utils import get_player + +MEMORY_SIZE = int(1e6) +MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 +FRAME_SKIP = 4 +UPDATE_FREQ = 4 +GAMMA = 0.99 + + +def run_train_episode(env, agent, rpm): + total_reward = 0 + all_cost = [] + state = env.reset() + steps = 0 + while True: + steps += 1 + context = rpm.recent_state() + context.append(state) + context = np.stack(context, axis=0) + action = agent.sample(context) + next_state, reward, isOver, _ = env.step(action) + rpm.append(Experience(state, action, reward, isOver)) + if rpm.size() > MEMORY_WARMUP_SIZE: + if steps % UPDATE_FREQ == 0: + batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( + args.batch_size) + batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] + batch_next_state = batch_all_state[:, 1:, :, :] + cost = agent.learn(batch_state, batch_action, batch_reward, + batch_next_state, batch_isOver) + all_cost.append(cost) + total_reward += reward + state = next_state + if isOver: + mean_loss = np.mean(all_cost) if all_cost else None + return total_reward, steps, mean_loss + + +def run_evaluate_episode(env, agent): + state = env.reset() + total_reward = 0 + while True: + pred_Q = agent.predict(state) + action = pred_Q.max(1)[1].item() + state, reward, isOver, _ = env.step(action) + total_reward += reward + if isOver: + return total_reward + + +def get_fixed_states(rpm, batch_size): + states = [] + for _ in range(3): + batch_all_state = rpm.sample_batch(batch_size)[0] + batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] + states.append(batch_state) + fixed_states = np.concatenate(states, axis=0) + return fixed_states + + +def evaluate_fixed_Q(agent, states): + with torch.no_grad(): + max_pred_Q = agent.alg.model(states).max(1)[0].mean() + return max_pred_Q.item() + + +def get_grad_norm(model): + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item()**2 + total_norm = total_norm**(1. / 2) + return total_norm + + +def main(): + env = get_player( + args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP) + test_env = get_player( + args.rom, + image_size=IMAGE_SIZE, + frame_skip=FRAME_SKIP, + context_len=CONTEXT_LEN) + rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN) + act_dim = env.action_space.n + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = AtariModel(CONTEXT_LEN, act_dim, args.algo) + if args.algo in ['DQN', 'Dueling']: + algorithm = DQN(model, gamma=GAMMA, lr=args.lr) + elif args.algo is 'Double': + algorithm = DDQN(model, gamma=GAMMA, lr=args.lr) + agent = AtariAgent(algorithm, act_dim=act_dim) + + with tqdm( + total=MEMORY_WARMUP_SIZE, desc='[Replay Memory Warm Up]') as pbar: + while rpm.size() < MEMORY_WARMUP_SIZE: + total_reward, steps, _ = run_train_episode(env, agent, rpm) + pbar.update(steps) + + # Get fixed states to check value function. + fixed_states = get_fixed_states(rpm, args.batch_size) + fixed_states = torch.tensor(fixed_states, dtype=torch.float, device=device) + + # train + test_flag = 0 + total_steps = 0 + + with tqdm(total=args.train_total_steps, desc='[Training Model]') as pbar: + while total_steps < args.train_total_steps: + total_reward, steps, loss = run_train_episode(env, agent, rpm) + total_steps += steps + pbar.update(steps) + if total_steps // args.test_every_steps >= test_flag: + while total_steps // args.test_every_steps >= test_flag: + test_flag += 1 + + eval_rewards = [] + for _ in range(3): + eval_rewards.append(run_evaluate_episode(test_env, agent)) + + tensorboard.add_scalar('dqn/eval', np.mean(eval_rewards), + total_steps) + tensorboard.add_scalar('dqn/score', total_reward, total_steps) + tensorboard.add_scalar('dqn/loss', loss, total_steps) + tensorboard.add_scalar('dqn/exploration', agent.exploration, + total_steps) + tensorboard.add_scalar('dqn/Q value', + evaluate_fixed_Q(agent, fixed_states), + total_steps) + tensorboard.add_scalar('dqn/grad_norm', + get_grad_norm(agent.alg.model), + total_steps) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--rom', default='rom_files/breakout.bin') + parser.add_argument( + '--batch_size', type=int, default=32, help='batch size for training') + parser.add_argument('--lr', default=3e-4, help='learning_rate') + parser.add_argument('--algo', default='DQN', help='DQN/Double/Dueling DQN') + parser.add_argument( + '--train_total_steps', + type=int, + default=int(1e7), + help='maximum environmental steps of games') + parser.add_argument( + '--test_every_steps', + type=int, + default=int(1e5), + help='the step interval between two consecutive evaluations') + args = parser.parse_args() + rom_name = args.rom.split('/')[-1].split('.')[0] + logger.set_dir(os.path.join('./train_log', rom_name)) + main() diff --git a/benchmark/torch/dqn/utils.py b/benchmark/torch/dqn/utils.py new file mode 100644 index 0000000..b938819 --- /dev/null +++ b/benchmark/torch/dqn/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +from atari import AtariPlayer +from atari_wrapper import FrameStack, MapState, FireResetEnv + + +def get_player(rom, + image_size, + viz=False, + train=False, + frame_skip=1, + context_len=1): + env = AtariPlayer( + rom, + frame_skip=frame_skip, + viz=viz, + live_lost_as_eoe=train, + max_num_frames=60000) + env = FireResetEnv(env) + env = MapState(env, lambda im: cv2.resize(im, image_size)) + if not train: + # in training, context is taken care of in expreplay buffer + env = FrameStack(env, context_len) + return env diff --git a/parl/__init__.py b/parl/__init__.py index b62e723..c17f8cd 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -19,10 +19,13 @@ generates new PARL python API import os from tensorboardX import SummaryWriter -from parl.utils.utils import _HAS_FLUID +from parl.utils.utils import _HAS_FLUID, _HAS_TORCH + if _HAS_FLUID: from parl.core.fluid import * from parl.core.fluid.plutils.compiler import compile +elif _HAS_TORCH: + from parl.core.torch import * from parl.remote import remote_class, connect from parl import algorithms diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py index 60359bb..20c3d3d 100644 --- a/parl/algorithms/__init__.py +++ b/parl/algorithms/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parl.utils.utils import _HAS_FLUID +from parl.utils.utils import _HAS_FLUID, _HAS_TORCH if _HAS_FLUID: from parl.algorithms.fluid import * +elif _HAS_TORCH: + from parl.algorithms.torch import * diff --git a/parl/algorithms/torch/__init__.py b/parl/algorithms/torch/__init__.py new file mode 100644 index 0000000..abc70cd --- /dev/null +++ b/parl/algorithms/torch/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.algorithms.torch.ddqn import * +from parl.algorithms.torch.dqn import * diff --git a/parl/algorithms/torch/ddqn.py b/parl/algorithms/torch/ddqn.py new file mode 100644 index 0000000..9b6e271 --- /dev/null +++ b/parl/algorithms/torch/ddqn.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +import copy +import torch +import torch.optim as optim +import torch.nn.functional as F +from parl.core.torch.algorithm import Algorithm +import numpy as np + +__all__ = ['DDQN'] + + +class DDQN(Algorithm): + def __init__(self, model, gamma=None, lr=None): + """ Double DQN algorithm + + Args: + model (parl.Model): model defining forward network of Q function. + gamma (float): discounted factor for reward computation. + lr (float): learning rate. + """ + self.model = model + self.target_model = copy.deepcopy(model) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(device) + self.target_model.to(device) + + assert isinstance(gamma, float) + assert isinstance(lr, float) + self.gamma = gamma + self.lr = lr + + self.mse_loss = torch.nn.MSELoss() + self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) + + def predict(self, obs): + """ use value model self.model to predict the action value + """ + with torch.no_grad(): + pred_q = self.model(obs) + return pred_q + + def learn(self, obs, action, reward, next_obs, terminal): + """ update value model self.model with Double DQN algorithm + """ + pred_value = self.model(obs).gather(1, action) + # model for selection actions. + greedy_action = self.model(next_obs).max(dim=1, keepdim=True)[1] + with torch.no_grad(): + # target_model for evaluation. + max_v = self.target_model(next_obs).gather(1, greedy_action) + target = reward + (1 - terminal) * self.gamma * max_v + self.optimizer.zero_grad() + loss = self.mse_loss(pred_value, target) + loss.backward() + self.optimizer.step() + return loss.item() + + def sync_target(self): + self.model.sync_weights_to(self.target_model) diff --git a/parl/algorithms/torch/dqn.py b/parl/algorithms/torch/dqn.py new file mode 100644 index 0000000..9244f5d --- /dev/null +++ b/parl/algorithms/torch/dqn.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +import copy +import torch +import torch.optim as optim +import torch.nn.functional as F +from parl.core.torch.algorithm import Algorithm +import numpy as np + +__all__ = ['DQN'] + + +class DQN(Algorithm): + def __init__(self, model, gamma=None, lr=None): + """ DQN algorithm + + Args: + model (parl.Model): model defining forward network of Q function. + gamma (float): discounted factor for reward computation. + lr (float): learning rate. + """ + self.model = model + self.target_model = copy.deepcopy(model) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(device) + self.target_model.to(device) + + assert isinstance(gamma, float) + assert isinstance(lr, float) + self.gamma = gamma + self.lr = lr + + self.mse_loss = torch.nn.MSELoss() + self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) + + def predict(self, obs): + """ use value model self.model to predict the action value + """ + with torch.no_grad(): + pred_q = self.model(obs) + return pred_q + + def learn(self, obs, action, reward, next_obs, terminal): + """ update value model self.model with DQN algorithm + """ + pred_value = self.model(obs).gather(1, action) + with torch.no_grad(): + max_v = self.target_model(next_obs).max(1, keepdim=True)[0] + target = reward + (1 - terminal) * self.gamma * max_v + self.optimizer.zero_grad() + loss = self.mse_loss(pred_value, target) + loss.backward() + self.optimizer.step() + return loss.item() + + def sync_target(self): + self.model.sync_weights_to(self.target_model) diff --git a/parl/core/fluid/tests/agent_base_test.py b/parl/core/fluid/tests/agent_base_test_.py similarity index 100% rename from parl/core/fluid/tests/agent_base_test.py rename to parl/core/fluid/tests/agent_base_test_.py diff --git a/parl/core/fluid/tests/model_base_test.py b/parl/core/fluid/tests/model_base_test_.py similarity index 100% rename from parl/core/fluid/tests/model_base_test.py rename to parl/core/fluid/tests/model_base_test_.py diff --git a/parl/core/fluid/tests/policy_distribution_test.py b/parl/core/fluid/tests/policy_distribution_test_.py similarity index 100% rename from parl/core/fluid/tests/policy_distribution_test.py rename to parl/core/fluid/tests/policy_distribution_test_.py diff --git a/parl/core/torch/__init__.py b/parl/core/torch/__init__.py new file mode 100644 index 0000000..64ea6ed --- /dev/null +++ b/parl/core/torch/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.core.torch.model import * +from parl.core.torch.algorithm import * +from parl.core.torch.agent import * diff --git a/parl/core/torch/agent.py b/parl/core/torch/agent.py new file mode 100644 index 0000000..7e2ef38 --- /dev/null +++ b/parl/core/torch/agent.py @@ -0,0 +1,150 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +import os +import torch + +from parl.core.agent_base import AgentBase +from parl.core.torch.algorithm import Algorithm +from parl.utils import machine_info + +__all__ = ['Agent'] +torch.set_num_threads(1) + + +class Agent(AgentBase): + """ + | `alias`: ``parl.Agent`` + | `alias`: ``parl.core.torch.agent.Agent`` + + | Agent is one of the three basic classes of PARL. + + | It is responsible for interacting with the environment and collecting data for training the policy. + | To implement a customized ``Agent``, users can: + + .. code-block:: python + + import parl + + class MyAgent(parl.Agent): + def __init__(self, algorithm, act_dim): + super(MyAgent, self).__init__(algorithm) + self.act_dim = act_dim + + Attributes: + device (torch.device): select GPU/CPU to be used. + alg (parl.Algorithm): algorithm of this agent. + + Public Functions: + - ``sample``: return a noisy action to perform exploration according to the policy. + - ``predict``: return an estimate Q function given current observation. + - ``learn``: update the parameters of self.alg. + - ``save``: save parameters of the ``agent`` to a given path. + - ``restore``: restore previous saved parameters from a given path. + + Todo: + - allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``. + """ + + def __init__(self, algorithm, device): + """. + + Args: + algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`. + device (torch.device): specify which GPU/CPU to be used. + """ + + assert isinstance(algorithm, Algorithm) + super(Agent, self).__init__(algorithm) + + self.alg = algorithm + self.device = torc.device('cuda' if torch.cuda. + is_available() else 'cpu') + + def learn(self, *args, **kwargs): + """The training interface for ``Agent``. + + It is often used in the training stage. + """ + raise NotImplementedError + + def predict(self, *args, **kwargs): + """Predict an estimated Q value when given the observation of the environment. + + It is often used in the evaluation stage. + """ + raise NotImplementedError + + def sample(self, *args, **kwargs): + """Return an action with noise when given the observation of the environment. + + In general, this function is used in train process as noise is added to the action to preform exploration. + + """ + raise NotImplementedError + + def save(self, save_path, model=None): + """Save parameters. + + Args: + save_path(str): where to save the parameters. + model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. + + Raises: + ValueError: if model is None and self.alg.model does not exist. + + Example: + + .. code-block:: python + + agent = AtariAgent() + agent.save('./model.ckpt') + + """ + if model is None: + model = self.alg.model + dirname = '/'.join(save_path.split('/')[:-1]) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(model.state_dict(), save_path) + + def restore(self, save_path, model=None): + """Restore previously saved parameters. + This method requires a model that describes the network structure. + The save_path argument is typically a value previously passed to ``save()``. + + Args: + save_path(str): path where parameters were previously saved. + model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. + + Raises: + ValueError: if model is None and self.alg does not exist. + + Example: + + .. code-block:: python + + agent = AtariAgent() + agent.save('./model.ckpt') + agent.restore('./model.ckpt') + + """ + + if model is None: + model = self.alg.model + checkpoint = torch.load(save_path) + model.load_state_dict(checkpoint) diff --git a/parl/core/torch/algorithm.py b/parl/core/torch/algorithm.py new file mode 100644 index 0000000..d953688 --- /dev/null +++ b/parl/core/torch/algorithm.py @@ -0,0 +1,92 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +from parl.core.algorithm_base import AlgorithmBase +from parl.core.torch.model import Model + +__all__ = ['Algorithm'] + + +class Algorithm(AlgorithmBase): + """ + | `alias`: ``parl.Algorithm`` + | `alias`: ``parl.core.torch.algorithm.Algorithm`` + + | ``Algorithm`` defines the way how to update the parameters of the + ``Model``. This is where we define loss functions and the optimizer of the + neural network. An ``Algorithm`` has at least a model. + + | PARL has implemented various algorithms(DQN/DDPG/PPO/A3C/IMPALA) that can + be reused quickly, which can be accessed with ``parl.algorithms``. + + Example: + + .. code-block:: python + + import parl + + model = Model() + dqn = parl.algorithms.DQN(model, lr=1e-3) + + Attributes: + model(``parl.Model``): a neural network that represents a policy or a + Q-value function. + + Pulic Functions: + - ``predict``: return an estimate q value given current observation. + - ``learn``: define the loss function and create an optimizer to + minimize the loss. + + """ + + def __init__(self, model=None): + """ + Args: + model(``parl.Model``): a neural network that represents a policy or + a Q-value function. + """ + assert isinstance(model, Model) + self.model = model + + def get_weights(self): + """ Get weights of self.model. + + Returns: + weights (list): a Python List containing the parameters of + self.model. + """ + return self.model.get_weights() + + def set_weights(self, params): + """ Set weights from ``get_weights`` to the model. + + Args: + weights (list): a Python List containing the parameters of + self.model. + """ + self.model.set_weights(params) + + def learn(self, *args, **kwargs): + """ Define the loss function and create an optimizer to minimize the loss. + """ + raise NotImplementedError + + def predict(self, *args, **kwargs): + """ Refine the predicting process, e.g,. use the policy model to + predict actions. + """ + raise NotImplementedError diff --git a/parl/core/torch/model.py b/parl/core/torch/model.py new file mode 100644 index 0000000..4827cfa --- /dev/null +++ b/parl/core/torch/model.py @@ -0,0 +1,131 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from parl.core.model_base import ModelBase +from parl.utils import machine_info + +__all__ = ['Model'] + + +class Model(nn.Module, ModelBase): + """ + | `alias`: ``parl.Model`` + | `alias`: ``parl.core.torch.agent.Model`` + + | ``Model`` is a base class of PARL for the neural network. A ``Model`` is + usually a policy or Q-value function, which predicts an action or an + estimate according to the environmental observation. + + | To use the ``PyTorch`` backend model, user needs to call ``super(Model, + self).__init__()`` at the beginning of ``__init__`` function. + + | ``Model`` supports duplicating a ``Model`` instance in a pythonic way: + + | ``copied_model = copy.deepcopy(model)`` + + Example: + + .. code-block:: python + + import parl + import torch.nn as nn + + class Policy(parl.Model): + def __init__(self): + super(Policy, self).__init__() + self.fc = nn.Linear(in_features=100, out_features=32) + + def policy(self, obs): + out = self.fc(obs) + return out + + policy = Policy() + copied_policy = copy.deepcopy(model) + + Attributes: + model_id(str): each model instance has its unique model_id. + + Public Functions: + - ``sync_weights_to``: synchronize parameters of the current model to + another model. + - ``get_weights``: return a list containing all the parameters of the + current model. + - ``set_weights``: copy parameters from ``set_weights()`` to the model. + - ``forward``: define the computations of a neural network. **Should** + be overridden by all subclasses. + + """ + + def __init___(self): + super(Model, self).__init__() + + def sync_weights_to(self, target_model, decay=0.0): + """Synchronize parameters of current model to another model. + + target_model_weights = decay * target_model_weights + (1 - decay) * + current_model_weights + + Args: + target_model (`parl.Model`): an instance of ``Model`` that has the + same neural network architecture as the current model. + decay (float): the rate of decline in copying parameters. 0 if no + parameters decay when synchronizing the parameters. + + Example: + + .. code-block:: python + + import copy + # create a model that has the same neural network structures. + target_model = copy.deepcopy(model) + + # after initializing the parameters ... + model.sync_weights_to(target_model) + + Note: + Before calling ``sync_weights_to``, parameters of the model must + have been initialized. + """ + + assert not target_model is self, "cannot copy between identical model" + assert isinstance(target_model, Model) + assert self.__class__.__name__ == target_model.__class__.__name__, \ + "must be the same class for params syncing!" + assert (decay >= 0 and decay <= 1) + + target_vars = dict(target_model.named_parameters()) + for name, var in self.named_parameters(): + target_vars[name].data.copy_(decay * target_vars[name].data + + (1 - decay) * var.data) + + def get_weights(self): + """Returns a Python list containing parameters of current model. + + Returns: a Python list containing the parameters of current model. + """ + return list(self.parameters()) + + def set_weights(self, weights): + """Copy parameters from ``set_weights()`` to the model. + + Args: + weights (list): a Python list containing the parameters. + """ + assert len(weights) == len(list(self.parameters())), \ + 'size of input weights should be same as weights number of current model' + + for var, weight in zip(self.parameters(), weights): + var.data.copy_(weight.data) diff --git a/parl/core/torch/tests/agent_base_test.py b/parl/core/torch/tests/agent_base_test.py new file mode 100644 index 0000000..688c716 --- /dev/null +++ b/parl/core/torch/tests/agent_base_test.py @@ -0,0 +1,102 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import os + +import torch +import torch.nn as nn +import torch.optim as optim + +from parl.core.torch.model import Model +from parl.core.torch.algorithm import Algorithm +from parl.core.torch.agent import Agent + + +class TestModel(Model): + def __init__(self): + super(TestModel, self).__init__() + self.fc1 = nn.Linear(10, 256) + self.fc2 = nn.Linear(256, 1) + + def forward(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + return out + + +class TestAlgorithm(Algorithm): + def __init__(self, model): + self.model = model + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) + + def predict(self, obs): + return self.model(obs) + + def learn(self, obs, label): + pred_output = self.model(obs) + cost = (pre_output - obs).pow(2) + self.optimizer.zero_grad() + cost.backward() + self.optimizer.step() + return cost.item() + + +class TestAgent(Agent): + def __init__(self, algorithm): + self.alg = algorithm + + def learn(self, obs, label): + cost = self.alg.lean(obs, label) + + def predict(self, obs): + return self.alg.predict(obs) + + +class AgentBaseTest(unittest.TestCase): + def setUp(self): + self.model = TestModel() + self.alg = TestAlgorithm(self.model) + + def test_agent(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + output = agent.predict(obs) + self.assertIsNotNone(output) + + def test_save(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + save_path1 = './model.ckpt' + save_path2 = './my_model/model-2.ckpt' + agent.save(save_path1) + agent.save(save_path2) + self.assertTrue(os.path.exists(save_path1)) + self.assertTrue(os.path.exists(save_path2)) + + def test_restore(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + output = agent.predict(obs) + save_path1 = './model.ckpt' + previous_output = agent.predict(obs).detach().cpu().numpy() + agent.save(save_path1) + agent.restore(save_path1) + current_output = agent.predict(obs).detach().cpu().numpy() + np.testing.assert_equal(current_output, previous_output) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/core/torch/tests/agent_base_test_.py b/parl/core/torch/tests/agent_base_test_.py new file mode 100644 index 0000000..688c716 --- /dev/null +++ b/parl/core/torch/tests/agent_base_test_.py @@ -0,0 +1,102 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import os + +import torch +import torch.nn as nn +import torch.optim as optim + +from parl.core.torch.model import Model +from parl.core.torch.algorithm import Algorithm +from parl.core.torch.agent import Agent + + +class TestModel(Model): + def __init__(self): + super(TestModel, self).__init__() + self.fc1 = nn.Linear(10, 256) + self.fc2 = nn.Linear(256, 1) + + def forward(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + return out + + +class TestAlgorithm(Algorithm): + def __init__(self, model): + self.model = model + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) + + def predict(self, obs): + return self.model(obs) + + def learn(self, obs, label): + pred_output = self.model(obs) + cost = (pre_output - obs).pow(2) + self.optimizer.zero_grad() + cost.backward() + self.optimizer.step() + return cost.item() + + +class TestAgent(Agent): + def __init__(self, algorithm): + self.alg = algorithm + + def learn(self, obs, label): + cost = self.alg.lean(obs, label) + + def predict(self, obs): + return self.alg.predict(obs) + + +class AgentBaseTest(unittest.TestCase): + def setUp(self): + self.model = TestModel() + self.alg = TestAlgorithm(self.model) + + def test_agent(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + output = agent.predict(obs) + self.assertIsNotNone(output) + + def test_save(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + save_path1 = './model.ckpt' + save_path2 = './my_model/model-2.ckpt' + agent.save(save_path1) + agent.save(save_path2) + self.assertTrue(os.path.exists(save_path1)) + self.assertTrue(os.path.exists(save_path2)) + + def test_restore(self): + agent = TestAgent(self.alg) + obs = torch.randn(3, 10) + output = agent.predict(obs) + save_path1 = './model.ckpt' + previous_output = agent.predict(obs).detach().cpu().numpy() + agent.save(save_path1) + agent.restore(save_path1) + current_output = agent.predict(obs).detach().cpu().numpy() + np.testing.assert_equal(current_output, previous_output) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/core/torch/tests/model_base_test.py b/parl/core/torch/tests/model_base_test.py new file mode 100644 index 0000000..d0554a6 --- /dev/null +++ b/parl/core/torch/tests/model_base_test.py @@ -0,0 +1,345 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import os +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.optim as optim + +from parl.utils import get_gpu_count +from parl.core.torch.model import Model +from parl.core.torch.algorithm import Algorithm +from parl.core.torch.agent import Agent + + +class TestModel(Model): + def __init__(self): + super(TestModel, self).__init__() + self.fc1 = nn.Linear(4, 256) + self.fc2 = nn.Linear(256, 128) + self.fc3 = nn.Linear(128, 1) + + def predict(self, obs): + out = self.fc1(obs) + out = self.fc2(out) + out = self.fc3(out) + return out + + +class ModelBaseTest(unittest.TestCase): + def setUp(self): + self.model = TestModel() + self.target_model = TestModel() + self.target_model2 = TestModel() + + gpu_count = get_gpu_count() + device = torch.device('cuda' if gpu_count else 'cpu') + + def test_sync_weights_in_one_program(self): + obs = torch.randn(1, 4) + + N = 10 + random_obs = torch.randn(N, 4) + for i in range(N): + x = random_obs[i].view(1, -1) + model_output = self.model.predict(x).item() + target_model_output = self.target_model.predict(x).item() + self.assertNotEqual(model_output, target_model_output) + + self.model.sync_weights_to(self.target_model) + + random_obs = torch.randn(N, 4) + for i in range(N): + x = random_obs[i].view(1, -1) + model_output = self.model.predict(x).item() + target_model_output = self.target_model.predict(x).item() + self.assertEqual(model_output, target_model_output) + + def _numpy_update(self, target_model, decay): + target_parameters = dict(target_model.named_parameters()) + updated_parameters = {} + for name, param in self.model.named_parameters(): + updated_parameters[name] = decay * target_parameters[name].detach( + ).cpu().numpy() + (1 - decay) * param.detach().cpu().numpy() + return updated_parameters + + def test_sync_weights_with_decay(self): + decay = 0.9 + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + + N = 10 + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_weights_with_multi_decay(self): + decay = 0.9 + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + + N = 10 + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_weights_with_different_decay(self): + decay = 0.9 + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + + N = 10 + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + decay = 0.8 + + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_sync_weights_with_different_target_model(self): + decay = 0.9 + updated_parameters = self._numpy_update(self.target_model, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model, decay) + + N = 10 + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + decay = 0.8 + + updated_parameters = self._numpy_update(self.target_model2, decay) + (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w, + target_model_fc2_b, target_model_fc3_w, + target_model_fc3_b) = (updated_parameters['fc1.weight'], + updated_parameters['fc1.bias'], + updated_parameters['fc2.weight'], + updated_parameters['fc2.bias'], + updated_parameters['fc3.weight'], + updated_parameters['fc3.bias']) + + self.model.sync_weights_to(self.target_model2, decay) + random_obs = np.random.randn(N, 4) + for i in range(N): + obs = np.expand_dims(random_obs[i], -1) # 4, 1 + real_target_outputs = self.target_model2.predict( + torch.Tensor(obs).view(1, -1)).item() + + out_np = np.dot(target_model_fc1_w, obs) + np.expand_dims( + target_model_fc1_b, -1) # (256, 256) + out_np = np.dot(target_model_fc2_w, out_np) + np.expand_dims( + target_model_fc2_b, -1) + out_np = np.dot(target_model_fc3_w, out_np) + np.expand_dims( + target_model_fc3_b, -1) + + self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5) + + def test_get_weights(self): + params = self.model.get_weights() + expected_params = list(self.model.parameters()) + self.assertEqual(len(params), len(expected_params)) + for param in params: + flag = False + for expected_param in expected_params: + if param.sum().item() - expected_param.sum().item() < 1e-5: + flag = True + break + self.assertTrue(flag) + + def test_set_weights(self): + params = self.model.get_weights() + new_params = [x + 1.0 for x in params] + + self.model.set_weights(new_params) + + for x, y in list(zip(new_params, self.model.get_weights())): + self.assertEqual(x.sum().item(), y.sum().item()) + + def test_set_weights_between_different_models(self): + model1 = TestModel() + model2 = TestModel() + + N = 10 + random_obs = torch.randn(N, 4) + for i in range(N): + x = random_obs[i].view(1, -1) + model1_output = model1.predict(x).item() + model2_output = model2.predict(x).item() + self.assertNotEqual(model1_output, model2_output) + + params = model1.get_weights() + model2.set_weights(params) + + random_obs = torch.randn(N, 4) + for i in range(N): + x = random_obs[i].view(1, -1) + model1_output = model1.predict(x).item() + model2_output = model2.predict(x).item() + self.assertEqual(model1_output, model2_output) + + def test_set_weights_wrong_params_num(self): + params = self.model.get_weights() + try: + self.model.set_weights(params[1:]) + except: + return + assert False + + def test_set_weights_wrong_params_shape(self): + params = self.model.get_weights() + params.reverse() + try: + self.model.set_weights(params) + except: + return + assert False + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/client.py b/parl/remote/client.py index 0d682fe..cb0a23b 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -96,9 +96,15 @@ class Client(object): to_distributed_files = list(code_files) + distributed_files for file in to_distributed_files: - with open(file, 'rb') as code_file: - code = code_file.read() - pyfiles[file] = code + try: + assert os.path.exists(file) + with open(file, 'rb') as code_file: + code = code_file.read() + pyfiles[file] = code + except AssertionError as e: + raise Exception( + 'Failed to create the client, the file {} does not exist.'. + format(file)) return cloudpickle.dumps(pyfiles) def _create_sockets(self, master_address): diff --git a/parl/remote/job.py b/parl/remote/job.py index 44e60bc..4be53a9 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -239,6 +239,16 @@ class Job(object): envdir = tempfile.mkdtemp() for file in pyfiles: code = pyfiles[file] + + # create directory (i.e. ./rom_files/) + if '/' in file: + try: + os.makedirs( + os.path.join(envdir, + *file.rsplit('/')[:-1])) + except OSError as e: + pass + file = os.path.join(envdir, file) with open(file, 'wb') as code_file: code_file.write(code) diff --git a/parl/remote/tests/cluster_test.py b/parl/remote/tests/cluster_test.py index fbb062e..9025b7b 100644 --- a/parl/remote/tests/cluster_test.py +++ b/parl/remote/tests/cluster_test.py @@ -89,7 +89,7 @@ class TestCluster(unittest.TestCase): master.exit() worker1.exit() - @timeout_decorator.timeout(seconds=500) + @timeout_decorator.timeout(seconds=800) def test_actor_exception(self): master = Master(port=1236) th = threading.Thread(target=master.run) diff --git a/parl/remote/tests/reset_job_test.py b/parl/remote/tests/reset_job_test.py index 2f833dc..85f0718 100644 --- a/parl/remote/tests/reset_job_test.py +++ b/parl/remote/tests/reset_job_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest import parl from parl.remote.master import Master diff --git a/parl/remote/tests/rom/pong.bin b/parl/remote/tests/rom/pong.bin new file mode 100644 index 0000000..e69de29 diff --git a/parl/remote/tests/send_job_test.py b/parl/remote/tests/send_job_test.py new file mode 100644 index 0000000..77ea421 --- /dev/null +++ b/parl/remote/tests/send_job_test.py @@ -0,0 +1,81 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import parl +import unittest +import time +import threading + +from parl.remote.master import Master +from parl.remote.worker import Worker +from parl.remote.client import disconnect + + +@parl.remote_class +class Actor(object): + def __init__(self, x=10): + self.x = x + + def check_local_file(self): + return os.path.exists('./rom_files/pong.bin') + + +class TestSendFile(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_send_file(self): + port = 1239 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(2) + + os.system('mkdir ./rom_files') + os.system('touch ./rom_files/pong.bin') + assert os.path.exists('./rom_files/pong.bin') + parl.connect( + 'localhost:{}'.format(port), + distributed_files=['./rom_files/pong.bin']) + time.sleep(5) + actor = Actor() + for _ in range(10): + if actor.check_local_file(): + break + time.sleep(10) + self.assertEqual(True, actor.check_local_file()) + del actor + time.sleep(10) + worker.exit() + master.exit() + + def test_send_file2(self): + port = 1240 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(2) + + self.assertRaises(Exception, parl.connect, 'localhost:{}'.format(port), + ['./rom_files/no_pong.bin']) + + worker.exit() + master.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/utils/utils.py b/parl/utils/utils.py index 8096cca..cb95b4d 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -16,7 +16,7 @@ import sys __all__ = [ 'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', - 'MAX_INT32', '_HAS_FLUID' + 'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH' ] @@ -86,3 +86,9 @@ try: _HAS_FLUID = True except ImportError: _HAS_FLUID = False + +try: + import torch + _HAS_TORCH = True +except ImportError: + _HAS_TORCH = False -- GitLab