Model¶
-
class
paddle.
Model
( network, inputs=None, labels=None ) [source] -
An Model object is network with training and inference features. Dynamic graph and static graph are supported at the same time, switched by paddle.enable_static(). The usage is as follows. But note, the switching between dynamic and static should be before instantiating a Model. The input description, i.e, paddle.static.InputSpec, must be required for static graph.
- Parameters
-
network (paddle.nn.Layer) – The network is an instance of paddle.nn.Layer.
inputs (InputSpec|list|dict|None) – inputs, entry points of network, could be a InputSpec instance, or lits of InputSpec instances, or dict ({name: InputSpec}), and it couldn’t be None in static graph.
labels (InputSpec|list|None) – labels, entry points of network, could be a InputSpec instnace or lits of InputSpec instances, or None. For static graph, if labels is required in loss, labels must be set. Otherwise, it could be None.
Examples
import paddle import paddle.nn as nn import paddle.vision.transforms as T from paddle.static import InputSpec device = paddle.set_device('cpu') # or 'gpu' net = nn.Sequential( nn.Flatten(1), nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10)) # inputs and labels are not required for dynamic graph. input = InputSpec([None, 784], 'float32', 'x') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(net, input, label) optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) model.prepare(optim, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy()) transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) data = paddle.vision.datasets.MNIST(mode='train', transform=transform) model.fit(data, epochs=2, batch_size=32, verbose=1)
-
train_batch
( inputs, labels=None ) -
Run one training step on a batch of data.
- Parameters
-
inputs (numpy.ndarray|Tensor|list) – Batch of input data. It could be a numpy array or paddle.Tensor, or a list of arrays or tensors (in case the model has multiple inputs).
labels (numpy.ndarray|Tensor|list) – Batch of labels. It could be a numpy array or paddle.Tensor, or a list of arrays or tensors (in case the model has multiple labels). If has no labels, set None. Default is None.
- Returns
-
A list of scalar training loss if the model has no metrics, or a tuple (list of scalar loss, list of metrics) if the model set metrics.
Examples
import numpy as np import paddle import paddle.nn as nn from paddle.static import InputSpec device = paddle.set_device('cpu') # or 'gpu' net = nn.Sequential( nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10)) input = InputSpec([None, 784], 'float32', 'x') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(net, input, label) optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) model.prepare(optim, paddle.nn.CrossEntropyLoss()) data = np.random.random(size=(4,784)).astype(np.float32) label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) loss = model.train_batch([data], [label]) print(loss)
-
eval_batch
( inputs, labels=None ) -
Run one evaluating step on a batch of data.
- Parameters
-
inputs (numpy.ndarray|Tensor|list) – Batch of input data. It could be a numpy array or paddle.Tensor, or a list of arrays or tensors (in case the model has multiple inputs).
labels (numpy.ndarray|Tensor|list) – Batch of labels. It could be a numpy array or paddle.Tensor, or a list of arrays or tensors (in case the model has multiple labels). If has no labels, set None. Default is None.
- Returns
-
A list of scalar testing loss if the model has no metrics, or a tuple (list of scalar loss, list of metrics) if the model set metrics.
Examples
import numpy as np import paddle import paddle.nn as nn from paddle.static import InputSpec device = paddle.set_device('cpu') # or 'gpu' net = nn.Sequential( nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10)) input = InputSpec([None, 784], 'float32', 'x') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(net, input, label) optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) model.prepare(optim, paddle.nn.CrossEntropyLoss()) data = np.random.random(size=(4,784)).astype(np.float32) label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) loss = model.eval_batch([data], [label]) print(loss)
-
predict_batch
( inputs ) -
Run one predicting step on a batch of data.
- Parameters
-
inputs (numpy.ndarray|Tensor|list) – Batch of input data. It could be a numpy array or paddle.Tensor, or a list of arrays or tensors (in case the model has multiple inputs).
- Returns
-
A list of numpy.ndarray of predictions, that is the outputs of Model forward.
Examples
import numpy as np import paddle import paddle.nn as nn from paddle.static import InputSpec device = paddle.set_device('cpu') # or 'gpu' input = InputSpec([None, 784], 'float32', 'x') label = InputSpec([None, 1], 'int64', 'label') net = nn.Sequential( nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10), nn.Softmax()) model = paddle.Model(net, input, label) model.prepare() data = np.random.random(size=(4,784)).astype(np.float32) out = model.predict_batch([data]) print(out)
-
save
( path, training=True ) -
This function saves parameters, optimizer information or model and paramters only for inference to path. It depends on the parameter training.
If training is set to True, the parameters saved contain all the trainable Variable, will save to a file with suffix “.pdparams”. The optimizer information contains all the variable used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. All the information will save to a file with suffix “.pdopt”. (If the optimizer have no variable need to save (like SGD), the fill will not generated). This function will silently overwrite existing file at the target location.
If training is set to False, only inference model will be saved.
- Parameters
-
path (str) – The file prefix to save model. The format is ‘dirname/file_prefix’ or ‘file_prefix’. if empty str. A exception will be raised.
training (bool, optional) – Whether to save for training. If not, save for inference only. Default: True.
- Returns
-
None
Examples
import paddle import paddle.nn as nn import paddle.vision.transforms as T from paddle.static import InputSpec class Mnist(nn.Layer): def __init__(self): super(Mnist, self).__init__() self.net = nn.Sequential( nn.Flatten(1), nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10), nn.Softmax()) def forward(self, x): return self.net(x) dynamic = True # False # if use static graph, do not set if not dynamic: paddle.enable_static() input = InputSpec([None, 784], 'float32', 'x') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(Mnist(), input, label) optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) model.prepare(optim, paddle.nn.CrossEntropyLoss()) transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) data = paddle.vision.datasets.MNIST(mode='train', transform=transform) model.fit(data, epochs=1, batch_size=32, verbose=0) model.save('checkpoint/test') # save for training model.save('inference_model', False) # save for inference
-
load
( path, skip_mismatch=False, reset_optimizer=False ) -
Load from files storing the model states and optimizer states. The file for optimizer states is not necessary if no need to restore the optimizer.
NOTE: parameters are retrieved out from the file storing model states accoring to their structured names.
For fine-tuning or transfer-learning models where some of the layers have changed, keep parameters needed to restore have same structured names in the pre-trained model and fine-tuning model.
- Parameters
-
path (str) – The prefix of files storing the model states and optimizer states. The files would be path.pdparams and path.pdopt separately, and the latter is not necessary when no need to restore.
skip_mismatch (bool) – Whether to skip the loading of mismatch parameter or raise an error when mismatch happens (not found the parameter in file storing model states of or receives a mismatch shape).
reset_optimizer (bool) – If True, ignore the providing file storing optimizer states and initialize optimizer states from scratch. Otherwise, restore optimizer states from path.pdopt if a optimizer has been set to the model. Default False.
- Returns
-
None
Examples
import paddle import paddle.nn as nn from paddle.static import InputSpec device = paddle.set_device('cpu') input = InputSpec([None, 784], 'float32', 'x') model = paddle.Model(nn.Sequential( nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10), nn.Softmax()), input) model.save('checkpoint/test') model.load('checkpoint/test')
-
parameters
( *args, **kwargs ) -
Returns a list of parameters of the model.
- Returns
-
A list of Parameter in static graph. A list of ParamBase in dynamic graph.
Examples
import paddle import paddle.nn as nn from paddle.static import InputSpec input = InputSpec([None, 784], 'float32', 'x') model = paddle.Model(nn.Sequential( nn.Linear(784, 200), nn.Tanh(), nn.Linear(200, 10)), input) params = model.parameters()
-
prepare
( optimizer=None, loss=None, metrics=None ) -
Configures the model before runing.
- Parameters
-
optimizer (Optimizer|None) – Optimizer must be set in training and should be a Optimizer instance. It can be None in eval and test mode.
loss (Loss|callable function|None) – Loss function can be a paddle.nn.Layer instance or any callable function taken the predicted values and ground truth values as input. It can be None when there is no loss.
metrics (Metric|list of Metric|None) – If metrics is set, all metrics will be calculated and output in train/eval mode.
- Returns
-
None
-
fit
( train_data=None, eval_data=None, batch_size=1, epochs=1, eval_freq=1, log_freq=10, save_dir=None, save_freq=1, verbose=2, drop_last=False, shuffle=True, num_workers=0, callbacks=None ) -
Trains the model for a fixed number of epochs. If eval_data is set, evaluation will be done at the end of each epoch.
- Parameters
-
train_data (Dataset|DataLoader) – An iterable data loader is used for train. An instance of paddle paddle.io.Dataset or paddle.io.Dataloader is recomended. Default: None.
eval_data (Dataset|DataLoader) – An iterable data loader is used for evaluation at the end of epoch. If None, will not do evaluation. An instance of paddle.io.Dataset or paddle.io.Dataloader is recomended. Default: None.
batch_size (int) – Integer number. The batch size of train_data and eval_data. When train_data and eval_data are both the instance of Dataloader, this parameter will be ignored. Default: 1.
epochs (int) – Integer number. The number of epochs to train the model. Default: 1.
eval_freq (int) – The frequency, in number of epochs, an evalutation is performed. Default: 1.
log_freq (int) – The frequency, in number of steps, the training logs are printed. Default: 10.
save_dir (str|None) – The directory to save checkpoint during training. If None, will not save checkpoint. Default: None.
save_freq (int) – The frequency, in number of epochs, to save checkpoint. Default: 1.
verbose (int) – The verbosity mode, should be 0, 1, or 2. 0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
drop_last (bool) – Whether drop the last incomplete batch of train_data when dataset size is not divisible by the batch size. When train_data is an instance of Dataloader, this parameter will be ignored. Default: False.
shuffle (bool) – Whther to shuffle train_data. When train_data is an instance of Dataloader, this parameter will be ignored. Default: True.
num_workers (int) – The number of subprocess to load data, 0 for no subprocess used and loading data in main process. When train_data and eval_data are both the instance of Dataloader, this parameter will be ignored. Default: 0.
callbacks (Callback|None) – A list of Callback instances to apply during training. If None, ProgBarLogger and ModelCheckpoint are automatically inserted. Default: None.
- Returns
-
None
Examples
An example use Dataset and set btch size, shuffle in fit. How to make a batch is done internally.
import paddle import paddle.vision.transforms as T from paddle.vision.datasets import MNIST from paddle.static import InputSpec dynamic = True if not dynamic: paddle.enable_static() transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) train_dataset = MNIST(mode='train', transform=transform) val_dataset = MNIST(mode='test', transform=transform) input = InputSpec([None, 1, 28, 28], 'float32', 'image') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model( paddle.vision.models.LeNet(), input, label) optim = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) model.prepare( optim, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 2))) model.fit(train_dataset, val_dataset, epochs=2, batch_size=64, save_dir='mnist_checkpoint')
An example use DataLoader, batch size and shuffle is set in DataLoader.
import paddle import paddle.vision.transforms as T from paddle.vision.datasets import MNIST from paddle.static import InputSpec dynamic = True if not dynamic: paddle.enable_static() transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) train_dataset = MNIST(mode='train', transform=transform) train_loader = paddle.io.DataLoader(train_dataset, batch_size=64) val_dataset = MNIST(mode='test', transform=transform) val_loader = paddle.io.DataLoader(val_dataset, batch_size=64) input = InputSpec([None, 1, 28, 28], 'float32', 'image') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model( paddle.vision.models.LeNet(), input, label) optim = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) model.prepare( optim, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 2))) model.fit(train_loader, val_loader, epochs=2, save_dir='mnist_checkpoint')
-
evaluate
( eval_data, batch_size=1, log_freq=10, verbose=2, num_workers=0, callbacks=None ) -
Evaluate the loss and metrics of the model on input dataset.
- Parameters
-
eval_data (Dataset|DataLoader) – An iterable data loader is used for evaluation. An instance of paddle.io.Dataset or paddle.io.Dataloader is recomended.
batch_size (int) – Integer number. The batch size of train_data and eval_data. When eval_data is the instance of Dataloader, this argument will be ignored. Default: 1.
log_freq (int) – The frequency, in number of steps, the eval logs are printed. Default: 10.
verbose (int) – The verbosity mode, should be 0, 1, or 2. 0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
num_workers (int) – The number of subprocess to load data, 0 for no subprocess used and loading data in main process. When train_data and eval_data are both the instance of Dataloader, this parameter will be ignored. Default: 0.
callbacks (Callback|None) – A list of Callback instances to apply during training. If None, ProgBarLogger and ModelCheckpoint are automatically inserted. Default: None.
- Returns
-
- Result of metric. The key is the names of Metric,
-
value is a scalar or numpy.array.
- Return type
-
dict
Examples
import paddle import paddle.vision.transforms as T from paddle.static import InputSpec # declarative mode transform = T.Compose([ T.Transpose(), T.Normalize([127.5], [127.5]) ]) val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform) input = InputSpec([-1, 1, 28, 28], 'float32', 'image') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(paddle.vision.models.LeNet(), input, label) model.prepare(metrics=paddle.metric.Accuracy()) result = model.evaluate(val_dataset, batch_size=64) print(result)
-
predict
( test_data, batch_size=1, num_workers=0, stack_outputs=False, callbacks=None ) -
Compute the output predictions on testing data.
- Parameters
-
test_data (Dataset|DataLoader) – An iterable data loader is used for predict. An instance of paddle.io.Dataset or paddle.io.Dataloader is recomended.
batch_size (int) – Integer number. The batch size of train_data and eval_data. When train_data and eval_data are both the instance of Dataloader, this argument will be ignored. Default: 1.
num_workers (int) – The number of subprocess to load data, 0 for no subprocess used and loading data in main process. When train_data and eval_data are both the instance of Dataloader, this argument will be ignored. Default: 0.
stack_outputs (bool) – Whether stack output field like a batch, as for an output filed of a sample is in shape [X, Y], test_data contains N samples, predict output field will be in shape [N, X, Y] if stack_output is True, and will be a length N list in shape [[X, Y], [X, Y], ….[X, Y]] if stack_outputs is False. stack_outputs as False is used for LoDTensor output situation, it is recommended set as True if outputs contains no LoDTensor. Default: False.
callbacks (Callback) – A Callback instance, default None.
- Returns
-
output of models.
- Return type
-
list
Examples
import numpy as np import paddle from paddle.static import InputSpec class MnistDataset(paddle.vision.datasets.MNIST): def __init__(self, mode, return_label=True): super(MnistDataset, self).__init__(mode=mode) self.return_label = return_label def __getitem__(self, idx): img = np.reshape(self.images[idx], [1, 28, 28]) if self.return_label: return img, np.array(self.labels[idx]).astype('int64') return img, def __len__(self): return len(self.images) test_dataset = MnistDataset(mode='test', return_label=False) # imperative mode input = InputSpec([-1, 1, 28, 28], 'float32', 'image') model = paddle.Model(paddle.vision.models.LeNet(), input) model.prepare() result = model.predict(test_dataset, batch_size=64) print(len(result[0]), result[0][0].shape) # declarative mode device = paddle.set_device('cpu') paddle.enable_static() input = InputSpec([-1, 1, 28, 28], 'float32', 'image') model = paddle.Model(paddle.vision.models.LeNet(), input) model.prepare() result = model.predict(test_dataset, batch_size=64) print(len(result[0]), result[0][0].shape)
-
summary
( input_size=None, dtype=None ) [source] -
Prints a string summary of the network.
- Parameters
-
input_size (tuple|InputSpec|list[tuple|InputSpec], optional) – size of input tensor. if not set, input_size will get from
self._inputs
if network only have one input, input_size can be tuple or InputSpec. if model have multiple input, input_size must be a list which contain every input’s shape. Default: None.dtypes (str, optional) – if dtypes is None, ‘float32’ will be used, Default: None.
- Returns
-
a summary of the network including total params and total trainable params.
- Return type
-
Dict
Examples
import paddle from paddle.static import InputSpec input = InputSpec([None, 1, 28, 28], 'float32', 'image') label = InputSpec([None, 1], 'int64', 'label') model = paddle.Model(paddle.vision.LeNet(), input, label) optim = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) model.prepare( optim, paddle.nn.CrossEntropyLoss()) params_info = model.summary() print(params_info)