[torch 参数更多 ]torch.gather¶
torch.gather¶
torch.gather(input,
dim,
index,
*,
sparse_grad=False,
out=None)
paddle.take_along_axis¶
paddle.take_along_axis(arr,
indices,
axis,
broadcast=True)
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
参数映射¶
PyTorch | PaddlePaddle | 备注 |
---|---|---|
input | x | 表示输入 Tensor ,仅参数名不一致。 |
dim | axis | 用于指定 index 获取输入的维度,仅参数名不一致。 |
index | indices | 聚合元素的索引矩阵,维度和输入 (input) 的维度一致,仅参数名不一致。 |
sparse_grad | - | 表示是否对梯度稀疏化,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |
out | - | 表示目标 Tensor , Paddle 无此参数,需要转写。 |