Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
ae6ab6c7
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ae6ab6c7
编写于
1月 01, 2022
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): fix exp_name seedx name bug with data generation path
上级
35241df3
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
43 addition
and
28 deletion
+43
-28
ding/entry/application_entry.py
ding/entry/application_entry.py
+2
-0
ding/entry/tests/test_serial_entry.py
ding/entry/tests/test_serial_entry.py
+11
-9
ding/entry/tests/test_serial_entry_algo.py
ding/entry/tests/test_serial_entry_algo.py
+9
-8
ding/policy/dqn.py
ding/policy/dqn.py
+5
-0
dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
...i/config/serial/pong/pong_qrdqn_generation_data_config.py
+2
-1
dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
...config/serial/qbert/qbert_qrdqn_generation_data_config.py
+2
-1
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
+1
-1
dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
.../cartpole/config/cartpole_qrdqn_generation_data_config.py
+1
-1
dizoo/classic_control/pendulum/config/pendulum_cql_config.py
dizoo/classic_control/pendulum/config/pendulum_cql_config.py
+1
-1
dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py
...lum/config/pendulum_sac_data_generation_default_config.py
+2
-1
dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py
...classic_control/pendulum/config/pendulum_td3_bc_config.py
+1
-1
dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py
...ol/pendulum/config/pendulum_td3_data_generation_config.py
+2
-2
dizoo/mujoco/config/hopper_sac_data_generation_default_config.py
...ujoco/config/hopper_sac_data_generation_default_config.py
+2
-1
dizoo/mujoco/config/hopper_td3_data_generation_config.py
dizoo/mujoco/config/hopper_td3_data_generation_config.py
+2
-1
未找到文件。
ding/entry/application_entry.py
浏览文件 @
ae6ab6c7
...
...
@@ -154,6 +154,7 @@ def collect_demo_data(
if
cfg
.
policy
.
cuda
:
exp_data
=
to_device
(
exp_data
,
'cpu'
)
# Save data transitions.
expert_data_path
=
os
.
path
.
join
(
cfg
.
exp_name
,
expert_data_path
)
offline_data_save_type
(
exp_data
,
expert_data_path
,
data_type
=
cfg
.
policy
.
collect
.
get
(
'data_type'
,
'naive'
))
print
(
'Collect demo data successfully'
)
...
...
@@ -227,6 +228,7 @@ def collect_episodic_demo_data(
if
cfg
.
policy
.
cuda
:
exp_data
=
to_device
(
exp_data
,
'cpu'
)
# Save data transitions.
expert_data_path
=
os
.
path
.
join
(
cfg
.
exp_name
,
expert_data_path
)
offline_data_save_type
(
exp_data
,
expert_data_path
,
data_type
=
cfg
.
policy
.
collect
.
get
(
'data_type'
,
'naive'
))
print
(
'Collect episodic demo data successfully'
)
...
...
ding/entry/tests/test_serial_entry.py
浏览文件 @
ae6ab6c7
...
...
@@ -2,6 +2,7 @@ import pytest
import
time
import
os
from
copy
import
deepcopy
import
torch
from
ding.entry
import
serial_pipeline
,
collect_demo_data
,
serial_pipeline_offline
from
dizoo.classic_control.cartpole.config.cartpole_dqn_config
import
cartpole_dqn_config
,
cartpole_dqn_create_config
...
...
@@ -360,7 +361,9 @@ def test_sqn():
@
pytest
.
mark
.
unittest
def
test_selfplay
():
try
:
selfplay_main
(
deepcopy
(
league_demo_ppo_config
),
seed
=
0
,
max_iterations
=
1
)
config
=
deepcopy
(
league_demo_ppo_config
)
config
.
exp_name
=
'test_selfplay'
selfplay_main
(
config
,
seed
=
0
,
max_iterations
=
1
)
except
Exception
:
assert
False
,
"pipeline fail"
...
...
@@ -368,7 +371,9 @@ def test_selfplay():
@
pytest
.
mark
.
unittest
def
test_league
():
try
:
league_main
(
deepcopy
(
league_demo_ppo_config
),
seed
=
0
,
max_iterations
=
1
)
config
=
deepcopy
(
league_demo_ppo_config
)
config
.
exp_name
=
'test_league'
league_main
(
config
,
seed
=
0
,
max_iterations
=
1
)
except
Exception
as
e
:
assert
False
,
"pipeline fail"
...
...
@@ -395,14 +400,13 @@ def test_cql():
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
pendulum_sac_data_genearation_default_config
),
deepcopy
(
pendulum_sac_data_genearation_default_create_config
)
]
collect_count
=
1000
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
'./sac/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./sac
_seed0
/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
@@ -442,11 +446,10 @@ def test_discrete_cql():
except
Exception
:
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
cartpole_qrdqn_generation_data_config
),
deepcopy
(
cartpole_qrdqn_generation_data_create_config
)]
collect_count
=
1000
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
'./cql_cartpole/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./cql_cartpole
_seed0
/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
@@ -467,7 +470,7 @@ def test_discrete_cql():
os
.
popen
(
'rm -rf cartpole cartpole_cql'
)
@
pytest
.
mark
.
algo
test
@
pytest
.
mark
.
unit
test
def
test_td3_bc
():
# train expert
config
=
[
deepcopy
(
pendulum_td3_config
),
deepcopy
(
pendulum_td3_create_config
)]
...
...
@@ -479,11 +482,10 @@ def test_td3_bc():
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
pendulum_td3_generation_config
),
deepcopy
(
pendulum_td3_generation_create_config
)]
collect_count
=
1000
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
'./td3/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./td3
_seed0
/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
ding/entry/tests/test_serial_entry_algo.py
浏览文件 @
ae6ab6c7
...
...
@@ -281,7 +281,9 @@ def test_acer():
@
pytest
.
mark
.
algotest
def
test_selfplay
():
try
:
selfplay_main
(
deepcopy
(
league_demo_ppo_config
),
seed
=
0
)
config
=
deepcopy
(
league_demo_ppo_config
)
config
.
exp_name
=
'test_selfplay'
selfplay_main
(
config
,
seed
=
0
)
except
Exception
:
assert
False
,
"pipeline fail"
with
open
(
"./algo_record.log"
,
"a+"
)
as
f
:
...
...
@@ -291,7 +293,9 @@ def test_selfplay():
@
pytest
.
mark
.
algotest
def
test_league
():
try
:
league_main
(
deepcopy
(
league_demo_ppo_config
),
seed
=
0
)
config
=
deepcopy
(
league_demo_ppo_config
)
config
.
exp_name
=
'test_league'
league_main
(
config
,
seed
=
0
)
except
Exception
:
assert
False
,
"pipeline fail"
with
open
(
"./algo_record.log"
,
"a+"
)
as
f
:
...
...
@@ -326,14 +330,13 @@ def test_cql():
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
pendulum_sac_data_genearation_default_config
),
deepcopy
(
pendulum_sac_data_genearation_default_create_config
)
]
collect_count
=
config
[
0
].
policy
.
other
.
replay_buffer
.
replay_buffer_size
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
config
[
0
].
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./sac_seed0/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
@@ -362,11 +365,10 @@ def test_discrete_cql():
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
cartpole_qrdqn_generation_data_config
),
deepcopy
(
cartpole_qrdqn_generation_data_create_config
)]
collect_count
=
config
[
0
].
policy
.
other
.
replay_buffer
.
replay_buffer_size
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
config
[
0
].
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./cql_cartpole_seed0/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
@@ -406,11 +408,10 @@ def test_td3_bc():
assert
False
,
"pipeline fail"
# collect expert data
import
torch
config
=
[
deepcopy
(
pendulum_td3_generation_config
),
deepcopy
(
pendulum_td3_generation_create_config
)]
collect_count
=
config
[
0
].
policy
.
other
.
replay_buffer
.
replay_buffer_size
expert_data_path
=
config
[
0
].
policy
.
collect
.
save_path
state_dict
=
torch
.
load
(
config
[
0
].
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
'./td3_seed0/ckpt/iteration_0.pth.tar'
,
map_location
=
'cpu'
)
try
:
collect_demo_data
(
config
,
seed
=
0
,
collect_count
=
collect_count
,
expert_data_path
=
expert_data_path
,
state_dict
=
state_dict
...
...
ding/policy/dqn.py
浏览文件 @
ae6ab6c7
...
...
@@ -71,12 +71,17 @@ class DQNPolicy(Policy):
config
=
dict
(
type
=
'dqn'
,
# (bool) Whether use cuda in policy
cuda
=
False
,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy
=
False
,
# (bool) Whether enable priority experience sample
priority
=
False
,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight
=
False
,
# (float) Discount factor(gamma) for returns
discount_factor
=
0.97
,
# (int) The number of step for calculating target q_value
nstep
=
1
,
learn
=
dict
(
# (bool) Whether to use multi gpu
...
...
dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from
easydict
import
EasyDict
pong_qrdqn_config
=
dict
(
exp_name
=
'pong_qrdqn_generation'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
8
,
...
...
@@ -39,7 +40,7 @@ pong_qrdqn_config = dict(
collect
=
dict
(
n_sample
=
100
,
data_type
=
'hdf5'
,
save_path
=
'
./expert/
expert.pkl'
,
save_path
=
'expert.pkl'
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
...
...
dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from
easydict
import
EasyDict
qbert_qrdqn_config
=
dict
(
exp_name
=
'qbert_qrdqn_geneation'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
8
,
...
...
@@ -39,7 +40,7 @@ qbert_qrdqn_config = dict(
collect
=
dict
(
n_sample
=
100
,
data_type
=
'hdf5'
,
save_path
=
'
./expert/
expert.pkl'
,
save_path
=
'expert.pkl'
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
4000
,
)),
other
=
dict
(
...
...
dizoo/classic_control/cartpole/config/cartpole_cql_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict(
),
collect
=
dict
(
data_type
=
'hdf5'
,
data_path
=
'./cartpole_generation
/expert_demos.hdf5'
,
data_path
=
'./cartpole_generation
_seed0/expert_demos.hdf5'
,
# user-specific
n_sample
=
80
,
unroll_len
=
1
,
),
...
...
dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict(
n_sample
=
80
,
unroll_len
=
1
,
data_type
=
'hdf5'
,
save_path
=
'
./cartpole_generation/
expert.pkl'
,
save_path
=
'expert.pkl'
,
),
other
=
dict
(
eps
=
dict
(
...
...
dizoo/classic_control/pendulum/config/pendulum_cql_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -37,7 +37,7 @@ pendulum_cql_default_config = dict(
n_sample
=
1
,
unroll_len
=
1
,
data_type
=
'hdf5'
,
data_path
=
'./
sac/expert_demos.hdf5'
,
data_path
=
'./
peudulum_sac_generation_seed0/expert_demos.hdf5'
,
# user-specific
),
command
=
dict
(),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
100
,
)),
...
...
dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py
浏览文件 @
ae6ab6c7
from
easydict
import
EasyDict
pendulum_sac_data_genearation_default_config
=
dict
(
exp_name
=
'peudulum_sac_generation'
,
seed
=
0
,
env
=
dict
(
collector_env_num
=
10
,
...
...
@@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict(
collect
=
dict
(
n_sample
=
1
,
unroll_len
=
1
,
save_path
=
'
./sac/
expert.pkl'
,
save_path
=
'expert.pkl'
,
data_type
=
'hdf5'
,
),
command
=
dict
(),
...
...
dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py
浏览文件 @
ae6ab6c7
...
...
@@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict(
noise_sigma
=
0.1
,
collector
=
dict
(
collect_print_freq
=
1000
,
),
data_type
=
'hdf5'
,
data_path
=
'./
td3/expert_demos.hdf5'
,
data_path
=
'./
pendulum_td3_generation_seed0/expert_demos.hdf5'
,
# user-specific
normalize_states
=
True
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
100
,
),
),
...
...
dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py
浏览文件 @
ae6ab6c7
from
easydict
import
EasyDict
pendulum_td3_generation_config
=
dict
(
exp_name
=
'
td3
'
,
exp_name
=
'
pendulum_td3_generation
'
,
env
=
dict
(
collector_env_num
=
8
,
evaluator_env_num
=
10
,
...
...
@@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict(
n_sample
=
10
,
noise_sigma
=
0.1
,
collector
=
dict
(
collect_print_freq
=
1000
,
),
save_path
=
'
./td3/
expert.pkl'
,
save_path
=
'expert.pkl'
,
data_type
=
'hdf5'
,
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
100
,
),
),
...
...
dizoo/mujoco/config/hopper_sac_data_generation_default_config.py
浏览文件 @
ae6ab6c7
from
easydict
import
EasyDict
hopper_sac_data_genearation_default_config
=
dict
(
exp
=
'hopper_sac_generation'
,
env
=
dict
(
env_id
=
'Hopper-v3'
,
norm_obs
=
dict
(
use_norm
=
False
,
),
...
...
@@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict(
collect
=
dict
(
n_sample
=
1
,
unroll_len
=
1
,
save_path
=
'
./default_experiment/
expert_iteration_200000.pkl'
,
save_path
=
'expert_iteration_200000.pkl'
,
),
command
=
dict
(),
eval
=
dict
(),
...
...
dizoo/mujoco/config/hopper_td3_data_generation_config.py
浏览文件 @
ae6ab6c7
from
easydict
import
EasyDict
halfcheetah_td3_default_config
=
dict
(
exp_name
=
'halfcheetah_td3_generation'
,
env
=
dict
(
env_id
=
'Hopper-v3'
,
norm_obs
=
dict
(
use_norm
=
False
,
),
...
...
@@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict(
n_sample
=
1
,
unroll_len
=
1
,
noise_sigma
=
0.1
,
save_path
=
'
./td3/
expert.pkl'
,
save_path
=
'expert.pkl'
,
data_type
=
'hdf5'
,
),
other
=
dict
(
replay_buffer
=
dict
(
replay_buffer_size
=
1000000
,
),
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录