diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index ce3d789b1aa672c3dc4e814d6e88275162194941..ab8267a47ceb3230ce707a52836121d027bd15b3 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -678,11 +678,14 @@ class DRQN(nn.Module): """ x, prev_state = inputs['obs'], inputs['prev_state'] + # for both inference and other cases, the network structure is encoder -> rnn network -> head + # the difference is inference take the data with seq_len=1 (or T = 1) if inference: x = self.encoder(x) - x = x.unsqueeze(0) + x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none. + # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None x, next_state = self.rnn(x, prev_state) - x = x.squeeze(0) + x = x.squeeze(0) # to delete the seq_len dim to match head network input x = self.head(x) x['next_state'] = next_state return x @@ -700,11 +703,14 @@ class DRQN(nn.Module): saved_hidden_state.append(prev_state) lstm_embedding.append(output) hidden_state = list(zip(*prev_state)) # {list: 2{tuple: B{Tensor:(1, 1, head_hidden_size}}} + # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}} hidden_state_list.append(torch.cat(hidden_state[0], dim=1)) x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size) x = parallel_wrapper(self.head)(x) # (T, B, action_shape) - x['next_state'] = prev_state # the last timestep state including h and c - x['hidden_state'] = torch.cat(hidden_state_list, dim=-3) # the all hidden state h + # the last timestep state including h and c for lstm, {list: B{tuple: 2{Tensor:(1, 1, head_hidden_size}}} + x['next_state'] = prev_state + # all hidden state h, this returns a tensor of the dim: seq_len*batch_size*head_hidden_size + x['hidden_state'] = torch.cat(hidden_state_list, dim=-3) if saved_hidden_state_timesteps is not None: x['saved_hidden_state'] = saved_hidden_state # the selected saved hidden states, including h and c return x diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index 11ecb8b97a2206c88eee5d05ceaca0d5bcd2db34..57f607107fef82d3d7f2b63c86d7f6b2ae94e971 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -89,18 +89,20 @@ class HiddenStateWrapper(IModelWrapper): """ super().__init__(model) self._state_num = state_num + # This is to maintain hidden states (when it comes to this wrapper, \ + # map self._state into data['prev_value] and update next_state, store in self._state) self._state = {i: init_fn() for i in range(state_num)} self._save_prev_state = save_prev_state self._init_fn = init_fn def forward(self, data, **kwargs): state_id = kwargs.pop('data_id', None) - valid_id = kwargs.pop('valid_id', None) - data, state_info = self.before_forward(data, state_id) + valid_id = kwargs.pop('valid_id', None) # None, not used in any code in DI-engine + data, state_info = self.before_forward(data, state_id) # update data['prev_state'] with self._state output = self._model.forward(data, **kwargs) h = output.pop('next_state', None) if h is not None: - self.after_forward(h, state_info, valid_id) + self.after_forward(h, state_info, valid_id) # this is to store the 'next hidden state' for each time step if self._save_prev_state: prev_state = get_tensor_data(data['prev_state']) output['prev_state'] = prev_state diff --git a/ding/policy/r2d2.py b/ding/policy/r2d2.py index 95efb6164419b1069402e33a58b59bb83dadcbe3..d1b79b4f7eee9efaf2f2261124c4607f9bb48a90 100644 --- a/ding/policy/r2d2.py +++ b/ding/policy/r2d2.py @@ -229,11 +229,13 @@ class R2D2Policy(Policy): data['weight'] = data['weight'] * torch.ones_like(data['done']) # every timestep in sequence has same weight, which is the _priority_IS_weight in PER - data['action'] = data['action'][bs:-self._nstep] - data['reward'] = data['reward'][bs:-self._nstep] + data['action'] = data['action'][bs:-self._nstep] # cut the seq_len from burn_in step to (seq_len - nstep) step + data['reward'] = data['reward'][bs:-self._nstep] # cut the seq_len from burn_in step to (seq_len - nstep) step # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value, # target_q_value, and target_q_action + + # these slicing are all done in the outermost layer, which is the seq_len dim data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep] # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from # [bs] timestep to [self._unroll_len_add_burnin_step-self._nstep] timestep @@ -259,10 +261,11 @@ class R2D2Policy(Policy): - total_loss (:obj:`float`): The calculated loss """ # forward - data = self._data_preprocess_learn(data) + data = self._data_preprocess_learn(data) # output datatype: Dict self._learn_model.train() self._target_model.train() # use the hidden state in timestep=0 + # note the reset method is performed at the hidden state wrapper, to reset self._state. self._learn_model.reset(data_id=None, state=data['prev_state'][0]) self._target_model.reset(data_id=None, state=data['prev_state'][0]) @@ -271,7 +274,8 @@ class R2D2Policy(Policy): inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True} burnin_output = self._learn_model.forward( inputs, saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] - ) + ) # keys include 'logit', 'hidden_state' 'saved_hidden_state', \ + # 'action', for their specific dim, please refer to DRQN model burnin_output_target = self._target_model.forward( inputs, saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] ) @@ -307,6 +311,8 @@ class R2D2Policy(Policy): else: l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) loss.append(l) + # td will be a list of the length (self._unroll_len_add_burnin_step - self._burnin_step - self._nstep) + # and each value is a tensor of the size batch_size td_error.append(e.abs()) loss = sum(loss) / (len(loss) + 1e-8) @@ -314,6 +320,7 @@ class R2D2Policy(Policy): td_error_per_sample = 0.9 * torch.max( torch.stack(td_error), dim=0 )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8)) + # torch.max(torch.stack(td_error), dim=0) will return tuple like thing, please refer to torch.max # td_error shape list(, B), for example, (75,64) # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error) @@ -332,7 +339,7 @@ class R2D2Policy(Policy): return { 'cur_lr': self._optimizer.defaults['lr'], 'total_loss': loss.item(), - 'priority': td_error_per_sample.abs().tolist(), + 'priority': td_error_per_sample.tolist(), # note abs operation has been performed above # the first timestep in the sequence, may not be the start of episode 'q_s_taken-a_t0': q_s_a_t0.mean().item(), 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(), @@ -365,6 +372,8 @@ class R2D2Policy(Policy): self._unroll_len_add_burnin_step = self._cfg.unroll_len + self._cfg.burnin_step self._unroll_len = self._unroll_len_add_burnin_step # for compatibility + # for r2d2, this hidden_state wrapper is to add the 'prev hidden state' for each transition. + # Note that collect env forms a batch and the key is added for the batch simultaneously. self._collect_model = model_wrap( self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True ) diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index e871b91f39aae75e13fc6217b6b096e37b61bc30..d40b5e936ff661cf5df5f95f0d9a4c186eddd258 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -125,7 +125,7 @@ def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tenso prev_state = [b.pop('prev_state') for b in batch] batch_data = default_collate(batch) # -> {some_key: T lists}, each list is [B, some_dim] batch_data = stack(batch_data) # -> {some_key: [T, B, some_dim]} - batch_data['prev_state'] = list(zip(*prev_state)) + batch_data['prev_state'] = list(zip(*prev_state)) # permute batch size dim with sequence len dim # append back prev_state, avoiding multi batch share the same data bug for i in range(len(batch)): batch[i]['prev_state'] = prev_state[i] diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index f03e3a2ee7cb429c68bb58403e819cbd1d17f530..7276b18c7f81f941355f16de77ed03a702e13e1a 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -257,6 +257,14 @@ class SampleSerialCollector(ISerialCollector): self._total_envstep_count += 1 # prepare data if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len: + # for r2d2: + # 1. for each collect_env, we want to collect data of the length self._traj_len + # except when it comes to a done. + # 2. however, even if timestep is done and assume we only collected 9 transitions, + # by going through self._policy.get_train_sample, it will be padded automatically. + # 3. so, a unit of train transition for r2d2 will have seq len + # (burnin + nstep) (collected_sample=1), and we need to collect n_sample. + # Episode is done or traj_buffer(maxlen=traj_len) is full. transitions = to_tensor_transitions(self._traj_buffer[env_id]) train_sample = self._policy.get_train_sample(transitions) diff --git a/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py b/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py index 135cf5591d6ed8413cac9cb176f6b52ba5e95f63..c9f6c30de1dbe38fce4e8a969d87f9253b8f399b 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py @@ -30,7 +30,7 @@ cartpole_r2d2_config = dict( learn=dict( # according to the R2D2 paper, actor parameter update interval is 400 # environment timesteps, and in per collect phase, we collect 32 sequence - # samples, the length of each samlpe sequence is + , + # samples, the length of each sample sequence is + , # which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8 # in most environments update_per_collect=8,