WMT14
该类是对 WMT14 测试数据集实现。 由于原始 WMT14 数据集太大,我们在这里提供了一组小数据集。该类将从 http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz 下载数据集。
参数
data_file (str)- 保存数据集压缩文件的路径,如果参数
download设置为 True,可设置为 None。默认为 None。mode (str)- 'train','test' 或'gen'。默认为'train'。
dict_size (int)- 词典大小。默认为-1。
download (bool)- 如果
data_file未设置,是否自动下载数据集。默认为 True。
返回值
Dataset,WMT14 数据集实例。
src_ids (np.array) - 源语言当前的 token id 序列。
trg_ids (np.array) - 目标语言当前的 token id 序列。
trg_ids_next (np.array) - 目标语言下一段的 token id 序列。
代码示例
>>> import paddle
>>> from paddle.text.datasets import WMT14
>>> class SimpleNet(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
...
... def forward(self, src_ids, trg_ids, trg_ids_next):
... return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
>>> wmt14 = WMT14(mode='train', dict_size=50)
>>> for i in range(10):
... src_ids, trg_ids, trg_ids_next = wmt14[i]
... src_ids = paddle.to_tensor(src_ids)
... trg_ids = paddle.to_tensor(trg_ids)
... trg_ids_next = paddle.to_tensor(trg_ids_next)
...
... model = SimpleNet()
... src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
... print(src_ids.item(), trg_ids.item(), trg_ids_next.item())
91 38 39
123 81 82
556 229 230
182 26 27
447 242 243
116 110 111
403 288 289
258 221 222
136 34 35
281 136 137