MultiHeadAttention

class paddle.nn. MultiHeadAttention ( embed_dim, num_heads, dropout=0.0, kdim=None, vdim=None, need_weights=False, weight_attr=None, bias_attr=None ) [source]

Attention mapps queries and a set of key-value pairs to outputs, and Multi-Head Attention performs multiple parallel attention to jointly attending to information from different representation subspaces.

Please refer to Attention Is All You Need for more details.

Parameters
  • embed_dim (int) – The expected feature size in the input and output.

  • num_heads (int) – The number of heads in multi-head attention.

  • dropout (float, optional) – The dropout probability used on attention weights to drop some attention targets. 0 for no dropout. Default 0

  • kdim (int, optional) – The feature size in key. If None, assumed equal to embed_dim. Default None.

  • vdim (int, optional) – The feature size in value. If None, assumed equal to embed_dim. Default None.

  • need_weights (bool, optional) – Indicate whether to return the attention weights. Default False.

  • weight_attr (ParamAttr, optional) – To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in ParamAttr .

  • bias_attr (ParamAttr|bool, optional) – To specify the bias parameter property. Default: None, which means the default bias parameter property is used. If it is set to False, this layer will not have trainable bias parameter. See usage for details in ParamAttr .

Examples

>>> import paddle

>>> # encoder input: [batch_size, sequence_length, d_model]
>>> query = paddle.rand((2, 4, 128))
>>> # self attention mask: [batch_size, num_heads, query_len, query_len]
>>> attn_mask = paddle.rand((2, 2, 4, 4))
>>> multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
>>> output = multi_head_attn(query, None, None, attn_mask=attn_mask)
>>> print(output.shape)
[2, 4, 128]
class Cache ( k, v )
k

Alias for field number 0

v

Alias for field number 1

class StaticCache ( k, v )
k

Alias for field number 0

v

Alias for field number 1

compute_kv ( key, value )

compute_kv

Applies linear projection on input keys and values, then splits heads (reshape and transpose) to get keys and values from different representation subspaces. The results are used as key-values pairs for subsequent multiple parallel attention.

It is part of calculations in multi-head attention, and is provided as a method to pre-compute and prefetch these results, thus we can use them to construct cache for inference.

Parameters
  • key (Tensor) – The keys for multi-head attention. It is a tensor with shape [batch_size, sequence_length, kdim]. The data type should be float32 or float64.

  • value (Tensor) – The values for multi-head attention. It is a tensor with shape [batch_size, sequence_length, vdim]. The data type should be float32 or float64.

Returns

Tuple. A tuple including transformed keys and values. Their shapes both are [batch_size, num_heads, sequence_length, embed_dim // num_heads], and their data types are same as inputs.

gen_cache ( key, value=None, type=<class 'paddle.nn.layer.transformer.Cache'> )

gen_cache

Generates cache for forward usage in inference accroding to arguments. The generated cache is an instance of MultiHeadAttention.Cache or an instance of MultiHeadAttention.StaticCache.

Cache or StaticCache is namedtuple with k and v as fields, and it stores tensors shaped [batch_size, num_heads, length, embed_dim] which are results of linear projection, reshape and transpose calculations in MultiHeadAttention.

If the generated cache is an instance of Cache, k and v fields reserve intermediate result tensors of previous positions, and the tensors are incremental among decoding steps, which mostly are used for decoder decoder self attention.

If the generated cache is an instance of StaticCache, k and v fields would be used as calculated result tensors on keys an values in forward, and the tensors keep unchanged among decoding steps, which are mostly used for decoder-encoder cross attention.

The cache is generated as follows:

1. If type is StaticCache, apply compute_kv(key, value) and use the results to create an instance of StaticCache.

2. If type is Cache and value is None, generate empty tensors shaped [batch_size, num_heads, 0, embed_dim // num_heads] and use the results to create an instance of Cache, where batch_size is from the first dimension of key.

3. If type is Cache and value is not None, use key, value to create an instance of Cache.

Parameters
  • key (Tensor) – The keys for multi-head attention. It is a tensor with shape [batch_size, key_length, kdim]. The data type should be float32 or float64. If value is None, it is only for batch size and data type reference.

  • value (Tensor, optional) – The values for multi-head attention. It is a tensor with shape [batch_size, value_length, vdim]. The data type should be float32 or float64. If None, key is only for batch size reference. Default None.

  • type (type) – It should be MultiHeadAttention.StaticCache or MultiHeadAttention.Cache to indicate the cache type to generate.

Returns

an instance of Cache or StaticCache accordingly.

Return type

namedtuple

forward ( query, key=None, value=None, attn_mask=None, cache=None )

forward

Applies multi-head attention to map queries and a set of key-value pairs to outputs.

Parameters
  • query (Tensor) – The queries for multi-head attention. It is a tensor with shape [batch_size, query_length, embed_dim]. The data type should be float32 or float64.

  • key (Tensor, optional) – The keys for multi-head attention. It is a tensor with shape [batch_size, key_length, kdim]. The data type should be float32 or float64. If None, use query as key. Default None.

  • value (Tensor, optional) – The values for multi-head attention. It is a tensor with shape [batch_size, value_length, vdim]. The data type should be float32 or float64. If None, use query as value. Default None.

  • attn_mask (Tensor, optional) – A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length]. When the data type is bool, the unwanted positions have False values and the others have True values. When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is float, the unwanted positions have -INF values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None.

  • cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional) – It is a namedtuple with k and v as fields, and stores tensors shaped [batch_size, num_heads, length, embed_dim] which are results of linear projection, reshape and transpose calculations in MultiHeadAttention. If it is an instance of Cache, k and v fields reserve intermediate results of previous positions, which mostly used for decoder self attention. If it is an instance of StaticCache, key and value args would be ignored, k and v fields would be used as calculated results on key and value, which mostly used for decoder-encoder cross attention. It is only used for inference and should be None for training. Default None.

Returns

Tensor|tuple. It is a tensor that has the same shape and data type as query, representing attention output. Or a tuple if need_weights is True or cache is not None. If need_weights is True, except for attention output, the tuple also includes the attention weights tensor shaped [batch_size, num_heads, query_length, key_length]. If cache is not None, the tuple then includes the new cache having the same type as cache, and if it is StaticCache, it is same as the input cache, if it is Cache, the new cache reserves tensors concatanating raw tensors with intermediate results of current query.