recompute_hybrid

paddle.incubate.distributed.fleet. recompute_hybrid ( ctx, function, *args, **kwargs ) [source]

recompute intermediate activations to save the memory in hybrid parallel scene. # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: # 1. The calculation output of recompute, there are tensors that do not require gradients. # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor

Parameters
  • ctx (dict) – include ‘mp_group’, ‘offload’, and ‘partition’ keys. the key ‘mp_group’ (Group), represents the avtivations are splitted in which group. the key ‘offload’ (bool, optional, default=False), represents whether to offload to cpu. the key ‘partition’ (bool, optional, default=False), represents whether to split activations in the mp_group.

  • function (paddle.nn.Layer) – layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation.

  • *args (Tensor) – inputs(tuple) to the function.

  • **kwargs (Dict) – inputs(dict) to the function.

Returns

Output of function on args and kwargs.