random_split

class paddle.io. random_split ( dataset, lengths, generator=None ) [源代码]

给定子集合 dataset 的长度数组,随机切分出原数据集合的非重复子集合。

参数

  • dataset (Dataset) - 此参数必须是 paddle.io.Datasetpaddle.io.IterableDataset 的一个子类实例或实现了 __len__ 的 Python 对象,用于生成样本下标。默认值为 None。

  • lengths (list) - 总和为原数组长度,表示子集合长度数组;或总和为 1.0,表示子集合长度占比的数组。

  • generator (Generator,可选) - 指定采样 data_source 的采样器。默认值为 None。

返回

list,返回按给定长度数组描述随机分割的原数据集合的非重复子集合。

代码示例

>>> import paddle

>>> paddle.seed(2023)
>>> a_list = paddle.io.random_split(range(10), [3, 7])
>>> print(len(a_list))
2

>>> # output of the first subset
>>> for idx, v in enumerate(a_list[0]):
...     print(idx, v) 
0 7
1 6
2 5

>>> # output of the second subset
>>> for idx, v in enumerate(a_list[1]):
...     print(idx, v) 
0 1
1 9
2 4
3 2
4 0
5 3
6 8