recompute_sequential

paddle.incubate.distributed.fleet. recompute_sequential ( ctx, functions, *args, **kwargs ) [source]

recompute intermediate activations to save the memory for ‘Sequential’ models. use ‘ctx’ to transmit some context params, it is similar to ‘recompute_hybrid’ API.

Parameters
  • ctx (dict) – include ‘segments’ and ‘preserve_rng_state’ keys, the key ‘segments’ (int, default 1), represents the number of chunks to create in the model, the key ‘preserve_rng_state’ (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be restored when the forward recalculation of backpropagation is performed.

  • functions (paddle.nn.Sequential) – layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation.

  • *args (Tensor) – inputs(tuple) to the function.

  • **kwargs (Dict) – inputs(dict) to the function.

Returns

Output of function on args and kwargs.

Examples

>>> 
>>> import paddle
>>> from paddle.incubate.distributed.fleet import recompute_sequential
>>> input = paddle.ones(shape=[8, 10])
>>> model = paddle.nn.Sequential(paddle.nn.Linear(10, 10), paddle.nn.Linear(10, 2))
>>> output = recompute_sequential({'segments' : 1}, model, input)