ctr_modul_trainer.py 20.2 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15

X
xiexionghang 已提交
16 17 18 19
import sys
import time
import json
import datetime
T
tangwei 已提交
20 21
import numpy as np

X
xiexionghang 已提交
22 23
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
24
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
X
xiexionghang 已提交
25

T
tangwei 已提交
26

27 28 29 30 31 32
from paddlerec.core.utils import fs as fs
from paddlerec.core.utils import util as util
from paddlerec.core.metrics.auc_metrics import AUCMetric
from paddlerec.core.modules.modul import build as model_basic
from paddlerec.core.utils import dataset
from paddlerec.core.trainer import Trainer
T
tangwei 已提交
33

T
tangwei 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

def wroker_numric_opt(value, env, opt):
    """
    numric count opt for workers
    Args:
        value: value for count
        env: mpi/gloo
        opt: count operator, SUM/MAX/MIN/AVG
    Return:
        count result
    """
    local_value = np.array([value])
    global_value = np.copy(local_value) * 0
    fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
    return global_value[0]


def worker_numric_sum(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "sum")


def worker_numric_avg(value, env="mpi"):
    """R
    """
    return worker_numric_sum(value, env) / fleet.worker_num()


def worker_numric_min(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "min")


def worker_numric_max(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "max")


T
tangwei 已提交
75
class CtrPaddleTrainer(Trainer):
X
xiexionghang 已提交
76 77
    """R
    """
T
tangwei 已提交
78

X
xiexionghang 已提交
79
    def __init__(self, config):
X
xiexionghang 已提交
80 81
        """R
        """
T
tangwei 已提交
82
        Trainer.__init__(self, config)
T
tangwei 已提交
83
        config['output_path'] = util.get_absolute_path(
X
xiexionghang 已提交
84
            config['output_path'], config['io']['afs'])
T
tangwei 已提交
85 86

        self.global_config = config
X
xiexionghang 已提交
87
        self._metrics = {}
T
tangwei 已提交
88

T
tangwei 已提交
89
        self._path_generator = util.PathGenerator({
X
xiexionghang 已提交
90
            'templates': [
X
xiexionghang 已提交
91 92 93 94 95 96 97 98 99
                {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
                {'name': 'xbox_delta_done', 'template': config['output_path'] + '/xbox_patch_done.txt'},
                {'name': 'xbox_base', 'template': config['output_path'] + '/xbox/{day}/base/'},
                {'name': 'xbox_delta', 'template': config['output_path'] + '/xbox/{day}/delta-{pass_id}/'},
                {'name': 'batch_model', 'template': config['output_path'] + '/batch_model/{day}/{pass_id}/'}
            ]
        })
        if 'path_generator' in config:
            self._path_generator.add_path_template(config['path_generator'])
T
tangwei 已提交
100

X
xiexionghang 已提交
101 102 103 104 105 106 107
        self.regist_context_processor('uninit', self.init)
        self.regist_context_processor('startup', self.startup)
        self.regist_context_processor('begin_day', self.begin_day)
        self.regist_context_processor('train_pass', self.train_pass)
        self.regist_context_processor('end_day', self.end_day)

    def init(self, context):
X
xiexionghang 已提交
108 109
        """R
        """
110 111 112
        role_maker = None
        if self.global_config.get('process_mode', 'mpi') == 'brilliant_cpu':
            afs_config = self.global_config['io']['afs']
T
tangwei 已提交
113
            role_maker = GeneralRoleMaker(
114 115 116 117
                hdfs_name=afs_config['fs_name'], hdfs_ugi=afs_config['fs_ugi'],
                path=self.global_config['output_path'] + "/gloo",
                init_timeout_seconds=1200, run_timeout_seconds=1200)
        fleet.init(role_maker)
X
xiexionghang 已提交
118 119 120 121 122 123 124 125 126 127
        data_var_list = []
        data_var_name_dict = {}
        runnnable_scope = []
        runnnable_cost_op = []
        context['status'] = 'startup'

        for executor in self.global_config['executor']:
            scope = fluid.Scope()
            self._exector_context[executor['name']] = {}
            self._exector_context[executor['name']]['scope'] = scope
T
tangwei 已提交
128
            self._exector_context[executor['name']]['model'] = model_basic.create(executor)
T
tangwei 已提交
129
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
130 131 132 133 134 135 136
            self._metrics.update(model.get_metrics())
            runnnable_scope.append(scope)
            runnnable_cost_op.append(model.get_cost_op())
            for var in model._data_var:
                if var.name in data_var_name_dict:
                    continue
                data_var_list.append(var)
T
tangwei 已提交
137
                data_var_name_dict[var.name] = var
X
xiexionghang 已提交
138

T
tangwei 已提交
139
        optimizer = model_basic.YamlModel.build_optimizer({
T
tangwei 已提交
140
            'metrics': self._metrics,
X
xiexionghang 已提交
141
            'optimizer_conf': self.global_config['optimizer']
X
xiexionghang 已提交
142 143 144 145
        })
        optimizer.minimize(runnnable_cost_op, runnnable_scope)
        for executor in self.global_config['executor']:
            scope = self._exector_context[executor['name']]['scope']
T
tangwei 已提交
146
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
147 148 149 150
            program = model._build_param['model']['train_program']
            if not executor['is_update_sparse']:
                program._fleet_opt["program_configs"][str(id(model.get_cost_op().block.program))]["push_sparse"] = []
            if 'train_thread_num' not in executor:
X
xiexionghang 已提交
151
                executor['train_thread_num'] = self.global_config['train_thread_num']
X
xiexionghang 已提交
152 153 154 155
            with fluid.scope_guard(scope):
                self._exe.run(model._build_param['model']['startup_program'])
            model.dump_model_program('./')

T
tangwei 已提交
156
        # server init done
X
xiexionghang 已提交
157 158
        if fleet.is_server():
            return 0
T
tangwei 已提交
159

X
xiexionghang 已提交
160 161 162 163 164
        self._dataset = {}
        for dataset_item in self.global_config['dataset']['data_list']:
            dataset_item['data_vars'] = data_var_list
            dataset_item.update(self.global_config['io']['afs'])
            dataset_item["batch_size"] = self.global_config['batch_size']
T
tangwei 已提交
165
            self._dataset[dataset_item['name']] = dataset.FluidTimeSplitDataset(dataset_item)
T
tangwei 已提交
166
        # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= last_day and config.reqi_dnn_plugin_pass >= last_pass:
X
xiexionghang 已提交
167 168 169 170 171
        #    util.reqi_changeslot(config.hdfs_dnn_plugin_path, join_save_params, common_save_params, update_save_params, scope2, scope3)
        fleet.init_worker()
        pass

    def print_log(self, log_str, params):
X
xiexionghang 已提交
172 173
        """R
        """
X
xiexionghang 已提交
174
        params['index'] = fleet.worker_index()
T
tangwei 已提交
175 176 177 178 179 180 181 182
        if params['master']:
            if fleet.worker_index() == 0:
                print(log_str)
                sys.stdout.flush()
        else:
            print(log_str)
        if 'stdout' in params:
            params['stdout'] += str(datetime.datetime.now()) + log_str
X
xiexionghang 已提交
183 184

    def print_global_metrics(self, scope, model, monitor_data, stdout_str):
X
xiexionghang 已提交
185 186
        """R
        """
X
xiexionghang 已提交
187
        metrics = model.get_metrics()
T
tangwei 已提交
188
        metric_calculator = AUCMetric(None)
X
xiexionghang 已提交
189
        for metric in metrics:
T
tangwei 已提交
190
            metric_param = {'label': metric, 'metric_dict': metrics[metric]}
X
xiexionghang 已提交
191
            metric_calculator.calculate(scope, metric_param)
T
tangwei 已提交
192
            metric_result = metric_calculator.get_result_to_string()
X
xiexionghang 已提交
193
            self.print_log(metric_result, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
194 195
            monitor_data += metric_result
            metric_calculator.clear(scope, metric_param)
T
tangwei 已提交
196

X
xiexionghang 已提交
197
    def save_model(self, day, pass_index, base_key):
X
xiexionghang 已提交
198 199
        """R
        """
T
tangwei 已提交
200
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
201
                                        {'master': True, 'log_format': 'save model cost %s sec'})
X
xiexionghang 已提交
202
        model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index})
T
tangwei 已提交
203 204 205
        save_mode = 0  # just save all
        if pass_index < 1:  # batch_model
            save_mode = 3  # unseen_day++, save all
T
tangwei 已提交
206
        util.rank0_print("going to save_model %s" % model_path)
X
xiexionghang 已提交
207
        fleet.save_persistables(None, model_path, mode=save_mode)
208 209
        if fleet._role_maker.is_first_worker():
            self._train_pass.save_train_progress(day, pass_index, base_key, model_path, is_checkpoint=True)
X
xiexionghang 已提交
210 211
        cost_printer.done()
        return model_path
T
tangwei 已提交
212

X
xiexionghang 已提交
213
    def save_xbox_model(self, day, pass_index, xbox_base_key, monitor_data):
X
xiexionghang 已提交
214 215
        """R
        """
X
xiexionghang 已提交
216 217
        stdout_str = ""
        xbox_patch_id = str(int(time.time()))
T
tangwei 已提交
218
        util.rank0_print("begin save delta model")
T
tangwei 已提交
219

X
xiexionghang 已提交
220 221
        model_path = ""
        xbox_model_donefile = ""
T
tangwei 已提交
222
        cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
T
tangwei 已提交
223 224
                                                          'log_format': 'save xbox model cost %s sec',
                                                          'stdout': stdout_str})
X
xiexionghang 已提交
225 226 227
        if pass_index < 1:
            save_mode = 2
            xbox_patch_id = xbox_base_key
X
xiexionghang 已提交
228 229
            model_path = self._path_generator.generate_path('xbox_base', {'day': day})
            xbox_model_donefile = self._path_generator.generate_path('xbox_base_done', {'day': day})
X
xiexionghang 已提交
230 231
        else:
            save_mode = 1
X
xiexionghang 已提交
232 233
            model_path = self._path_generator.generate_path('xbox_delta', {'day': day, 'pass_id': pass_index})
            xbox_model_donefile = self._path_generator.generate_path('xbox_delta_done', {'day': day})
X
xiexionghang 已提交
234 235 236
        total_save_num = fleet.save_persistables(None, model_path, mode=save_mode)
        cost_printer.done()

T
tangwei 已提交
237
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
238 239
                                                          'log_format': 'save cache model cost %s sec',
                                                          'stdout': stdout_str})
T
tangwei 已提交
240
        model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
X
xiexionghang 已提交
241 242 243
        if self.global_config['save_cache_model']:
            cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode)
            model_file_handler.write(
T
tangwei 已提交
244 245
                "file_prefix:part\npart_num:16\nkey_num:%d\n" % cache_save_num,
                model_path + '/000_cache/sparse_cache.meta', 'w')
X
xiexionghang 已提交
246
        cost_printer.done()
T
tangwei 已提交
247
        util.rank0_print("save xbox cache model done, key_num=%s" % cache_save_num)
X
xiexionghang 已提交
248 249 250 251 252

        save_env_param = {
            'executor': self._exe,
            'save_combine': True
        }
T
tangwei 已提交
253
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
254 255
                                                          'log_format': 'save dense model cost %s sec',
                                                          'stdout': stdout_str})
256 257 258 259 260 261 262
        if fleet._role_maker.is_first_worker():
            for executor in self.global_config['executor']:
                if 'layer_for_inference' not in executor:
                    continue
                executor_name = executor['name']
                model = self._exector_context[executor_name]['model']
                save_env_param['inference_list'] = executor['layer_for_inference']
T
tangwei 已提交
263
                save_env_param['scope'] = self._exector_context[executor_name]['scope']
264 265
                model.dump_inference_param(save_env_param)
                for dnn_layer in executor['layer_for_inference']:
T
tangwei 已提交
266 267
                    model_file_handler.cp(dnn_layer['save_file_name'],
                                          model_path + '/dnn_plugin/' + dnn_layer['save_file_name'])
268
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
269 270 271
        cost_printer.done()

        xbox_done_info = {
X
xiexionghang 已提交
272 273 274 275 276 277 278 279 280
            "id": xbox_patch_id,
            "key": xbox_base_key,
            "ins_path": "",
            "ins_tag": "feasign",
            "partition_type": "2",
            "record_count": "111111",
            "monitor_data": monitor_data,
            "mpi_size": str(fleet.worker_num()),
            "input": model_path.rstrip("/") + "/000",
T
tangwei 已提交
281 282
            "job_id": util.get_env_value("JOB_ID"),
            "job_name": util.get_env_value("JOB_NAME")
X
xiexionghang 已提交
283
        }
284 285 286 287 288
        if fleet._role_maker.is_first_worker():
            model_file_handler.write(json.dumps(xbox_done_info) + "\n", xbox_model_donefile, 'a')
            if pass_index > 0:
                self._train_pass.save_train_progress(day, pass_index, xbox_base_key, model_path, is_checkpoint=False)
        fleet._role_maker._barrier_worker()
T
tangwei 已提交
289 290
        return stdout_str

X
xiexionghang 已提交
291
    def run_executor(self, executor_config, dataset, stdout_str):
X
xiexionghang 已提交
292 293
        """R
        """
X
xiexionghang 已提交
294 295 296 297 298 299 300
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = self._train_pass._base_key
        executor_name = executor_config['name']
        scope = self._exector_context[executor_name]['scope']
        model = self._exector_context[executor_name]['model']
        with fluid.scope_guard(scope):
T
tangwei 已提交
301
            util.rank0_print("Begin " + executor_name + " pass")
X
xiexionghang 已提交
302 303 304
            begin = time.time()
            program = model._build_param['model']['train_program']
            self._exe.train_from_dataset(program, dataset, scope,
T
tangwei 已提交
305
                                         thread=executor_config['train_thread_num'], debug=self.global_config['debug'])
X
xiexionghang 已提交
306
            end = time.time()
X
xiexionghang 已提交
307
            local_cost = (end - begin) / 60.0
T
tangwei 已提交
308 309 310
            avg_cost = worker_numric_avg(local_cost)
            min_cost = worker_numric_min(local_cost)
            max_cost = worker_numric_max(local_cost)
T
tangwei 已提交
311
            util.rank0_print("avg train time %s mins, min %s mins, max %s mins" % (avg_cost, min_cost, max_cost))
X
xiexionghang 已提交
312 313 314 315
            self._exector_context[executor_name]['cost'] = max_cost

            monitor_data = ""
            self.print_global_metrics(scope, model, monitor_data, stdout_str)
T
tangwei 已提交
316
            util.rank0_print("End " + executor_name + " pass")
X
xiexionghang 已提交
317 318
            if self._train_pass.need_dump_inference(pass_id) and executor_config['dump_inference_model']:
                stdout_str += self.save_xbox_model(day, pass_id, xbox_base_key, monitor_data)
319
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
320 321

    def startup(self, context):
X
xiexionghang 已提交
322 323
        """R
        """
X
xiexionghang 已提交
324 325 326 327 328
        if fleet.is_server():
            fleet.run_server()
            context['status'] = 'wait'
            return
        stdout_str = ""
T
tangwei 已提交
329
        self._train_pass = util.TimeTrainPass(self.global_config)
X
xiexionghang 已提交
330
        if not self.global_config['cold_start']:
T
tangwei 已提交
331
            cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
332 333
                                            {'master': True, 'log_format': 'load model cost %s sec',
                                             'stdout': stdout_str})
X
xiexionghang 已提交
334
            self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True})
T
tangwei 已提交
335
            # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
X
xiexionghang 已提交
336 337
            #    and config.reqi_dnn_plugin_pass >= self._pass_id:
            #    fleet.load_one_table(0, self._train_pass._checkpoint_model_path)
T
tangwei 已提交
338
            # else:
X
xiexionghang 已提交
339 340 341 342
            fleet.init_server(self._train_pass._checkpoint_model_path, mode=0)
            cost_printer.done()
        if self.global_config['save_first_base']:
            self.print_log("save_first_base=True", {'master': True})
X
xiexionghang 已提交
343
            self.print_log("going to save xbox base model", {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
344
            self._train_pass._base_key = int(time.time())
X
xiexionghang 已提交
345
            stdout_str += self.save_xbox_model(self._train_pass.date(), 0, self._train_pass._base_key, "")
X
xiexionghang 已提交
346
        context['status'] = 'begin_day'
T
tangwei 已提交
347

X
xiexionghang 已提交
348
    def begin_day(self, context):
X
xiexionghang 已提交
349 350
        """R
        """
X
xiexionghang 已提交
351 352 353 354 355
        stdout_str = ""
        if not self._train_pass.next():
            context['is_exit'] = True
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
X
xiexionghang 已提交
356
        self.print_log("======== BEGIN DAY:%s ========" % day, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
357 358 359 360
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
        else:
            context['status'] = 'train_pass'
T
tangwei 已提交
361

X
xiexionghang 已提交
362
    def end_day(self, context):
X
xiexionghang 已提交
363 364
        """R
        """
X
xiexionghang 已提交
365 366 367 368 369
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = int(time.time())
        context['status'] = 'begin_day'

T
tangwei 已提交
370 371
        util.rank0_print("shrink table")
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
372
                                        {'master': True, 'log_format': 'shrink table done, cost %s sec'})
X
xiexionghang 已提交
373 374 375 376 377 378 379 380 381
        fleet.shrink_sparse_table()
        for executor in self._exector_context:
            self._exector_context[executor]['model'].shrink({
                'scope': self._exector_context[executor]['scope'],
                'decay': self.global_config['optimizer']['dense_decay_rate']
            })
        cost_printer.done()

        next_date = self._train_pass.date(delta_day=1)
T
tangwei 已提交
382
        util.rank0_print("going to save xbox base model")
X
xiexionghang 已提交
383
        self.save_xbox_model(next_date, 0, xbox_base_key, "")
T
tangwei 已提交
384
        util.rank0_print("going to save batch model")
X
xiexionghang 已提交
385 386
        self.save_model(next_date, 0, xbox_base_key)
        self._train_pass._base_key = xbox_base_key
387
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
388 389

    def train_pass(self, context):
X
xiexionghang 已提交
390 391
        """R
        """
X
xiexionghang 已提交
392 393 394 395 396
        stdout_str = ""
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        base_key = self._train_pass._base_key
        pass_time = self._train_pass._current_train_time.strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
397
        self.print_log("    ==== begin delta:%s ========" % pass_id, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
398 399
        train_begin_time = time.time()

T
tangwei 已提交
400
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
401 402
                                        {'master': True, 'log_format': 'load into memory done, cost %s sec',
                                         'stdout': stdout_str})
X
xiexionghang 已提交
403 404 405 406
        current_dataset = {}
        for name in self._dataset:
            current_dataset[name] = self._dataset[name].load_dataset({
                'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
407
                'begin_time': pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
408
            })
409
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
410
        cost_printer.done()
T
tangwei 已提交
411

T
tangwei 已提交
412 413
        util.rank0_print("going to global shuffle")
        cost_printer = util.CostPrinter(util.print_cost, {
X
xiexionghang 已提交
414
            'master': True, 'stdout': stdout_str,
T
tangwei 已提交
415
            'log_format': 'global shuffle done, cost %s sec'})
X
xiexionghang 已提交
416 417 418 419
        for name in current_dataset:
            current_dataset[name].global_shuffle(fleet, self.global_config['dataset']['shuffle_thread'])
        cost_printer.done()
        # str(dataset.get_shuffle_data_size(fleet))
420
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
421 422

        if self.global_config['prefetch_data']:
T
tangwei 已提交
423 424
            next_pass_time = (self._train_pass._current_train_time +
                              datetime.timedelta(minutes=self._train_pass._interval_per_pass)).strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
425 426 427
            for name in self._dataset:
                self._dataset[name].preload_dataset({
                    'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
428
                    'begin_time': next_pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
429
                })
T
tangwei 已提交
430

431
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
432 433 434
        pure_train_begin = time.time()
        for executor in self.global_config['executor']:
            self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str)
T
tangwei 已提交
435
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
436
                                        {'master': True, 'log_format': 'release_memory cost %s sec'})
X
xiexionghang 已提交
437 438 439
        for name in current_dataset:
            current_dataset[name].release_memory()
        pure_train_cost = time.time() - pure_train_begin
T
tangwei 已提交
440

X
xiexionghang 已提交
441 442 443 444 445
        if self._train_pass.is_checkpoint_pass(pass_id):
            self.save_model(day, pass_id, base_key)

        train_end_time = time.time()
        train_cost = train_end_time - train_begin_time
T
tangwei 已提交
446
        other_cost = train_cost - pure_train_cost
X
xiexionghang 已提交
447 448 449
        log_str = "finished train day %s pass %s time cost:%s sec job time cost:" % (day, pass_id, train_cost)
        for executor in self._exector_context:
            log_str += '[' + executor + ':' + str(self._exector_context[executor]['cost']) + ']'
T
tangwei 已提交
450
        log_str += '[other_cost:' + str(other_cost) + ']'
T
tangwei 已提交
451 452
        util.rank0_print(log_str)
        stdout_str += util.now_time_str() + log_str
X
xiexionghang 已提交
453
        sys.stdout.write(stdout_str)
454
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
455 456 457 458 459 460
        stdout_str = ""
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
            return
        elif not self._train_pass.next():
            context['is_exit'] = True