save

paddle.jit. save ( layer, path, input_spec=None, **configs ) [source]

Saves input Layer or function as paddle.jit.TranslatedLayer format model, which can be used for inference or fine-tuning after loading.

It will save the translated program and all related persistable variables of input Layer to given path .

path is the prefix of saved objects, and the saved translated program file suffix is .pdmodel , the saved persistable variables file suffix is .pdiparams , and here also saved some additional variable description information to a file, its suffix is .pdiparams.info, these additional information is used in fine-tuning.

The saved model can be loaded by follow APIs:
  • paddle.jit.load

  • paddle.static.load_inference_model

  • Other C++ inference APIs

Note

When using paddle.jit.save to save a function, parameters will not be saved. If you have to save the parameter, please pass the Layer containing function and parameter to paddle.jit.save.

Parameters
  • layer (Layer|function) – The Layer or function to be saved.

  • path (str) – The path prefix to save model. The format is dirname/file_prefix or file_prefix.

  • input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional) – Describes the input of the saved model’s forward method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument, such as int, float, string, or list/dict of them.If None, all input variables of the original Layer’s forward method would be the inputs of the saved model. Default None.

  • **configs (dict, optional) – Other save configuration options for compatibility. We do not recommend using these configurations, they may be removed in the future. If not necessary, DO NOT use them. Default None. The following options are currently supported: (1) output_spec (list[Tensor]): Selects the output targets of the saved model. By default, all return variables of original Layer’s forward method are kept as the output of the saved model. If the provided output_spec list is not all output variables, the saved model will be pruned according to the given output_spec list.

Returns

None

Examples

>>> 
>>> # example 1: save layer
>>> import numpy as np
>>> import paddle
>>> import paddle.nn as nn
>>> import paddle.optimizer as opt

>>> BATCH_SIZE = 16
>>> BATCH_NUM = 4
>>> EPOCH_NUM = 4

>>> IMAGE_SIZE = 784
>>> CLASS_NUM = 10

>>> # define a random dataset
>>> class RandomDataset(paddle.io.Dataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __getitem__(self, idx):
...         image = np.random.random([IMAGE_SIZE]).astype('float32')
...         label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
...         return image, label
...
...     def __len__(self):
...         return self.num_samples

>>> class LinearNet(nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
...
...     @paddle.jit.to_static
...     def forward(self, x):
...         return self._linear(x)

>>> def train(layer, loader, loss_fn, opt):
...     for epoch_id in range(EPOCH_NUM):
...         for batch_id, (image, label) in enumerate(loader()):
...             out = layer(image)
...             loss = loss_fn(out, label)
...             loss.backward()
...             opt.step()
...             opt.clear_grad()
...             print("Epoch {} batch {}: loss = {}".format(
...                 epoch_id, batch_id, np.mean(loss.numpy())))

>>> # 1. train & save model.

>>> # create network
>>> layer = LinearNet()
>>> loss_fn = nn.CrossEntropyLoss()
>>> adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

>>> # create data loader
>>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
>>> loader = paddle.io.DataLoader(dataset,
...     batch_size=BATCH_SIZE,
...     shuffle=True,
...     drop_last=True,
...     num_workers=2
... )

>>> # train
>>> train(layer, loader, loss_fn, adam)

>>> # save
>>> path = "example_model/linear"
>>> paddle.jit.save(layer, path)

>>> # example 2: save function
>>> import paddle
>>> from paddle.static import InputSpec


>>> def save_function():
...     @paddle.jit.to_static
...     def fun(inputs):
...         return paddle.tanh(inputs)
...
...     path = 'test_jit_save_load_function_1/func'
...     inps = paddle.rand([3, 6])
...     origin = fun(inps)
...
...     paddle.jit.save(fun, path)
...     load_func = paddle.jit.load(path)
...
...     load_result = load_func(inps)
...     print((load_result - origin).abs().max() < 1e-10)

>>> save_function()

Used in the guide/tutorials