static_pylayer¶
该 API 返回 forward_fn(inputs),并且根据传入的 forward_fn 和 backward_fn 的执行逻辑创建两个 sub_block, 同时创建 pylayer 算子,pylayer 算子的属性储存创建的 sub_block ID。
forward_fn 和 backward_fn 需要返回同样嵌套结构(nest structure)的 Tensor。 PaddlePaddle 里 Tensor 的嵌套结构是指一个 Tensor,或者 Tensor 的元组(tuple),或者 Tensor 的列表(list)。
注解
- 如果 - backward_fn不为 None,用户需要使- forward_fn的输入 Tensor 的数量和- backward_fn的输出 Tensor 的数量相同,- forward_fn的输出 Tensor 的数量和- backward_fn的输入 Tensor 的数量相同。
- 在 - backward_fn被设置为- None的情况下,- inputs里所有 Variable 的- stop_gradient属性应该被设为- True,否则可能会在反向传播(backward propagation)中得到意想不到的结果。
- 本 API 只能被运行在静态图模式下。 
参数¶
forward_fn (callable) - 一个前向传播(forward propagation)时被调用的 callable。
inputs (list[Variable]) - Variable 类型列表,其含义为
forward_fn的输入 Variable。
backward_fn (callable,可选) - 一个反向传播(backward propagation)时被调用的 callable。默认值:
None,表示不需要进行反向传播(backward propagation)。
name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值:
None。
返回¶
Variable|list(Variable)|tuple(Variable),该 API 返回 forward_fn(inputs)。
代码示例¶
>>> import paddle
>>> import numpy as np
>>> paddle.enable_static()
>>> def forward_fn(x):
...     return paddle.exp(x)
>>> def backward_fn(dy):
...     return 2 * paddle.exp(dy)
>>> main_program = paddle.static.Program()
>>> start_program = paddle.static.Program()
>>> place = paddle.CPUPlace()
>>> exe = paddle.static.Executor(place)
>>> with paddle.static.program_guard(main_program, start_program):
...     data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
...     data.stop_gradient = False
...     ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
...     data_grad = paddle.static.gradients([ret], data)[0]
>>> exe.run(start_program)
>>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
>>> x, x_grad, y = exe.run(
...     main_program,
...     feed={"X": x},
...     fetch_list=[
...         data.name,
...         data_grad.name,
...         ret.name
...     ],
... )
>>> print(x)
[[1. 2. 3. 4. 5.]]
>>> print(x_grad)
[[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
>>> print(y)
[[  2.7182817   7.389056   20.085537   54.59815   148.41316  ]]