[ 仅 API 调用方式不一致 ]torch.utils.data.Sampler
torch.utils.data.Sampler
torch.utils.data.Sampler(data_source)
转写示例
# PyTorch 写法
class MySampler(torch.utils.data.Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# Paddle 写法
class MySampler(paddle.io.Sampler):
def __init__(self, data_source):
super().__init__(data_source)
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)