[ 仅 paddle 参数更多 ]torch.utils.data.BatchSampler

torch.utils.data.BatchSampler

torch.utils.data.BatchSampler(sampler,
                              batch_size,
                              drop_last)

paddle.io.BatchSampler

paddle.io.BatchSampler(dataset=None,
                       sampler=None,
                       shuffle=Fasle,
                       batch_size=1,
                       drop_last=False)

其中 Paddle 相比 Pytorch 支持更多其他参数,具体如下:

参数映射

| PyTorch | PaddlePaddle | 备注 | | ——— | ————-| ———- | | sampler | sampler | 底层取样器,PyTorch 可为 Sampler 或 Iterable 数据类型,Paddle 可为 Sampler 数据类型。 | | - | dataset | 此参数必须是 paddle.io.Dataset 或 paddle.io.IterableDataset 的一个子类实例或实现了 len 的 Python 对象,用于生成样本下标,PyTorch 无此参数,Paddle 保持默认即可。 | | - | shuffle | 是否需要在生成样本下标时打乱顺序,PyTorch 无此参数,Paddle 保持默认即可。 |

转写示例

sampler(Iterable):底层取样器

# 若 sampler 为 Iterable 数据类型,则需要按如下方式转写
# Pytorch 写法
torch.utils.data.BatchSampler(sampler=[1., 2., 3., 4.], 3, False)

# Paddle 写法
sampler = paddle.io.Sampler([1., 2., 3., 4.])
paddle.io.BatchSampler(sampler, 3, False)