[ 参数不一致 ]torch.split

torch.split

torch.split(tensor,
            split_size_or_sections,
            dim=0)

paddle.split

paddle.split(x,
             num_or_sections,
             axis=0,
             name=None)

其中 Pytorch 的 split_size_or_sections 与 Paddle 的 num_or_sections 用法不一致,具体如下:

参数映射

| PyTorch | PaddlePaddle | 备注 | | ————- | ———— | —————————————————— | | tensor | x | 表示输入 Tensor ,仅参数名不一致。 | | split_size_or_sections| num_or_sections| 当类型为 int 时,torch 表示单个块大小,paddle 表示结果有多少个块,需要转写。 | | dim | axis | 表示需要分割的维度,仅参数名不一致。 |

转写示例

split_size_or_sections:单个块大小

split_size = 2
dim = 1
# Pytorch 写法
torch.split(a, split_size, dim)
# 在输入 dim 时,返回 (values, indices)

# Paddle 写法
paddle.split(a, a.shape[dims]/split_size, dim)