提交 62b3340c 编写于 作者: L leiyuning 提交者: Gitee

!3 Add example for bert

Merge pull request !3 from c_34/master
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert Init."""
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
SaturateCast, CreateAttentionMaskFromInputMask
__all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert for pretraining."""
import numpy as np
import mindspore.nn as nn
from mindspore.common.initializer import initializer, TruncatedNormal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
class ClipGradients(nn.Cell):
"""
Clip gradients.
Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
Outputs:
tuple[Tensor], clipped gradients.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
if clip_type != 0 and clip_type != 1:
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads
class GetMaskedLMOutput(nn.Cell):
"""
Get masked lm output.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, masked lm output.
"""
def __init__(self, config):
super(GetMaskedLMOutput, self).__init__()
self.width = config.hidden_size
self.reshape = P.Reshape()
self.gather = P.GatherV2()
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(self.width,
config.hidden_size,
weight_init=weight_init,
activation=config.hidden_act).to_float(config.compute_type)
self.layernorm = nn.LayerNorm(config.hidden_size).to_float(config.compute_type)
self.output_bias = Parameter(
initializer(
'zero',
config.vocab_size),
name='output_bias')
self.matmul = P.MatMul(transpose_b=True)
self.log_softmax = nn.LogSoftmax(axis=-1)
self.shape_flat_offsets = (-1, 1)
self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
self.last_idx = (-1,)
self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
self.cast = P.Cast()
self.compute_type = config.compute_type
self.dtype = config.dtype
def construct(self,
input_tensor,
output_weights,
positions):
flat_offsets = self.reshape(
self.rng * self.seq_length_tensor, self.shape_flat_offsets)
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
input_tensor = self.cast(input_tensor, self.compute_type)
output_weights = self.cast(output_weights, self.compute_type)
input_tensor = self.dense(input_tensor)
input_tensor = self.layernorm(input_tensor)
logits = self.matmul(input_tensor, output_weights)
logits = self.cast(logits, self.dtype)
logits = logits + self.output_bias
log_probs = self.log_softmax(logits)
return log_probs
class GetNextSentenceOutput(nn.Cell):
"""
Get next sentence output.
Args:
config (BertConfig): The config of Bert.
Returns:
Tensor, next sentence output.
"""
def __init__(self, config):
super(GetNextSentenceOutput, self).__init__()
self.log_softmax = P.LogSoftmax()
self.weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2,
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype
self.cast = P.Cast()
def construct(self, input_tensor):
logits = self.dense(input_tensor)
logits = self.cast(logits, self.dtype)
log_prob = self.log_softmax(logits)
return log_prob
class BertPreTraining(nn.Cell):
"""
Bert pretraining network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
Returns:
Tensor, prediction_scores, seq_relationship_score.
"""
def __init__(self, config, is_training, use_one_hot_embeddings):
super(BertPreTraining, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cls1 = GetMaskedLMOutput(config)
self.cls2 = GetNextSentenceOutput(config)
def construct(self, input_ids, input_mask, token_type_id,
masked_lm_positions):
sequence_output, pooled_output, embedding_table = \
self.bert(input_ids, token_type_id, input_mask)
prediction_scores = self.cls1(sequence_output,
embedding_table,
masked_lm_positions)
seq_relationship_score = self.cls2(pooled_output)
return prediction_scores, seq_relationship_score
class BertPretrainingLoss(nn.Cell):
"""
Provide bert pre-training loss.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, total loss.
"""
def __init__(self, config):
super(BertPretrainingLoss, self).__init__()
self.vocab_size = config.vocab_size
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
masked_lm_weights, next_sentence_labels):
"""Defines the computation performed."""
label_ids = self.reshape(masked_lm_ids, self.last_idx)
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
numerator = self.reduce_sum(label_weights * per_example_loss, ())
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
masked_lm_loss = numerator / denominator
# next_sentence_loss
labels = self.reshape(next_sentence_labels, self.last_idx)
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(
one_hot_labels * seq_relationship_score, self.last_idx))
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
# total_loss
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
class BertNetworkWithLoss(nn.Cell):
"""
Provide bert pre-training loss through network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(BertNetworkWithLoss, self).__init__()
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
self.loss = BertPretrainingLoss(config)
self.cast = P.Cast()
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
prediction_scores, seq_relationship_score = \
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
total_loss = self.loss(prediction_scores, seq_relationship_score,
masked_lm_ids, masked_lm_weights, next_sentence_labels)
return self.cast(total_loss, mstype.float32)
class BertTrainOneStepCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
def set_sens(self, value):
self.sens = value
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class BertTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
return F.depend(ret, succ)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert model."""
import math
import copy
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
class BertConfig:
"""
Configuration for `BertModel`.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 32000.
hidden_size (int): Size of the bert encoder layers. Default: 768.
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
cell. Default: 12.
num_attention_heads (int): Number of attention heads in the BertTransformer
encoder cell. Default: 12.
intermediate_size (int): Size of intermediate layer in the BertTransformer
encoder cell. Default: 3072.
hidden_act (str): Activation function used in the BertTransformer encoder
cell. Default: "gelu".
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
type_vocab_size (int): Size of token type vocab. Default: 16.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
from dataset. Default: True.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length=128,
vocab_size=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.input_mask_from_dataset = input_mask_from_dataset
self.token_type_ids_from_dataset = token_type_ids_from_dataset
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
embedding_shape,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]),
name='embedding_table')
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.GatherV2()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
def construct(self, input_ids):
extended_ids = self.expand(input_ids, -1)
flat_ids = self.reshape(extended_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, self.shape)
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional and token type embeddings to word embeddings.
Args:
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
token_type_vocab_size (int): Size of token type vocab. Default: 16.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
embedding_size,
embedding_shape,
use_relative_positions=False,
use_token_type=False,
token_type_vocab_size=16,
use_one_hot_embeddings=False,
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
super(EmbeddingPostprocessor, self).__init__()
self.use_token_type = use_token_type
self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[token_type_vocab_size,
embedding_size]),
name='embedding_table')
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.1, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm(embedding_size)
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.Slice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
embedding_size]),
name='full_position_embeddings')
def construct(self, token_type_ids, word_embeddings):
output = word_embeddings
if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)
output = self.dropout(output)
return output
class BertOutput(nn.Cell):
"""
Apply a linear computation to hidden status and a residual computation to input.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
dropout_prob (float): The dropout probability. Default: 0.1.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
in_channels,
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.add = P.TensorAdd()
self.layernorm = nn.LayerNorm(out_channels).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(output, input_tensor)
output = self.layernorm(output)
return output
class RelaPosMatrixGenerator(nn.Cell):
"""
Generates matrix of relative positions between inputs.
Args:
length (int): Length of one dim for the matrix to be generated.
max_relative_position (int): Max value of relative position.
"""
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self.range_length = -length + 1
self.tile = P.Tile()
self.range_mat = P.Reshape()
self.sub = P.Sub()
self.expanddims = P.ExpandDims()
self.cast = P.Cast()
def construct(self):
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
tile_row_out = self.tile(range_vec_row_out, (self._length,))
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.
final_mat = distance_mat_clipped + self._max_relative_position
return final_mat
class RelaPosEmbeddingsGenerator(nn.Cell):
"""
Generates tensor of size [length, length, depth].
Args:
length (int): Length of one dim for the matrix to be generated.
depth (int): Size of each attention head.
max_relative_position (int): Maxmum value of relative position.
initializer_range (float): Initialization value of TruncatedNormal.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
length,
depth,
max_relative_position,
initializer_range,
use_one_hot_embeddings=False):
super(RelaPosEmbeddingsGenerator, self).__init__()
self.depth = depth
self.vocab_size = max_relative_position * 2 + 1
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embeddings_table = Parameter(
initializer(TruncatedNormal(initializer_range),
[self.vocab_size, self.depth]),
name='embeddings_for_position')
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.shape = P.Shape()
self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul()
def construct(self):
relative_positions_matrix_out = self.relative_positions_matrix()
# Generate embedding for each relative position of dimension depth.
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
else:
embeddings = self.gather(self.embeddings_table,
relative_positions_matrix_out, 0)
return embeddings
class SaturateCast(nn.Cell):
"""
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
the danger that the value will overflow or underflow.
Args:
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
"""
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.min_op = P.Minimum()
self.max_op = P.Maximum()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
out = self.max_op(x, self.tensor_min_type)
out = self.min_op(out, self.tensor_max_type)
return self.cast(out, self.dst_type)
class BertAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
batch_size (int): Batch size of input datasets.
from_tensor_width (int): Size of last dim of from_tensor.
to_tensor_width (int): Size of last dim of to_tensor.
from_seq_length (int): Length of from_tensor sequence.
to_seq_length (int): Length of to_tensor sequence.
num_attention_heads (int): Number of attention heads. Default: 1.
size_per_head (int): Size of each attention head. Default: 512.
query_act (str): Activation function for the query transform. Default: None.
key_act (str): Activation function for the key transform. Default: None.
value_act (str): Activation function for the value transform. Default: None.
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.0.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
tensor. Default: False.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
from_tensor_width,
to_tensor_width,
from_seq_length,
to_seq_length,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=False,
attention_probs_dropout_prob=0.0,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=False,
use_relative_positions=False,
compute_type=mstype.float32):
super(BertAttention, self).__init__()
self.batch_size = batch_size
self.from_seq_length = from_seq_length
self.to_seq_length = to_seq_length
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
weight = TruncatedNormal(initializer_range)
units = num_attention_heads * size_per_head
self.query_layer = nn.Dense(from_tensor_width,
units,
activation=query_act,
weight_init=weight).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width,
units,
activation=key_act,
weight_init=weight).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width,
units,
activation=value_act,
weight_init=weight).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (
batch_size, to_seq_length, num_attention_heads, size_per_head)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.TensorAdd()
self.cast = P.Cast()
self.get_dtype = P.DType()
if do_return_2d_tensor:
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
else:
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
# reshape 2d/3d input tensors to 2d
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query_layer(from_tensor_2d)
key_out = self.key_layer(to_tensor_2d)
value_out = self.value_layer(to_tensor_2d)
query_layer = self.reshape(query_out, self.shape_from)
query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, self.shape_to)
key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# 'relations_keys' = [F|T, F|T, H]
relations_keys = self._generate_relative_positions_embeddings()
relations_keys = self.cast_compute_type(relations_keys)
# query_layer_t is [F, B, N, H]
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
# query_layer_r is [F, B * N, H]
query_layer_r = self.reshape(query_layer_t,
(self.from_seq_length,
self.batch_num,
self.size_per_head))
# key_position_scores is [F, B * N, F|T]
key_position_scores = self.matmul_trans_b(query_layer_r,
relations_keys)
# key_position_scores_r is [F, B, N, F|T]
key_position_scores_r = self.reshape(key_position_scores,
(self.from_seq_length,
self.batch_size,
self.num_attention_heads,
self.from_seq_length))
# key_position_scores_r_t is [B, N, F, F|T]
key_position_scores_r_t = self.transpose(key_position_scores_r,
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(attention_scores, self.scores_mul)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores)))
adder = self.multiply(multiply_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_probs = self.softmax(attention_scores)
attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, self.shape_to)
value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# 'relations_values' = [F|T, F|T, H]
relations_values = self._generate_relative_positions_embeddings()
relations_values = self.cast_compute_type(relations_values)
# attention_probs_t is [F, B, N, T]
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
# attention_probs_r is [F, B * N, T]
attention_probs_r = self.reshape(
attention_probs_t,
(self.from_seq_length,
self.batch_num,
self.to_seq_length))
# value_position_scores is [F, B * N, H]
value_position_scores = self.matmul(attention_probs_r,
relations_values)
# value_position_scores_r is [F, B, N, H]
value_position_scores_r = self.reshape(value_position_scores,
(self.from_seq_length,
self.batch_size,
self.num_attention_heads,
self.size_per_head))
# value_position_scores_r_t is [B, N, F, H]
value_position_scores_r_t = self.transpose(value_position_scores_r,
self.trans_shape_position)
context_layer = context_layer + value_position_scores_r_t
context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, self.shape_return)
return context_layer
class BertSelfAttention(nn.Cell):
"""
Apply self-attention.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
hidden_size (int): Size of the bert encoder layers.
num_attention_heads (int): Number of attention heads. Default: 12.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length,
hidden_size,
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (hidden_size, num_attention_heads))
self.size_per_head = int(hidden_size / num_attention_heads)
self.attention = BertAttention(
batch_size=batch_size,
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
from_seq_length=seq_length,
to_seq_length=seq_length,
num_attention_heads=num_attention_heads,
size_per_head=self.size_per_head,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
use_relative_positions=use_relative_positions,
has_attention_mask=True,
do_return_2d_tensor=True,
compute_type=compute_type)
self.output = BertOutput(in_channels=hidden_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, attention_mask):
input_tensor = self.reshape(input_tensor, self.shape)
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
output = self.output(attention_output, input_tensor)
return output
class BertEncoderCell(nn.Cell):
"""
Encoder cells used in BertTransformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the bert encoder layers. Default: 768.
seq_length (int): Length of input sequence. Default: 512.
num_attention_heads (int): Number of attention heads. Default: 12.
intermediate_size (int): Size of intermediate layer. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.02.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size=768,
seq_length=512,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.02,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.output = BertOutput(in_channels=intermediate_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
def construct(self, hidden_states, attention_mask):
# self-attention
attention_output = self.attention(hidden_states, attention_mask)
# feed construct
intermediate_output = self.intermediate(attention_output)
# add and normalize
output = self.output(intermediate_output, attention_output)
return output
class BertTransformer(nn.Cell):
"""
Multi-layer bert transformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
seq_length (int): Length of input sequence.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
"""
def __init__(self,
batch_size,
hidden_size,
seq_length,
num_hidden_layers,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
layers = []
for _ in range(num_hidden_layers):
layer = BertEncoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
self.out_shape = (batch_size, seq_length, hidden_size)
def construct(self, input_tensor, attention_mask):
prev_output = self.reshape(input_tensor, self.shape)
all_encoder_layers = ()
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask)
prev_output = layer_output
if self.return_all_encoders:
layer_output = self.reshape(layer_output, self.out_shape)
all_encoder_layers = all_encoder_layers + (layer_output,)
if not self.return_all_encoders:
prev_output = self.reshape(prev_output, self.out_shape)
all_encoder_layers = all_encoder_layers + (prev_output,)
return all_encoder_layers
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (Class): Configuration for BertModel.
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.input_mask = None
if not self.input_mask_from_dataset:
self.input_mask = initializer(
"ones", [config.batch_size, config.seq_length], mstype.int32)
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (config.batch_size, 1, config.seq_length)
self.broadcast_ones = initializer(
"ones", [config.batch_size, config.seq_length, 1], mstype.float32)
self.batch_matmul = P.BatchMatMul()
def construct(self, input_mask):
if not self.input_mask_from_dataset:
input_mask = self.input_mask
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
return attention_mask
class BertModel(nn.Cell):
"""
Bidirectional Encoder Representations from Transformers.
Args:
config (Class): Configuration for BertModel.
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
config,
is_training,
use_one_hot_embeddings=False):
super(BertModel, self).__init__()
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.input_mask_from_dataset = config.input_mask_from_dataset
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [self.batch_size, self.seq_length,
self.embedding_size]
if not self.token_type_ids_from_dataset:
self.token_type_ids = initializer(
"zeros", [self.batch_size, self.seq_length], mstype.int32)
self.bert_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_relative_positions=config.use_relative_positions,
use_token_type=True,
token_type_vocab_size=config.type_vocab_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.bert_encoder = BertTransformer(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
self.slice = P.StridedSlice()
self.squeeze_1 = P.Squeeze(axis=1)
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
activation="tanh",
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
def construct(self, input_ids, token_type_ids, input_mask):
# embedding
if not self.token_type_ids_from_dataset:
token_type_ids = self.token_type_ids
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
# bert encoder
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(self.batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
pooled_output = self.cast(pooled_output, self.dtype)
return sequence_output, pooled_output, embedding_tables
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
# please add your own dataset path
'DATA_DIR': "/your/path/examples.tfrecord",
# please add your own dataset schema path
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \
--input_file=./sample_text.txt \
--output_file=./examples.tfrecord \
--vocab_file=./your/path/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
"""
import os
import numpy as np
from config import bert_train_cfg, bert_net_cfg
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb
_current_dir = os.path.dirname(os.path.realpath(__file__))
def create_train_dataset(batch_size):
"""create train dataset"""
# apply repeat operations
repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
return ds
def weight_variable(shape):
"""weight variable"""
np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones)
def train_bert():
"""train bert"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
ds = create_train_dataset(bert_net_cfg.batch_size)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
start_learning_rate=bert_train_cfg.start_learning_rate,
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True)
model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__':
train_bert()
# Bert NEZHA
`NEZHA` (**NE**ural contextuali**Z**ed representation for C**H**inese l**A**nguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
- `Bert_NEZHA`: Source of NEZHA model same as the one from `mindspore.model_zoo.Bert_NEZHA`
- `Bert_NEZHA_cnwiki`: The NEZHA pretraining example using data from cnwiki.
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册