variable_length_memory_efficient_attention¶
- paddle.incubate.nn.functional. variable_length_memory_efficient_attention ( query, key, value, seq_lens, kv_seq_lens, mask=None, scale=None, causal=False, pre_cache_length=0 ) [source]
-
Cutlass Memory Efficient Variable Attention. This method requires SM_ARCH in sm70, sm75, sm80.
- Parameters
-
query (Tensor) – The Query Tensor. Its shape is [batchsize, num_head, seq_len, head_size].
key (Tensor) – The Key Tensor. Its shape is [batchsize, num_head, seq_len, head_size].
value (Tensor) – The Value Tensor. Its shape is [batchsize, num_head, seq_len, head_size].
seq_lens (Tensor) – The sequence lengths of the sequences in the batch, used to index query. Its shape is [batchsize, 1].
kv_seq_lens (Tensor) – The sequence lengths of the sequences in the batch, used to index key and value. Its shape is [batchsize, 1].
mask (Tensor) – The Mask Tensor. Its shape is [batchsize, 1, query_seq_len, key_seq_len].
scale (Float) – The attention matrix’s scale. Default is sqrt(1.0 / head_size).
causal (Bool) – Whether causal masking is used or not. Default is False.
pre_cache_length (Int) – The length of the pre-cache. Default is 0.
- Returns
-
the output Tensor.
- Return type
-
Tensor
Examples
>>> >>> import math >>> import paddle >>> from paddle.incubate.nn.functional import variable_length_memory_efficient_attention >>> paddle.device.set_device('gpu') >>> batch = 1 >>> num_head = 8 >>> seq_len = 256 >>> head_size = 32 >>> dtype = paddle.float16 >>> query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) >>> key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) >>> value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) >>> seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32') >>> mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype) >>> scale = float(1.0 / math.sqrt(head_size)) >>> pre_cache_length = 0 >>> def naive_attention_impl(query, key, value, mask, scale): ... qk_res = paddle.matmul(query, key, transpose_y=True) ... attention = qk_res * scale ... attention = attention + mask ... softmax_result = paddle.nn.functional.softmax(attention, -1) ... result = paddle.matmul(softmax_result, value) ... return result >>> out = naive_attention_impl(query, key, value, mask, scale) >>> # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale, pre_cache_length) >>> out.shape # [batch, num_head, seq_len, head_size] [1, 8, 256, 32]