saved_tensors_hooks

class paddle.autograd. saved_tensors_hooks ( pack_hook, unpack_hook ) [source]

Dynamic graph, registers a pair of pack / unpack hooks for saved tensors.

Parameters
  • pack_hook (function) – The pack hook will be called every time the forward operation inputs/outputs tensors need be saved for backward. Then you can save it to CPU or Disk. The input of pack_hook is a tensor need be saved. The output of pack_hook is then stored information instead of the original tensor. pack_hook will also be called while any tensor need be saved by PyLayerContext.save_for_backward. If a tensor saved for backward is no need buffer, pack_hook will not be called. Only the tensor saved for backward is LoDTensor, pack_hook will be called.

  • unpack_hook (function) – The unpack hook will be called every time the backward need use the saved inputs/outputs tensors. Then you can reload the tensor and return it to paddle framework. The input of unpack_hook is the information returned by pack_hook. The output of unpack_hook is a tensor reloaded by the information, and the tensor must has the same content as the original tensor passed as input to the corresponding pack_hook.

Returns

None

Examples


>>> # Example1
>>> import paddle
>>> def pack_hook(x):
...     print("Packing", x)
...     return x.numpy()
>>> def unpack_hook(x):
...     print("UnPacking", x)
...     return paddle.to_tensor(x)
>>> a = paddle.ones([3,3])
>>> b = paddle.ones([3,3]) * 2
>>> a.stop_gradient = False
>>> b.stop_gradient = False
>>> with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook):
...     y = paddle.multiply(a, b)
>>> y.sum().backward()
>>> # Example2
>>> import paddle
>>> from paddle.autograd import PyLayer
>>> class cus_multiply(PyLayer):
...     @staticmethod
...     def forward(ctx, a, b):
...         y = paddle.multiply(a, b)
...         ctx.save_for_backward(a, b)
...         return y
...
...     @staticmethod
...     def backward(ctx, dy):
...         a,b = ctx.saved_tensor()
...         grad_a = dy * a
...         grad_b = dy * b
...         return grad_a, grad_b
>>> def pack_hook(x):
...     print("Packing", x)
...     return x.numpy()
>>> def unpack_hook(x):
...     print("UnPacking", x)
...     return paddle.to_tensor(x)
>>> a = paddle.ones([3,3])
>>> b = paddle.ones([3,3]) * 2
>>> a.stop_gradient = False
>>> b.stop_gradient = False
>>> with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook):
...     y = cus_multiply.apply(a, b)
>>> y.sum().backward()