scaled_dot_product_attention

paddle.fluid.nets. scaled_dot_product_attention ( queries, keys, values, num_heads=1, dropout_rate=0.0 ) [源代码]

该接口实现了的基于点积(并进行了缩放)的多头注意力(Multi-Head Attention)机制。attention可以表述为将一个查询(query)和一组键值对(key-value pair)映射为一个输出;Multi-Head Attention则是使用多路进行attention,而且对attention的输入进行了线性变换。公式如下:

\[\begin{split}MultiHead(Q, K, V ) & = Concat(head_1, ..., head_h)\\ where \ head_i & = Attention(QW_i^Q , KW_i^K , VW_i^V )\\ Attention(Q, K, V) & = softmax(\frac{QK^\mathrm{T}}{\sqrt{d_k}})V\\\end{split}\]

其中, \(Q, K, V\) 分别对应 querieskeysvalues ,详细内容请参阅 Attention Is All You Need

要注意该接口实现支持的是batch形式, \(Attention(Q, K, V)\) 中使用的矩阵乘是batch形式的矩阵乘法,参考 fluid.layers. matmul

参数

  • queries (Variable) - 形状为 \([N, L_q, d_k \times h]\) 的三维Tensor,其中 \(N\) 为batch_size, \(L_q\) 为查询序列长度, \(d_k \times h\) 为查询的特征维度大小,\(h\) 为head数。数据类型为float32或float64。

  • keys (Variable) - 形状为 \([N, L_k, d_k \times h]\) 的三维Tensor,其中 \(N\) 为batch_size, \(L_k\) 为键值序列长度, \(d_k \times h\) 为键的特征维度大小,\(h\) 为head数。数据类型与 queries 相同。

  • values (Variable) - 形状为 \([N, L_k, d_v \times h]\) 的三维Tensor,其中 \(N\) 为batch_size, \(L_k\) 为键值序列长度, \(d_v \times h\) 为值的特征维度大小,\(h\) 为head数。数据类型与 queries 相同。

  • num_heads (int) - 指明所使用的head数。head数为1时不对输入进行线性变换。默认值为1。

  • dropout_rate (float) - 以指定的概率对要attention到的内容进行dropout。默认值为0,即不使用dropout。

返回

形状为 \([N, L_q, d_v * h]\) 的三维Tensor,其中 \(N\) 为batch_size, \(L_q\) 为查询序列长度, \(d_v * h\) 为值的特征维度大小。与输入具有相同的数据类型。表示Multi-Head Attention的输出。

返回类型

Variable

抛出异常

  • ValueErrorquerieskeysvalues 必须都是三维。

  • ValueErrorquerieskeys 的最后一维(特征维度)大小必须相同。

  • ValueErrorkeysvalues 的第二维(长度维度)大小必须相同。

  • ValueErrorkeys 的最后一维(特征维度)大小必须是 num_heads 的整数倍。

  • ValueErrorvalues 的最后一维(特征维度)大小必须是 num_heads 的整数倍。

代码示例

import paddle.fluid as fluid

queries = fluid.data(name="queries", shape=[3, 5, 9], dtype="float32")
keys = fluid.data(name="keys", shape=[3, 6, 9], dtype="float32")
values = fluid.data(name="values", shape=[3, 6, 10], dtype="float32")
contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values)
contexts.shape  # [3, 5, 10]