decorate

paddle.amp. decorate ( models, optimizers=None, level='O1', dtype='float16', master_weight=None, save_dtype=None, master_grad=False, excluded_layers=None ) [source]

Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm.

Commonly, it is used together with auto_cast to achieve Pure float16/bfloat16 in imperative mode.

Parameters
  • models (Layer|list of Layer) – The defined models by user, models must be either a single model or a list of models. Default is None.

  • optimizers (Optimizer|list of Optimizer, optional) – The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.

  • level (str, optional) – Auto mixed precision level. Accepted values are ‘O1’ and ‘O2’: O1 represent mixed precision, the decorator will do nothing; O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)

  • dtype (str, optional) – Whether to use ‘float16’ or ‘bfloat16’. Default is ‘float16’.

  • master_weight (bool, optinal) – For level=’O2’, whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.

  • save_dtype (float, optional) – The save model parameter dtype when use paddle.save or paddle.jit.save,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.

  • master_grad (bool, optional) – For level=’O2’, whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients.

  • excluded_layers (Layer|list of Layer, optional) – Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. excluded_layers can be specified as an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.

Examples

>>> 
>>> # Demo1: single model and optimizer:
>>> import paddle
>>> paddle.device.set_device('gpu')

>>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer = paddle.optimizer.SGD(parameters=model.parameters())

>>> model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2')

>>> data = paddle.rand([10, 3, 32, 32])

>>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
...     output = model(data)
...     print(output.dtype)
paddle.float16

>>> # Demo2: multi models and optimizers:
>>> model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())

>>> models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')

>>> data = paddle.rand([10, 3, 32, 32])

>>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
...     output = models[0](data)
...     output2 = models[1](data)
...     print(output.dtype)
...     print(output2.dtype)
paddle.float16
paddle.float16

>>> # Demo3: optimizers is None:
>>> model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer3 = paddle.optimizer.Adam(parameters=model3.parameters())

>>> model = paddle.amp.decorate(models=model3, level='O2')

>>> data = paddle.rand([10, 3, 32, 32])

>>> with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
...     output = model(data)
...     print(output.dtype)
paddle.float16

Used in the guide/tutorials