未验证 提交 b532b4cd 编写于 作者: 蒲源 提交者: GitHub

polish(pu): polish r2d3 for abs priority (#158)

* polish(pu): polish r2d3

* polish(pu): first abs then sum each item in td-error
上级 b7cd6751
...@@ -352,14 +352,15 @@ class R2D3Policy(Policy): ...@@ -352,14 +352,15 @@ class R2D3Policy(Policy):
value_gamma=value_gamma[t], value_gamma=value_gamma[t],
) )
loss.append(l) loss.append(l)
td_error.append(e.abs()) # td_error.append(e.abs()) # first sum then abs
td_error.append(e) # first abs then sum
# loss statistics for debugging # loss statistics for debugging
loss_nstep.append(loss_statistics[0]) loss_nstep.append(loss_statistics[0])
loss_1step.append(loss_statistics[1]) loss_1step.append(loss_statistics[1])
loss_sl.append(loss_statistics[2]) loss_sl.append(loss_statistics[2])
else: else:
l, e = dqfd_nstep_td_error( l, e, loss_statistics = dqfd_nstep_td_error(
td_data, td_data,
self._gamma, self._gamma,
self.lambda1, self.lambda1,
...@@ -371,7 +372,12 @@ class R2D3Policy(Policy): ...@@ -371,7 +372,12 @@ class R2D3Policy(Policy):
value_gamma=value_gamma[t], value_gamma=value_gamma[t],
) )
loss.append(l) loss.append(l)
td_error.append(e.abs()) # td_error.append(e.abs()) # first sum then abs
td_error.append(e) # first abs then sum
# loss statistics for debugging
loss_nstep.append(loss_statistics[0])
loss_1step.append(loss_statistics[1])
loss_sl.append(loss_statistics[2])
loss = sum(loss) / (len(loss) + 1e-8) loss = sum(loss) / (len(loss) + 1e-8)
# loss statistics for debugging # loss statistics for debugging
......
...@@ -669,8 +669,9 @@ def dqfd_nstep_td_error( ...@@ -669,8 +669,9 @@ def dqfd_nstep_td_error(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE lambda_supervised_loss * JE
) * weight ) * weight
).mean(), lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + ).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_supervised_loss * JE lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
) )
...@@ -775,8 +776,9 @@ def dqfd_nstep_td_error_with_rescale( ...@@ -775,8 +776,9 @@ def dqfd_nstep_td_error_with_rescale(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE lambda_supervised_loss * JE
) * weight ) * weight
).mean(), lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + ).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_supervised_loss * JE, (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
) )
......
...@@ -4,7 +4,7 @@ from ding.entry import serial_pipeline ...@@ -4,7 +4,7 @@ from ding.entry import serial_pipeline
collector_env_num = 8 collector_env_num = 8
evaluator_env_num = 5 evaluator_env_num = 5
pong_r2d2_config = dict( pong_r2d2_config = dict(
exp_name='debug_pong_r2d2_n5_bs2_ul40_rbs1e4_seed0', exp_name='pong_r2d2_n5_bs2_ul40_rbs1e4_seed0',
env=dict( env=dict(
collector_env_num=collector_env_num, collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num, evaluator_env_num=evaluator_env_num,
......
...@@ -6,10 +6,11 @@ module_path = os.path.dirname(__file__) ...@@ -6,10 +6,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8 collector_env_num = 8
evaluator_env_num = 5 evaluator_env_num = 5
expert_replay_buffer_size=1000 #TODO 1000 expert_replay_buffer_size = int(5e3) # TODO(pu)
"""agent config""" """agent config"""
pong_r2d3_config = dict( pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_offppoexpert_k0_pho1-256_rbs2e4', exp_name='pong_r2d3_offppoexpert_k0_pho1-4_rbs2e4_ds5e3',
env=dict( env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True), manager=dict(shared_memory=True, force_reproducibility=True),
...@@ -62,7 +63,7 @@ pong_r2d3_config = dict( ...@@ -62,7 +63,7 @@ pong_r2d3_config = dict(
env_num=collector_env_num, env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
# from expert demonstrations versus from the agent's own experience. # from expert demonstrations versus from the agent's own experience.
pho=1/256, # 1/256, #TODO(pu), 0.25, pho=1/4, # TODO(pu)
), ),
eval=dict(env_num=evaluator_env_num, ), eval=dict(env_num=evaluator_env_num, ),
other=dict( other=dict(
...@@ -98,7 +99,7 @@ create_config = pong_r2d3_create_config ...@@ -98,7 +99,7 @@ create_config = pong_r2d3_create_config
"""export config""" """export config"""
expert_pong_r2d3_config = dict( expert_pong_r2d3_config = dict(
exp_name='expert_pong_r2d3_ppoexpert_k0_pho1-256_rbs2e4', exp_name='expert_pong_r2d3_ppoexpert_k0_pho1-4_rbs2e4_ds5e3',
env=dict( env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True), manager=dict(shared_memory=True, force_reproducibility=True),
......
...@@ -5,11 +5,11 @@ module_path = os.path.dirname(__file__) ...@@ -5,11 +5,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8 collector_env_num = 8
evaluator_env_num = 5 evaluator_env_num = 5
expert_replay_buffer_size=int(5e3) expert_replay_buffer_size = int(5e3) # TODO(pu)
"""agent config""" """agent config"""
pong_r2d3_config = dict( pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3', exp_name='pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
env=dict( env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True), manager=dict(shared_memory=True, force_reproducibility=True),
...@@ -45,7 +45,7 @@ pong_r2d3_config = dict( ...@@ -45,7 +45,7 @@ pong_r2d3_config = dict(
# in most environments # in most environments
value_rescale=True, value_rescale=True,
update_per_collect=8, update_per_collect=8,
batch_size=64, # TODO(pu) batch_size=64,
learning_rate=0.0005, learning_rate=0.0005,
target_update_theta=0.001, target_update_theta=0.001,
# DQFD related parameters # DQFD related parameters
...@@ -97,7 +97,7 @@ create_config = pong_r2d3_create_config ...@@ -97,7 +97,7 @@ create_config = pong_r2d3_create_config
"""export config""" """export config"""
expert_pong_r2d3_config = dict( expert_pong_r2d3_config = dict(
exp_name='expert_pong_r2d3_r2d2expert_k0_pho1-4_rbs1e4_ds5e3', exp_name='expert_pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
env=dict( env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True), manager=dict(shared_memory=True, force_reproducibility=True),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册