WMT14

class paddle.text. WMT14 ( data_file=None, mode='train', dict_size=- 1, download=True ) [source]

Implementation of WMT14 test dataset. The original WMT14 dataset is too large and a small set of data for set is provided. This module will download dataset from http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz .

Parameters
  • data_file (str) – path to data tar file, can be set None if download is True. Default None

  • mode (str) – ‘train’, ‘test’ or ‘gen’. Default ‘train’

  • dict_size (int) – word dictionary size. Default -1.

  • download (bool) – whether to download dataset automatically if data_file is not set. Default True

Returns

Instance of WMT14 dataset
  • src_ids (np.array) - The sequence of token ids of source language.

  • trg_ids (np.array) - The sequence of token ids of target language.

  • trg_ids_next (np.array) - The next sequence of token ids of target language.

Return type

Dataset

Examples

>>> import paddle
>>> from paddle.text.datasets import WMT14

>>> class SimpleNet(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...
...     def forward(self, src_ids, trg_ids, trg_ids_next):
...         return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)

>>> wmt14 = WMT14(mode='train', dict_size=50)

>>> for i in range(10):
...     src_ids, trg_ids, trg_ids_next = wmt14[i]
...     src_ids = paddle.to_tensor(src_ids)
...     trg_ids = paddle.to_tensor(trg_ids)
...     trg_ids_next = paddle.to_tensor(trg_ids_next)
...
...     model = SimpleNet()
...     src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
...     print(src_ids.item(), trg_ids.item(), trg_ids_next.item())
91 38 39
123 81 82
556 229 230
182 26 27
447 242 243
116 110 111
403 288 289
258 221 222
136 34 35
281 136 137
get_dict ( reverse=False )

get_dict

Get the source and target dictionary.

Parameters

reverse (bool) – wether to reverse key and value in dictionary, i.e. key: value to value: key.

Returns

Two dictionaries, the source and target dictionary.

Examples

>>> from paddle.text.datasets import WMT14
>>> wmt14 = WMT14(mode='train', dict_size=50)
>>> src_dict, trg_dict = wmt14.get_dict()