attention¶
- paddle.sparse.nn.functional. attention ( query, key, value, sparse_mask, key_padding_mask=None, attn_mask=None, name=None ) [source]
-
- Note:
-
This API is only used from
CUDA 11.7
.
SparseCsrTensor is used to store the intermediate result of Attention matrix in Transformer module, which can reduce memory usage and improve performance.
sparse_mask
express the sparse layout in CSR format. The calculation equation is:\[result = softmax(\]rac{ Q * K^T }{sqrt{d}}) * V
where :
Q
,K
, andV
represent the three input parameters of the attention module. The shape of the three parameters are: [batch_size, num_heads, seq_len, head_dim], andd
representshead_dim
.- Args:
-
query(DenseTensor): query in the Attention module. 4D Tensor with float32 or float64. key(DenseTensor): key in the Attention module. 4D Tensor with float32 or float64. value(DenseTensor): value in the Attention module. 4D Tensor with float32 or float64. sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is [batch_size*num_heads, seq_len, seq_len] . nnz of each batch must be the same. dtype of crows and cols must be int64, dtype of values can be float32 or float64.
- key_padding_mask(DenseTensor, optional): The key padding mask tensor in the Attention module.
-
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
- attn_mask(DenseTensor, optional): The attention mask tensor in the Attention module.
-
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
- name(str, optional): The default value is None. Normally there is no need for user
-
to set this property. For more information, please refer to Name.
- Returns:
-
4D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. dtype is same with input.
- Examples:
-
# required: gpu import paddle batch_size = 16 num_heads = 16 seq_len = 512 head_dim = 32 query = paddle.rand([batch_size, num_heads, seq_len, head_dim]) key = paddle.rand([batch_size, num_heads, seq_len, head_dim]) value = paddle.rand([batch_size, num_heads, seq_len, head_dim]) query.stop_gradient = False key.stop_gradient = False value.stop_gradient = False mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len]) sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr() kp_mask = paddle.randint(0, 2, [batch_size, seq_len]) attn_mask = paddle.randint(0, 2, [seq_len, seq_len]) output = paddle.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask) output.backward()