summary

paddle. summary ( net, input_size=None, dtypes=None, input=None ) [源代码]

通过 input_sizeinput 打印网络 net 的基础结构和参数信息。input_size 指定网络 net 输入 Tensor 的大小,而 input 指定网络 net 的输入 Tensor;如果给出 input,那么 input_sizedtypes 的输入将被忽略。

参数

  • net (Layer) - 网络实例,必须是 Layer 的子类。

  • input_size (tuple|InputSpec|list[tuple|InputSpec,可选) - 输入 Tensor 的大小。如果网络只有一个输入,那么该值需要设定为 tuple 或 InputSpec。如果模型有多个输入。那么该值需要设定为 list[tuple|InputSpec],包含每个输入的 shape。默认值:None。

  • dtypes (str,可选) - 输入 Tensor 的数据类型,如果没有给定,默认使用 float32 类型。默认值:None。

  • input (tensor,可选) - 输入的 Tensor,如果给出 input,那么 input_sizedtypes 的输入将被忽略。默认值:None。

返回

字典,包含了总的参数量和总的可训练的参数量。

代码示例 1

 >>> # example 1: Single Input Demo
 >>> import paddle
 >>> import paddle.nn as nn
 >>> # Define Network
 >>> class LeNet(nn.Layer):
 ...     def __init__(self, num_classes=10):
 ...         super().__init__()
 ...         self.num_classes = num_classes
 ...         self.features = nn.Sequential(
 ...             nn.Conv2D(1, 6, 3, stride=1, padding=1),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2),
 ...             nn.Conv2D(6, 16, 5, stride=1, padding=0),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2))
 ...
 ...         if num_classes > 0:
 ...             self.fc = nn.Sequential(
 ...                 nn.Linear(400, 120),
 ...                 nn.Linear(120, 84),
 ...                 nn.Linear(84, 10))
 ...
 ...     def forward(self, inputs):
 ...         x = self.features(inputs)
 ...
 ...         if self.num_classes > 0:
 ...             x = paddle.flatten(x, 1)
 ...             x = self.fc(x)
 ...         return x
 ...
 >>> lenet = LeNet()
 >>> params_info = paddle.summary(lenet, (1, 1, 28, 28)) # doctest: +NORMALIZE_WHITESPACE
 ---------------------------------------------------------------------------
  Layer (type)       Input Shape          Output Shape         Param #
 ===========================================================================
    Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
     ReLU-1        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
   MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
     ReLU-2       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
   MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    Linear-1          [[1, 400]]            [1, 120]           48,120
    Linear-2          [[1, 120]]            [1, 84]            10,164
    Linear-3          [[1, 84]]             [1, 10]              850
 ===========================================================================
 Total params: 61,610
 Trainable params: 61,610
 Non-trainable params: 0
 ---------------------------------------------------------------------------
 Input size (MB): 0.00
 Forward/backward pass size (MB): 0.11
 Params size (MB): 0.24
 Estimated Total Size (MB): 0.35
 ---------------------------------------------------------------------------
 <BLANKLINE>
 >>> print(params_info)
 {'total_params': 61610, 'trainable_params': 61610}

代码示例 2

 >>> # example 2: multi input demo
 >>> import paddle
 >>> import paddle.nn as nn
 >>> class LeNetMultiInput(nn.Layer):
 ...     def __init__(self, num_classes=10):
 ...         super().__init__()
 ...         self.num_classes = num_classes
 ...         self.features = nn.Sequential(
 ...             nn.Conv2D(1, 6, 3, stride=1, padding=1),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2),
 ...             nn.Conv2D(6, 16, 5, stride=1, padding=0),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2))
 ...
 ...         if num_classes > 0:
 ...             self.fc = nn.Sequential(
 ...                 nn.Linear(400, 120),
 ...                 nn.Linear(120, 84),
 ...                 nn.Linear(84, 10))
 ...
 ...     def forward(self, inputs, y):
 ...         x = self.features(inputs)
 ...
 ...         if self.num_classes > 0:
 ...             x = paddle.flatten(x, 1)
 ...             x = self.fc(x + y)
 ...         return x
 ...
 >>> lenet_multi_input = LeNetMultiInput()

 >>> params_info = paddle.summary(lenet_multi_input,
 ...                              [(1, 1, 28, 28), (1, 400)],
 ...                              dtypes=['float32', 'float32']) # doctest: +NORMALIZE_WHITESPACE
 ---------------------------------------------------------------------------
  Layer (type)       Input Shape          Output Shape         Param #
 ===========================================================================
    Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
     ReLU-1        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
   MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
     ReLU-2       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
   MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    Linear-1          [[1, 400]]            [1, 120]           48,120
    Linear-2          [[1, 120]]            [1, 84]            10,164
    Linear-3          [[1, 84]]             [1, 10]              850
 ===========================================================================
 Total params: 61,610
 Trainable params: 61,610
 Non-trainable params: 0
 ---------------------------------------------------------------------------
 Input size (MB): 0.00
 Forward/backward pass size (MB): 0.11
 Params size (MB): 0.24
 Estimated Total Size (MB): 0.35
 ---------------------------------------------------------------------------
 <BLANKLINE>
 >>> print(params_info)
 {'total_params': 61610, 'trainable_params': 61610}

代码示例 3

 >>> # example 3: List/Dict Input Demo
 >>> import paddle
 >>> import paddle.nn as nn

 >>> # list input demo
 >>> class LeNetListInput(nn.Layer):
 ...     def __init__(self, num_classes=10):
 ...         super().__init__()
 ...         self.num_classes = num_classes
 ...         self.features = nn.Sequential(
 ...             nn.Conv2D(1, 6, 3, stride=1, padding=1),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2),
 ...             nn.Conv2D(6, 16, 5, stride=1, padding=0),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2))
 ...
 ...         if num_classes > 0:
 ...             self.fc = nn.Sequential(
 ...                 nn.Linear(400, 120),
 ...                 nn.Linear(120, 84),
 ...                 nn.Linear(84, 10))
 ...
 ...     def forward(self, inputs):
 ...         x = self.features(inputs[0])
 ...
 ...         if self.num_classes > 0:
 ...             x = paddle.flatten(x, 1)
 ...             x = self.fc(x + inputs[1])
 ...         return x
 ...
 >>> lenet_list_input = LeNetListInput()
 >>> input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])]
 >>> params_info = paddle.summary(lenet_list_input, input=input_data) # doctest: +NORMALIZE_WHITESPACE
 ---------------------------------------------------------------------------
  Layer (type)       Input Shape          Output Shape         Param #
 ===========================================================================
    Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
     ReLU-1        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
   MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
     ReLU-2       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
   MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    Linear-1          [[1, 400]]            [1, 120]           48,120
    Linear-2          [[1, 120]]            [1, 84]            10,164
    Linear-3          [[1, 84]]             [1, 10]              850
 ===========================================================================
 Total params: 61,610
 Trainable params: 61,610
 Non-trainable params: 0
 ---------------------------------------------------------------------------
 Input size (MB): 0.00
 Forward/backward pass size (MB): 0.11
 Params size (MB): 0.24
 Estimated Total Size (MB): 0.35
 ---------------------------------------------------------------------------
 <BLANKLINE>
 >>> print(params_info)
 {'total_params': 61610, 'trainable_params': 61610}

 >>> # dict input demo
 >>> class LeNetDictInput(nn.Layer):
 ...     def __init__(self, num_classes=10):
 ...         super().__init__()
 ...         self.num_classes = num_classes
 ...         self.features = nn.Sequential(
 ...             nn.Conv2D(1, 6, 3, stride=1, padding=1),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2),
 ...             nn.Conv2D(6, 16, 5, stride=1, padding=0),
 ...             nn.ReLU(),
 ...             nn.MaxPool2D(2, 2))
 ...
 ...         if num_classes > 0:
 ...             self.fc = nn.Sequential(
 ...                 nn.Linear(400, 120),
 ...                 nn.Linear(120, 84),
 ...                 nn.Linear(84, 10))
 ...
 ...     def forward(self, inputs):
 ...         x = self.features(inputs['x1'])
 ...
 ...         if self.num_classes > 0:
 ...             x = paddle.flatten(x, 1)
 ...             x = self.fc(x + inputs['x2'])
 ...         return x
 ...
 >>> lenet_dict_input = LeNetDictInput()
 >>> input_data = {'x1': paddle.rand([1, 1, 28, 28]),
 ...               'x2': paddle.rand([1, 400])}
 >>> # The module suffix number indicates its sequence in modules of the same type, used for differentiation identification
 >>> params_info = paddle.summary(lenet_dict_input, input=input_data) # doctest: +NORMALIZE_WHITESPACE
 ---------------------------------------------------------------------------
  Layer (type)       Input Shape          Output Shape         Param #
 ===========================================================================
    Conv2D-3       [[1, 1, 28, 28]]      [1, 6, 28, 28]          60
     ReLU-3        [[1, 6, 28, 28]]      [1, 6, 28, 28]           0
   MaxPool2D-3     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0
    Conv2D-4       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416
     ReLU-4       [[1, 16, 10, 10]]     [1, 16, 10, 10]           0
   MaxPool2D-4    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0
    Linear-4          [[1, 400]]            [1, 120]           48,120
    Linear-5          [[1, 120]]            [1, 84]            10,164
    Linear-6          [[1, 84]]             [1, 10]              850
 ===========================================================================
 Total params: 61,610
 Trainable params: 61,610
 Non-trainable params: 0
 ---------------------------------------------------------------------------
 Input size (MB): 0.00
 Forward/backward pass size (MB): 0.11
 Params size (MB): 0.24
 Estimated Total Size (MB): 0.35
 ---------------------------------------------------------------------------
 <BLANKLINE>
 >>> print(params_info)
 {'total_params': 61610, 'trainable_params': 61610}