SyncBatchNorm

class paddle.nn. SyncBatchNorm ( num_features, momentum=0.9, epsilon=1e-05, weight_attr=None, bias_attr=None, data_format='NCHW', name=None ) [source]

This interface is used to construct a callable object of the SyncBatchNorm class. It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can be used as a normalizer function for other operations, such as conv2d and fully connected operations. The data is normalized by the mean and variance of the channel based on whole mini-batch , which including data in all gpus. Refer to Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift for more details.

When model in training mode, the \(\\mu_{\\beta}\) and \(\\sigma_{\\beta}^{2}\) are the statistics of whole mini-batch data in all gpus. Calculated as follows:

\[\begin{split}\mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\ \ mini-batch\ mean \\ \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \ \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\\end{split}\]
  • \(x\) : whole mini-batch data in all gpus

  • \(m\) : the size of the whole mini-batch data

When model in evaluation mode, the \(\\mu_{\\beta}\) and \(\sigma_{\beta}^{2}\) are global statistics (moving_mean and moving_variance, which usually got from the pre-trained model). Global statistics calculated as follows:

\[\begin{split}moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\ moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\\end{split}\]

The formula of normalization is as follows:

\[\begin{split}\hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\ \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift\end{split}\]
  • \(\epsilon\) : add a smaller value to the variance to prevent division by zero

  • \(\gamma\) : trainable scale parameter vector

  • \(\beta\) : trainable shift parameter vector

Note

If you want to use container to pack your model and has SyncBatchNorm in the evaluation phase, please use LayerList or Sequential instead of list to pack the model.

Parameters
  • num_features (int) – Indicate the number of channels of the input Tensor.

  • epsilon (float, optional) – The small value added to the variance to prevent division by zero. Default: 1e-5.

  • momentum (float, optional) – The value used for the moving_mean and moving_var computation. Default: 0.9.

  • weight_attr (ParamAttr|bool, optional) – The parameter attribute for Parameter scale of this layer. If it is set to None or one attribute of ParamAttr, this layerr will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with ones. If it is set to False, this layer will not have trainable scale parameter. Default: None.

  • bias_attr (ParamAttr|bool, optional) – The parameter attribute for the bias of this layer. If it is set to None or one attribute of ParamAttr, this layer will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. If it is set to False, this layer will not have trainable bias parameter. Default: None.

Shapes:
  • input: Tensor that the dimension from 2 to 5.

  • output: Tensor with the same shape as input.

Examples

>>> 

>>> import paddle
>>> import paddle.nn as nn
>>> paddle.device.set_device('gpu')
>>> x = paddle.to_tensor([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')

>>> if paddle.is_compiled_with_cuda():
...     sync_batch_norm = nn.SyncBatchNorm(2)
...     hidden1 = sync_batch_norm(x)
...     print(hidden1)
Tensor(shape=[1, 2, 2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
[[[[ 0.26824948,  1.09363246],
   [ 0.26824948, -1.63013160]],
  [[ 0.80956620, -0.66528702],
   [-1.27446556,  1.13018656]]]])
forward ( x )

forward

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

classmethod convert_sync_batchnorm ( layer )

convert_sync_batchnorm

Helper function to convert :class: paddle.nn.BatchNorm*d layers in the model to :class: paddle.nn.SyncBatchNorm layers.

Parameters

layer (paddle.nn.Layer) – model containing one or more BatchNorm*d layers.

Returns

The original model with converted SyncBatchNorm layers. If BatchNorm*d layer in the model, use SyncBatchNorm layer instead.

Examples

>>> import paddle
>>> import paddle.nn as nn

>>> model = nn.Sequential(nn.Conv2D(3, 5, 3), nn.BatchNorm2D(5))
>>> sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
>>> print(sync_model)
Sequential(
    (0): Conv2D(3, 5, kernel_size=[3, 3], data_format=NCHW)
    (1): SyncBatchNorm(num_features=5, momentum=0.9, epsilon=1e-05)
)