sparse_attention

paddle.nn.functional. sparse_attention ( query, key, value, sparse_csr_offset, sparse_csr_columns, name=None ) [源代码]

对 Transformer 模块中的 Attention 矩阵进行了稀疏化,从而减少内存消耗和计算量。

其稀疏数据排布通过 CSR 格式表示,CSR 格式包含两个参数,offsetcolunms。计算公式为:

\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]

其中,QKV 表示注意力模块的三个输入参数。这三个参数的维度是一样的。d 代表这三个参数的最后一个维度的大小。

警告

目前该 API 只在 CUDA11.3 及以上版本中使用。

参数

  • query (Tensor) - 输入的 Tensor,代表注意力模块中的 query,这是一个 4 维 Tensor,形状为:[batch_size, num_heads, seq_len, head_dim],数据类型为 float32 或 float64。

  • key (Tensor) - 输入的 Tensor,代表注意力模块中的 key,这是一个 4 维 Tensor,形状为:[batch_size, num_heads, seq_len, head_dim],数据类型为 float32 或 float64。

  • value (Tensor) - 输入的 Tensor,代表注意力模块中的 value,这是一个 4 维 Tensor,形状为:[batch_size, num_heads, seq_len, head_dim],数据类型为 float32 或 float64。

  • sparse_csr_offset (Tensor) - 输入的 Tensor,注意力模块中的稀疏特性,稀疏特性使用 CSR 格式表示,offset 代表矩阵中每一行非零元的数量。这是一个 3 维 Tensor,形状为:[batch_size, num_heads, seq_len + 1],数据类型为 int32。

  • sparse_csr_columns (Tensor) - 输入的 Tensor,注意力模块中的稀疏特性,稀疏特性使用 CSR 格式表示,colunms 代表矩阵中每一行非零元的列索引值。这是一个 3 维 Tensor,形状为:[batch_size, num_heads, sparse_nnz],数据类型为 int32。

返回

Tensor,代表注意力模块的结果。这是一个 4 维 Tensor,形状为:[batch_size, num_heads, seq_len, head_dim],数据类型为 float32 或 float64。

代码示例

>>> import paddle

>>> paddle.disable_static()

>>> # `query`, `key` and `value` all have shape [1, 1, 4, 2]
>>> query = paddle.to_tensor([[[[0, 1, ], [2, 3],
...                             [0, 1], [2, 3]]]], dtype="float32")
>>> key = paddle.to_tensor([[[[0, 1], [2, 3],
...                           [0, 1], [2, 3]]]], dtype="float32")
>>> value = paddle.to_tensor([[[[0, 1], [2, 3],
...                             [0, 1], [2, 3]]]], dtype="float32")
...
>>> offset = paddle.to_tensor([[[0, 2, 4, 6, 8]]], dtype="int32")
>>> columns = paddle.to_tensor([[[0, 1, 0, 1, 2, 3, 2, 3]]], dtype="int32")
...
>>> print(offset.shape)
[1, 1, 5]
>>> print(columns.shape)
[1, 1, 8]
...
>>> key_padding_mask = paddle.to_tensor([[1, 1, 1, 0]], dtype="float32")
>>> attention_mask = paddle.to_tensor([[1, 0, 1, 1],
...                                    [1, 1, 1, 1],
...                                    [1, 1, 1, 1],
...                                    [1, 1, 1, 1]], dtype="float32")
>>> output_mask = paddle.nn.functional.sparse_attention(query, key,
...                                                     value, offset, columns,
...                                                     key_padding_mask=key_padding_mask,
...                                                     attn_mask=attention_mask)
>>> print(output_mask)
Tensor(shape=[1, 1, 4, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[[[0.        , 1.        ],
   [1.99830270, 2.99830270],
   [0.        , 1.        ],
   [0.        , 1.        ]]]])

>>> output = paddle.nn.functional.sparse_attention(query, key,
...                                             value, offset, columns)
>>> print(output)
Tensor(shape=[1, 1, 4, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[[[1.60885942, 2.60885954],
   [1.99830270, 2.99830270],
   [1.60885942, 2.60885954],
   [1.99830270, 2.99830270]]]])