[ torch 参数更多 ]torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, preserve_rng_state=True, **kwargs)
paddle.distributed.fleet.utils.recompute
paddle.distributed.fleet.utils.recompute(function, preserve_rng_state=True, use_reentrant=True, *args, **kwargs)
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
参数映射
| PyTorch | PaddlePaddle | 备注 |
|---|---|---|
| function | function | 模型前向传播的部分连续的层函数组成的序列。 |
| *args | *args | function 的输入。 |
| use_reentrant | use_reentrant | recompute 的实现方式。 |
| context_fn | - | 控制梯度检查点的执行上下文, 一般对训练结果影响不大,可直接删除。 |
| determinism_check | - | 控制是否在反向传播时检查操作的确定性, 一般对训练结果影响不大,可直接删除。 |
| debug | - | 是否启用调试模式, 一般对训练结果影响不大,可直接删除。 |
| preserve_rng_state | preserve_rng_state | 是否保存前向的 rng。 |
| **kwargs | **kwargs | 用于指定 Extension 的其他参数,支持的参数与 setuptools.Extension 一致。 |