From 35241df379105c591e62ebb04682a65eef35f861 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sat, 1 Jan 2022 20:25:44 +0800 Subject: [PATCH] feature(nyz): add vim in docker and add multiple seed cli --- Dockerfile.base | 2 +- ding/config/config.py | 4 + ding/entry/cli.py | 186 ++++++++++++++++++++++-------------------- 3 files changed, 104 insertions(+), 88 deletions(-) diff --git a/Dockerfile.base b/Dockerfile.base index f4af2d7..614e418 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime WORKDIR /ding RUN apt update \ - && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git gcc \g++ make locales -y \ + && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make locales -y \ && apt clean \ && rm -rf /var/cache/apt/* \ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \ diff --git a/ding/config/config.py b/ding/config/config.py index a27f0e5..a6c5447 100644 --- a/ding/config/config.py +++ b/ding/config/config.py @@ -414,6 +414,8 @@ def compile_config( cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode if 'exp_name' not in cfg: cfg.exp_name = 'default_experiment' + # add seed as suffix of exp_name + cfg.exp_name = cfg.exp_name + '_seed{}'.format(seed) if save_cfg: if not os.path.exists(cfg.exp_name): try: @@ -524,6 +526,8 @@ def compile_config_parallel( cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator) # seed cfg.seed = seed + # add seed as suffix of exp_name + cfg.exp_name = cfg.exp_name + '_seed{}'.format(seed) if save_cfg: save_config(cfg, save_path) diff --git a/ding/entry/cli.py b/ding/entry/cli.py index e55ce67..a84ff53 100644 --- a/ding/entry/cli.py +++ b/ding/entry/cli.py @@ -1,5 +1,7 @@ +from typing import List, Union import click from click.core import Context, Option +import numpy as np from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ from .predefined_config import get_predefined_config @@ -65,7 +67,8 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) '-s', '--seed', type=int, - default=0, + default=[0], + multiple=True, help='random generator seed(for all the possible package: random, numpy, torch and user env)' ) @click.option('-e', '--env', type=str, help='RL env name') @@ -117,7 +120,7 @@ def cli( # serial/eval mode: str, config: str, - seed: int, + seed: Union[int, List], env: str, policy: str, train_iter: int, @@ -155,89 +158,98 @@ def cli( from ..utils.profiler_helper import Profiler profiler = Profiler() profiler.profile(profile) - if mode == 'serial': - from .serial_entry import serial_pipeline - if config is None: - config = get_predefined_config(env, policy) - serial_pipeline(config, seed, max_iterations=train_iter) - elif mode == 'serial_onpolicy': - from .serial_entry_onpolicy import serial_pipeline_onpolicy - if config is None: - config = get_predefined_config(env, policy) - serial_pipeline_onpolicy(config, seed, max_iterations=train_iter) - elif mode == 'serial_sqil': - if config == 'lunarlander_sqil_config.py' or 'cartpole_sqil_config.py' or 'pong_sqil_config.py' \ - or 'spaceinvaders_sqil_config.py' or 'qbert_sqil_config.py': - from .serial_entry_sqil import serial_pipeline_sqil - if config is None: - config = get_predefined_config(env, policy) - expert_config = input("Enter the name of the config you used to generate your expert model: ") - serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter) - elif mode == 'serial_reward_model': - from .serial_entry_reward_model import serial_pipeline_reward_model - if config is None: - config = get_predefined_config(env, policy) - serial_pipeline_reward_model(config, seed, max_iterations=train_iter) - elif mode == 'serial_gail': - from .serial_entry_gail import serial_pipeline_gail - if config is None: - config = get_predefined_config(env, policy) - expert_config = input("Enter the name of the config you used to generate your expert model: ") - serial_pipeline_gail(config, expert_config, seed, max_iterations=train_iter, collect_data=True) - elif mode == 'serial_dqfd': - from .serial_entry_dqfd import serial_pipeline_dqfd - if config is None: - config = get_predefined_config(env, policy) - expert_config = input("Enter the name of the config you used to generate your expert model: ") - assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\ - + "the models used in q learning now; However, one should still type the DQFD config in this "\ - + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') - serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter) - elif mode == 'serial_trex': - from .serial_entry_trex import serial_pipeline_reward_model_trex - if config is None: - config = get_predefined_config(env, policy) - serial_pipeline_reward_model_trex(config, seed, max_iterations=train_iter) - elif mode == 'serial_trex_onpolicy': - from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy - if config is None: - config = get_predefined_config(env, policy) - serial_pipeline_reward_model_trex_onpolicy(config, seed, max_iterations=train_iter) - elif mode == 'parallel': - from .parallel_entry import parallel_pipeline - parallel_pipeline(config, seed, enable_total_log, disable_flask_log) - elif mode == 'dist': - from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \ - dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \ - dist_add_replicas, dist_delete_replicas, dist_restart_replicas - if module == 'config': - dist_prepare_config( - config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, learner_port, - collector_port - ) - elif module == 'coordinator': - dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log) - elif module == 'learner_aggregator': - dist_launch_learner_aggregator( - config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log - ) - elif module == 'collector': - dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log) - elif module == 'learner': - dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log) - elif module == 'spawn_learner': - dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log) - elif add in ['collector', 'learner']: - dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory) - elif delete in ['collector', 'learner']: - dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace) - elif restart in ['collector', 'learner']: - dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name) - else: - raise Exception - elif mode == 'eval': - from .application_entry import eval - if config is None: - config = get_predefined_config(env, policy) - eval(config, seed, load_path=load_path, replay_path=replay_path) + def run_single_pipeline(seed, config): + if mode == 'serial': + from .serial_entry import serial_pipeline + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline(config, seed, max_iterations=train_iter) + elif mode == 'serial_onpolicy': + from .serial_entry_onpolicy import serial_pipeline_onpolicy + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_onpolicy(config, seed, max_iterations=train_iter) + elif mode == 'serial_sqil': + if config == 'lunarlander_sqil_config.py' or 'cartpole_sqil_config.py' or 'pong_sqil_config.py' \ + or 'spaceinvaders_sqil_config.py' or 'qbert_sqil_config.py': + from .serial_entry_sqil import serial_pipeline_sqil + if config is None: + config = get_predefined_config(env, policy) + expert_config = input("Enter the name of the config you used to generate your expert model: ") + serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter) + elif mode == 'serial_reward_model': + from .serial_entry_reward_model import serial_pipeline_reward_model + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_reward_model(config, seed, max_iterations=train_iter) + elif mode == 'serial_gail': + from .serial_entry_gail import serial_pipeline_gail + if config is None: + config = get_predefined_config(env, policy) + expert_config = input("Enter the name of the config you used to generate your expert model: ") + serial_pipeline_gail(config, expert_config, seed, max_iterations=train_iter, collect_data=True) + elif mode == 'serial_dqfd': + from .serial_entry_dqfd import serial_pipeline_dqfd + if config is None: + config = get_predefined_config(env, policy) + expert_config = input("Enter the name of the config you used to generate your expert model: ") + assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\ + + "the models used in q learning now; However, one should still type the DQFD config in this "\ + + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') + serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter) + elif mode == 'serial_trex': + from .serial_entry_trex import serial_pipeline_reward_model_trex + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_reward_model_trex(config, seed, max_iterations=train_iter) + elif mode == 'serial_trex_onpolicy': + from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_reward_model_trex_onpolicy(config, seed, max_iterations=train_iter) + elif mode == 'parallel': + from .parallel_entry import parallel_pipeline + parallel_pipeline(config, seed, enable_total_log, disable_flask_log) + elif mode == 'dist': + from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \ + dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \ + dist_add_replicas, dist_delete_replicas, dist_restart_replicas + if module == 'config': + dist_prepare_config( + config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, + learner_port, collector_port + ) + elif module == 'coordinator': + dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log) + elif module == 'learner_aggregator': + dist_launch_learner_aggregator( + config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log + ) + + elif module == 'collector': + dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log) + elif module == 'learner': + dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log) + elif module == 'spawn_learner': + dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log) + elif add in ['collector', 'learner']: + dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory) + elif delete in ['collector', 'learner']: + dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace) + elif restart in ['collector', 'learner']: + dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name) + else: + raise Exception + elif mode == 'eval': + from .application_entry import eval + if config is None: + config = get_predefined_config(env, policy) + eval(config, seed, load_path=load_path, replay_path=replay_path) + + if isinstance(seed, (list, tuple)): + assert len(seed) > 0, "Please input at least 1 seed" + for s in seed: + run_single_pipeline(s, config) + else: + raise TypeError("invalid seed type: {}".format(type(seed))) -- GitLab