Imdb
该类是对 IMDB 测试数据集的实现。
参数
data_file (str) - 保存压缩数据的路径,如果参数
download设置为 True,可设置为 None。默认为 None。mode (str) - 'train' 或'test' 模式。默认为'train'。
cutoff (int) - 构建词典的截止大小。默认为 Default 150。
download (bool) - 如果
data_file未设置,是否自动下载数据集。默认为 True。
返回
Dataset, IMDB 数据集实例。
代码示例
>>> import paddle
>>> from paddle.text.datasets import Imdb
>>> class SimpleNet(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, doc, label):
... return paddle.sum(doc), label
>>> imdb = Imdb(mode='train')
>>> for i in range(10):
... doc, label = imdb[i]
... doc = paddle.to_tensor(doc)
... label = paddle.to_tensor(label)
...
... model = SimpleNet()
... image, label = model(doc, label)
... print(doc.shape, label.shape)
paddle.Size([121]) paddle.Size([1])
paddle.Size([115]) paddle.Size([1])
paddle.Size([386]) paddle.Size([1])
paddle.Size([471]) paddle.Size([1])
paddle.Size([585]) paddle.Size([1])
paddle.Size([206]) paddle.Size([1])
paddle.Size([221]) paddle.Size([1])
paddle.Size([324]) paddle.Size([1])
paddle.Size([166]) paddle.Size([1])
paddle.Size([598]) paddle.Size([1])