fused_dropout_add

paddle.incubate.nn.functional. fused_dropout_add ( x, y, p=0.5, training=True, mode='upscale_in_train', name=None ) [source]

Fused Dropout and Add.

Parameters
  • x (Tensor) – The input tensor. The data type is bfloat16, float16, float32 or float64.

  • y (Tensor) – The input tensor. The data type is bfloat16, float16, float32 or float64.

  • p (float|int, optional) – Probability of setting units to zero. Default: 0.5.

  • training (bool, optional) – A flag indicating whether it is in train phrase or not. Default: True.

  • mode (str, optional) –

    [‘upscale_in_train’(default) | ‘downscale_in_infer’].

    1. upscale_in_train (default), upscale the output at training time

      • train: \(out = x \times \frac{mask}{(1.0 - dropout\_prob)} + y\)

      • inference: \(out = x + y\)

    2. downscale_in_infer, downscale the output at inference

      • train: \(out = input \times mask + y\)

      • inference: \(out = input \times (1.0 - dropout\_prob) + y\)

  • name (str, optional) – Name for the operation, Default: None. For more information, please refer to Name.

Returns

A Tensor representing the fused dropout and add, has same shape and data type as x .

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.nn.functional import fused_dropout_add

>>> paddle.set_device('gpu')
>>> paddle.seed(2023)
>>> x = paddle.randn([4, 10], dtype="float32")
>>> y = paddle.randn([4, 10], dtype="float32")
>>> out = fused_dropout_add(x, y, p=0.5)
>>> print(out)
Tensor(shape=[4, 10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[-0.49133155,  0.53819323, -2.58393312,  0.06336236, -1.09908366,
   0.22085167,  2.19751787,  0.05034769,  0.53417486,  0.84864247],
 [ 0.78248203, -1.59652555, -0.14399840, -0.77985179, -0.17006736,
  -0.30991879, -0.36593807, -0.51025450,  1.46401680,  0.61627960],
 [ 4.50472546, -0.48472026,  0.60729283,  0.33509624, -0.25593102,
  -1.45173049,  1.06727099,  0.00440830, -0.77340341,  0.67393088],
 [ 1.29453969,  0.07568165,  0.71947742, -0.71768606, -2.57172823,
   1.89179027,  3.26482797,  1.10493207, -1.04569530, -1.04862499]])