Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
35241df3
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
35241df3
编写于
1月 01, 2022
作者:
N
niuyazhe
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add vim in docker and add multiple seed cli
上级
58084df3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
104 addition
and
88 deletion
+104
-88
Dockerfile.base
Dockerfile.base
+1
-1
ding/config/config.py
ding/config/config.py
+4
-0
ding/entry/cli.py
ding/entry/cli.py
+99
-87
未找到文件。
Dockerfile.base
浏览文件 @
35241df3
...
...
@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
WORKDIR /ding
RUN apt update \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git gcc \g++ make locales -y \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git
vim
gcc \g++ make locales -y \
&& apt clean \
&& rm -rf /var/cache/apt/* \
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
...
...
ding/config/config.py
浏览文件 @
35241df3
...
...
@@ -414,6 +414,8 @@ def compile_config(
cfg
.
policy
.
eval
.
evaluator
.
n_episode
=
cfg
.
env
.
n_evaluator_episode
if
'exp_name'
not
in
cfg
:
cfg
.
exp_name
=
'default_experiment'
# add seed as suffix of exp_name
cfg
.
exp_name
=
cfg
.
exp_name
+
'_seed{}'
.
format
(
seed
)
if
save_cfg
:
if
not
os
.
path
.
exists
(
cfg
.
exp_name
):
try
:
...
...
@@ -524,6 +526,8 @@ def compile_config_parallel(
cfg
.
system
.
coordinator
=
deep_merge_dicts
(
Coordinator
.
default_config
(),
cfg
.
system
.
coordinator
)
# seed
cfg
.
seed
=
seed
# add seed as suffix of exp_name
cfg
.
exp_name
=
cfg
.
exp_name
+
'_seed{}'
.
format
(
seed
)
if
save_cfg
:
save_config
(
cfg
,
save_path
)
...
...
ding/entry/cli.py
浏览文件 @
35241df3
from
typing
import
List
,
Union
import
click
from
click.core
import
Context
,
Option
import
numpy
as
np
from
ding
import
__TITLE__
,
__VERSION__
,
__AUTHOR__
,
__AUTHOR_EMAIL__
from
.predefined_config
import
get_predefined_config
...
...
@@ -65,7 +67,8 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
'-s'
,
'--seed'
,
type
=
int
,
default
=
0
,
default
=
[
0
],
multiple
=
True
,
help
=
'random generator seed(for all the possible package: random, numpy, torch and user env)'
)
@
click
.
option
(
'-e'
,
'--env'
,
type
=
str
,
help
=
'RL env name'
)
...
...
@@ -117,7 +120,7 @@ def cli(
# serial/eval
mode
:
str
,
config
:
str
,
seed
:
int
,
seed
:
Union
[
int
,
List
]
,
env
:
str
,
policy
:
str
,
train_iter
:
int
,
...
...
@@ -155,89 +158,98 @@ def cli(
from
..utils.profiler_helper
import
Profiler
profiler
=
Profiler
()
profiler
.
profile
(
profile
)
if
mode
==
'serial'
:
from
.serial_entry
import
serial_pipeline
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_onpolicy'
:
from
.serial_entry_onpolicy
import
serial_pipeline_onpolicy
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_onpolicy
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_sqil'
:
if
config
==
'lunarlander_sqil_config.py'
or
'cartpole_sqil_config.py'
or
'pong_sqil_config.py'
\
or
'spaceinvaders_sqil_config.py'
or
'qbert_sqil_config.py'
:
from
.serial_entry_sqil
import
serial_pipeline_sqil
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
serial_pipeline_sqil
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_reward_model'
:
from
.serial_entry_reward_model
import
serial_pipeline_reward_model
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_gail'
:
from
.serial_entry_gail
import
serial_pipeline_gail
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
serial_pipeline_gail
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
,
collect_data
=
True
)
elif
mode
==
'serial_dqfd'
:
from
.serial_entry_dqfd
import
serial_pipeline_dqfd
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
assert
(
expert_config
==
config
[:
config
.
find
(
'_dqfd'
)]
+
'_dqfd_config.py'
),
"DQFD only supports "
\
+
"the models used in q learning now; However, one should still type the DQFD config in this "
\
+
"place, i.e., {}{}"
.
format
(
config
[:
config
.
find
(
'_dqfd'
)],
'_dqfd_config.py'
)
serial_pipeline_dqfd
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_trex'
:
from
.serial_entry_trex
import
serial_pipeline_reward_model_trex
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model_trex
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_trex_onpolicy'
:
from
.serial_entry_trex_onpolicy
import
serial_pipeline_reward_model_trex_onpolicy
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model_trex_onpolicy
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'parallel'
:
from
.parallel_entry
import
parallel_pipeline
parallel_pipeline
(
config
,
seed
,
enable_total_log
,
disable_flask_log
)
elif
mode
==
'dist'
:
from
.dist_entry
import
dist_launch_coordinator
,
dist_launch_collector
,
dist_launch_learner
,
\
dist_prepare_config
,
dist_launch_learner_aggregator
,
dist_launch_spawn_learner
,
\
dist_add_replicas
,
dist_delete_replicas
,
dist_restart_replicas
if
module
==
'config'
:
dist_prepare_config
(
config
,
seed
,
platform
,
coordinator_host
,
learner_host
,
collector_host
,
coordinator_port
,
learner_port
,
collector_port
)
elif
module
==
'coordinator'
:
dist_launch_coordinator
(
config
,
seed
,
coordinator_port
,
disable_flask_log
)
elif
module
==
'learner_aggregator'
:
dist_launch_learner_aggregator
(
config
,
seed
,
aggregator_host
,
aggregator_port
,
module_name
,
disable_flask_log
)
elif
module
==
'collector'
:
dist_launch_collector
(
config
,
seed
,
collector_port
,
module_name
,
disable_flask_log
)
elif
module
==
'learner'
:
dist_launch_learner
(
config
,
seed
,
learner_port
,
module_name
,
disable_flask_log
)
elif
module
==
'spawn_learner'
:
dist_launch_spawn_learner
(
config
,
seed
,
learner_port
,
module_name
,
disable_flask_log
)
elif
add
in
[
'collector'
,
'learner'
]:
dist_add_replicas
(
add
,
kubeconfig
,
replicas
,
coordinator_name
,
namespace
,
cpus
,
gpus
,
memory
)
elif
delete
in
[
'collector'
,
'learner'
]:
dist_delete_replicas
(
delete
,
kubeconfig
,
replicas
,
coordinator_name
,
namespace
)
elif
restart
in
[
'collector'
,
'learner'
]:
dist_restart_replicas
(
restart
,
kubeconfig
,
coordinator_name
,
namespace
,
restart_pod_name
)
else
:
raise
Exception
elif
mode
==
'eval'
:
from
.application_entry
import
eval
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
eval
(
config
,
seed
,
load_path
=
load_path
,
replay_path
=
replay_path
)
def
run_single_pipeline
(
seed
,
config
):
if
mode
==
'serial'
:
from
.serial_entry
import
serial_pipeline
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_onpolicy'
:
from
.serial_entry_onpolicy
import
serial_pipeline_onpolicy
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_onpolicy
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_sqil'
:
if
config
==
'lunarlander_sqil_config.py'
or
'cartpole_sqil_config.py'
or
'pong_sqil_config.py'
\
or
'spaceinvaders_sqil_config.py'
or
'qbert_sqil_config.py'
:
from
.serial_entry_sqil
import
serial_pipeline_sqil
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
serial_pipeline_sqil
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_reward_model'
:
from
.serial_entry_reward_model
import
serial_pipeline_reward_model
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_gail'
:
from
.serial_entry_gail
import
serial_pipeline_gail
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
serial_pipeline_gail
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
,
collect_data
=
True
)
elif
mode
==
'serial_dqfd'
:
from
.serial_entry_dqfd
import
serial_pipeline_dqfd
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
expert_config
=
input
(
"Enter the name of the config you used to generate your expert model: "
)
assert
(
expert_config
==
config
[:
config
.
find
(
'_dqfd'
)]
+
'_dqfd_config.py'
),
"DQFD only supports "
\
+
"the models used in q learning now; However, one should still type the DQFD config in this "
\
+
"place, i.e., {}{}"
.
format
(
config
[:
config
.
find
(
'_dqfd'
)],
'_dqfd_config.py'
)
serial_pipeline_dqfd
(
config
,
expert_config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_trex'
:
from
.serial_entry_trex
import
serial_pipeline_reward_model_trex
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model_trex
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_trex_onpolicy'
:
from
.serial_entry_trex_onpolicy
import
serial_pipeline_reward_model_trex_onpolicy
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_reward_model_trex_onpolicy
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'parallel'
:
from
.parallel_entry
import
parallel_pipeline
parallel_pipeline
(
config
,
seed
,
enable_total_log
,
disable_flask_log
)
elif
mode
==
'dist'
:
from
.dist_entry
import
dist_launch_coordinator
,
dist_launch_collector
,
dist_launch_learner
,
\
dist_prepare_config
,
dist_launch_learner_aggregator
,
dist_launch_spawn_learner
,
\
dist_add_replicas
,
dist_delete_replicas
,
dist_restart_replicas
if
module
==
'config'
:
dist_prepare_config
(
config
,
seed
,
platform
,
coordinator_host
,
learner_host
,
collector_host
,
coordinator_port
,
learner_port
,
collector_port
)
elif
module
==
'coordinator'
:
dist_launch_coordinator
(
config
,
seed
,
coordinator_port
,
disable_flask_log
)
elif
module
==
'learner_aggregator'
:
dist_launch_learner_aggregator
(
config
,
seed
,
aggregator_host
,
aggregator_port
,
module_name
,
disable_flask_log
)
elif
module
==
'collector'
:
dist_launch_collector
(
config
,
seed
,
collector_port
,
module_name
,
disable_flask_log
)
elif
module
==
'learner'
:
dist_launch_learner
(
config
,
seed
,
learner_port
,
module_name
,
disable_flask_log
)
elif
module
==
'spawn_learner'
:
dist_launch_spawn_learner
(
config
,
seed
,
learner_port
,
module_name
,
disable_flask_log
)
elif
add
in
[
'collector'
,
'learner'
]:
dist_add_replicas
(
add
,
kubeconfig
,
replicas
,
coordinator_name
,
namespace
,
cpus
,
gpus
,
memory
)
elif
delete
in
[
'collector'
,
'learner'
]:
dist_delete_replicas
(
delete
,
kubeconfig
,
replicas
,
coordinator_name
,
namespace
)
elif
restart
in
[
'collector'
,
'learner'
]:
dist_restart_replicas
(
restart
,
kubeconfig
,
coordinator_name
,
namespace
,
restart_pod_name
)
else
:
raise
Exception
elif
mode
==
'eval'
:
from
.application_entry
import
eval
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
eval
(
config
,
seed
,
load_path
=
load_path
,
replay_path
=
replay_path
)
if
isinstance
(
seed
,
(
list
,
tuple
)):
assert
len
(
seed
)
>
0
,
"Please input at least 1 seed"
for
s
in
seed
:
run_single_pipeline
(
s
,
config
)
else
:
raise
TypeError
(
"invalid seed type: {}"
.
format
(
type
(
seed
)))
OpenDILab开源决策智能平台
@m0_55289267
mentioned in commit
b51cc77c
·
1月 02, 2022
mentioned in commit
b51cc77c
mentioned in commit b51cc77c617cb991fc7c992b18d1bd5667d196a7
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录