From 9f8abe917847bbd308b9c2a8ac90e08f252f2ade Mon Sep 17 00:00:00 2001 From: Chengmo Date: Wed, 17 Jun 2020 11:41:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86Save=E5=8F=8AInfer?= =?UTF-8?q?=E9=98=B6=E6=AE=B5=E7=9A=84=E8=8B=A5=E5=B9=B2bug=20(#95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix setup * fix bug for dssm reader * fix net bug at PY3 for afm * fix multi cards with files * fix ctr * add validation * add validation * add validation * fix compile * fix ci * fix user define runner * fix gnn reader at PY3 * fix fast yaml config at PY3 Co-authored-by: tangwei --- core/trainers/framework/network.py | 5 + core/trainers/framework/runner.py | 90 +++++++++++++++--- core/trainers/framework/startup.py | 18 ++++ core/trainers/general_trainer.py | 10 +- core/utils/validation.py | 83 +++++++++++------ .../match/dssm/synthetic_evaluate_reader.py | 4 +- models/match/dssm/synthetic_reader.py | 8 +- models/rank/afm/model.py | 2 +- models/rank/dnn/config.yaml | 4 +- models/recall/fasttext/config.yaml | 92 +++++++++---------- models/recall/gnn/evaluate_reader.py | 3 +- models/recall/gnn/reader.py | 3 +- run.py | 48 +++++++--- 13 files changed, 257 insertions(+), 113 deletions(-) diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 71f2a4e7..74d2c975 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -94,6 +94,7 @@ class SingleNetwork(NetworkBase): context["model"][model_dict["name"]]["model"] = model context["model"][model_dict["name"]][ "default_main_program"] = train_program.clone() + context["model"][model_dict["name"]]["compiled_program"] = None context["dataset"] = {} for dataset in context["env"]["dataset"]: @@ -149,6 +150,7 @@ class PSNetwork(NetworkBase): context["model"][model_dict["name"]]["model"] = model context["model"][model_dict["name"]]["default_main_program"] = context[ "fleet"].main_program.clone() + context["model"][model_dict["name"]]["compiled_program"] = None if context["fleet"].is_server(): self._server(context) @@ -245,6 +247,8 @@ class PslibNetwork(NetworkBase): context["model"][model_dict["name"]]["model"] = model context["model"][model_dict["name"]][ "default_main_program"] = train_program.clone() + context["model"][model_dict["name"]][ + "compile_program"] = None if context["fleet"].is_server(): self._server(context) @@ -314,6 +318,7 @@ class CollectiveNetwork(NetworkBase): context["model"][model_dict["name"]]["model"] = model context["model"][model_dict["name"]][ "default_main_program"] = train_program + context["model"][model_dict["name"]]["compiled_program"] = None context["dataset"] = {} for dataset in context["env"]["dataset"]: diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 46a7a1c4..fb2c87a9 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -50,6 +50,7 @@ class RunnerBase(object): reader_name = model_dict["dataset_name"] model_name = model_dict["name"] model_class = context["model"][model_dict["name"]]["model"] + fetch_vars = [] fetch_alias = [] fetch_period = int( @@ -89,19 +90,7 @@ class RunnerBase(object): def _executor_dataloader_train(self, model_dict, context): model_name = model_dict["name"] model_class = context["model"][model_dict["name"]]["model"] - - if context["is_infer"]: - program = context["model"][model_name]["main_program"] - elif context["is_fleet"]: - if context["fleet_mode"].upper() == "PS": - program = self._get_ps_program(model_dict, context) - elif context["fleet_mode"].upper() == "COLLECTIVE": - program = context["model"][model_name]["main_program"] - elif not context["is_fleet"]: - if context["device"].upper() == "CPU": - program = self._get_single_cpu_program(model_dict, context) - elif context["device"].upper() == "GPU": - program = self._get_single_gpu_program(model_dict, context) + program = self._get_dataloader_program(model_dict, context) reader_name = model_dict["dataset_name"] fetch_vars = [] @@ -143,6 +132,24 @@ class RunnerBase(object): except fluid.core.EOFException: reader.reset() + def _get_dataloader_program(self, model_dict, context): + model_name = model_dict["name"] + if context["model"][model_name]["compiled_program"] == None: + if context["is_infer"]: + program = context["model"][model_name]["main_program"] + elif context["is_fleet"]: + if context["fleet_mode"].upper() == "PS": + program = self._get_ps_program(model_dict, context) + elif context["fleet_mode"].upper() == "COLLECTIVE": + program = context["model"][model_name]["main_program"] + elif not context["is_fleet"]: + if context["device"].upper() == "CPU": + program = self._get_single_cpu_program(model_dict, context) + elif context["device"].upper() == "GPU": + program = self._get_single_gpu_program(model_dict, context) + context["model"][model_name]["compiled_program"] = program + return context["model"][model_name]["compiled_program"] + def _get_strategy(self, model_dict, context): _build_strategy = fluid.BuildStrategy() _exe_strategy = fluid.ExecutionStrategy() @@ -218,12 +225,17 @@ class RunnerBase(object): def save(self, epoch_id, context, is_fleet=False): def need_save(epoch_id, epoch_interval, is_last=False): + name = "runner." + context["runner_name"] + "." + total_epoch = int(envs.get_global_env(name + "epochs", 1)) + if epoch_id + 1 == total_epoch: + is_last = True + if is_last: return True if epoch_id == -1: return False - return epoch_id % epoch_interval == 0 + return (epoch_id + 1) % epoch_interval == 0 def save_inference_model(): name = "runner." + context["runner_name"] + "." @@ -415,3 +427,53 @@ class PslibRunner(RunnerBase): """ context["status"] = "terminal_pass" + + +class SingleInferRunner(RunnerBase): + def __init__(self, context): + print("Running SingleInferRunner.") + pass + + def run(self, context): + self._dir_check(context) + + for index, epoch_name in enumerate(self.epoch_model_name_list): + for model_dict in context["phases"]: + self._load(context, model_dict, + self.epoch_model_path_list[index]) + begin_time = time.time() + self._run(context, model_dict) + end_time = time.time() + seconds = end_time - begin_time + print("Infer {} of {} done, use time: {}".format(model_dict[ + "name"], epoch_name, seconds)) + context["status"] = "terminal_pass" + + def _load(self, context, model_dict, model_path): + if model_path is None or model_path == "": + return + print("load persistables from", model_path) + + with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): + train_prog = context["model"][model_dict["name"]]["main_program"] + startup_prog = context["model"][model_dict["name"]][ + "startup_program"] + with fluid.program_guard(train_prog, startup_prog): + fluid.io.load_persistables( + context["exe"], model_path, main_program=train_prog) + + def _dir_check(self, context): + dirname = envs.get_global_env( + "runner." + context["runner_name"] + ".init_model_path", None) + self.epoch_model_path_list = [] + self.epoch_model_name_list = [] + + for file in os.listdir(dirname): + file_path = os.path.join(dirname, file) + if os.path.isdir(file_path): + self.epoch_model_path_list.append(file_path) + self.epoch_model_name_list.append(file) + + if len(self.epoch_model_path_list) == 0: + self.epoch_model_path_list.append(dirname) + self.epoch_model_name_list.append(dirname) diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index 82e24472..2687dcdd 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -101,3 +101,21 @@ class CollectiveStartup(StartupBase): context["exe"].run(startup_prog) self.load(context, True) context["status"] = "train_pass" + + +class SingleInferStartup(StartupBase): + def __init__(self, context): + print("Running SingleInferStartup.") + pass + + def startup(self, context): + for model_dict in context["phases"]: + with fluid.scope_guard(context["model"][model_dict["name"]][ + "scope"]): + train_prog = context["model"][model_dict["name"]][ + "main_program"] + startup_prog = context["model"][model_dict["name"]][ + "startup_program"] + with fluid.program_guard(train_prog, startup_prog): + context["exe"].run(startup_prog) + context["status"] = "train_pass" diff --git a/core/trainers/general_trainer.py b/core/trainers/general_trainer.py index ded99e1c..4936cc8b 100644 --- a/core/trainers/general_trainer.py +++ b/core/trainers/general_trainer.py @@ -101,7 +101,9 @@ class GeneralTrainer(Trainer): startup_class = envs.lazy_instance_by_fliename(startup_class_path, "Startup")(context) else: - if self.engine == EngineMode.SINGLE: + if self.engine == EngineMode.SINGLE and context["is_infer"]: + startup_class_name = "SingleInferStartup" + elif self.engine == EngineMode.SINGLE and not context["is_infer"]: startup_class_name = "SingleStartup" elif self.fleet_mode == FleetMode.PS or self.fleet_mode == FleetMode.PSLIB: startup_class_name = "PSStartup" @@ -117,12 +119,14 @@ class GeneralTrainer(Trainer): def runner(self, context): runner_class_path = envs.get_global_env( - self.runner_env_name + ".runner_class_paht", default_value=None) + self.runner_env_name + ".runner_class_path", default_value=None) if runner_class_path: runner_class = envs.lazy_instance_by_fliename(runner_class_path, "Runner")(context) else: - if self.engine == EngineMode.SINGLE: + if self.engine == EngineMode.SINGLE and context["is_infer"]: + runner_class_name = "SingleInferRunner" + elif self.engine == EngineMode.SINGLE and not context["is_infer"]: runner_class_name = "SingleRunner" elif self.fleet_mode == FleetMode.PSLIB: runner_class_name = "PslibRunner" diff --git a/core/utils/validation.py b/core/utils/validation.py index a2911cb3..7448daf8 100755 --- a/core/utils/validation.py +++ b/core/utils/validation.py @@ -16,16 +16,25 @@ from paddlerec.core.utils import envs class ValueFormat: - def __init__(self, value_type, value, value_handler): + def __init__(self, value_type, value, value_handler, required=False): self.value_type = value_type - self.value = value self.value_handler = value_handler + self.value = value + self.required = required def is_valid(self, name, value): - ret = self.is_type_valid(name, value) + + if not self.value_type: + ret = True + else: + ret = self.is_type_valid(name, value) + if not ret: return ret + if not self.value or not self.value_handler: + return True + ret = self.is_value_valid(name, value) return ret @@ -33,21 +42,21 @@ class ValueFormat: if self.value_type == "int": if not isinstance(value, int): print("\nattr {} should be int, but {} now\n".format( - name, self.value_type)) + name, type(value))) return False return True elif self.value_type == "str": if not isinstance(value, str): print("\nattr {} should be str, but {} now\n".format( - name, self.value_type)) + name, type(value))) return False return True elif self.value_type == "strs": if not isinstance(value, list): print("\nattr {} should be list(str), but {} now\n".format( - name, self.value_type)) + name, type(value))) return False for v in value: if not isinstance(v, str): @@ -56,10 +65,29 @@ class ValueFormat: return False return True + elif self.value_type == "dict": + if not isinstance(value, dict): + print("\nattr {} should be str, but {} now\n".format( + name, type(value))) + return False + return True + + elif self.value_type == "dicts": + if not isinstance(value, list): + print("\nattr {} should be list(dist), but {} now\n".format( + name, type(value))) + return False + for v in value: + if not isinstance(v, dict): + print("\nattr {} should be list(dist), but list({}) now\n". + format(name, type(v))) + return False + return True + elif self.value_type == "ints": if not isinstance(value, list): print("\nattr {} should be list(int), but {} now\n".format( - name, self.value_type)) + name, type(value))) return False for v in value: if not isinstance(v, int): @@ -74,7 +102,7 @@ class ValueFormat: return False def is_value_valid(self, name, value): - ret = self.value_handler(value) + ret = self.value_handler(name, value, self.value) return ret @@ -112,38 +140,35 @@ def le_value_handler(name, value, values): def register(): validations = {} - validations["train.workspace"] = ValueFormat("str", None, eq_value_handler) - validations["train.device"] = ValueFormat("str", ["cpu", "gpu"], - in_value_handler) - validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler) - validations["train.engine"] = ValueFormat( - "str", ["train", "infer", "local_cluster_train", "cluster_train"], - in_value_handler) - - requires = ["workspace", "dataset", "mode", "runner", "phase"] - return validations, requires + validations["workspace"] = ValueFormat("str", None, None, True) + validations["mode"] = ValueFormat(None, None, None, True) + validations["runner"] = ValueFormat("dicts", None, None, True) + validations["phase"] = ValueFormat("dicts", None, None, True) + validations["hyper_parameters"] = ValueFormat("dict", None, None, False) + return validations def yaml_validation(config): - all_checkers, require_checkers = register() + all_checkers = register() + + require_checkers = [] + for name, checker in all_checkers.items(): + if checker.required: + require_checkers.append(name) _config = envs.load_yaml(config) - flattens = envs.flatten_environs(_config) for required in require_checkers: - if required not in flattens.keys(): + if required not in _config.keys(): print("\ncan not find {} in yaml, which is required\n".format( required)) return False - for name, flatten in flattens.items(): + for name, value in _config.items(): checker = all_checkers.get(name, None) - - if not checker: - continue - - ret = checker.is_valid(name, flattens) - if not ret: - return False + if checker: + ret = checker.is_valid(name, value) + if not ret: + return False return True diff --git a/models/match/dssm/synthetic_evaluate_reader.py b/models/match/dssm/synthetic_evaluate_reader.py index 29c14ce5..3d8413cd 100755 --- a/models/match/dssm/synthetic_evaluate_reader.py +++ b/models/match/dssm/synthetic_evaluate_reader.py @@ -30,8 +30,8 @@ class Reader(ReaderBase): This function needs to be implemented by the user, based on data format """ features = line.rstrip('\n').split('\t') - query = map(float, features[0].split(',')) - pos_doc = map(float, features[1].split(',')) + query = [float(feature) for feature in features[0].split(',')] + pos_doc = [float(feature) for feature in features[1].split(',')] feature_names = ['query', 'doc_pos'] yield zip(feature_names, [query] + [pos_doc]) diff --git a/models/match/dssm/synthetic_reader.py b/models/match/dssm/synthetic_reader.py index d5641dc1..e358e04c 100755 --- a/models/match/dssm/synthetic_reader.py +++ b/models/match/dssm/synthetic_reader.py @@ -31,13 +31,15 @@ class Reader(ReaderBase): This function needs to be implemented by the user, based on data format """ features = line.rstrip('\n').split('\t') - query = map(float, features[0].split(',')) - pos_doc = map(float, features[1].split(',')) + query = [float(feature) for feature in features[0].split(',')] + pos_doc = [float(feature) for feature in features[1].split(',')] feature_names = ['query', 'doc_pos'] neg_docs = [] for i in range(len(features) - 2): feature_names.append('doc_neg_' + str(i)) - neg_docs.append(map(float, features[i + 2].split(','))) + neg_docs.append([ + float(feature) for feature in features[i + 2].split(',') + ]) yield zip(feature_names, [query] + [pos_doc] + neg_docs) diff --git a/models/rank/afm/model.py b/models/rank/afm/model.py index c4bb0ef9..2ca2d5db 100644 --- a/models/rank/afm/model.py +++ b/models/rank/afm/model.py @@ -133,7 +133,7 @@ class Model(ModelBase): attention_h) # (batch_size * (num_field*(num_field-1)/2)) * 1 attention_out = fluid.layers.softmax( attention_out) # (batch_size * (num_field*(num_field-1)/2)) * 1 - num_interactions = self.num_field * (self.num_field - 1) / 2 + num_interactions = int(self.num_field * (self.num_field - 1) / 2) attention_out = fluid.layers.reshape( attention_out, shape=[-1, num_interactions, diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index 539cbb00..a5032970 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -63,7 +63,7 @@ runner: device: cpu save_checkpoint_interval: 2 # save model interval of epochs save_inference_interval: 4 # save inference - save_checkpoint_path: "increment" # save checkpoint path + save_checkpoint_path: "increment_dnn" # save checkpoint path save_inference_path: "inference" # save inference path save_inference_feed_varnames: [] # feed vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference @@ -77,7 +77,7 @@ runner: epochs: 1 # device to run training or infer device: cpu - init_model_path: "increment/0" # load model path + init_model_path: "increment_dnn" # load model path phases: [phase2] # runner will run all the phase in each epoch diff --git a/models/recall/fasttext/config.yaml b/models/recall/fasttext/config.yaml index 87da67f7..c9725ab8 100644 --- a/models/recall/fasttext/config.yaml +++ b/models/recall/fasttext/config.yaml @@ -11,23 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -workspace: "paddlerec.models.recall.fasttext" +workspace: "paddlerec.models.recall.fasttext" # list of dataset dataset: -- name: dataset_train # name of dataset to distinguish different datasets - batch_size: 10 - type: DataLoader # or QueueDataset - data_path: "{workspace}/data/train" - word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" - word_ngrams_path: "{workspace}/data/dict/word_ngrams_id.txt" - data_converter: "{workspace}/reader.py" -- name: dataset_infer # name - batch_size: 10 - type: DataLoader # or QueueDataset - data_path: "{workspace}/data/test" - word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" - data_converter: "{workspace}/evaluate_reader.py" + - name: dataset_train # name of dataset to distinguish different datasets + batch_size: 10 + type: DataLoader # or QueueDataset + data_path: "{workspace}/data/train" + word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" + word_ngrams_path: "{workspace}/data/dict/word_ngrams_id.txt" + data_converter: "{workspace}/reader.py" + - name: dataset_infer # name + batch_size: 10 + type: DataLoader # or QueueDataset + data_path: "{workspace}/data/test" + word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" + data_converter: "{workspace}/evaluate_reader.py" hyper_parameters: optimizer: @@ -45,41 +45,41 @@ hyper_parameters: max_n: 5 # select runner by name -mode: train_runner +mode: [train_runner, infer_runner] # config of each runner. # runner is a kind of paddle training class, which wraps the train/infer process. runner: -- name: train_runner - class: train - # num of epochs - epochs: 2 - # device to run training or infer - device: cpu - save_checkpoint_interval: 1 # save model interval of epochs - save_inference_interval: 1 # save inference - save_checkpoint_path: "increment" # save checkpoint path - save_inference_path: "inference" # save inference path - save_inference_feed_varnames: [] # feed vars of save inference - save_inference_fetch_varnames: [] # fetch vars of save inference - init_model_path: "" # load model path - print_interval: 1 -- name: infer_runner - class: single_infer - # num of epochs - epochs: 1 - # device to run training or infer - device: cpu - init_model_path: "increment/0" # load model path - print_interval: 1 + - name: train_runner + class: train + # num of epochs + epochs: 2 + # device to run training or infer + device: cpu + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 1 # save inference + save_checkpoint_path: "increment" # save checkpoint path + save_inference_path: "inference" # save inference path + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + init_model_path: "" # load model path + print_interval: 10 + phases: [phase1] + - name: infer_runner + class: infer + # device to run training or infer + device: cpu + init_model_path: "increment/0" # load model path + print_interval: 1 + phases: [phase2] # runner will run all the phase in each epoch phase: -- name: phase1 - model: "{workspace}/model.py" # user-defined model - dataset_name: dataset_train # select dataset by name - thread_num: 1 - gradient_scale_strategy: 1 -# - name: phase2 -# model: "{workspace}/model.py" # user-defined model -# dataset_name: dataset_infer # select dataset by name -# thread_num: 1 + - name: phase1 + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_train # select dataset by name + thread_num: 1 + gradient_scale_strategy: 1 + - name: phase2 + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_infer # select dataset by name + thread_num: 1 diff --git a/models/recall/gnn/evaluate_reader.py b/models/recall/gnn/evaluate_reader.py index 864d0047..6ea0efc9 100755 --- a/models/recall/gnn/evaluate_reader.py +++ b/models/recall/gnn/evaluate_reader.py @@ -36,7 +36,8 @@ class Reader(ReaderBase): for line in fin: line = line.strip().split('\t') res.append( - tuple([map(int, line[0].split(',')), int(line[1])])) + tuple([[int(l) + for l in line[0].split(',')], int(line[1])])) return res def make_data(self, cur_batch, batch_size): diff --git a/models/recall/gnn/reader.py b/models/recall/gnn/reader.py index 5ea4cfee..9e382aea 100755 --- a/models/recall/gnn/reader.py +++ b/models/recall/gnn/reader.py @@ -35,7 +35,8 @@ class Reader(ReaderBase): for line in fin: line = line.strip().split('\t') res.append( - tuple([map(int, line[0].split(',')), int(line[1])])) + tuple([[int(l) + for l in line[0].split(',')], int(line[1])])) return res def make_data(self, cur_batch, batch_size): diff --git a/run.py b/run.py index 1abefd5f..699d48f9 100755 --- a/run.py +++ b/run.py @@ -17,6 +17,7 @@ import subprocess import sys import argparse import tempfile +import warnings import copy from paddlerec.core.factory import TrainerFactory @@ -220,6 +221,14 @@ def single_infer_engine(args): device_class = ".".join(["runner", mode, "device"]) selected_gpus_class = ".".join(["runner", mode, "selected_gpus"]) + epochs_class = ".".join(["runner", mode, "epochs"]) + epochs = run_extras.get(epochs_class, 1) + if epochs > 1: + warnings.warn( + "It makes no sense to predict the same model for multiple epochs", + category=UserWarning, + stacklevel=2) + trainer = run_extras.get(trainer_class, "GeneralTrainer") fleet_mode = run_extras.get(fleet_class, "ps") device = run_extras.get(device_class, "cpu") @@ -372,9 +381,19 @@ def local_cluster_engine(args): envs.workspace_adapter_by_specific(path, workspace) for path in datapaths ] + all_workers = [len(os.listdir(path)) for path in datapaths] all_workers.append(workers) - return min(all_workers) + max_worker_num = min(all_workers) + + if max_worker_num >= workers: + return workers + + print( + "phases do not have enough datas for training, set worker/gpu cards num from {} to {}". + format(workers, max_worker_num)) + + return max_worker_num from paddlerec.core.engine.local_cluster import LocalClusterEngine @@ -397,24 +416,31 @@ def local_cluster_engine(args): worker_num = run_extras.get(worker_class, 1) server_num = run_extras.get(server_class, 1) - max_worker_num = get_worker_num(run_extras, worker_num) - - if max_worker_num < worker_num: - print( - "has phase do not have enough datas for training, set worker num from {} to {}". - format(worker_num, max_worker_num)) - worker_num = max_worker_num device = device.upper() fleet_mode = fleet_mode.upper() + cluster_envs = {} + + # Todo: delete follow hard code when paddle support ps-gpu. + if device == "CPU": + fleet_mode = "PS" + elif device == "GPU": + fleet_mode = "COLLECTIVE" + if fleet_mode == "PS" and device != "CPU": + raise ValueError("PS can not be used with GPU") + if fleet_mode == "COLLECTIVE" and device != "GPU": - raise ValueError("COLLECTIVE can not be used with GPU") + raise ValueError("COLLECTIVE can not be used without GPU") - cluster_envs = {} + if fleet_mode == "PS": + worker_num = get_worker_num(run_extras, worker_num) - if device == "GPU": + if fleet_mode == "COLLECTIVE": cluster_envs["selected_gpus"] = selected_gpus + gpus = selected_gpus.split(",") + gpu_num = get_worker_num(run_extras, len(gpus)) + cluster_envs["selected_gpus"] = ','.join(gpus[:gpu_num]) cluster_envs["server_num"] = server_num cluster_envs["worker_num"] = worker_num -- GitLab