[ 仅 paddle 参数更多 ]torch.utils.data.BatchSampler

torch.utils.data.BatchSampler

torch.utils.data.BatchSampler(sampler,
                              batch_size,
                              drop_last)

paddle.io.BatchSampler

paddle.io.BatchSampler(dataset=None,
                       sampler=None,
                       shuffle=Fasle,
                       batch_size=1,
                       drop_last=False)

其中 Paddle 相比 PyTorch 支持更多其他参数,具体如下:

参数映射

PyTorch PaddlePaddle 备注
sampler sampler 底层取样器,PyTorch 可为 Sampler 或 Iterable 数据类型,Paddle 可为 Sampler 数据类型。
- dataset 此参数必须是 paddle.io.Dataset 或 paddle.io.IterableDataset 的一个子类实例或实现了 len 的 Python 对象,用于生成样本下标,PyTorch 无此参数,Paddle 保持默认即可。
- shuffle 是否需要在生成样本下标时打乱顺序,PyTorch 无此参数,Paddle 保持默认即可。

转写示例

sampler(Iterable):底层取样器

# 若 sampler 为 Iterable 数据类型,则需要按如下方式转写
# PyTorch 写法
torch.utils.data.BatchSampler(sampler=[1., 2., 3., 4.], 3, False)

# Paddle 写法
sampler = paddle.io.Sampler([1., 2., 3., 4.])
paddle.io.BatchSampler(sampler, 3, False)