提交 5c6df8b3 编写于 作者: N niuyazhe

Merge branch 'main' into feature/buffer

......@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
WORKDIR /ding
RUN apt update \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl gcc \g++ make locales -y \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git gcc \g++ make locales -y \
&& apt clean \
&& rm -rf /var/cache/apt/* \
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
......
......@@ -17,7 +17,7 @@ from .base_serial_evaluator import ISerialEvaluator
class BattleInteractionSerialEvaluator(ISerialEvaluator):
"""
Overview:
1v1 battle evaluator class.
Multiple player battle evaluator class.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Property:
......@@ -108,8 +108,9 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
assert len(_policy) == 2, "1v1 serial evaluator needs 2 policy, but found {}".format(len(_policy))
assert len(_policy) > 1, "battle evaluator needs more than 1 policy, but found {}".format(len(_policy))
self._policy = _policy
self._policy_num = len(self._policy)
for p in self._policy:
p.reset()
......@@ -192,7 +193,7 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0
info = {}
return_info = [[] for _ in range(2)]
return_info = [[] for _ in range(self._policy_num)]
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset()
for p in self._policy:
......@@ -223,7 +224,7 @@ class BattleInteractionSerialEvaluator(ISerialEvaluator):
if 'episode_info' in t.info[0]:
eval_monitor.update_info(env_id, t.info[0]['episode_info'])
eval_monitor.update_reward(env_id, reward)
for policy_id in range(2):
for policy_id in range(self._policy_num):
return_info[policy_id].append(t.info[policy_id])
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
......
from typing import Optional, Any, List, Tuple
from collections import namedtuple, deque
from collections import namedtuple
from easydict import EasyDict
import numpy as np
import torch
......@@ -14,7 +14,7 @@ from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF,
class BattleSampleSerialCollector(ISerialCollector):
"""
Overview:
Sample collector(n_sample) with two policy battle
Sample collector(n_sample) with multiple(n VS n) policy battle
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
Property:
......@@ -91,12 +91,17 @@ class BattleSampleSerialCollector(ISerialCollector):
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
assert len(_policy) == 2, "1v1 sample collector needs 2 policy, but found {}".format(len(_policy))
assert len(_policy) > 1, "battle sample collector needs more than 1 policy, but found {}".format(
len(_policy)
)
self._policy = _policy
self._policy_num = len(self._policy)
self._default_n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None)
self._unroll_len = _policy[0].get_attribute('unroll_len')
self._on_policy = _policy[0].get_attribute('cfg').on_policy
self._policy_collect_data = [getattr(self._policy[i], 'collect_data', True) for i in range(2)]
self._policy_collect_data = [
getattr(self._policy[i], 'collect_data', True) for i in range(self._policy_num)
]
if self._default_n_sample is not None:
self._traj_len = max(
self._unroll_len,
......@@ -136,7 +141,7 @@ class BattleSampleSerialCollector(ISerialCollector):
# _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
self._traj_buffer = {
env_id: {policy_id: TrajBuffer(maxlen=self._traj_len)
for policy_id in range(2)}
for policy_id in range(self._policy_num)}
for env_id in range(self._env_num)
}
self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)}
......@@ -221,9 +226,9 @@ class BattleSampleSerialCollector(ISerialCollector):
)
if policy_kwargs is None:
policy_kwargs = {}
collected_sample = [0 for _ in range(2)]
return_data = [[] for _ in range(2)]
return_info = [[] for _ in range(2)]
collected_sample = [0 for _ in range(self._policy_num)]
return_data = [[] for _ in range(self._policy_num)]
return_info = [[] for _ in range(self._policy_num)]
while any([c < n_sample for i, c in enumerate(collected_sample) if self._policy_collect_data[i]]):
with self._timer:
......@@ -281,12 +286,12 @@ class BattleSampleSerialCollector(ISerialCollector):
if timestep.done:
self._total_episode_count += 1
info = {
'reward0': timestep.info[0]['final_eval_reward'],
'reward1': timestep.info[1]['final_eval_reward'],
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'train_sample': self._env_info[env_id]['train_sample'],
}
for i in range(self._policy_num):
info['reward{}'.format(i)] = timestep.info[i]['final_eval_reward']
self._episode_info.append(info)
for i, p in enumerate(self._policy):
p.reset([env_id])
......@@ -311,8 +316,10 @@ class BattleSampleSerialCollector(ISerialCollector):
episode_count = len(self._episode_info)
envstep_count = sum([d['step'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_reward0 = [d['reward0'] for d in self._episode_info]
episode_reward1 = [d['reward1'] for d in self._episode_info]
episode_reward = []
for i in range(self._policy_num):
episode_reward_item = [d['reward{}'.format(i)] for d in self._episode_info]
episode_reward.append(episode_reward_item)
self._total_duration += duration
info = {
'episode_count': episode_count,
......@@ -321,18 +328,14 @@ class BattleSampleSerialCollector(ISerialCollector):
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward0_mean': np.mean(episode_reward0),
'reward0_std': np.std(episode_reward0),
'reward0_max': np.max(episode_reward0),
'reward0_min': np.min(episode_reward0),
'reward1_mean': np.mean(episode_reward1),
'reward1_std': np.std(episode_reward1),
'reward1_max': np.max(episode_reward1),
'reward1_min': np.min(episode_reward1),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
}
for k, fn in {'mean': np.mean, 'std': np.std, 'max': np.max, 'min': np.min}.items():
for i in range(self._policy_num):
# such as reward0_mean
info['reward{}_{}'.format(i, k)] = fn(episode_reward[i])
self._episode_info.clear()
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
for k, v in info.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册