From 1040c9fcf669c68092ebbd1edf9cbb1e03eded28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Wed, 22 Dec 2021 10:34:19 +0800 Subject: [PATCH] fix(pu): fix dqfd compatibility (#161) * polish(pu): polish r2d3 * polish(pu): first abs then sum each item in td-error * fix(pu): fix dqfd compatibility --- ding/policy/dqfd.py | 2 +- ding/rl_utils/tests/test_td.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/policy/dqfd.py b/ding/policy/dqfd.py index 22108b6..d64e75f 100644 --- a/ding/policy/dqfd.py +++ b/ding/policy/dqfd.py @@ -209,7 +209,7 @@ class DQFDPolicy(DQNPolicy): data['is_expert'] # set is_expert flag(expert 1, agent 0) ) value_gamma = data.get('value_gamma') - loss, td_error_per_sample = dqfd_nstep_td_error( + loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( data_n, self._gamma, self.lambda1, diff --git a/ding/rl_utils/tests/test_td.py b/ding/rl_utils/tests/test_td.py index 06fbe60..3fd32cf 100644 --- a/ding/rl_utils/tests/test_td.py +++ b/ding/rl_utils/tests/test_td.py @@ -253,7 +253,7 @@ def test_dqfd_nstep_td(): data = dqfd_nstep_td_data( q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert ) - loss, td_error_per_sample = dqfd_nstep_td_error( + loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep ) assert td_error_per_sample.shape == (batch_size, ) -- GitLab