py_func

paddle.static. py_func ( func, x, out, backward_func=None, skip_vars_in_backward_input=None ) [源代码]

PaddlePaddle 通过 py_func 在 Python 端注册 OP。py_func 的设计原理在于 Paddle 中的 Tensor 与 numpy 数组可以方便的互相转换,从而可使用 Python 中的 numpy API 来自定义一个 Python OP。

该自定义的 Python OP 的前向函数是 func,反向函数是 backward_func 。 Paddle 将在前向部分调用 func,并在反向部分调用 backward_func (如果 backward_func 不是 None)。 xfunc 的输入,必须为 Tensor 类型;outfunc 的输出,既可以是 Tensor 类型,也可以是 numpy 数组。

反向函数 backward_func 的输入依次为:前向输入 x 、前向输出 outout 的梯度。如果 out 的某些输出没有梯度,则 backward_func 的相关输入为 None。如果 x 的某些变量没有梯度,则用户应在 backward_func 中主动返回 None。

在调用该接口之前,还应正确设置 out 的数据类型和形状,而 outx 对应梯度的数据类型和形状将自动推断而出。

此功能还可用于调试正在运行的网络,可以通过添加没有输出的 py_func 运算,并在 func 中打印输入 x

参数

  • func (callable) - 所注册的 Python OP 的前向函数,运行网络时,将根据该函数与前向输入 x,计算前向输出 out。在 func 建议先主动将 Tensor 转换为 numpy 数组,方便灵活的使用 numpy 相关的操作,如果未转换成 numpy,则可能某些操作无法兼容。

  • x (Tensor|tuple(Tensor)|list[Tensor]) - 前向函数 func 的输入,多个 Tensor 以 tuple(Tensor)或 list[Tensor]的形式传入。

  • out (T|tuple(T)|list[T]) - 前向函数 func 的输出,可以为 T|tuple(T)|list[T],其中 T 既可以为 Tensor,也可以为 numpy 数组。由于 Paddle 无法自动推断 out 的形状和数据类型,必须应事先创建 out

  • backward_func (callable,可选) - 所注册的 Python OP 的反向函数。默认值为 None,意味着没有反向计算。若不为 None,则会在运行网络反向时调用 backward_func 计算 x 的梯度。

  • skip_vars_in_backward_input (Tensor,可选) - backward_func 的输入中不需要的变量,可以是 Tensor|tuple(Tensor)|list[Tensor]。这些变量必须是 xout 中的一个。默认值为 None,意味着没有变量需要从 xout 中去除。若不为 None,则这些变量将不是 backward_func 的输入。该参数仅在 backward_func 不为 None 时有用。

返回

Tensor|tuple(Tensor)|list[Tensor],前向函数的输出 out

代码示例 1

 >>> import paddle
 >>> import numpy as np

 >>> np.random.seed(1107)
 >>> paddle.seed(1107)

 >>> paddle.enable_static()
 >>> # Creates a forward function, Tensor can be input directly without
 >>> # being converted into numpy array.
 >>> def tanh(x):
 ...     return np.tanh(x)

 >>> # Skip x in backward function and return the gradient of x
 >>> # Tensor must be actively converted to numpy array, otherwise,
 >>> # operations such as +/- can't be used.
 >>> def tanh_grad(y, dy):
 ...     return np.array(dy) * (1 - np.square(np.array(y)))

 >>> # Creates a forward function for debugging running networks(print value)
 >>> def debug_func(x):
 ...     # print(x)
 ...     pass
 >>> def create_tmp_var(name, dtype, shape):
 ...     return paddle.static.default_main_program().current_block().create_var(
 ...         name=name, dtype=dtype, shape=shape)
 >>> def simple_net(img, label):
 ...     hidden = img
 ...     for idx in range(4):
 ...         hidden = paddle.static.nn.fc(hidden, size=200)
 ...         new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
 ...             dtype=hidden.dtype, shape=hidden.shape)
 ...         # User-defined forward and backward
 ...         hidden = paddle.static.py_func(func=tanh, x=hidden,
 ...             out=new_hidden, backward_func=tanh_grad,
 ...             skip_vars_in_backward_input=hidden)
 ...         # User-defined debug functions that print out the input Tensor
 ...         paddle.static.py_func(func=debug_func, x=hidden, out=None)
 ...     prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
 ...     ce_loss = paddle.nn.loss.CrossEntropyLoss()
 ...     return ce_loss(prediction, label)
 >>> x = paddle.static.data(name='x', shape=[1,4], dtype='float32')
 >>> y = paddle.static.data(name='y', shape=[1], dtype='int64')
 >>> res = simple_net(x, y)
 >>> exe = paddle.static.Executor(paddle.CPUPlace())
 >>> exe.run(paddle.static.default_startup_program())
 >>> input1 = np.random.random(size=[1,4]).astype('float32')
 >>> input2 = np.random.randint(1, 10, size=[1], dtype='int64')
 >>> out = exe.run(paddle.static.default_main_program(),
 ...                 feed={'x':input1, 'y':input2},
 ...                 fetch_list=[res.name])
 >>> print(out[0].shape)
 ()

代码示例 2

 >>> # This example shows how to turn Tensor into numpy array and
 >>> # use numpy API to register an Python OP
 >>> import paddle
 >>> import numpy as np

 >>> np.random.seed(1107)
 >>> paddle.seed(1107)

 >>> paddle.enable_static()
 >>> def element_wise_add(x, y):
 ...     # Tensor must be actively converted to numpy array, otherwise,
 ...     # numpy.shape can't be used.
 ...     x = np.array(x)
 ...     y = np.array(y)
 ...     if x.shape != y.shape:
 ...         raise AssertionError("the shape of inputs must be the same!")
 ...     result = np.zeros(x.shape, dtype='int32')
 ...     for i in range(len(x)):
 ...         for j in range(len(x[0])):
 ...             result[i][j] = x[i][j] + y[i][j]
 ...     return result
 >>> def create_tmp_var(name, dtype, shape):
 ...     return paddle.static.default_main_program().current_block().create_var(
 ...                 name=name, dtype=dtype, shape=shape)
 >>> def py_func_demo():
 ...     start_program = paddle.static.default_startup_program()
 ...     main_program = paddle.static.default_main_program()
 ...     # Input of the forward function
 ...     x = paddle.static.data(name='x', shape=[2, 3], dtype='int32')
 ...     y = paddle.static.data(name='y', shape=[2, 3], dtype='int32')
 ...     # Output of the forward function, name/dtype/shape must be specified
 ...     output = create_tmp_var('output','int32', [3, 1])
 ...     # Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor]
 ...     paddle.static.py_func(func=element_wise_add, x=[x, y], out=output)
 ...     exe=paddle.static.Executor(paddle.CPUPlace())
 ...     exe.run(start_program)
 ...     # Feed numpy array to main_program
 ...     input1 = np.random.randint(1, 10, size=[2, 3], dtype='int32')
 ...     input2 = np.random.randint(1, 10, size=[2, 3], dtype='int32')
 ...     out = exe.run(main_program,
 ...                feed={'x':input1, 'y':input2},
 ...                fetch_list=[output.name])
 ...     print("{0} + {1} = {2}".format(input1, input2, out))
 >>> py_func_demo()
 >>> # [[1 5 4]   + [[3 7 7]  =  [array([[ 4, 12, 11]
 >>> #  [9 4 8]]     [2 3 9]]            [11,  7, 17]], dtype=int32)]