From abb108dcebcf437df2fbd874863bd5e17ff417f6 Mon Sep 17 00:00:00 2001 From: Xiaoyao Xi <24541791+xixiaoyao@users.noreply.github.com> Date: Sun, 29 Mar 2020 23:52:41 +0800 Subject: [PATCH] Update base_backbone.py --- paddlepalm/backbone/base_backbone.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/paddlepalm/backbone/base_backbone.py b/paddlepalm/backbone/base_backbone.py index 9a8f79f..38c604b 100644 --- a/paddlepalm/backbone/base_backbone.py +++ b/paddlepalm/backbone/base_backbone.py @@ -12,23 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""v1.1""" class Backbone(object): """interface of backbone model.""" - def __init__(self, config, phase): - """ + def __init__(self, phase): + """该函数完成一个主干网络的构造,至少需要包含一个phase参数。 + 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 Args: - config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数 - phase: str类型。运行阶段,目前支持train和predict + phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict """ + assert isinstance(config, dict) @property def inputs_attr(self): - """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象 + 为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape + 中的相应维度设置为-1。 + Return: dict类型。对各个输入对象的属性描述。例如, 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 @@ -40,7 +43,9 @@ class Backbone(object): @property def outputs_attr(self): - """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如 + str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: dict类型。对各个输出对象的属性描述。例如, 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 @@ -57,4 +62,3 @@ class Backbone(object): 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 """ raise NotImplementedError() - -- GitLab