Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
5c6df8b3
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
5c6df8b3
编写于
11月 13, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'main' into feature/buffer
上级
0bdda6b7
3a91c429
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
25 deletion
+29
-25
Dockerfile.base
Dockerfile.base
+1
-1
ding/worker/collector/battle_interaction_serial_evaluator.py
ding/worker/collector/battle_interaction_serial_evaluator.py
+5
-4
ding/worker/collector/battle_sample_serial_collector.py
ding/worker/collector/battle_sample_serial_collector.py
+23
-20
未找到文件。
Dockerfile.base
浏览文件 @
5c6df8b3
...
...
@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
WORKDIR /ding
RUN apt update \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl gcc \g++ make locales -y \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl g
it g
cc \g++ make locales -y \
&& apt clean \
&& rm -rf /var/cache/apt/* \
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
...
...
ding/worker/collector/battle_interaction_serial_evaluator.py
浏览文件 @
5c6df8b3
...
...
@@ -17,7 +17,7 @@ from .base_serial_evaluator import ISerialEvaluator
class
BattleInteractionSerialEvaluator
(
ISerialEvaluator
):
"""
Overview:
1v1
battle evaluator class.
Multiple player
battle evaluator class.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Property:
...
...
@@ -108,8 +108,9 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
"""
assert
hasattr
(
self
,
'_env'
),
"please set env first"
if
_policy
is
not
None
:
assert
len
(
_policy
)
==
2
,
"1v1 serial evaluator needs 2
policy, but found {}"
.
format
(
len
(
_policy
))
assert
len
(
_policy
)
>
1
,
"battle evaluator needs more than 1
policy, but found {}"
.
format
(
len
(
_policy
))
self
.
_policy
=
_policy
self
.
_policy_num
=
len
(
self
.
_policy
)
for
p
in
self
.
_policy
:
p
.
reset
()
...
...
@@ -192,7 +193,7 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
assert
n_episode
is
not
None
,
"please indicate eval n_episode"
envstep_count
=
0
info
=
{}
return_info
=
[[]
for
_
in
range
(
2
)]
return_info
=
[[]
for
_
in
range
(
self
.
_policy_num
)]
eval_monitor
=
VectorEvalMonitor
(
self
.
_env
.
env_num
,
n_episode
)
self
.
_env
.
reset
()
for
p
in
self
.
_policy
:
...
...
@@ -223,7 +224,7 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
if
'episode_info'
in
t
.
info
[
0
]:
eval_monitor
.
update_info
(
env_id
,
t
.
info
[
0
][
'episode_info'
])
eval_monitor
.
update_reward
(
env_id
,
reward
)
for
policy_id
in
range
(
2
):
for
policy_id
in
range
(
self
.
_policy_num
):
return_info
[
policy_id
].
append
(
t
.
info
[
policy_id
])
self
.
_logger
.
info
(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}"
.
format
(
...
...
ding/worker/collector/battle_sample_serial_collector.py
浏览文件 @
5c6df8b3
from
typing
import
Optional
,
Any
,
List
,
Tuple
from
collections
import
namedtuple
,
deque
from
collections
import
namedtuple
from
easydict
import
EasyDict
import
numpy
as
np
import
torch
...
...
@@ -14,7 +14,7 @@ from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF,
class
BattleSampleSerialCollector
(
ISerialCollector
):
"""
Overview:
Sample collector(n_sample) with
two
policy battle
Sample collector(n_sample) with
multiple(n VS n)
policy battle
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
Property:
...
...
@@ -91,12 +91,17 @@ class BattleSampleSerialCollector(ISerialCollector):
"""
assert
hasattr
(
self
,
'_env'
),
"please set env first"
if
_policy
is
not
None
:
assert
len
(
_policy
)
==
2
,
"1v1 sample collector needs 2 policy, but found {}"
.
format
(
len
(
_policy
))
assert
len
(
_policy
)
>
1
,
"battle sample collector needs more than 1 policy, but found {}"
.
format
(
len
(
_policy
)
)
self
.
_policy
=
_policy
self
.
_policy_num
=
len
(
self
.
_policy
)
self
.
_default_n_sample
=
_policy
[
0
].
get_attribute
(
'cfg'
).
collect
.
get
(
'n_sample'
,
None
)
self
.
_unroll_len
=
_policy
[
0
].
get_attribute
(
'unroll_len'
)
self
.
_on_policy
=
_policy
[
0
].
get_attribute
(
'cfg'
).
on_policy
self
.
_policy_collect_data
=
[
getattr
(
self
.
_policy
[
i
],
'collect_data'
,
True
)
for
i
in
range
(
2
)]
self
.
_policy_collect_data
=
[
getattr
(
self
.
_policy
[
i
],
'collect_data'
,
True
)
for
i
in
range
(
self
.
_policy_num
)
]
if
self
.
_default_n_sample
is
not
None
:
self
.
_traj_len
=
max
(
self
.
_unroll_len
,
...
...
@@ -136,7 +141,7 @@ class BattleSampleSerialCollector(ISerialCollector):
# _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
self
.
_traj_buffer
=
{
env_id
:
{
policy_id
:
TrajBuffer
(
maxlen
=
self
.
_traj_len
)
for
policy_id
in
range
(
2
)}
for
policy_id
in
range
(
self
.
_policy_num
)}
for
env_id
in
range
(
self
.
_env_num
)
}
self
.
_env_info
=
{
env_id
:
{
'time'
:
0.
,
'step'
:
0
,
'train_sample'
:
0
}
for
env_id
in
range
(
self
.
_env_num
)}
...
...
@@ -221,9 +226,9 @@ class BattleSampleSerialCollector(ISerialCollector):
)
if
policy_kwargs
is
None
:
policy_kwargs
=
{}
collected_sample
=
[
0
for
_
in
range
(
2
)]
return_data
=
[[]
for
_
in
range
(
2
)]
return_info
=
[[]
for
_
in
range
(
2
)]
collected_sample
=
[
0
for
_
in
range
(
self
.
_policy_num
)]
return_data
=
[[]
for
_
in
range
(
self
.
_policy_num
)]
return_info
=
[[]
for
_
in
range
(
self
.
_policy_num
)]
while
any
([
c
<
n_sample
for
i
,
c
in
enumerate
(
collected_sample
)
if
self
.
_policy_collect_data
[
i
]]):
with
self
.
_timer
:
...
...
@@ -281,12 +286,12 @@ class BattleSampleSerialCollector(ISerialCollector):
if
timestep
.
done
:
self
.
_total_episode_count
+=
1
info
=
{
'reward0'
:
timestep
.
info
[
0
][
'final_eval_reward'
],
'reward1'
:
timestep
.
info
[
1
][
'final_eval_reward'
],
'time'
:
self
.
_env_info
[
env_id
][
'time'
],
'step'
:
self
.
_env_info
[
env_id
][
'step'
],
'train_sample'
:
self
.
_env_info
[
env_id
][
'train_sample'
],
}
for
i
in
range
(
self
.
_policy_num
):
info
[
'reward{}'
.
format
(
i
)]
=
timestep
.
info
[
i
][
'final_eval_reward'
]
self
.
_episode_info
.
append
(
info
)
for
i
,
p
in
enumerate
(
self
.
_policy
):
p
.
reset
([
env_id
])
...
...
@@ -311,8 +316,10 @@ class BattleSampleSerialCollector(ISerialCollector):
episode_count
=
len
(
self
.
_episode_info
)
envstep_count
=
sum
([
d
[
'step'
]
for
d
in
self
.
_episode_info
])
duration
=
sum
([
d
[
'time'
]
for
d
in
self
.
_episode_info
])
episode_reward0
=
[
d
[
'reward0'
]
for
d
in
self
.
_episode_info
]
episode_reward1
=
[
d
[
'reward1'
]
for
d
in
self
.
_episode_info
]
episode_reward
=
[]
for
i
in
range
(
self
.
_policy_num
):
episode_reward_item
=
[
d
[
'reward{}'
.
format
(
i
)]
for
d
in
self
.
_episode_info
]
episode_reward
.
append
(
episode_reward_item
)
self
.
_total_duration
+=
duration
info
=
{
'episode_count'
:
episode_count
,
...
...
@@ -321,18 +328,14 @@ class BattleSampleSerialCollector(ISerialCollector):
'avg_envstep_per_sec'
:
envstep_count
/
duration
,
'avg_episode_per_sec'
:
episode_count
/
duration
,
'collect_time'
:
duration
,
'reward0_mean'
:
np
.
mean
(
episode_reward0
),
'reward0_std'
:
np
.
std
(
episode_reward0
),
'reward0_max'
:
np
.
max
(
episode_reward0
),
'reward0_min'
:
np
.
min
(
episode_reward0
),
'reward1_mean'
:
np
.
mean
(
episode_reward1
),
'reward1_std'
:
np
.
std
(
episode_reward1
),
'reward1_max'
:
np
.
max
(
episode_reward1
),
'reward1_min'
:
np
.
min
(
episode_reward1
),
'total_envstep_count'
:
self
.
_total_envstep_count
,
'total_episode_count'
:
self
.
_total_episode_count
,
'total_duration'
:
self
.
_total_duration
,
}
for
k
,
fn
in
{
'mean'
:
np
.
mean
,
'std'
:
np
.
std
,
'max'
:
np
.
max
,
'min'
:
np
.
min
}.
items
():
for
i
in
range
(
self
.
_policy_num
):
# such as reward0_mean
info
[
'reward{}_{}'
.
format
(
i
,
k
)]
=
fn
(
episode_reward
[
i
])
self
.
_episode_info
.
clear
()
self
.
_logger
.
info
(
"collect end:
\n
{}"
.
format
(
'
\n
'
.
join
([
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
info
.
items
()])))
for
k
,
v
in
info
.
items
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录