paddle.fluid.contrib.mixed_precision.bf16.amp_utils. bf16_guard ( ) [source]

As for the pure bf16 training, if users set use_bf16_guard to True, only those ops created in the context manager bf16_guard will be transformed as float16 type.


import numpy as np
import paddle
import paddle.nn.functional as F
data ='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)

with paddle.static.amp.bf16_guard():
    bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
    pool = F.max_pool2d(bn, kernel_size=2, stride=2)
    hidden = paddle.static.nn.fc(pool, size=10)
    loss = paddle.mean(hidden)