PrepareLayerInput¶
使用用户提供的函数,对标记 Layer 的输入进行处理。
参数¶
fn (callable,可选) - 用来处理标记 Layer 输入的函数,该函数需要接受并且仅接受一个参数 process_mesh ,并返回真正用来处理输入的函数。默认为 None。
代码示例
>>> import paddle
>>> import paddle.distributed as dist
>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))
>>> def layer_input_hook(process_mesh):
...     def hook(layer, input, output):
...         return input
...     return hook
>>> layer = MLP()
>>> mp_config = {
...     'fc1': dist.PrepareLayerOutput(layer_input_hook)
... }