FusedMultiHeadAttention

class paddle.incubate.nn. FusedMultiHeadAttention ( embed_dim, num_heads, dropout_rate=0.5, attn_dropout_rate=0.5, kdim=None, vdim=None, normalize_before=False, need_weights=False, weight_attr=None, bias_attr=None, name=None ) [源代码]

多头注意力机制

注意力机制可以将查询(Query)与一组键值对(Key-Value)映射到输出。而多头注意力机制是将注意力机制的计算过程计算多次,以便模型提取不同子空间的信息。

细节可参考论文 Attention is all you need

FusedMultiHeadAttention 与已有的 MultiHeadAttention 有两处不同:

(1)表达的计算逻辑范围不同。相比 MultiHeadAttentionFusedMultiHeadAttention 的前面在 normalize_before=True 时,多了 layer_norm 算子,后面多了 residual adddropoutlayer_norm 的计算。

(2)q, k, v的weight的存储格式不同。 MultiHeadAttention 将q, k, v的weight存储在三个张量中。 FusedMultiHeadAttention 的q, k, v的weight被统一存在一个权重张量中,其维度为 [3, num_heads, head_dim, embed_dim]

参数

  • embed_dim (int) - 输入输出的维度。

  • num_heads (int) - 多头注意力机制的Head数量。

  • dropout_rate (float,可选) - multi-head attention后面的dropout算子的注意力目标的随机失活率。0表示进行dropout计算。默认值:0.5。

  • attn_dropout_rate (float,可选) - multi-head attention中的dropout算子的注意力目标的随机失活率。0表示不进行dropout计算。默认值:0.5。

  • kdim (int,可选) - 键值对中key的维度。如果为 Nonekdim = embed_dim 。默认值 None

  • vdim (int,可选) - 键值对中value的维度。如果为 Nonekdim = embed_dim 。默认值: None

  • normalize_before (bool, 可选) - 是pre_layer_norm结构(True)还是post_layer_norm结构(False)。pre_layer_norm结构中, layer_norm 算子位于multi-head attention和ffn的前面,post_layer_norm结构中, layer_norm 位于两者的后面。默认值: False

  • need_weights (bool, 可选) - 表明是否返回注意力权重。默认值: False

  • weight_attr (ParamAttr,可选) - 指定权重参数属性的对象。默认值: None ,表示使用默认的权重参数属性。具体用法请参见 ParamAttr

  • bias_attr (ParamAttr,可选)- 指定偏置参数属性的对象。默认值: None ,表示使用默认的偏置参数属性。具体用法请参见 ParamAttr

  • name (str,可选) - 操作的名称。默认值为: None 。更多信息请参见 Name

形状

  • x (Tensor): 默认形状为 [batch_size, sequence_length, embed_dim] ,其数据类型为float32,float64或者float16。

  • output (Tensor): 其形状和数据类型与输入x相同。

返回

计算FusedMultiHeadAttention的可调用对象

代码示例

import paddle
from paddle.incubate.nn import FusedMultiHeadAttention

# input: [batch_size, sequence_length, embed_dim]
query = paddle.rand((2, 4, 128))
# self-attention mask: [batch_size, num_heads, query_len, query_len]
attn_mask = paddle.rand((2, 2, 4, 4))
multi_head_attn = FusedMultiHeadAttention(128, 2)
output = multi_head_attn(query, None, None, attn_mask=attn_mask)  # [2, 4, 128]