fused_rotary_position_embedding

paddle.incubate.nn.functional. fused_rotary_position_embedding ( q, k=None, v=None, sin=None, cos=None, position_ids=None, use_neox_rotary_style=True ) [source]

Fused rotary position embedding.

Parameters
  • q (Tensor) – The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.

  • k (Tensor, optional) – The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.

  • v (Tensor, optional) – The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.

  • sin (Tensor, optional) – The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of sin must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.

  • cos (Tensor, optional) – The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of cos must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.

  • position_ids (Tensor, optional) – The input tensor. The data type is int64. The shape of position_ids must be [batch_size, seq_len].

  • use_neox_rotary_style (optional|bool) – When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True.

Returns

out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as q .

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.nn.functional import fused_rotary_position_embedding

>>> paddle.set_device('gpu')

>>> # batch_size = 2
>>> # seq_len = 2
>>> # num_heads = 2
>>> # head_dim = 2

>>> paddle.seed(1204)

>>> # q, k, v: [batch_size, seq_len, num_heads, head_dim]
>>> q = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> k = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> v = paddle.randn([2, 2, 2, 2], dtype='float16')

>>> # sin, cos: [1, seq_len, 1, head_dim]
>>> x = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> y = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> sin = paddle.sin(x)
>>> cos = paddle.cos(y)

>>> # position_ids: [batch_size, seq_len]
>>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64')

>>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
>>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
>>> print(out_q)
Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True,
[[[[-0.54931641,  0.64990234],
   [-1.08691406,  1.18261719]],
  [[ 0.57812500,  0.11749268],
   [-0.63281250,  0.15551758]]],
 [[[-0.77050781,  0.07733154],
   [-0.73730469, -0.16735840]],
  [[ 0.07116699, -0.90966797],
   [-0.03628540, -0.20202637]]]])