Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
a0435286
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
a0435286
编写于
12月 29, 2021
作者:
R
Robin Chen
提交者:
GitHub
12月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(nyz): update multi-discrete policies (#167)
上级
2699aa5e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
24 addition
and
15 deletion
+24
-15
ding/policy/policy_factory.py
ding/policy/policy_factory.py
+4
-1
dizoo/common/policy/md_dqn.py
dizoo/common/policy/md_dqn.py
+7
-2
dizoo/common/policy/md_ppo.py
dizoo/common/policy/md_ppo.py
+10
-9
dizoo/common/policy/md_rainbow_dqn.py
dizoo/common/policy/md_rainbow_dqn.py
+3
-3
未找到文件。
ding/policy/policy_factory.py
浏览文件 @
a0435286
...
...
@@ -31,7 +31,10 @@ class PolicyFactory:
def
forward
(
data
:
Dict
[
int
,
Any
],
*
args
,
**
kwargs
)
->
Dict
[
int
,
Any
]:
def
discrete_random_action
(
min_val
,
max_val
,
shape
):
return
np
.
random
.
randint
(
min_val
,
max_val
,
shape
)
action
=
np
.
random
.
randint
(
min_val
,
max_val
,
shape
)
if
len
(
action
)
>
1
:
action
=
list
(
np
.
expand_dims
(
action
,
axis
=
1
))
return
action
def
continuous_random_action
(
min_val
,
max_val
,
shape
):
bounded_below
=
min_val
!=
float
(
"inf"
)
...
...
dizoo/common/policy/md_dqn.py
浏览文件 @
a0435286
...
...
@@ -56,9 +56,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
value_gamma
=
data
.
get
(
'value_gamma'
)
if
isinstance
(
q_value
,
list
):
tl
_num
=
len
(
q_value
)
act
_num
=
len
(
q_value
)
loss
,
td_error_per_sample
=
[],
[]
for
i
in
range
(
tl_num
):
q_value_list
=
[]
for
i
in
range
(
act_num
):
td_data
=
q_nstep_td_data
(
q_value
[
i
],
target_q_value
[
i
],
data
[
'action'
][
i
],
target_q_action
[
i
],
data
[
'reward'
],
data
[
'done'
],
data
[
'weight'
]
...
...
@@ -68,8 +69,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
)
loss
.
append
(
loss_
)
td_error_per_sample
.
append
(
td_error_per_sample_
.
abs
())
q_value_list
.
append
(
q_value
[
i
].
mean
().
item
())
loss
=
sum
(
loss
)
/
(
len
(
loss
)
+
1e-8
)
td_error_per_sample
=
sum
(
td_error_per_sample
)
/
(
len
(
td_error_per_sample
)
+
1e-8
)
q_value_mean
=
sum
(
q_value_list
)
/
act_num
else
:
data_n
=
q_nstep_td_data
(
q_value
,
target_q_value
,
data
[
'action'
],
target_q_action
,
data
[
'reward'
],
data
[
'done'
],
data
[
'weight'
]
...
...
@@ -77,6 +80,7 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
loss
,
td_error_per_sample
=
q_nstep_td_error
(
data_n
,
self
.
_gamma
,
nstep
=
self
.
_nstep
,
value_gamma
=
value_gamma
)
q_value_mean
=
q_value
.
mean
().
item
()
# ====================
# Q-learning update
...
...
@@ -94,5 +98,6 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
return
{
'cur_lr'
:
self
.
_optimizer
.
defaults
[
'lr'
],
'total_loss'
:
loss
.
item
(),
'q_value_mean'
:
q_value_mean
,
'priority'
:
td_error_per_sample
.
abs
().
tolist
(),
}
dizoo/common/policy/md_ppo.py
浏览文件 @
a0435286
...
...
@@ -34,26 +34,19 @@ class MultiDiscretePPOPolicy(PPOPolicy):
# ====================
return_infos
=
[]
self
.
_learn_model
.
train
()
if
self
.
_value_norm
:
unnormalized_return
=
data
[
'adv'
]
+
data
[
'value'
]
*
self
.
_running_mean_std
.
std
data
[
'return'
]
=
unnormalized_return
/
self
.
_running_mean_std
.
std
self
.
_running_mean_std
.
update
(
unnormalized_return
.
cpu
().
numpy
())
else
:
data
[
'return'
]
=
data
[
'adv'
]
+
data
[
'value'
]
for
epoch
in
range
(
self
.
_cfg
.
learn
.
epoch_per_collect
):
if
self
.
_recompute_adv
:
with
torch
.
no_grad
():
# obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value
=
self
.
_learn_model
.
forward
(
data
[
'obs'
],
mode
=
'compute_critic'
)[
'value'
]
next_value
=
self
.
_learn_model
.
forward
(
data
[
'next_obs'
],
mode
=
'compute_critic'
)[
'value'
]
if
self
.
_value_norm
:
value
*=
self
.
_running_mean_std
.
std
next_value
*=
self
.
_running_mean_std
.
std
gae_data_
=
gae_data
(
value
,
next_value
,
data
[
'reward'
],
data
[
'done
'
])
compute_adv_data
=
gae_data
(
value
,
next_value
,
data
[
'reward'
],
data
[
'done'
],
data
[
'traj_flag
'
])
# GAE need (T, B) shape input and return (T, B) output
data
[
'adv'
]
=
gae
(
gae_data_
,
self
.
_gamma
,
self
.
_gae_lambda
)
data
[
'adv'
]
=
gae
(
compute_adv_data
,
self
.
_gamma
,
self
.
_gae_lambda
)
# value = value[:-1]
unnormalized_returns
=
value
+
data
[
'adv'
]
...
...
@@ -65,6 +58,14 @@ class MultiDiscretePPOPolicy(PPOPolicy):
data
[
'value'
]
=
value
data
[
'return'
]
=
unnormalized_returns
else
:
# don't recompute adv
if
self
.
_value_norm
:
unnormalized_return
=
data
[
'adv'
]
+
data
[
'value'
]
*
self
.
_running_mean_std
.
std
data
[
'return'
]
=
unnormalized_return
/
self
.
_running_mean_std
.
std
self
.
_running_mean_std
.
update
(
unnormalized_return
.
cpu
().
numpy
())
else
:
data
[
'return'
]
=
data
[
'adv'
]
+
data
[
'value'
]
for
batch
in
split_data_generator
(
data
,
self
.
_cfg
.
learn
.
batch_size
,
shuffle
=
True
):
output
=
self
.
_learn_model
.
forward
(
batch
[
'obs'
],
mode
=
'compute_actor_critic'
)
adv
=
batch
[
'adv'
]
...
...
dizoo/common/policy/md_rainbow_dqn.py
浏览文件 @
a0435286
...
...
@@ -67,10 +67,10 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
value_gamma
=
value_gamma
)
else
:
tl
_num
=
len
(
q_dist
)
act
_num
=
len
(
q_dist
)
losses
=
[]
td_error_per_samples
=
[]
for
i
in
range
(
tl
_num
):
for
i
in
range
(
act
_num
):
td_data
=
dist_nstep_td_data
(
q_dist
[
i
],
target_q_dist
[
i
],
data
[
'action'
][
i
],
target_q_action
[
i
],
data
[
'reward'
],
data
[
'done'
],
data
[
'weight'
]
...
...
@@ -87,7 +87,7 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
losses
.
append
(
td_loss
)
td_error_per_samples
.
append
(
td_error_per_sample
)
loss
=
sum
(
losses
)
/
(
len
(
losses
)
+
1e-8
)
td_error_per_sample_mean
=
sum
(
td_error_per_samples
)
td_error_per_sample_mean
=
sum
(
td_error_per_samples
)
/
(
len
(
td_error_per_samples
)
+
1e-8
)
# ====================
# Rainbow update
# ====================
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录