amp_guard

paddle.fluid.dygraph.amp.auto_cast. amp_guard ( enable=True, custom_white_list=None, custom_black_list=None, level='O1', dtype='float16' ) [source]
Api_attr

imperative

Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. If enabled, the input data type (float32 or float16) of each operator is decided by autocast algorithm for better performance.

Commonly, it is used together with GradScaler to achieve Auto-Mixed-Precision in imperative mode. It is used together with decorator to achieve Pure fp16 in imperative mode.

Parameters
  • enable (bool, optional) – Enable auto-mixed-precision or not. Default is True.

  • custom_white_list (set|list|tuple, optional) – The custom white_list. It’s the set of ops that support fp16 calculation and are considered numerically-safe and performance-critical. These ops will be converted to fp16.

  • custom_black_list (set|list|tuple, optional) – The custom black_list. The set of ops that support fp16 calculation and are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be converted to fp16.

  • level (str, optional) – Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)

  • dtype (str, optional) – Whether to use ‘float16’ or ‘bfloat16’. Default is ‘float16’.

Examples

import numpy as np
import paddle

data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with paddle.fluid.dygraph.guard():
    conv2d = paddle.fluid.dygraph.Conv2D(3, 2, 3)
    data = paddle.fluid.dygraph.to_variable(data)
    with paddle.fluid.dygraph.amp_guard():
        conv = conv2d(data)
        print(conv.dtype) # FP16
    with paddle.fluid.dygraph.amp_guard(enable=False):
        conv = conv2d(data)
        print(conv.dtype) # FP32