TensorDataset

class paddle.io. TensorDataset ( tensors ) [source]

Dataset defined by a list of tensors.

Each tensor should be in shape of [N, …], while N is the sample number, and ecah tensor contains a field of sample, TensorDataset retrieve each sample by indexing tensors in the 1st dimension.

Parameters

tensors (list of Tensor) – tensors with same shape in the 1st dimension.

Returns

a Dataset instance wrapping tensors.

Return type

Dataset

Examples

import numpy as np
import paddle
from paddle.io import TensorDataset

paddle.disable_static()

input_np = np.random.random([2, 3, 4]).astype('float32')
input = paddle.to_tensor(input_np)
label_np = np.random.random([2, 1]).astype('int32')
label = paddle.to_tensor(label_np)

dataset = TensorDataset([input, label])

for i in range(len(dataset)):
    input, label = dataset[i]
    print(input, label)