IndependentTransform

class paddle.distribution. IndependentTransform ( base, reinterpreted_batch_rank ) [源代码]

IndependentTransform 将一个基础变换 base 的部分批(batch)维度 reinterpreted_batch_rank 扩展为事件(event)维度。

IndependentTransform 不改变基础变换 forward 以及 inverse 计算结果,但会对 forward_log_det_jacobian 以及 inverse_log_det_jacobian 计算结果沿着扩展的维度进行求和。

例如,假设基础变换为 ExpTransform,其输入为一个随机采样结果 x,形状为 \((S=[4], B=[2,2], E=[3])\) , \(S\)\(B\)\(E\) 分别表示采样形状、批形状、事件形状,reinterpreted_batch_rank=1。则 IndependentTransform(ExpTransform) 变换后,x 的形状为 \((S=[4], B=[2], E=[2,3])\),即将最右侧的批维度作为事件维度。此时 forwardinverse 输出形状仍是 \([4, 2, 2, 3]\),但 forward_log_det_jacobian 以及 inverse_log_det_jacobian 输出形状为 \([4, 2]\)

参数

  • base (Transform) - 基础变换。

  • reinterpreted_batch_rank (int) - 被扩展为事件维度的最右侧批维度数量,需大于 0。

代码示例

>>> import paddle

>>> x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]])

>>> # Exponential transform with event_rank = 1
>>> multi_exp = paddle.distribution.IndependentTransform(
...     paddle.distribution.ExpTransform(), 1)
>>> print(multi_exp.forward(x))
Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
        [[2.71828175  , 7.38905621  , 20.08553696 ],
         [54.59814835 , 148.41316223, 403.42880249]])
>>> print(multi_exp.forward_log_det_jacobian(x))
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
        [6. , 15.])

方法

forward(x)

计算正变换 \(y=f(x)\) 的结果。

参数

  • x (Tensor) - 正变换输入参数,通常为 Distribution 的随机采样结果。

返回

Tensor,正变换的计算结果。

inverse(y)

计算逆变换 \(x = f^{-1}(y)\)

参数

  • y (Tensor) - 逆变换的输入参数。

返回

Tensor,逆变换的计算结果。

forward_log_det_jacobian(x)

计算正变换雅可比行列式绝对值的对数。

如果变换不是一一映射,则雅可比矩阵不存在,返回 NotImplementedError

参数

  • x (Tensor) - 输入参数。

返回

Tensor,正变换雅可比行列式绝对值的对数。

inverse_log_det_jacobian(y)

计算逆变换雅可比行列式绝对值的对数。

forward_log_det_jacobian 互为负数。

参数

  • y (Tensor) - 输入参数。

返回

Tensor,逆变换雅可比行列式绝对值的对数。

forward_shape(shape)

推断正变换输出形状。

参数

  • shape (Sequence[int]) - 正变换输入的形状。

返回

Sequence[int],正变换输出的形状。

inverse_shape(shape)

推断逆变换输出形状。

参数

  • shape (Sequence[int]) - 逆变换输入的形状。

返回

Sequence[int],逆变换输出的形状。