recompute

paddle.distributed.fleet.utils. recompute ( function, *args, **kwargs ) [源代码]

重新计算中间激活函数值来节省显存。

参数

  • function (paddle.nn.Layer) - 模型前向传播的部分连续的层函数组成的序列,它们的中间激活函数值将在前向传播过程中被释放掉来节省显存,并且在反向梯度计算的时候会重新被计算。

  • args (Tensor) - function 的输入。

  • kwargs (Dict) - kwargs 只应该包含两类键值对。一类键值是 function 的字典参数,另外一类仅只能包含 preserve_rng_state 和 use_reentrant 两个 key。 preserve_rng_state 的键值对,用来表示是否保存前向的 rng,如果为 True,那么在反向传播的重计算前向时会还原上次前向的 rng 值。默认 preserve_rng_state 为 True。 use_reentrant 的键值对,用来表示 recompute 的实现方式,如果为 True,意味着 recompute 使用 PyLayer 的方式实现的,如果为 False, recompute 内部则使用 hook 的方式实现的,默认值是 True。在某些场景下,比如 recompute 与数据并行结合时,需要额外调用 no_sync 函数,此时可以设置 use_reentrant=False,选用 hook 方式的 recompute,可以避免额外调用 no_sync 函数。

返回

function 作用在输入的输出

代码示例

>>> import paddle
>>> from paddle.distributed.fleet.utils import recompute
>>> import random
>>> paddle.seed(2023)
>>> def get_fc_block(block_idx, input_size, is_last=False):
...     block_name = "block_" + str(block_idx)
...     block = paddle.nn.Sequential(
...         (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
...         (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
...         (block_name + "_relu_1", paddle.nn.ReLU()),
...         (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
...         (block_name + "_relu_2", paddle.nn.ReLU()),
...     )
...     if is_last:
...         block.add_sublayer(
...             block_name + "_fc_2",
...             paddle.nn.Linear(
...                 input_size, 1, bias_attr=False
...             )
...         )
...     else:
...         block.add_sublayer(
...             block_name + "_fc_2",
...             paddle.nn.Linear(input_size, input_size, bias_attr=False)
...         )
...     return block

>>> class Naive_fc_net(paddle.nn.Layer):
...     def __init__(self, input_size=10,
...                 recompute_blocks=[1, 3],
...                 recompute_kwargs={}):
...         super().__init__()
...         self.recompute_blocks = recompute_blocks
...         self.recompute_kwargs = recompute_kwargs
...         self.runfunc0 = get_fc_block(0, input_size, is_last=False)
...         self.runfunc1 = get_fc_block(1, input_size, is_last=False)
...         self.runfunc2 = get_fc_block(2, input_size, is_last=False)
...         self.runfunc3 = get_fc_block(3, input_size, is_last=False)
...         self.runfunc4 = get_fc_block(4, input_size, is_last=True)
...         self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
...     def forward(self, inputs):
...         nums = len(self.total_func)
...         for i in range(nums):
...             if i in self.recompute_blocks:
...                 inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
...             else:
...                 inputs = self.total_func[i](inputs)
...         return inputs

>>> def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
...     gen = paddle.seed(10)
...     gen.manual_seed(10)
...     random.seed(10)
...     if cuda_state:
...         paddle.set_cuda_rng_state(cuda_state)
...     batch_size, input_size = 1, 10
...     model = Naive_fc_net(
...         input_size,
...         recompute_blocks=recompute_block,
...         recompute_kwargs=recompute_kwargs)
...     optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
...     loss_ = []
...     param_ = []
...     grad_ = []
...     for _ in range(5):
...         x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
...         y_pred = model(x)
...         loss = y_pred.mean()
...         loss_.append(loss.item())
...         loss.backward()
...         optimizer.step()
...         param_.append(model.parameters()[9])
...         grad_.append(model.parameters()[3]._grad_ivar())
...         optimizer.clear_grad()
...     return loss_, param_, grad_

>>> cuda_state = paddle.get_cuda_rng_state()
>>> # without recompute
>>> loss_ref, param_ref, grad_ref = run_model(
...     cuda_state, recompute_block=[]
... )

>>> loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
>>> print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
>>> # The result of the recompute_loss should be the same as the normal_loss.
normal_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0], recompute_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0]

使用本API的教程文档