jvp¶
计算函数 func 在 xs 处的雅可比矩阵与向量 v 的乘积。
警告
该 API 目前为 Beta 版本,函数签名在未来版本可能发生变化。
参数¶
- func (Callable) - Python 函数,输入参数为 - xs,输出为 Tensor 或 Tensor 序列。
- xs (Tensor|Sequence[Tensor]) - 函数 - func的输入参数,数据类型为 Tensor 或 Tensor 序列。
- v (Tensor|Sequence[Tensor]|None,可选) - 用于计算 - jvp的输入向量,形状要求 与- xs一致。默认值为- None,即相当于形状与- xs一致,值全为 1 的 Tensor 或 Tensor 序列。
返回¶
- func_out (Tensor|tuple[Tensor]) - 函数 - func(xs)的输出。
- jvp (Tensor|tuple[Tensor]) - - jvp计算结果。
代码示例¶
>>> import paddle
>>> def func(x):
...     return paddle.matmul(x, x)
...
>>> x = paddle.ones(shape=[2, 2], dtype='float32')
>>> _, jvp_result = paddle.incubate.autograd.jvp(func, x)
>>> print(jvp_result)
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [[4., 4.],
        [4., 4.]])
>>> v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]])
>>> _, jvp_result = paddle.incubate.autograd.jvp(func, x, v)
>>> print(jvp_result)
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [[2., 1.],
        [1., 0.]])