[ 组合替代实现 ]torch.nn.modules.module.register_module_forward_hook¶
torch.nn.modules.module.register_module_forward_hook¶
torch.nn.modules.module.register_module_forward_hook(hook, *, always_call=False)
paddle.nn.Layer.register_forward_post_hook¶
paddle.nn.Layer.register_forward_post_hook(hook)
其中,PyTorch 为给全局所有 module 注册 hook,而 Paddle 为给单个 Layer 注册 hook。PyTorch 相比 Paddle 支持更多其他参数,具体如下:
参数映射¶
PyTorch | PaddlePaddle | 备注 |
---|---|---|
hook | hook | 被注册为 forward post-hook 的函数。 |
always_call | - | 是否强制调用钩子,Paddle 无此参数,一般对训练结果影响不大,可直接删除。 |
转写示例¶
# PyTorch 写法
Linear = torch.nn.Linear(2, 4)
Conv2d = torch.nn.Conv2d(3, 16, 3)
Batch2d = torch.nn.BatchNorm2d(10)
torch.nn.modules.module.register_module_forward_hook(hook)
# Paddle 写法
Linear = paddle.nn.Linear(2, 4)
Conv2d = paddle.nn.Conv2d(3, 16, 3)
Batch2d = paddle.nn.BatchNorm2D(10)
Linear.register_forward_post_hook(hook)
Conv2d.register_forward_post_hook(hook)
Batch2d.register_forward_post_hook(hook)