[ 仅参数名不一致 ]torch.Tensor.take_along_dim

torch.Tensor.take_along_dim

torch.Tensor.take_along_dim(indices,
                    dim)

paddle.Tensor.take_along_axis

paddle.Tensor.take_along_axis(indices,
                    axis,
                    broadcast=True)

两者功能一致,参数名不一致,具体如下:

参数映射

| PyTorch | PaddlePaddle | 备注 | | ————- | ———— | —————————————————— | | indices | indices | 索引矩阵,包含沿轴提取 1d 切片的下标,必须和 arr 矩阵有相同的维度,需要能够 broadcast 与 arr 矩阵对齐,数据类型为:int、int64。 | | dim | axis | 指定沿着哪个维度获取对应的值,数据类型为:int,仅参数名不一致。 |