stack

paddle. stack ( x, axis=0, name=None ) [source]

Stacks all the input tensors x along axis dimemsion. All tensors must be of the same shape and same dtype.

For example, given N tensors of shape [A, B], if axis == 0, the shape of stacked tensor is [N, A, B]; if axis == 1, the shape of stacked tensor is [A, N, B], etc.

Case 1:

  Input:
    x[0].shape = [1, 2]
    x[0].data = [ [1.0 , 2.0 ] ]
    x[1].shape = [1, 2]
    x[1].data = [ [3.0 , 4.0 ] ]
    x[2].shape = [1, 2]
    x[2].data = [ [5.0 , 6.0 ] ]

  Attrs:
    axis = 0

  Output:
    Out.dims = [3, 1, 2]
    Out.data =[ [ [1.0, 2.0] ],
                [ [3.0, 4.0] ],
                [ [5.0, 6.0] ] ]


Case 2:

  Input:
    x[0].shape = [1, 2]
    x[0].data = [ [1.0 , 2.0 ] ]
    x[1].shape = [1, 2]
    x[1].data = [ [3.0 , 4.0 ] ]
    x[2].shape = [1, 2]
    x[2].data = [ [5.0 , 6.0 ] ]


  Attrs:
    axis = 1 or axis = -2  # If axis = -2, axis = axis+ndim(x[0])+1 = -2+2+1 = 1.

  Output:
    Out.shape = [1, 3, 2]
    Out.data =[ [ [1.0, 2.0]
                  [3.0, 4.0]
                  [5.0, 6.0] ] ]
Parameters
  • x (list[Tensor]|tuple[Tensor]) – Input x can be a list or tuple of tensors, the Tensors in x must be of the same shape and dtype. Supported data types: float32, float64, int32, int64.

  • axis (int, optional) – The axis along which all inputs are stacked. axis range is [-(R+1), R+1), where R is the number of dimensions of the first input tensor x[0]. If axis < 0, axis = axis+R+1. The default value of axis is 0.

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

The stacked tensor with same data type as input.

Return type

Tensor

Example

import paddle

x1 = paddle.to_tensor([[1.0, 2.0]])
x2 = paddle.to_tensor([[3.0, 4.0]])
x3 = paddle.to_tensor([[5.0, 6.0]])

out = paddle.stack([x1, x2, x3], axis=0)
print(out.shape)  # [3, 1, 2]
print(out)
# [[[1., 2.]],
#  [[3., 4.]],
#  [[5., 6.]]]

out = paddle.stack([x1, x2, x3], axis=-2)
print(out.shape)  # [1, 3, 2]
print(out)
# [[[1., 2.],
#   [3., 4.],
#   [5., 6.]]]