From a490729ff7d18c7ef2507fe9c6f46ded9a140ec4 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Thu, 9 Dec 2021 16:28:24 +0800 Subject: [PATCH] feature(xjx): refactor buffer (#129) * Init base buffer and storage * Use ratelimit as middleware * Pass style check * Keep the return original return value * Add buffer.view * Add replace flag on sample, rewrite middleware processing * Test slicing * Add buffer copy middleware * Add update/delete api in buffer, rename middleware * Implement update and delete api of buffer * add naive use time count middleware in buffer * Rename next to chain * feature(nyz): add staleness check middleware and polish buffer * feature(nyz): add naive priority experience replay * Sample by indices * Combine buffer and storage layers * Support indices when deleting items from the queue * Use dataclass to save buffered data, remove return_index and return_meta * Add ignore_insufficient * polish(nyz): add return index in push and copy same data in sample * Drop useless import * Fix sample with indices, ensure return size is equal to input size or indices size * Make sure sampled data in buffer is different from each other * Support sample by grouped meta key * Support sample by rolling window * Add import/export data in buffer * Padding after sampling from buffer * Polish use_time_check * Use buffer as dataset * Set collate_fn in buffer test * feature(nyz): add deque buffer compatibility wrapper and demo * polish(nyz): polish code style and add pong dqn new deque buffer demo * feature(nyz): add use_time_count compatibility in wrapper * feature(nyz): add priority replay buffer compatibility in wrapper * Improve performance of buffer.update * polish(nyz): add priority max limit and correct flake8 * Use __call__ to rewrite middleware * Rewrite buffer index * Fix buffer delete * Skip first item * Rewrite buffer delete * Use caller * Use caller in priority * Add group sample Co-authored-by: niuyazhe --- ding/worker/__init__.py | 1 + ding/worker/buffer/__init__.py | 3 + ding/worker/buffer/buffer.py | 197 ++++++++++++ ding/worker/buffer/deque_buffer.py | 282 +++++++++++++++++ ding/worker/buffer/deque_buffer_wrapper.py | 110 +++++++ ding/worker/buffer/middleware/__init__.py | 6 + ding/worker/buffer/middleware/clone_object.py | 28 ++ ding/worker/buffer/middleware/group_sample.py | 37 +++ ding/worker/buffer/middleware/padding.py | 40 +++ ding/worker/buffer/middleware/priority.py | 132 ++++++++ .../buffer/middleware/staleness_check.py | 37 +++ .../buffer/middleware/use_time_check.py | 47 +++ ding/worker/buffer/tests/test_buffer.py | 292 ++++++++++++++++++ ding/worker/buffer/tests/test_middleware.py | 155 ++++++++++ ding/worker/buffer/utils/__init__.py | 1 + ding/worker/buffer/utils/fast_copy.py | 58 ++++ .../config/serial/pong/pong_dqn_config.py | 1 + .../config/lunarlander_dqn_deque_config.py | 78 +++++ .../entry/cartpole_dqn_buffer_main.py | 80 +++++ 19 files changed, 1585 insertions(+) create mode 100644 ding/worker/buffer/__init__.py create mode 100644 ding/worker/buffer/buffer.py create mode 100644 ding/worker/buffer/deque_buffer.py create mode 100644 ding/worker/buffer/deque_buffer_wrapper.py create mode 100644 ding/worker/buffer/middleware/__init__.py create mode 100644 ding/worker/buffer/middleware/clone_object.py create mode 100644 ding/worker/buffer/middleware/group_sample.py create mode 100644 ding/worker/buffer/middleware/padding.py create mode 100644 ding/worker/buffer/middleware/priority.py create mode 100644 ding/worker/buffer/middleware/staleness_check.py create mode 100644 ding/worker/buffer/middleware/use_time_check.py create mode 100644 ding/worker/buffer/tests/test_buffer.py create mode 100644 ding/worker/buffer/tests/test_middleware.py create mode 100644 ding/worker/buffer/utils/__init__.py create mode 100644 ding/worker/buffer/utils/fast_copy.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py create mode 100644 dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py diff --git a/ding/worker/__init__.py b/ding/worker/__init__.py index f4c5b41..4a231e0 100644 --- a/ding/worker/__init__.py +++ b/ding/worker/__init__.py @@ -3,3 +3,4 @@ from .learner import * from .replay_buffer import * from .coordinator import * from .adapter import * +from .buffer import * diff --git a/ding/worker/buffer/__init__.py b/ding/worker/buffer/__init__.py new file mode 100644 index 0000000..d3cf7c0 --- /dev/null +++ b/ding/worker/buffer/__init__.py @@ -0,0 +1,3 @@ +from .buffer import Buffer, apply_middleware, BufferedData +from .deque_buffer import DequeBuffer +from .deque_buffer_wrapper import DequeBufferWrapper diff --git a/ding/worker/buffer/buffer.py b/ding/worker/buffer/buffer.py new file mode 100644 index 0000000..1b6db8c --- /dev/null +++ b/ding/worker/buffer/buffer.py @@ -0,0 +1,197 @@ +from abc import abstractmethod +from typing import Any, List, Optional, Union, Callable +import copy +from dataclasses import dataclass + + +def apply_middleware(func_name: str): + + def wrap_func(base_func: Callable): + + def handler(buffer, *args, **kwargs): + """ + Overview: + The real processing starts here, we apply the middleware one by one, + each middleware will receive next `chained` function, which is an executor of next + middleware. You can change the input arguments to the next `chained` middleware, and you + also can get the return value from the next middleware, so you have the + maximum freedom to choose at what stage to implement your method. + """ + + def wrap_handler(middleware, *args, **kwargs): + if len(middleware) == 0: + return base_func(buffer, *args, **kwargs) + + def chain(*args, **kwargs): + return wrap_handler(middleware[1:], *args, **kwargs) + + func = middleware[0] + return func(func_name, chain, *args, **kwargs) + + return wrap_handler(buffer.middleware, *args, **kwargs) + + return handler + + return wrap_func + + +@dataclass +class BufferedData: + data: Any + index: str + meta: dict + + +class Buffer: + """ + Buffer is an abstraction of device storage, third-party services or data structures, + For example, memory queue, sum-tree, redis, or di-store. + """ + + def __init__(self) -> None: + self.middleware = [] + + @abstractmethod + def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: + """ + Overview: + Push data and it's meta information in buffer. + Arguments: + - data (:obj:`Any`): The data which will be pushed into buffer. + - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. + Returns: + - buffered_data (:obj:`BufferedData`): The pushed data. + """ + raise NotImplementedError + + @abstractmethod + def sample( + self, + size: Optional[int] = None, + indices: Optional[List[str]] = None, + replace: bool = False, + sample_range: Optional[slice] = None, + ignore_insufficient: bool = False, + groupby: str = None, + rolling_window: int = None + ) -> Union[List[BufferedData], List[List[BufferedData]]]: + """ + Overview: + Sample data with length ``size``. + Arguments: + - size (:obj:`Optional[int]`): The number of the data that will be sampled. + - indices (:obj:`Optional[List[str]]`): Sample with multiple indices. + - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer. + - sample_range (:obj:`slice`): Sample range slice. + - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size + with no repetition will not cause an exception. + - groupby (:obj:`str`): Groupby key in meta. + - rolling_window (:obj:`int`): Return batches of window size. + Returns: + - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`): + A list of data with length ``size``, may be nested if groupby or rolling_window is set. + """ + raise NotImplementedError + + @abstractmethod + def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: + """ + Overview: + Update data and meta by index + Arguments: + - index (:obj:`str`): Index of data. + - data (:obj:`any`): Pure data. + - meta (:obj:`dict`): Meta information. + Returns: + - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false. + """ + raise NotImplementedError + + @abstractmethod + def batch_update( + self, + indices: List[str], + datas: Optional[List[Optional[Any]]] = None, + metas: Optional[List[Optional[dict]]] = None + ) -> None: + """ + Overview: + Batch update data and meta by indices, maybe useful in some data architectures. + Arguments: + - indices (:obj:`List[str]`): Index of data. + - datas (:obj:`Optional[List[Optional[Any]]]`): Pure data. + - metas (:obj:`Optional[List[Optional[dict]]]`): Meta information. + """ + raise NotImplementedError + + @abstractmethod + def delete(self, index: str): + """ + Overview: + Delete one data sample by index + Arguments: + - index (:obj:`str`): Index + """ + raise NotImplementedError + + @abstractmethod + def count(self) -> int: + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + raise NotImplementedError + + @abstractmethod + def get(self, idx: int) -> BufferedData: + """ + Overview: + Get item by subscript index + Arguments: + - idx (:obj:`int`): Subscript index + Returns: + - buffered_data (:obj:`BufferedData`): Item from buffer + """ + raise NotImplementedError + + def use(self, func: Callable) -> "Buffer": + r""" + Overview: + Use algorithm middleware to modify the behavior of the buffer. + Every middleware should be a callable function, it will receive three argument parts, including: + 1. The buffer instance, you can use this instance to visit every thing of the buffer, + including the storage. + 2. The functions called by the user, there are three methods named `push`, `sample` and `clear`, + so you can use these function name to decide which action to choose. + 3. The remaining arguments passed by the user to the original function, will be passed in *args. + + Each middleware handler should return two parts of the value, including: + 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, + no more middleware will be executed during this execution + 2. The remaining values, will be passed to the next middleware or the default function in the buffer. + Arguments: + - func (:obj:`Callable`): The middleware handler + Returns: + - buffer (:obj:`Buffer`): The instance self + """ + self.middleware.append(func) + return self + + def view(self) -> "Buffer": + r""" + Overview: + A view is a new instance of buffer, with a deepcopy of every property except the storage. + The storage is shared among all the buffer instances. + Returns: + - buffer (:obj:`Buffer`): The instance self + """ + return copy.copy(self) + + def __copy__(self) -> "Buffer": + raise NotImplementedError + + def __len__(self) -> int: + return self.count() + + def __getitem__(self, idx: int) -> BufferedData: + return self.get(idx) diff --git a/ding/worker/buffer/deque_buffer.py b/ding/worker/buffer/deque_buffer.py new file mode 100644 index 0000000..dcb34cf --- /dev/null +++ b/ding/worker/buffer/deque_buffer.py @@ -0,0 +1,282 @@ +from typing import Any, Iterable, List, Optional, Tuple, Union +from collections import defaultdict, deque, OrderedDict + +from ding.worker.buffer import Buffer, apply_middleware, BufferedData +from ding.worker.buffer.utils import fastcopy +import itertools +import random +import uuid +import logging + + +class BufferIndex(): + """ + Overview: + Save index string and offset in key value pair. + """ + + def __init__(self, maxlen: int, *args, **kwargs): + self.maxlen = maxlen + self.__map = OrderedDict(*args, **kwargs) + self._last_key = next(reversed(self.__map)) if len(self) > 0 else None + self._cumlen = len(self.__map) + + def get(self, key: str) -> int: + value = self.__map[key] + value = value % self._cumlen + min(0, (self.maxlen - self._cumlen)) + return value + + def __len__(self) -> int: + return len(self.__map) + + def has(self, key: str) -> bool: + return key in self.__map + + def append(self, key: str): + self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0 + self._last_key = key + self._cumlen += 1 + if len(self) > self.maxlen: + self.__map.popitem(last=False) + + def clear(self): + self.__map = OrderedDict() + self._last_key = None + self._cumlen = 0 + + +class DequeBuffer(Buffer): + + def __init__(self, size: int) -> None: + super().__init__() + self.storage = deque(maxlen=size) + # Meta index is a dict which use deque as values + self.indices = BufferIndex(maxlen=size) + self.meta_index = {} + + @apply_middleware("push") + def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: + return self._push(data, meta) + + @apply_middleware("sample") + def sample( + self, + size: Optional[int] = None, + indices: Optional[List[str]] = None, + replace: bool = False, + sample_range: Optional[slice] = None, + ignore_insufficient: bool = False, + groupby: str = None, + rolling_window: int = None + ) -> Union[List[BufferedData], List[List[BufferedData]]]: + storage = self.storage + if sample_range: + storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step)) + + # Size and indices + assert size or indices, "One of size and indices must not be empty." + if (size and indices) and (size != len(indices)): + raise AssertionError("Size and indices length must be equal.") + if not size: + size = len(indices) + # Indices and groupby + assert not (indices and groupby), "Cannot use groupby and indicex at the same time." + # Groupby and rolling_window + assert not (groupby and rolling_window), "Cannot use groupby and rolling_window at the same time." + assert not (indices and rolling_window), "Cannot use indices and rolling_window at the same time." + + value_error = None + sampled_data = [] + if indices: + indices_set = set(indices) + hashed_data = filter(lambda item: item.index in indices_set, storage) + hashed_data = map(lambda item: (item.index, item), hashed_data) + hashed_data = dict(hashed_data) + # Re-sample and return in indices order + sampled_data = [hashed_data[index] for index in indices] + elif groupby: + sampled_data = self._sample_by_group(size=size, groupby=groupby, replace=replace, storage=storage) + elif rolling_window: + sampled_data = self._sample_by_rolling_window( + size=size, replace=replace, rolling_window=rolling_window, storage=storage + ) + else: + if replace: + sampled_data = random.choices(storage, k=size) + else: + try: + sampled_data = random.sample(storage, k=size) + except ValueError as e: + value_error = e + + if value_error or len(sampled_data) != size: + if ignore_insufficient: + logging.warning( + "Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}". + format(self.count(), size) + ) + else: + raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count())) + + sampled_data = self._independence(sampled_data) + + return sampled_data + + @apply_middleware("update") + def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: + if not self.indices.has(index): + return False + i = self.indices.get(index) + item = self.storage[i] + if data is not None: + item.data = data + if meta is not None: + item.meta = meta + for key in self.meta_index: + self.meta_index[key][i] = meta[key] if key in meta else None + return True + + @apply_middleware("delete") + def delete(self, indices: Union[str, Iterable[str]]) -> None: + if isinstance(indices, str): + indices = [indices] + del_idx = [] + for index in indices: + if self.indices.has(index): + del_idx.append(self.indices.get(index)) + if len(del_idx) == 0: + return + del_idx = sorted(del_idx, reverse=True) + for idx in del_idx: + del self.storage[idx] + remain_indices = [item.index for item in self.storage] + key_value_pairs = zip(remain_indices, range(len(indices))) + self.indices = BufferIndex(self.storage.maxlen, key_value_pairs) + + def count(self) -> int: + return len(self.storage) + + def get(self, idx: int) -> BufferedData: + return self.storage[idx] + + @apply_middleware("clear") + def clear(self) -> None: + self.storage.clear() + self.indices.clear() + self.meta_index = {} + + def import_data(self, data_with_meta: List[Tuple[Any, dict]]) -> None: + for data, meta in data_with_meta: + self._push(data, meta) + + def export_data(self) -> List[BufferedData]: + return list(self.storage) + + def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: + index = uuid.uuid1().hex + if meta is None: + meta = {} + buffered = BufferedData(data=data, index=index, meta=meta) + self.storage.append(buffered) + self.indices.append(index) + # Add meta index + for key in self.meta_index: + self.meta_index[key].append(meta[key] if key in meta else None) + + return buffered + + def _independence( + self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]] + ) -> Union[List[BufferedData], List[List[BufferedData]]]: + """ + Overview: + Make sure that each record is different from each other, but remember that this function + is different from clone_object. You may change the data in the buffer by modifying a record. + Arguments: + - buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data, + can be nested if groupby or rolling_window has been set. + """ + if len(buffered_samples) == 0: + return buffered_samples + occurred = defaultdict(int) + + for i, buffered in enumerate(buffered_samples): + if isinstance(buffered, list): + sampled_list = buffered + # Loop over nested samples + for j, buffered in enumerate(sampled_list): + occurred[buffered.index] += 1 + if occurred[buffered.index] > 1: + sampled_list[j] = fastcopy.copy(buffered) + elif isinstance(buffered, BufferedData): + occurred[buffered.index] += 1 + if occurred[buffered.index] > 1: + buffered_samples[i] = fastcopy.copy(buffered) + else: + raise Exception("Get unexpected buffered type {}".format(type(buffered))) + return buffered_samples + + def _sample_by_group(self, + size: int, + groupby: str, + replace: bool = False, + storage: deque = None) -> List[List[BufferedData]]: + """ + Overview: + Sampling by `group` instead of records, the result will be a collection + of lists with a length of `size`, but the length of each list may be different from other lists. + """ + if storage is None: + storage = self.storage + if groupby not in self.meta_index: + self._create_index(groupby) + meta_indices = list(set(self.meta_index[groupby])) + sampled_groups = [] + if replace: + sampled_groups = random.choices(meta_indices, k=size) + else: + try: + sampled_groups = random.sample(meta_indices, k=size) + except ValueError: + raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(meta_indices))) + sampled_data = defaultdict(list) + for buffered in storage: + meta_value = buffered.meta[groupby] if groupby in buffered.meta else None + if meta_value in sampled_groups: + sampled_data[buffered.meta[groupby]].append(buffered) + return list(sampled_data.values()) + + def _sample_by_rolling_window( + self, + size: Optional[int] = None, + replace: bool = False, + rolling_window: int = None, + storage: deque = None + ) -> List[List[BufferedData]]: + if storage is None: + storage = self.storage + if replace: + sampled_indices = random.choices(range(len(storage)), k=size) + else: + try: + sampled_indices = random.sample(range(len(storage)), k=size) + except ValueError as e: + pass + sampled_data = [] + for idx in sampled_indices: + slice_ = list(itertools.islice(storage, idx, idx + rolling_window)) + sampled_data.append(slice_) + return sampled_data + + def _create_index(self, meta_key: str): + self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen) + for data in self.storage: + self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None) + + def __iter__(self) -> deque: + return iter(self.storage) + + def __copy__(self) -> "DequeBuffer": + buffer = type(self)(size=self.storage.maxlen) + buffer.storage = self.storage + return buffer diff --git a/ding/worker/buffer/deque_buffer_wrapper.py b/ding/worker/buffer/deque_buffer_wrapper.py new file mode 100644 index 0000000..3e35afc --- /dev/null +++ b/ding/worker/buffer/deque_buffer_wrapper.py @@ -0,0 +1,110 @@ +from typing import Optional +import copy +from easydict import EasyDict +import numpy as np + +from ding.worker.buffer import DequeBuffer +from ding.worker.buffer.middleware import use_time_check, PriorityExperienceReplay +from ding.utils import BUFFER_REGISTRY + + +@BUFFER_REGISTRY.register('deque') +class DequeBufferWrapper(object): + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + config = dict( + replay_buffer_size=10000, + max_use=float("inf"), + train_iter_per_log=100, + priority=False, + priority_IS_weight=False, + priority_power_factor=0.6, + IS_weight_power_factor=0.4, + IS_weight_anneal_train_iter=int(1e5), + priority_max_limit=1000, + ) + + def __init__( + self, + cfg: EasyDict, + tb_logger: Optional[object] = None, + exp_name: str = 'default_experiement', + instance_name: str = 'buffer' + ) -> None: + self.cfg = cfg + self.priority_max_limit = cfg.priority_max_limit + self.name = '{}_iter'.format(instance_name) + self.tb_logger = tb_logger + self.buffer = DequeBuffer(size=cfg.replay_buffer_size) + self.last_log_train_iter = -1 + + # use_count middleware + if self.cfg.max_use != float("inf"): + self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use)) + # priority middleware + if self.cfg.priority: + self.buffer.use( + PriorityExperienceReplay( + self.buffer, + self.cfg.replay_buffer_size, + IS_weight=self.cfg.priority_IS_weight, + priority_power_factor=self.cfg.priority_power_factor, + IS_weight_power_factor=self.cfg.IS_weight_power_factor, + IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter + ) + ) + self.last_sample_index = None + self.last_sample_meta = None + + def sample(self, size: int, train_iter: int): + output = self.buffer.sample(size=size, ignore_insufficient=True) + if len(output) > 0: + if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log: + meta = [o.meta for o in output] + if self.cfg.max_use != float("inf"): + use_count_avg = np.mean([m['use_count'] for m in meta]) + self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter) + if self.cfg.priority: + self.last_sample_index = [o.index for o in output] + self.last_sample_meta = meta + priority_list = [m['priority'] for m in meta] + priority_avg = np.mean(priority_list) + priority_max = np.max(priority_list) + self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter) + self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter) + self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter) + + data = [o.data for o in output] + if self.cfg.priority_IS_weight: + IS = [o.meta['priority_IS'] for o in output] + for i in range(len(data)): + data[i]['IS'] = IS[i] + return data + else: + return None + + def push(self, data, cur_collector_envstep: int = -1) -> None: + for d in data: + meta = {} + if self.cfg.priority and 'priority' in d: + init_priority = d.pop('priority') + meta['priority'] = init_priority + self.buffer.push(d, meta=meta) + + def update(self, meta: dict) -> None: + if not self.cfg.priority: + return + if self.last_sample_index is None: + return + new_meta = self.last_sample_meta + for m, p in zip(new_meta, meta['priority']): + m['priority'] = min(self.priority_max_limit, p) + for idx, m in zip(self.last_sample_index, new_meta): + self.buffer.update(idx, data=None, meta=m) + self.last_sample_index = None + self.last_sample_meta = None diff --git a/ding/worker/buffer/middleware/__init__.py b/ding/worker/buffer/middleware/__init__.py new file mode 100644 index 0000000..056e18a --- /dev/null +++ b/ding/worker/buffer/middleware/__init__.py @@ -0,0 +1,6 @@ +from .clone_object import clone_object +from .use_time_check import use_time_check +from .staleness_check import staleness_check +from .priority import PriorityExperienceReplay +from .padding import padding +from .group_sample import group_sample diff --git a/ding/worker/buffer/middleware/clone_object.py b/ding/worker/buffer/middleware/clone_object.py new file mode 100644 index 0000000..9f74842 --- /dev/null +++ b/ding/worker/buffer/middleware/clone_object.py @@ -0,0 +1,28 @@ +from typing import Callable, Any, List, Union +from ding.worker.buffer import BufferedData +from ding.worker.buffer.utils import fastcopy + + +def clone_object(): + """ + This middleware freezes the objects saved in memory buffer as a copy, + try this middleware when you need to keep the object unchanged in buffer, and modify + the object after sampling it (usuallly in multiple threads) + """ + + def push(chain: Callable, data: Any, *args, **kwargs) -> BufferedData: + data = fastcopy.copy(data) + return chain(data, *args, **kwargs) + + def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: + data = chain(*args, **kwargs) + return fastcopy.copy(data) + + def _clone_object(action: str, chain: Callable, *args, **kwargs): + if action == "push": + return push(chain, *args, **kwargs) + elif action == "sample": + return sample(chain, *args, **kwargs) + return chain(*args, **kwargs) + + return _clone_object diff --git a/ding/worker/buffer/middleware/group_sample.py b/ding/worker/buffer/middleware/group_sample.py new file mode 100644 index 0000000..6956cfc --- /dev/null +++ b/ding/worker/buffer/middleware/group_sample.py @@ -0,0 +1,37 @@ +import random +from typing import Callable, List +from ding.worker.buffer.buffer import BufferedData + + +def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable: + """ + Overview: + The middleware is designed to process the data in each group after sampling from the buffer. + Arguments: + - size_in_group (:obj:`int`): Sample size in each group. + - ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true. + - max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true. + """ + + def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]: + if not kwargs.get("groupby"): + raise Exception("Group sample must be used when the `groupby` parameter is specified.") + sampled_data = chain(*args, **kwargs) + for i, grouped_data in enumerate(sampled_data): + if ordered_in_group: + if max_use_in_group: + end = max(0, len(grouped_data) - size_in_group) + 1 + else: + end = len(grouped_data) + start_idx = random.choice(range(end)) + sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group] + else: + sampled_data[i] = random.sample(grouped_data, k=size_in_group) + return sampled_data + + def _group_sample(action: str, chain: Callable, *args, **kwargs): + if action == "sample": + return sample(chain, *args, **kwargs) + return chain(*args, **kwargs) + + return _group_sample diff --git a/ding/worker/buffer/middleware/padding.py b/ding/worker/buffer/middleware/padding.py new file mode 100644 index 0000000..8e0dbf2 --- /dev/null +++ b/ding/worker/buffer/middleware/padding.py @@ -0,0 +1,40 @@ +import random +from typing import Callable, Union, List + +from ding.worker.buffer import BufferedData +from ding.worker.buffer.utils import fastcopy + + +def padding(policy="random"): + """ + Overview: + Fill the nested buffer list to the same size as the largest list. + The default policy `random` will randomly select data from each group + and fill it into the current group list. + Arguments: + - policy (:obj:`str`): Padding policy, supports `random`, `none`. + """ + + def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: + sampled_data = chain(*args, **kwargs) + if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData): + return sampled_data + max_len = len(max(sampled_data, key=len)) + for i, grouped_data in enumerate(sampled_data): + group_len = len(grouped_data) + if group_len == max_len: + continue + for _ in range(max_len - group_len): + if policy == "random": + sampled_data[i].append(fastcopy.copy(random.choice(grouped_data))) + elif policy == "none": + sampled_data[i].append(BufferedData(data=None, index=None, meta=None)) + + return sampled_data + + def _padding(action: str, chain: Callable, *args, **kwargs): + if action == "sample": + return sample(chain, *args, **kwargs) + return chain(*args, **kwargs) + + return _padding diff --git a/ding/worker/buffer/middleware/priority.py b/ding/worker/buffer/middleware/priority.py new file mode 100644 index 0000000..ef1dadb --- /dev/null +++ b/ding/worker/buffer/middleware/priority.py @@ -0,0 +1,132 @@ +from typing import Callable, Any, List, Dict, Optional, Union +import copy +import numpy as np +from ding.utils import SumSegmentTree, MinSegmentTree +from ding.worker.buffer.buffer import BufferedData + + +class PriorityExperienceReplay: + + def __init__( + self, + buffer: 'Buffer', # noqa + buffer_size: int, + IS_weight: bool = True, + priority_power_factor: float = 0.6, + IS_weight_power_factor: float = 0.4, + IS_weight_anneal_train_iter: int = int(1e5), + ) -> None: + self.buffer = buffer + self.buffer_idx = {} + self.buffer_size = buffer_size + self.IS_weight = IS_weight + self.priority_power_factor = priority_power_factor + self.IS_weight_power_factor = IS_weight_power_factor + self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter + + # Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data. + self.max_priority = 1.0 + # Capacity needs to be the power of 2. + capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) + self.sum_tree = SumSegmentTree(capacity) + if self.IS_weight: + self.min_tree = MinSegmentTree(capacity) + self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter + self.pivot = 0 + + def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: + if meta is None: + meta = {'priority': self.max_priority} + else: + if 'priority' not in meta: + meta['priority'] = self.max_priority + meta['priority_idx'] = self.pivot + self._update_tree(meta['priority'], self.pivot) + buffered = chain(data, meta=meta, *args, **kwargs) + index = buffered.index + self.buffer_idx[self.pivot] = index + self.pivot = (self.pivot + 1) % self.buffer_size + return buffered + + def sample(self, chain: Callable, size: int, *args, + **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: + # Divide [0, 1) into size intervals on average + intervals = np.array([i * 1.0 / size for i in range(size)]) + # Uniformly sample within each interval + mass = intervals + np.random.uniform(size=(size, )) * 1. / size + # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) + mass *= self.sum_tree.reduce() + indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass] + indices = [self.buffer_idx[i] for i in indices] + # Sample with indices + data = chain(indices=indices, *args, **kwargs) + if self.IS_weight: + # Calculate max weight for normalizing IS + sum_tree_root = self.sum_tree.reduce() + p_min = self.min_tree.reduce() / sum_tree_root + buffer_count = self.buffer.count() + max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor) + for i in range(len(data)): + meta = data[i].meta + priority_idx = meta['priority_idx'] + p_sample = self.sum_tree[priority_idx] / sum_tree_root + weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor) + meta['priority_IS'] = weight / max_weight + self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal) + return data + + def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None: + update_flag = chain(index, data, meta, *args, **kwargs) + if update_flag: # when update succeed + assert meta is not None, "Please indicate dict-type meta in priority update" + new_priority, idx = meta['priority'], meta['priority_idx'] + assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority) + new_priority += 1e-5 # Add epsilon to avoid priority == 0 + self._update_tree(new_priority, idx) + self.max_priority = max(self.max_priority, new_priority) + + def delete(self, chain: Callable, index: str, *args, **kwargs) -> None: + for item in self.buffer.storage: + meta = item.meta + priority_idx = meta['priority_idx'] + self.sum_tree[priority_idx] = self.sum_tree.neutral_element + self.min_tree[priority_idx] = self.min_tree.neutral_element + self.buffer_idx.pop(priority_idx) + return chain(index, *args, **kwargs) + + def clear(self, chain: Callable) -> None: + self.max_priority = 1.0 + capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) + self.sum_tree = SumSegmentTree(capacity) + if self.IS_weight: + self.min_tree = MinSegmentTree(capacity) + self.buffer_idx = {} + self.pivot = 0 + chain() + + def _update_tree(self, priority: float, idx: int) -> None: + weight = priority ** self.priority_power_factor + self.sum_tree[idx] = weight + if self.IS_weight: + self.min_tree[idx] = weight + + def state_dict(self) -> Dict: + return { + 'max_priority': self.max_priority, + 'IS_weight_power_factor': self.IS_weight_power_factor, + 'sumtree': self.sumtree, + 'mintree': self.mintree, + 'buffer_idx': self.buffer_idx, + } + + def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None: + for k, v in _state_dict.items(): + if deepcopy: + setattr(self, '{}'.format(k), copy.deepcopy(v)) + else: + setattr(self, '{}'.format(k), v) + + def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any: + if action in ["push", "sample", "update", "delete", "clear"]: + return getattr(self, action)(chain, *args, **kwargs) + return chain(*args, **kwargs) diff --git a/ding/worker/buffer/middleware/staleness_check.py b/ding/worker/buffer/middleware/staleness_check.py new file mode 100644 index 0000000..1bad6a4 --- /dev/null +++ b/ding/worker/buffer/middleware/staleness_check.py @@ -0,0 +1,37 @@ +from typing import Callable, Any, List + + +def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable: # noqa + """ + Overview: + This middleware aims to check staleness before each sample operation, + staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is, + If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible. + """ + + def push(next: Callable, data: Any, *args, **kwargs) -> Any: + assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[ + 'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': }" + return next(data, *args, **kwargs) + + def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]: + delete_index = [] + for i, item in enumerate(buffer_.storage): + index, meta = item.index, item.meta + staleness = train_iter_sample_data - meta['train_iter_data_collected'] + meta['staleness'] = staleness + if staleness > max_staleness: + delete_index.append(index) + for index in delete_index: + buffer_.delete(index) + data = next(*args, **kwargs) + return data + + def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any: + if action == "push": + return push(next, *args, **kwargs) + elif action == "sample": + return sample(next, *args, **kwargs) + return next(*args, **kwargs) + + return _staleness_check diff --git a/ding/worker/buffer/middleware/use_time_check.py b/ding/worker/buffer/middleware/use_time_check.py new file mode 100644 index 0000000..871137a --- /dev/null +++ b/ding/worker/buffer/middleware/use_time_check.py @@ -0,0 +1,47 @@ +from collections import defaultdict +from typing import Callable, Any, List, Optional, Union +from ding.worker.buffer import BufferedData + + +def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable: # noqa + """ + Overview: + This middleware aims to check the usage times of data in buffer. If the usage times of a data is + greater than or equal to max_use, this data will be removed from buffer as soon as possible. + """ + use_count = defaultdict(int) + + def _need_delete(item: BufferedData) -> bool: + nonlocal use_count + idx = item.index + use_count[idx] += 1 + item.meta['use_count'] = use_count[idx] + if use_count[idx] >= max_use: + return True + else: + return False + + def _check_use_count(sampled_data: List[BufferedData]): + delete_indices = [item.index for item in filter(_need_delete, sampled_data)] + buffer_.delete(delete_indices) + for index in delete_indices: + del use_count[index] + + def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: + sampled_data = chain(*args, **kwargs) + if len(sampled_data) == 0: + return sampled_data + + if isinstance(sampled_data[0], BufferedData): + _check_use_count(sampled_data) + else: + for grouped_data in sampled_data: + _check_use_count(grouped_data) + return sampled_data + + def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any: + if action == "sample": + return sample(chain, *args, **kwargs) + return chain(*args, **kwargs) + + return _use_time_check diff --git a/ding/worker/buffer/tests/test_buffer.py b/ding/worker/buffer/tests/test_buffer.py new file mode 100644 index 0000000..da001d3 --- /dev/null +++ b/ding/worker/buffer/tests/test_buffer.py @@ -0,0 +1,292 @@ +import pytest +import time +import random +from typing import Callable +from ding.worker.buffer import DequeBuffer +from ding.worker.buffer.buffer import BufferedData +from torch.utils.data import DataLoader + + +class RateLimit: + r""" + Add rate limit threshold to push function + """ + + def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None: + self.max_rate = max_rate + self.window_seconds = window_seconds + self.buffered = [] + + def __call__(self, action: str, chain: Callable, *args, **kwargs): + if action == "push": + return self.push(chain, *args, **kwargs) + return chain(*args, **kwargs) + + def push(self, chain, data, *args, **kwargs) -> None: + current = time.time() + # Cut off stale records + self.buffered = [t for t in self.buffered if t > current - self.window_seconds] + if len(self.buffered) < self.max_rate: + self.buffered.append(current) + return chain(data, *args, **kwargs) + else: + return None + + +def add_10() -> Callable: + """ + Transform data on sampling + """ + + def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs): + sampled_data = chain(size, replace, *args, **kwargs) + return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data] + + def _subview(action: str, chain: Callable, *args, **kwargs): + if action == "sample": + return sample(chain, *args, **kwargs) + return chain(*args, **kwargs) + + return _subview + + +@pytest.mark.unittest +def test_naive_push_sample(): + # Push and sample + buffer = DequeBuffer(size=10) + for i in range(20): + buffer.push(i) + assert buffer.count() == 10 + assert 0 not in [item.data for item in buffer.sample(10)] + + # Clear + buffer.clear() + assert buffer.count() == 0 + + # Test replace sample + for i in range(5): + buffer.push(i) + assert buffer.count() == 5 + assert len(buffer.sample(10, replace=True)) == 10 + + # Test slicing + buffer.clear() + for i in range(10): + buffer.push(i) + assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5 + assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))] + + +@pytest.mark.unittest +def test_rate_limit_push_sample(): + buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5)) + for i in range(10): + buffer.push(i) + assert buffer.count() == 5 + assert 5 not in buffer.sample(5) + + +@pytest.mark.unittest +def test_buffer_view(): + buf1 = DequeBuffer(size=10) + for i in range(1): + buf1.push(i) + assert buf1.count() == 1 + + buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10()) + + for i in range(10): + buf2.push(i) + # With 1 record written by buf1 and 5 records written by buf2 + assert len(buf1.middleware) == 0 + assert buf1.count() == 6 + # All data in buffer should bigger than 10 because of `add_10` + assert all(d.data >= 10 for d in buf2.sample(5)) + # But data in storage is still less than 10 + assert all(d.data < 10 for d in buf1.sample(5)) + + +@pytest.mark.unittest +def test_sample_with_index(): + buf = DequeBuffer(size=10) + for i in range(10): + buf.push({"data": i}, {"meta": i}) + # Random sample and get indices + indices = [item.index for item in buf.sample(10)] + assert len(indices) == 10 + random.shuffle(indices) + indices = indices[:5] + + # Resample by indices + new_indices = [item.index for item in buf.sample(indices=indices)] + assert len(new_indices) == len(indices) + for index in new_indices: + assert index in indices + + +@pytest.mark.unittest +def test_update(): + buf = DequeBuffer(size=10) + for i in range(1): + buf.push({"data": i}, {"meta": i}) + + # Update one data + [item] = buf.sample(1) + item.data["new_prop"] = "any" + meta = None + success = buf.update(item.index, item.data, item.meta) + assert success + # Resample + [item] = buf.sample(1) + assert "new_prop" in item.data + assert meta is None + # Update object that not exists in buffer + success = buf.update("invalidindex", {}, None) + assert not success + + # When exceed buffer size + for i in range(20): + buf.push({"data": i}) + assert len(buf.indices) == 10 + assert len(buf.storage) == 10 + for i in range(10): + index = buf.storage[i].index + assert buf.indices.get(index) == i + + +@pytest.mark.unittest +def test_delete(): + maxlen = 100 + cumlen = 40 + dellen = 20 + buf = DequeBuffer(size=maxlen) + for i in range(cumlen): + buf.push(i) + # Delete data + del_indices = [item.index for item in buf.sample(dellen)] + buf.delete(del_indices) + # Reappend + for i in range(10): + buf.push(i) + remlen = min(cumlen, maxlen) - dellen + 10 + assert len(buf.indices) == remlen + assert len(buf.storage) == remlen + for i in range(remlen): + index = buf.storage[i].index + assert buf.indices.get(index) == i + + +@pytest.mark.unittest +def test_ignore_insufficient(): + buffer = DequeBuffer(size=10) + for i in range(2): + buffer.push(i) + + with pytest.raises(ValueError): + buffer.sample(3, ignore_insufficient=False) + data = buffer.sample(3, ignore_insufficient=True) + assert len(data) == 0 + + +@pytest.mark.unittest +def test_independence(): + # By replace + buffer = DequeBuffer(size=1) + data = {"key": "origin"} + buffer.push(data) + sampled_data = buffer.sample(2, replace=True) + assert len(sampled_data) == 2 + sampled_data[0].data["key"] = "new" + assert sampled_data[1].data["key"] == "origin" + + # By indices + buffer = DequeBuffer(size=1) + data = {"key": "origin"} + buffered = buffer.push(data) + indices = [buffered.index, buffered.index] + sampled_data = buffer.sample(indices=indices) + assert len(sampled_data) == 2 + sampled_data[0].data["key"] = "new" + assert sampled_data[1].data["key"] == "origin" + + +@pytest.mark.unittest +def test_groupby(): + buffer = DequeBuffer(size=3) + buffer.push("a", {"group": 1}) + buffer.push("b", {"group": 2}) + buffer.push("c", {"group": 2}) + + sampled_data = buffer.sample(2, groupby="group") + assert len(sampled_data) == 2 + group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1] + group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1] + # Group1 should contain a + assert "a" == group1[0].data + # Group2 should contain b and c + data = [buffered.data for buffered in group2] # ["b", "c"] + assert "b" in data + assert "c" in data + + # Push new data and swap out a, the result will all in group 2 + buffer.push("d", {"group": 2}) + sampled_data = buffer.sample(1, groupby="group") + assert len(sampled_data) == 1 + assert len(sampled_data[0]) == 3 + data = [buffered.data for buffered in sampled_data[0]] + assert "d" in data + + # Update meta, set first data's group to 1 + first: BufferedData = buffer.storage[0] + buffer.update(first.index, first.data, {"group": 1}) + sampled_data = buffer.sample(2, groupby="group") + assert len(sampled_data) == 2 + + # Delete last record, each group will only have one record + last: BufferedData = buffer.storage[-1] + buffer.delete(last.index) + sampled_data = buffer.sample(2, groupby="group") + assert len(sampled_data) == 2 + + +@pytest.mark.unittest +def test_rolling_window(): + buffer = DequeBuffer(size=10) + for i in range(10): + buffer.push(i) + sampled_data = buffer.sample(10, rolling_window=3) + assert len(sampled_data) == 10 + + # Test data independence + buffer = DequeBuffer(size=2) + for i in range(2): + buffer.push({"key": i}) + sampled_data = buffer.sample(2, rolling_window=3) + assert len(sampled_data) == 2 + group_long = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1] + group_short = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1] + + # Modify the second value + group_long[1].data["key"] = 10 + assert group_short[0].data["key"] == 1 + + +@pytest.mark.unittest +def test_import_export(): + buffer = DequeBuffer(size=10) + data_with_meta = [(i, {}) for i in range(10)] + buffer.import_data(data_with_meta) + assert buffer.count() == 10 + + sampled_data = buffer.export_data() + assert len(sampled_data) == 10 + + +@pytest.mark.unittest +def test_dataset(): + buffer = DequeBuffer(size=10) + for i in range(10): + buffer.push(i) + dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch) + for batch in dataloader: + assert len(batch) in [4, 6] diff --git a/ding/worker/buffer/tests/test_middleware.py b/ding/worker/buffer/tests/test_middleware.py new file mode 100644 index 0000000..f306b9c --- /dev/null +++ b/ding/worker/buffer/tests/test_middleware.py @@ -0,0 +1,155 @@ +import pytest +import torch +from ding.worker.buffer import DequeBuffer +from ding.worker.buffer.middleware import clone_object, use_time_check, staleness_check +from ding.worker.buffer.middleware import PriorityExperienceReplay, group_sample +from ding.worker.buffer.middleware.padding import padding + + +@pytest.mark.unittest +def test_clone_object(): + buffer = DequeBuffer(size=10).use(clone_object()) + + # Store a dict, a list, a tensor + arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])] + for o in arr: + buffer.push(o) + + # Modify it + for item in buffer.sample(len(arr)): + item = item.data + if isinstance(item, dict): + item["key"] = "v2" + elif isinstance(item, list): + item.append("b") + elif isinstance(item, torch.Tensor): + item[0] = 3 + else: + raise Exception("Unexpected type") + + # Resample it, and check their values + for item in buffer.sample(len(arr)): + item = item.data + if isinstance(item, dict): + assert item["key"] == "v1" + elif isinstance(item, list): + assert len(item) == 1 + elif isinstance(item, torch.Tensor): + assert item[0] == 1 + else: + raise Exception("Unexpected type") + + +def get_data(): + return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'} + + +@pytest.mark.unittest +def test_use_time_check(): + N = 6 + buffer = DequeBuffer(size=10) + buffer.use(use_time_check(buffer, max_use=2)) + + for _ in range(N): + buffer.push(get_data()) + + for _ in range(2): + data = buffer.sample(size=N, replace=False) + assert len(data) == N + with pytest.raises(ValueError): + buffer.sample(size=1, replace=False) + + +@pytest.mark.unittest +def test_staleness_check(): + N = 6 + buffer = DequeBuffer(size=10) + buffer.use(staleness_check(buffer, max_staleness=10)) + + with pytest.raises(AssertionError): + buffer.push(get_data()) + for _ in range(N): + buffer.push(get_data(), meta={'train_iter_data_collected': 0}) + data = buffer.sample(size=N, replace=False, train_iter_sample_data=9) + assert len(data) == N + data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) # edge case + assert len(data) == N + for _ in range(2): + buffer.push(get_data(), meta={'train_iter_data_collected': 5}) + assert buffer.count() == 8 + with pytest.raises(ValueError): + data = buffer.sample(size=N, replace=False, train_iter_sample_data=11) + assert buffer.count() == 2 + + +@pytest.mark.unittest +def test_priority(): + N = 5 + buffer = DequeBuffer(size=10) + buffer.use(PriorityExperienceReplay(buffer, buffer_size=10, IS_weight=True)) + for _ in range(N): + buffer.push(get_data()) + assert buffer.count() == N + for _ in range(N): + buffer.push(get_data(), meta={'priority': 2.0}) + assert buffer.count() == N + N + data = buffer.sample(size=N + N, replace=False) + assert len(data) == N + N + for item in data: + meta = item.meta + assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS'])) + meta['priority'] = 3.0 + for item in data: + data, index, meta = item.data, item.index, item.meta + buffer.update(index, data, meta) + data = buffer.sample(size=1) + assert data[0].meta['priority'] == 3.0 + buffer.delete(data[0].index) + assert buffer.count() == N + N - 1 + buffer.clear() + assert buffer.count() == 0 + + +@pytest.mark.unittest +def test_padding(): + buffer = DequeBuffer(size=10) + buffer.use(padding()) + for i in range(10): + buffer.push(i, {"group": i & 5}) # [3,3,2,2] + sampled_data = buffer.sample(4, groupby="group") + assert len(sampled_data) == 4 + for grouped_data in sampled_data: + assert len(grouped_data) == 3 + + +@pytest.mark.unittest +def test_group_sample(): + buffer = DequeBuffer(size=10) + buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True)) + for i in range(4): + buffer.push(i, {"episode": 0}) + for i in range(6): + buffer.push(i, {"episode": 1}) + sampled_data = buffer.sample(2, groupby="episode") + assert len(sampled_data) == 2 + + def check_group0(grouped_data): + # In group0 should find only last record with data as None + n_none = 0 + for item in grouped_data: + if item.data is None: + n_none += 1 + assert n_none == 1 + + def check_group1(grouped_data): + # In group1 every record should have data and meta + for item in grouped_data: + assert item.data is not None + + for grouped_data in sampled_data: + assert len(grouped_data) == 5 + meta = grouped_data[0].meta + if meta and "episode" in meta and meta["episode"] == 1: + check_group1(grouped_data) + else: + check_group0(grouped_data) diff --git a/ding/worker/buffer/utils/__init__.py b/ding/worker/buffer/utils/__init__.py new file mode 100644 index 0000000..495c5f7 --- /dev/null +++ b/ding/worker/buffer/utils/__init__.py @@ -0,0 +1 @@ +from .fast_copy import FastCopy, fastcopy diff --git a/ding/worker/buffer/utils/fast_copy.py b/ding/worker/buffer/utils/fast_copy.py new file mode 100644 index 0000000..865c68d --- /dev/null +++ b/ding/worker/buffer/utils/fast_copy.py @@ -0,0 +1,58 @@ +import torch +import numpy as np +from typing import Any, List +from ding.worker.buffer.buffer import BufferedData + + +class FastCopy: + """ + The idea of this class comes from this article + https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. + We use recursive calls to copy each object that needs to be copied, which will be 5x faster + than copy.deepcopy. + """ + + def __init__(self): + dispatch = {} + dispatch[list] = self._copy_list + dispatch[dict] = self._copy_dict + dispatch[torch.Tensor] = self._copy_tensor + dispatch[np.ndarray] = self._copy_ndarray + dispatch[BufferedData] = self._copy_buffereddata + self.dispatch = dispatch + + def _copy_list(self, l: List) -> dict: + ret = l.copy() + for idx, item in enumerate(ret): + cp = self.dispatch.get(type(item)) + if cp is not None: + ret[idx] = cp(item) + return ret + + def _copy_dict(self, d: dict) -> dict: + ret = d.copy() + for key, value in ret.items(): + cp = self.dispatch.get(type(value)) + if cp is not None: + ret[key] = cp(value) + + return ret + + def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor: + return t.clone() + + def _copy_ndarray(self, a: np.ndarray) -> np.ndarray: + return np.copy(a) + + def _copy_buffereddata(self, d: BufferedData) -> BufferedData: + return BufferedData(data=self.copy(d.data), index=d.index, meta=self.copy(d.meta)) + + def copy(self, sth: Any) -> Any: + cp = self.dispatch.get(type(sth)) + if cp is None: + return sth + else: + return cp(sth) + + +fastcopy = FastCopy() diff --git a/dizoo/atari/config/serial/pong/pong_dqn_config.py b/dizoo/atari/config/serial/pong/pong_dqn_config.py index 2f7735d..df07e39 100644 --- a/dizoo/atari/config/serial/pong/pong_dqn_config.py +++ b/dizoo/atari/config/serial/pong/pong_dqn_config.py @@ -51,6 +51,7 @@ pong_dqn_create_config = dict( ), env_manager=dict(type='subprocess'), policy=dict(type='dqn'), + # replay_buffer=dict(type='deque'), ) pong_dqn_create_config = EasyDict(pong_dqn_create_config) create_config = pong_dqn_create_config diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py new file mode 100644 index 0000000..04dcb7f --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py @@ -0,0 +1,78 @@ +from easydict import EasyDict +from ding.entry import serial_pipeline + +nstep = 3 +lunarlander_dqn_default_config = dict( + exp_name='lunarlander_dqn_priority', + env=dict( + # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' + manager=dict(shared_memory=True, ), + # Env number respectively for collector and evaluator. + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=200, + ), + policy=dict( + # Whether to use cuda for network. + cuda=False, + priority=True, + priority_IS_weight=False, + model=dict( + obs_shape=8, + action_shape=4, + encoder_hidden_size_list=[512, 64], + # Whether to use dueling head. + dueling=True, + ), + # Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # How many steps in td error. + nstep=nstep, + # learn_mode config + learn=dict( + update_per_collect=10, + batch_size=64, + learning_rate=0.001, + # Frequency of target network update. + target_update_freq=100, + ), + # collect_mode config + collect=dict( + # You can use either "n_sample" or "n_episode" in collector.collect. + # Get "n_sample" samples per collect. + n_sample=64, + # Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + ), + # command_mode config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # Decay type. Support ['exp', 'linear']. + type='exp', + start=0.95, + end=0.1, + decay=50000, + ), + replay_buffer=dict(replay_buffer_size=100000, priority=True, priority_IS_weight=False) + ), + ), +) +lunarlander_dqn_default_config = EasyDict(lunarlander_dqn_default_config) +main_config = lunarlander_dqn_default_config + +lunarlander_dqn_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqn'), + replay_buffer=dict(type='deque'), +) +lunarlander_dqn_create_config = EasyDict(lunarlander_dqn_create_config) +create_config = lunarlander_dqn_create_config + +if __name__ == "__main__": + serial_pipeline([main_config, create_config], seed=0) diff --git a/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py b/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py new file mode 100644 index 0000000..561ea89 --- /dev/null +++ b/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py @@ -0,0 +1,80 @@ +import os +import gym +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, DequeBufferWrapper +from ding.envs import BaseEnvManager, DingEnvWrapper +from ding.policy import DQNPolicy +from ding.model import DQN +from ding.utils import set_pkg_seed +from ding.rl_utils import get_epsilon_greedy_fn +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config + + +# Get DI-engine form env class +def wrapped_cartpole_env(): + return DingEnvWrapper(gym.make('CartPole-v0')) + + +def main(cfg, seed=0): + cfg = compile_config( + cfg, + BaseEnvManager, + DQNPolicy, + BaseLearner, + SampleSerialCollector, + InteractionSerialEvaluator, + DequeBufferWrapper, + save_cfg=True + ) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager) + evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager) + + # Set random seed for all package and instance + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + + # Set up RL Policy + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + # Set up collection, training and evaluation utilities + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = SampleSerialCollector( + cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + replay_buffer = DequeBufferWrapper(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) + + # Set up other modules, etc. epsilon greedy + eps_cfg = cfg.policy.other.eps + epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) + + # Training & Evaluation loop + while True: + # Evaluating at the beginning and with specific frequency + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Update other modules + eps = epsilon_greedy(collector.envstep) + # Sampling data from environments + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + # Training + for i in range(cfg.policy.learn.update_per_collect): + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + break + learner.train(train_data, collector.envstep) + + +if __name__ == "__main__": + main(cartpole_dqn_config) -- GitLab