register_hook

paddle.Tensor. register_hook ( self, hook )

Registers a backward hook for current Tensor.

The hook will be called every time the gradient Tensor of current Tensor is computed.

The hook should not modify the input gradient Tensor, but it can optionally return a new gradient Tensor which will be used in place of current Tensor’s gradient.

The hook should have the following signature:

hook(grad) -> Tensor or None

Parameters

hook (function) – A backward hook to be registered for Tensor.grad

Returns

A helper object that can be used to remove the registered hook by calling remove() method.

Return type

TensorHookRemoveHelper

Examples

>>> import paddle

>>> # hook function return None
>>> def print_hook_fn(grad):
...     print(grad)
...
>>> # hook function return Tensor
>>> def double_hook_fn(grad):
...     grad = grad * 2
...     return grad
...
>>> x = paddle.to_tensor([0., 1., 2., 3.], stop_gradient=False)
>>> y = paddle.to_tensor([4., 5., 6., 7.], stop_gradient=False)
>>> z = paddle.to_tensor([1., 2., 3., 4.])

>>> # one Tensor can register multiple hooks
>>> h = x.register_hook(print_hook_fn)
>>> x.register_hook(double_hook_fn)

>>> w = x + y
>>> # register hook by lambda function
>>> w.register_hook(lambda grad: grad * 2)

>>> o = z.matmul(w)
>>> o.backward()
>>> # print_hook_fn print content in backward
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=False,
[2., 4., 6., 8.])

>>> print("w.grad:", w.grad)
w.grad: None
>>> print("x.grad:", x.grad)
x.grad: Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=False,
[4. , 8. , 12., 16.])
>>> print("y.grad:", y.grad)
y.grad: Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=False,
[2., 4., 6., 8.])

>>> # remove hook
>>> h.remove()