TensorDataset

class paddle.io. TensorDataset [源代码]

由 Tensor 列表定义的数据集。

每个 Tensor 的形状应为[N,...],而 N 是样本数,每个 Tensor 表示样本中一个字段,TensorDataset 中通过在第一维索引 Tensor 来获取每个样本。

参数

  • tensors (list of Tensors) - Tensor 列表,这些 Tensor 的第一维形状相同

返回

Dataset,由 Tensor 列表定义的数据集

代码示例

>>> import numpy as np
>>> import paddle
>>> from paddle.io import TensorDataset


>>> input_np = np.random.random([2, 3, 4]).astype('float32')
>>> input = paddle.to_tensor(input_np)
>>> label_np = np.random.random([2, 1]).astype('int32')
>>> label = paddle.to_tensor(label_np)

>>> dataset = TensorDataset([input, label])

>>> for i in range(len(dataset)):
...     input, label = dataset[i]
...     # do something

使用本API的教程文档