variable_length_memory_efficient_attention

paddle.incubate.nn.functional. variable_length_memory_efficient_attention ( query, key, value, seq_lens, kv_seq_lens, mask=None, scale=None, causal=False, pre_cache_length=0 ) [source]

Cutlass Memory Efficient Variable Attention. This method requires SM_ARCH in sm70, sm75, sm80.

Parameters
  • query (Tensor) – The Query Tensor. Its shape is [batchsize, seq_len, num_head, head_size].

  • key (Tensor) – The Key Tensor. Its shape is [batchsize, seq_len, num_head, head_size].

  • value (Tensor) – The Value Tensor. Its shape is [batchsize, seq_len, num_head, head_size].

  • seq_lens (Tensor) – The sequence lengths of the sequences in the batch, used to index query. Its shape is [batchsize, 1].

  • kv_seq_lens (Tensor) – The sequence lengths of the sequences in the batch, used to index key and value. Its shape is [batchsize, 1].

  • mask (Tensor) – The Mask Tensor. Its shape is [batchsize, 1, query_seq_len, key_seq_len].

  • scale (Float) – The attention matrix’s scale. Default is sqrt(1.0 / head_size).

  • causal (Bool) – Whether causal masking is used or not. Default is False.

  • pre_cache_length (Int) – The length of the pre-cache. Default is 0.

Returns

the output Tensor.

Return type

Tensor

Examples

>>> 
>>> import math
>>> import paddle
>>> from paddle.incubate.nn.functional import variable_length_memory_efficient_attention
>>> paddle.device.set_device('gpu')

>>> batch = 1
>>> num_head = 8
>>> seq_len = 256
>>> head_size = 32

>>> dtype = paddle.float16

>>> query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32')
>>> mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype)

>>> scale = float(1.0 / math.sqrt(head_size))
>>> pre_cache_length = 0

>>> def naive_attention_impl(query, key, value, mask, scale):
...     qk_res = paddle.matmul(query, key, transpose_y=True)
...     attention = qk_res * scale
...     attention = attention + mask
...     softmax_result = paddle.nn.functional.softmax(attention, -1)
...     result = paddle.matmul(softmax_result, value)
...     return result

>>> out = naive_attention_impl(query, key, value, mask, scale)
>>> # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale, pre_cache_length)

>>> print(out.shape) # [batch, seq_len, num_head, head_size]
[1, 8, 256, 32]