summary

paddle. summary ( net, input_size=None, dtypes=None, input=None ) [源代码]

通过 input_sizeinput 打印网络 net 的基础结构和参数信息。input_size 指定网络 net 输入 Tensor 的大小,而 input 指定网络 net 的输入 Tensor;如果给出 input,那么 input_sizedtypes 的输入将被忽略。

参数

  • net (Layer) - 网络实例,必须是 Layer 的子类。

  • input_size (tuple|InputSpec|list[tuple|InputSpec,可选) - 输入 Tensor 的大小。如果网络只有一个输入,那么该值需要设定为 tuple 或 InputSpec。如果模型有多个输入。那么该值需要设定为 list[tuple|InputSpec],包含每个输入的 shape。默认值:None。

  • dtypes (str,可选) - 输入 Tensor 的数据类型,如果没有给定,默认使用 float32 类型。默认值:None。

  • input (tensor,可选) - 输入的 Tensor,如果给出 input,那么 input_sizedtypes 的输入将被忽略。默认值:None。

返回

字典,包含了总的参数量和总的可训练的参数量。

代码示例

>>> import paddle
>>> import paddle.nn as nn
>>> paddle.seed(2023)
>>> class LeNet(nn.Layer):
...     def __init__(self, num_classes=10):
...         super().__init__()
...         self.num_classes = num_classes
...         self.features = nn.Sequential(
...             nn.Conv2D(1, 6, 3, stride=1, padding=1),
...             nn.ReLU(),
...             nn.MaxPool2D(2, 2),
...             nn.Conv2D(6, 16, 5, stride=1, padding=0),
...             nn.ReLU(),
...             nn.MaxPool2D(2, 2))
...
...         if num_classes > 0:
...             self.fc = nn.Sequential(
...                 nn.Linear(400, 120),
...                 nn.Linear(120, 84),
...                 nn.Linear(84, 10))
...
...     def forward(self, inputs):
...         x = self.features(inputs)
...
...         if self.num_classes > 0:
...             x = paddle.flatten(x, 1)
...             x = self.fc(x)
...         return x
...
>>> lenet = LeNet()

>>> params_info = paddle.summary(lenet, (1, 1, 28, 28))
>>> print(params_info)
---------------------------------------------------------------------------
Layer (type)       Input Shape          Output Shape         Param #
===========================================================================
  Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    ReLU-1        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
  MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
  Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    ReLU-2       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
  MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
  Linear-1          [[1, 400]]            [1, 120]           48,120
  Linear-2          [[1, 120]]            [1, 84]            10,164
  Linear-3          [[1, 84]]             [1, 10]              850
===========================================================================
Total params: 61,610
Trainable params: 61,610
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
---------------------------------------------------------------------------
{'total_params': 61610, 'trainable_params': 61610}
>>> # multi input demo
>>> class LeNetMultiInput(LeNet):
...     def forward(self, inputs, y):
...         x = self.features(inputs)
...
...         if self.num_classes > 0:
...             x = paddle.flatten(x, 1)
...             x = self.fc(x + y)
...         return x
...
>>> lenet_multi_input = LeNetMultiInput()

>>> params_info = paddle.summary(lenet_multi_input,
...                              [(1, 1, 28, 28), (1, 400)],
...                              dtypes=['float32', 'float32'])
>>> print(params_info)
---------------------------------------------------------------------------
Layer (type)       Input Shape          Output Shape         Param #
===========================================================================
  Conv2D-3       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    ReLU-3        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
  MaxPool2D-3     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
  Conv2D-4       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    ReLU-4       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
  MaxPool2D-4    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
  Linear-4          [[1, 400]]            [1, 120]           48,120
  Linear-5          [[1, 120]]            [1, 84]            10,164
  Linear-6          [[1, 84]]             [1, 10]              850
===========================================================================
Total params: 61,610
Trainable params: 61,610
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
---------------------------------------------------------------------------
{'total_params': 61610, 'trainable_params': 61610}
>>> # list input demo
>>> class LeNetListInput(LeNet):
...     def forward(self, inputs):
...         x = self.features(inputs[0])
...
...         if self.num_classes > 0:
...             x = paddle.flatten(x, 1)
...             x = self.fc(x + inputs[1])
...         return x
...
>>> lenet_list_input = LeNetListInput()
>>> input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
>>> params_info = paddle.summary(lenet_list_input, input=input_data)
>>> print(params_info)
---------------------------------------------------------------------------
Layer (type)       Input Shape          Output Shape         Param #
===========================================================================
  Conv2D-5       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    ReLU-5        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
  MaxPool2D-5     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
  Conv2D-6       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    ReLU-6       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
  MaxPool2D-6    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
  Linear-7          [[1, 400]]            [1, 120]           48,120
  Linear-8          [[1, 120]]            [1, 84]            10,164
  Linear-9          [[1, 84]]             [1, 10]              850
===========================================================================
Total params: 61,610
Trainable params: 61,610
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
---------------------------------------------------------------------------
{'total_params': 61610, 'trainable_params': 61610}
>>> # dict input demo
>>> class LeNetDictInput(LeNet):
...     def forward(self, inputs):
...         x = self.features(inputs['x1'])
...
...         if self.num_classes > 0:
...             x = paddle.flatten(x, 1)
...             x = self.fc(x + inputs['x2'])
...         return x
...
>>> lenet_dict_input = LeNetDictInput()
>>> input_data = {'x1': paddle.rand([1, 1, 28, 28]),
...               'x2': paddle.rand([1, 400])}
>>> params_info = paddle.summary(lenet_dict_input, input=input_data)
>>> print(params_info)
---------------------------------------------------------------------------
Layer (type)       Input Shape          Output Shape         Param #
===========================================================================
  Conv2D-7       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
    ReLU-7        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
  MaxPool2D-7     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
  Conv2D-8       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
    ReLU-8       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
  MaxPool2D-8    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
  Linear-10         [[1, 400]]            [1, 120]           48,120
  Linear-11         [[1, 120]]            [1, 84]            10,164
  Linear-12         [[1, 84]]             [1, 10]              850
===========================================================================
Total params: 61,610
Trainable params: 61,610
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
---------------------------------------------------------------------------
{'total_params': 61610, 'trainable_params': 61610}