scaled_dot_product_attention
- paddle.compat.nn.functional. scaled_dot_product_attention ( query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False ) Tensor [source]
-
The equation is:
\[result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V\]where :
Q,K, andVrepresent the three input parameters of the attention module. The dimensions of the three parameters are the same.drepresents the size of the last dimension of the three parameters.Warning
This API only verifies inputs with dtype float16 and bfloat16, other dtypes may fall back to math implementation, which is less optimized.
Note
This API differs from scaled_dot_product_attention in that: The QKV layout of this API is [batch_size, num_heads, seq_len, head_dim] or [num_heads, seq_len, head_dim].
- Parameters
-
query (Tensor) – The query tensor in the Attention module. 4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. 3-D tensor with shape: [num_heads, seq_len, head_dim]. The dtype can be float16 or bfloat16.
key (Tensor) – The key tensor in the Attention module. 4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. 3-D tensor with shape: [num_heads, seq_len, head_dim]. The dtype can be float16 or bfloat16.
value (Tensor) – The value tensor in the Attention module. 4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. 3-D tensor with shape: [num_heads, seq_len, head_dim]. The dtype can be float16 or bfloat16.
attn_mask (Tensor, optional) – The attention mask tensor. The shape should be broadcastable to [batch_size, num_heads, seq_len_key, seq_len_query]. The dtype can be bool or same type of query. The bool mask indicates the positions should take part in attention. The non-bool mask will be added to attention score.
is_causal (bool, optional) – Whether enable causal mode. If True, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias when the mask is a non-square matrix. An error is thrown if both attn_mask and is_causal are set.
scale (float, optional) – The scaling factor used in the calculation of attention weights. If None, scale = 1 / sqrt(head_dim).
enable_gqa (bool, optional) – Whether enable GQA mode. Default False.
- Returns
-
- The attention tensor.
-
4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. 3-D tensor with shape: [num_heads, seq_len, head_dim]. The dtype can be float16 or bfloat16.
- Return type
-
out(Tensor)
Examples
>>> >>> import paddle >>> q = paddle.rand((1, 2, 128, 16), dtype=paddle.bfloat16) >>> output = paddle.compat.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> print(output) >>>
