tensor_split
将输入 Tensor 沿着轴 axis 分割成多个子 Tensor,允许进行不等长地分割。
如下图,Tenser x 的 shape 为[6],经过 paddle.tensor_split(x, num_or_indices=4) 变换后,得到 out0,out1,out2,out3 四个子 Tensor :
其中,由于 x 在 axis = 0 方向上的长度 6 不能被 num_or_indices = 4 整除,故分割后前 int(6 % 4) 个部分的大小将是 int(6 / 4) + 1 ,其余部分的大小将是 int(6 / 4) 。
参数
x (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 0。 别名:
inputnum_or_indices (int|list|tuple) - 如果
num_or_indices是一个整数n,则x沿axis拆分为n部分。如果x可被n整除,则每个部分都是x.shape[axis]/n。如果x不能被n整除,则前int(x.shape[axis]%n)个部分的大小将是int(x.shape[axis]/n)+1,其余部分的大小将是int(x.shape[axis]/n)。如果num_or_indices是整数索引的列表或元组,则在每个索引处沿axis分割x。例如,num_or_indices=[2, 4]在axis=0时将沿轴 0 将x拆分为x[:2]、x[2:4]和x[4:]。 别名:indices(类型为 list 或 tuple 时) 或sections(类型为 int 时)axis (int|Tensor,可选) - 整数或者形状为[]的 0-D Tensor,数据类型为 int32 或 int64。表示需要分割的维度。如果
axis < 0,则划分的维度为rank(x) + axis。默认值为 0。 别名:dimname (str,可选) - 具体用法请参见 api_guide_Name,一般无需设置,默认值为 None。
返回
list[Tensor],分割后的 Tensor 列表。
代码示例 1
>>> import paddle
>>> # evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
代码示例 2
>>> import paddle
>>> # not evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]
代码示例 3
>>> import paddle
>>> # split with indices
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]
代码示例 4
>>> import paddle
>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]