split

paddle.compat. split ( tensor, split_size_or_sections, dim=0 ) [源代码]

PyTorch 兼容的 split 版本,允许了非整除的 split_size_or_sections 输入

使用前请详细参考:【输入参数用法不一致】torch.split 以确定是否使用此模块。

备注

此 API 遵循 torch.split 的函数签名和行为以实现 PyTorch 兼容。 如需使用 Paddle 原生实现,请参考 split

参数

  • tensor (Tensor) - 输入 N 维 Tensor,支持 bool、bfloat16、float16、float32、float64、uint8、int8、int32 或 int64 数据类型

  • split_size_or_sections (int|list|tuple) - 若为整数,则将 Tensor 均匀分割为指定大小的块,与 split 不同,本 API 不要求此参数整除对应维度的通道数:非整除情况下输出元组的最后一个 tensor 对应维度将为余数大小,小于此值。若为列表/元组,则按指定尺寸分割,禁止使用负值(例如对 9 通道的维度,[2,3,-1] 会被拒绝)

  • dim (int|Tensor, 可选) - 分割维度,可为整数或形状为[]的 0-D Tensor(数据类型需为 int32int64)。若 dim < 0,则实际维度为 rank(x) + dim。默认值:0

返回

tuple(Tensor),分割后的 Tensor 元组

代码示例

>>> import paddle

>>> # x is a Tensor of shape [3, 8, 5]
>>> x = paddle.rand([3, 8, 5])

>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
>>> print(out0.shape)
[3, 3, 5]
>>> print(out1.shape)
[3, 3, 5]
>>> print(out2.shape)
[3, 2, 5]

>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=[1, 2, 5], dim=1)
>>> print(out0.shape)
[3, 1, 5]
>>> print(out1.shape)
[3, 2, 5]
>>> print(out2.shape)
[3, 5, 5]

>>> # dim is negative, the real dim is (rank(x) + dim)=1
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=-2)
>>> print(out0.shape)
[3, 3, 5]
>>> print(out1.shape)
[3, 3, 5]
>>> print(out2.shape)
[3, 2, 5]