未验证 提交 d91c9c63 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update multihead_trainer.py

上级 b88394ac
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count, data_feeder, decode_fake
from paddlepalm import Trainer
from paddlepalm.utils import reader_helper
import numpy as np
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
import time
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
......@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer):
"""
Build forward computation graph for training, which usually built from input layer to loss node.
Args:
backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
heads: a list of Head objects. Phase of each head should be set as 'train', which is used to build task specific output layers.
Return:
- loss_var: a Variable object. The computational graph variable(node) of loss.
"""
......@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer):
def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'):
"""
Bind readers and loaded train/predict data to trainers.
Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only
works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from
their `mix_ratio`.
Args:
readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each
sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling.
num_epochs: training epochs of the sampling_reference task (trainer).
"""
self._check_phase(phase)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册