tensor_split¶
- paddle. tensor_split ( x, num_or_indices, axis=0, name=None ) [source]
-
Split the input tensor into multiple sub-Tensors along
axis, allowing not being of equal size.- Parameters
-
x (Tensor) – A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple) – If
num_or_indicesis an intn,xis split intonsections alongaxis. Ifxis divisible byn, each section will bex.shape[axis] / n. Ifxis not divisible byn, the firstint(x.shape[axis] % n)sections will have sizeint(x.shape[axis] / n) + 1, and the rest will beint(x.shape[axis] / n). If ``num_or_indicesis a list or tuple of integer indices,xis split alongaxisat each of the indices. For instance,num_or_indices=[2, 4]withaxis=0would splitxintox[:2],x[2:4]andx[4:]along axis 0.axis (int|Tensor, optional) – The axis along which to split, it can be a integer or a
0-D Tensorwith shape [] and data typeint32orint64. If :math::axis < 0, the axis to split along is \(rank(x) + axis\). Default is 0.name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .
- Returns
-
list[Tensor], The list of segmented Tensors.
Examples
>>> import paddle >>> # x is a Tensor of shape [8] >>> # evenly split >>> x = paddle.rand([8]) >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2) >>> print(out0.shape) [4] >>> print(out1.shape) [4] >>> # not evenly split >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3) >>> print(out0.shape) [3] >>> print(out1.shape) [3] >>> print(out2.shape) [2] >>> # split with indices >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3]) >>> print(out0.shape) [2] >>> print(out1.shape) [1] >>> print(out2.shape) [5] >>> # split along axis >>> # x is a Tensor of shape [7, 8] >>> x = paddle.rand([7, 8]) >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1) >>> print(out0.shape) [7, 4] >>> print(out1.shape) [7, 4] >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1) >>> print(out0.shape) [7, 2] >>> print(out1.shape) [7, 1] >>> print(out2.shape) [7, 5]
