FusedMultiTransformer

class paddle.incubate.nn. FusedMultiTransformer ( embed_dim, num_heads, dim_feedforward, dropout_rate=0.0, activation='gelu', normalize_before=True, ln_scale_attrs=None, ln_bias_attrs=None, qkv_weight_attrs=None, qkv_bias_attrs=None, linear_weight_attrs=None, linear_bias_attrs=None, ffn_ln_scale_attrs=None, ffn_ln_bias_attrs=None, ffn1_weight_attrs=None, ffn1_bias_attrs=None, ffn2_weight_attrs=None, ffn2_bias_attrs=None, epsilon=1e-05, num_layers=- 1, nranks=1, trans_qkvw=True, ring_id=- 1, name=None ) [source]

FusedMultiTransformer is composed of multi transformer layers which contains two sub-layers which are self (multi-head) attention and feedforward network. The function of one transformer layer is consistent with the following pseudo code:

>>> 
>>> if pre_layer_norm:
...     out = layer_norm(x)
...     out = qkv_linear(out) + qkv_bias
... else:
...     out = qkv_linear(x) + qkv_bias
>>> out = transpose(out, perm=[2, 0, 3, 1, 4])
>>> # extract q, k and v from out.
>>> q = out[0:1, ::]
>>> k = out[1:2, ::]
>>> v = out[2:3, ::]
>>> out = q * k^t
>>> out = attn_mask + out
>>> out = softmax(out)
>>> out = dropout(out)
>>> out = out * v
>>> out = transpose(out, perm=[0, 2, 1, 3])
>>> out = linear(out)
>>> if pre_layer_norm:
...     out = x + dropout(out + bias)
... else:
...     out = layer_norm(x + dropout(out + bias))

>>> residual = out;
>>> if pre_layer_norm:
...     out = ffn_layer_norm(out)
>>> out = ffn1_linear(out)
>>> out = dropout(activation(out + ffn1_bias))
>>> out = ffn2_linear(out)
>>> out = residual + dropout(out + ffn2_bias)
>>> if not pre_layer_norm:
...     out = ffn_layer_norm(out)
Parameters
  • embed_dim (int) – The expected feature size in the input and output.

  • num_heads (int) – The number of heads in multi-head attention(MHA).

  • dim_feedforward (int) – The hidden layer size in the feedforward network(FFN).

  • dropout_rate (float, optional) – The dropout probability used in pre-process and post-precess of MHA and FFN sub-layer. Default 0.0

  • activation (str, optional) – The activation function in the feedforward network. Default “gelu”.

  • normalize_before (bool, optional) – Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer normalization and post-precess includes dropout, residual connection. Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Default True

  • ln_scale_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • ln_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • qkv_weight_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • qkv_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for Attention qkv computation. For Attention qkv bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • linear_weight_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for Attention linear. For Attention linear weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • linear_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for Attention linear computation. For Attention linear bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • ffn_ln_scale_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for FFN layer_norm. For FFN layer_norm weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • ffn_ln_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • ffn1_weight_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for FFN first linear. For FFN first linear weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • ffn1_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for FFN first linear. For FFN first linear bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • ffn2_weight_attrs (ParamAttr|list|tuple, optional) – To specify the weight parameter property for FFN second linear. For FFN second linear weight, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr.

  • ffn2_bias_attrs (ParamAttr|list|tuple|bool, optional) – To specify the bias parameter property for FFN second linear. For FFN second linear bias, if it is a list/tuple, attrs[0] would be used as attr for transformer layer 0, and attrs[1] would be used as attr for transformer layer 1, etc. Otherwise, all layers both use it as attr to create parameters. The False value means the corresponding layer would not have trainable bias parameter. Default: None, which means the default bias parameter property is used. See usage for details in ParamAttr.

  • epsilon (float, optional) – Small float value added to denominator of the layer_norm to avoid dividing by zero. Default: 1e-05.

  • num_layers (int, optional) – The number of layers of the transformer. If qkv_weight_attrs is a list or tuple, the number of layers is obtained from qkv_weight_attrs. num_layers only takes effect when qkv_weight_attrs is not a list or tuple. Default: -1.

  • nranks (int, optional) – Distributed tensor model parallel nranks. Default is 1, means not using mp.

  • trans_qkvw (bool, optional) – Whether to transpose for weights of qkv. If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True.

  • ring_id (int, optional) – For distributed tensor model parallel. Default is -1, means not using mp.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.nn import FusedMultiTransformer
>>> paddle.device.set_device('gpu')

>>> # encoder input: [batch_size, src_len, d_model]
>>> enc_input = paddle.rand((2, 4, 128))
>>> # self attention mask: [batch_size, 1, src_len, src_len]
>>> attn_mask = paddle.rand((2, 1, 4, 4))
>>> encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
>>> enc_output = encoder_layers(enc_input, attn_mask)
>>> print(enc_output.shape)
[2, 4, 128]
forward ( src, attn_mask=None, caches=None, pre_caches=None, rotary_embs=None, rotary_emb_dims=0, seq_lens=None, time_step=None )

forward

Applies multi transformer layers on the input.

Parameters
  • src (Tensor) – The input of Transformer layers. It is a tensor with shape [batch_size, sequence_length, d_model]. The data type should be float16 or float32.

  • attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape [batch_size, 1, sequence_length, sequence_length]. It can be None when nothing wanted or needed to be prevented attention to. Default None.

  • caches (list(Tensor)|tuple(Tensor), optional) – The cache structure tensors for the inference generation model. It is only used for inference and should be None for training. The shape is [2, batch_size, num_head, max_seq_len, head_dim]. Default None.

  • pre_caches (list(Tensor)|tuple(Tensor), optional) – The prefix caches for the generation model. The shape is [2, bsz, num_head, cache_len, head_dim]. Default None.

  • rotary_embs (Tensor optional) – The RoPE embs for the rotary computation. The shape is [2, bsz, 1, seq_len, head_dim]. Default None.

  • rotary_emb_dims (int, optional) – The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.

  • seq_lens (Tensor optional) – The sequence lengths of this batch. The shape is [bsz]. Default None.

  • time_step (Tensor, optional) – The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is [1], must be in CPUPlace. Default None.

Returns

If caches is None, return a tensor that has the same shape and data type with src, representing the output of Transformer layers. If caches is not None, return the tuple (output, caches), which output is the output of Transformer layers, caches is inplace with input caches.

Return type

Tensor|tuple