fused_multi_transformer

paddle.incubate.nn.functional. fused_multi_transformer ( x, ln_scales, ln_biases, qkv_weights, qkv_biases, linear_weights, linear_biases, ffn_ln_scales, ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, pre_layer_norm=True, epsilon=1e-05, cache_kvs=None, pre_caches=None, seq_lens=None, rotary_embs=None, time_step=None, attn_mask=None, dropout_rate=0.0, rotary_emb_dims=0, activation='gelu', training=False, mode='upscale_in_train', trans_qkvw=True, ring_id=- 1, name=None ) [source]

This is a fusion operator to compute multi transformer layers in transformer model architecture. This operator only supports running on GPU. The function of the transformer layer is consistent with the following pseudo code:

>>> if pre_layer_norm:
...     out = layer_norm(x)
...     out = qkv_linear(out) + qkv_bias
... else:
...     out = qkv_linear(x) + qkv_bias
>>> out = transpose(out, perm=[2, 0, 3, 1, 4])
>>> # extract q, k and v from out.
>>> q = out[0:1, ::]
>>> k = out[1:2, ::]
>>> v = out[2:3, ::]
>>> out = q * k^t
>>> out = attn_mask + out
>>> out = softmax(out)
>>> out = dropout(out)
>>> out = out * v
>>> out = transpose(out, perm=[0, 2, 1, 3])
>>> out = linear(out)
>>> if pre_layer_norm:
...     out = x + dropout(out + bias)
... else:
...     out = layer_norm(x + dropout(out + bias))

>>> residual = out;
>>> if pre_layer_norm:
...     out = ffn_layer_norm(out)
>>> out = ffn1_linear(out)
>>> out = dropout(activation(out + ffn1_bias))
>>> out = ffn2_linear(out)
>>> out = residual + dropout(out + ffn2_bias)
>>> if not pre_layer_norm:
...     out = ffn_layer_norm(out)
Parameters
  • x (Tensor) – the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is [batch_size, sequence_length, d_model].

  • ln_scales (list(Tensor)|tuple(Tensor)) – The weight tensors of attention layer_norm, the shape is [d_model].

  • ln_biases (list(Tensor)|tuple(Tensor)) – The bias tensors of attention layer_norm. the shape is [d_model].

  • qkv_weights (list(Tensor)|tuple(Tensor)) – The weight tensors of attention qkv computation. The shape is [3, num_head, dim_head, d_model].

  • qkv_biases (list(Tensor)|tuple(Tensor)|None) – The bias tensors of attention qkv computation. The shape is [3, num_head, dim_head].

  • linear_weights (list(Tensor)|tuple(Tensor)) – The weight tensors of attention linear. The shape is [num_head * dim_head, d_model].

  • linear_biases (list(Tensor)|tuple(Tensor)|None) – The bias tensors of attention linear. The shape is [d_model].

  • ffn_ln_scales (list(Tensor)|tuple(Tensor)) – The weight tensors of feedforward layer_norm, the shape is [d_model]

  • ffn_ln_biases (list(Tensor)|tuple(Tensor)) – The bias tensors of feedforward layer_norm, the shape is [d_model]

  • ffn1_weights (list(Tensor)|tuple(Tensor)) – The weight tensors of feedforward first linear, the shape is [d_model, dim_feedforward].

  • ffn1_biases (list(Tensor)|tuple(Tensor)|None) – The bias tensors of feedforward first linear, the shape is [dim_feedforward].

  • ffn2_weights (list(Tensor)|tuple(Tensor)) – The weight tensors of feedforward second linear, the shape is [dim_feedforward, d_model].

  • ffn2_biases (list(Tensor)|tuple(Tensor)|None) – The bias tensors of feedforward second linear, the shape is [d_model].

  • pre_layer_norm (bool, optional) – whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.

  • epsilon (float, optional) – Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.

  • cache_kvs (list(Tensor)|tuple(Tensor), optional) – The cache structure tensors for the generation model. The shape is [2, bsz, num_head, max_seq_len, head_dim]. Default None.

  • pre_caches (list(Tensor)|tuple(Tensor), optional) – The prefix caches for the generation model. The shape is [2, bsz, num_head, cache_len, head_dim]. Default None.

  • seq_lens (Tensor optional) – The sequence lengths of this batch. The shape is [bsz]. Default None.

  • rotary_embs (Tensor optional) – The RoPE embs for rotary computation. The shape is [2, bsz, 1, seq_len, head_dim]. Default None.

  • time_step (Tensor, optional) – The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is [1], must be in CPUPlace. Default None.

  • attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape [batch_size, 1, sequence_length, sequence_length]. Default None.

  • dropout_rate (float, optional) – The dropout probability of setting units to zero. Default 0.0.

  • rotary_emb_dims (int, optional) – The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.

  • activation (str, optional) – The activation. Default “gelu”.

  • training (bool, optional) – A flag indicating whether it is in train phrase or not. Default False.

  • mode (str, optional) –

    [‘upscale_in_train’(default) | ‘downscale_in_infer’]

    1. upscale_in_train(default), upscale the output at training time

      • train: out = input * mask / ( 1.0 - p )

      • inference: out = input

    2. downscale_in_infer, downscale the output at inference

      • train: out = input * mask

      • inference: out = input * (1.0 - p)

  • trans_qkvw (bool, optional) – Whether to transpose for weights of qkv. If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True.

  • ring_id (int, optional) – For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

If cache_kvs is None, return a tensor that has the same shape and data type with x, representing the output of Transformer layers. If cache_kvs is not None, return the tuple (output, cache_kvs), which output is the output of Transformer layers, cache_kvs is inplace with input cache_kvs.

Return type

Tensor|tuple

Examples

>>> 
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> import paddle.incubate.nn.functional as F

>>> # input: [batch_size, seq_len, embed_dim]
>>> x = paddle.rand(shape=(2, 4, 128), dtype="float32")

>>> # ln_scale: [embed_dim], ln_bias: [embed_dim]
>>> ln_scale = paddle.rand(shape=(128,), dtype="float32")
>>> ln_bias = paddle.rand(shape=(128,), dtype="float32")

>>> # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim]
>>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
>>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")

>>> # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim]
>>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
>>> linear_bias = paddle.rand(shape=(128,), dtype="float32")

>>> # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim]
>>> ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32")
>>> ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32")

>>> # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim]
>>> ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32")
>>> ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32")

>>> # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim]
>>> ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32")
>>> ffn2_bias = paddle.rand(shape=(128,), dtype="float32")

>>> # self attention mask: [batch_size, 1, seq_len, seq_len]
>>> attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32")

>>> # output: [batch_size, seq_len, embed_dim]
>>> output = F.fused_multi_transformer(
...     x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias],
...     [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias],
...     [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias],
...     attn_mask=attn_mask)
>>> print(output.shape)
[2, 4, 128]