static_pylayer

paddle.static.nn. static_pylayer ( forward_fn, inputs, backward_fn=None, name=None ) [源代码]

该 API 返回 forward_fn(inputs),并且根据传入的 forward_fnbackward_fn 的执行逻辑创建两个 sub_block, 同时创建 pylayer 算子,pylayer 算子的属性储存创建的 sub_block ID。

forward_fnbackward_fn 需要返回同样嵌套结构(nest structure)的 Tensor。 PaddlePaddle 里 Tensor 的嵌套结构是指一个 Tensor,或者 Tensor 的元组(tuple),或者 Tensor 的列表(list)。

注解

  1. 如果 backward_fn 被设置为 None,用户需要使 forward_fn 的输入数量和 backward_fn 的输出数量相同,forward_fn 的输出数量和 backward_fn 的输入数量相同。

  2. backward_fn 被设置为 None 的情况下,inputs 里所有 Variable 的 stop_gradient 属性应该被设为 True,否则可能会在反向传播(backward propagation)中得到意想不到的结果。

  3. 本 API 只能被运行在静态图模式下。

参数

  • forward_fn (callable) - 一个前向传播(forward propagation)时被调用的 callable。

  • inputs (list[Variable]) - Variable 类型列表,其含义为 forward_fn 的输入 Variable。

  • backward_fn (callable,可选) - 一个反向传播(backward propagation)时被调用的 callable。默认值:None,表示不需要进行反向传播(backward propagation)。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值:None

返回

Variable|list(Variable)|tuple(Variable),该 API 返回 forward_fn(inputs)

代码示例

>>> import paddle
>>> import numpy as np

>>> paddle.enable_static()

>>> def forward_fn(x):
...     return paddle.exp(x)

>>> def backward_fn(dy):
...     return 2 * paddle.exp(dy)

>>> main_program = paddle.static.Program()
>>> start_program = paddle.static.Program()

>>> place = paddle.CPUPlace()
>>> exe = paddle.static.Executor(place)
>>> with paddle.static.program_guard(main_program, start_program):
...     data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
...     data.stop_gradient = False
...     ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
...     data_grad = paddle.static.gradients([ret], data)[0]

>>> exe.run(start_program)
>>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
>>> x, x_grad, y = exe.run(
...     main_program,
...     feed={"X": x},
...     fetch_list=[
...         data.name,
...         data_grad.name,
...         ret.name
...     ],
... )

>>> print(x)
[[1. 2. 3. 4. 5.]]
>>> print(x_grad)
[[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
>>> print(y)
[[  2.7182817   7.389056   20.085537   54.59815   148.41316  ]]