tensordot
Tensor 缩并运算(Tensor Contraction),即沿着 axes 给定的多个轴对两个 Tensor 对应元素的乘积进行加和操作。
图解说明
可以选择沿一个或多个轴进行点积操作,操作后返回的结果张量维度是 a 和 b 上未参与点积的维度的并集。 图例中展示了一个 shape = [2,2,2]的 a 张量,和 shape = [2,3]的 b 张量。 shape = [2,2,3]的 res 张量为 a,b 两个张量沿着 a 张量的最后一个轴和 b 张量的第一个轴进行缩并的结果(即参数 axes = 1 的情况)
 
 
       参数
x (Tensor)- 缩并运算操作的左 Tensor,数据类型为
float16或float32或float64。
y (Tensor)- 缩并运算操作的右 Tensor,与
x具有相同的数据类型。
axes (int|tuple|list|Tensor)- 指定对
x和y做缩并运算的轴,默认值为整数 2。
axes可以是一个非负整数。若输入的是一个整数n,则表示对x的后n个轴和对y的前n个轴进行缩并运算。
axes可以是一个一维的整数 tuple 或 list,表示x和y沿着相同的轴方向进行缩并运算。例如,axes=[0, 1]表示x的前两个轴和y的前两个轴对应进行缩并运算。
axes可以是一个 tuple 或 list,其中包含一个或两个一维的整数 tuple|list|Tensor。如果axes包含一个 tuple|list|Tensor,则对x和y的相同轴做缩并运算,具体轴下标由该 tuple|list|Tensor 中的整数值指定。如果axes包含两个 tuple|list|Tensor,则第一个指定x做缩并运算的轴下标,第二个指定y的对应轴下标。如果axes包含两个以上的 tuple|list|Tensor,只有前两个会被作为轴下标序列使用,其它的将被忽略。
axes可以是一个 Tensor,这种情况下该 Tensor 会被转换成 list,然后应用前述规则确定做缩并运算的轴。请注意,输入 Tensor 类型的axes只在动态图模式下可用。
name (str,可选) - 具体用法请参见 api_guide_Name,一般无需设置,默认值为 None。
返回
一个 Tensor,表示 Tensor 缩并的结果,数据类型与 x 和 y 相同。一般情况下,有 \(output.ndim = x.ndim + y.ndim - 2 \times n_{axes}\),其中 \(n_{axes}\) 表示做 Tensor 缩并的轴数量。
备注
- 本 API 支持 Tensor 维度广播, - x和- y做缩并操作的对应维度 size 必须相等,或适用于广播规则。
- 本 API 支持 axes 扩展,当指定的 - x和- y两个轴序列长短不一时,短的序列会自动在末尾补充和长序列相同的轴下标。例如,如果输入- axes=[[0, 1, 2, 3], [1, 0]],则指定- x的轴序列是[0, 1, 2, 3],对应- y的轴序列会自动从[1,0]扩展成[1, 0, 2, 3]。
代码示例
>>> import paddle
>>> from typing import Literal
>>> data_type: Literal["float64"] = 'float64'
>>> # For two 2-d tensor x and y, the case axes=0 is equivalent to outer product.
>>> # Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases.
>>> x = paddle.arange(4, dtype=data_type).reshape([2, 2])
>>> y = paddle.arange(4, dtype=data_type).reshape([2, 2])
>>> z = paddle.tensordot(x, y, axes=0)
>>> print(z)
Tensor(shape=[2, 2, 2, 2], dtype=float64, place=Place(cpu), stop_gradient=True,
 [[[[0., 0.],
    [0., 0.]],
   [[0., 1.],
    [2., 3.]]],
  [[[0., 2.],
    [4., 6.]],
   [[0., 3.],
    [6., 9.]]]])
>>> # For two 1-d tensor x and y, the case axes=1 is equivalent to inner product.
>>> x = paddle.arange(10, dtype=data_type)
>>> y = paddle.arange(10, dtype=data_type)
>>> z1 = paddle.tensordot(x, y, axes=1)
>>> z2 = paddle.dot(x, y)
>>> print(z1)
Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True,
285.)
>>> print(z2)
Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True,
285.)
>>> # For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication.
>>> x = paddle.arange(6, dtype=data_type).reshape([2, 3])
>>> y = paddle.arange(12, dtype=data_type).reshape([3, 4])
>>> z1 = paddle.tensordot(x, y, axes=1)
>>> z2 = paddle.matmul(x, y)
>>> print(z1)
Tensor(shape=[2, 4], dtype=float64, place=Place(cpu), stop_gradient=True,
[[20., 23., 26., 29.],
 [56., 68., 80., 92.]])
>>> print(z2)
Tensor(shape=[2, 4], dtype=float64, place=Place(cpu), stop_gradient=True,
[[20., 23., 26., 29.],
 [56., 68., 80., 92.]])
>>> # When axes is a 1-d int list, x and y will be contracted along the same given axes.
>>> # Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]].
>>> x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4])
>>> y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4])
>>> z = paddle.tensordot(x, y, axes=[1, 2])
>>> print(z)
Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
[[506. , 1298., 2090.],
 [1298., 3818., 6338.]])
>>> # When axes is a list containing two 1-d int list, the first will be applied to x and the second to y.
>>> x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5])
>>> y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2])
>>> z = paddle.tensordot(x, y, axes=([1, 0], [0, 1]))
>>> print(z)
Tensor(shape=[5, 2], dtype=float64, place=Place(cpu), stop_gradient=True,
[[4400., 4730.],
 [4532., 4874.],
 [4664., 5018.],
 [4796., 5162.],
 [4928., 5306.]])
>>> # Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]].
>>> x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6])
>>> y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6])
>>> z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]])
>>> print(z)
Tensor(shape=[4, 4], dtype=float64, place=Place(cpu), stop_gradient=True,
[[23217330., 24915630., 26613930., 28312230.],
 [24915630., 26775930., 28636230., 30496530.],
 [26613930., 28636230., 30658530., 32680830.],
 [28312230., 30496530., 32680830., 34865130.]])