pca_lowrank
计算在稀疏矩阵上的线性主成分分析(PCA)。
记 \(X\) 为一个稀疏矩阵,输出结果满足:
         \[X = U * diag(S) * V^{T}\]
       
 
       参数
x (Tensor) - 输入的需要进行线性主成分分析的一个稀疏方阵,类型为 Tensor。
x的形状应为[M, N],数据类型支持 float32, float64。
q (int,可选) - 对输入 \(X\) 的秩稍微高估的预估值,默认值是 \(q=min(6,N,M)\)。
center (bool,可选) - 是否对输入矩阵进行中心化操作,类型为 bool ,默认为 True。
name (str,可选) - 具体用法请参见 api_guide_Name,一般无需设置,默认值为 None。
返回
Tensor U,形状为 N x q 的矩阵。
Tensor S,长度为 q 的向量。
Tensor V,形状为 M x q 的矩阵。
tuple (U, S, V): 对输入 \(X\) 的奇异值分解的近似最优解。
代码示例
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> format = "coo"
>>> paddle.seed(2023)
>>> dense_x = paddle.randn((5, 5), dtype='float64')
>>> if format == "coo":
...     sparse_x = dense_x.to_sparse_coo(len(dense_x.shape))
>>> else:
...     sparse_x = dense_x.to_sparse_csr()
>>> print("sparse.pca_lowrank API only support CUDA 11.x")
>>> # U, S, V = None, None, None
>>> # use code blow when your device CUDA version >= 11.0
>>> U, S, V = paddle.sparse.pca_lowrank(sparse_x)
>>> print(U)
Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [[-0.31412600,  0.44814876,  0.18390454, -0.19967630, -0.79170452],
        [-0.31412600,  0.44814876,  0.18390454, -0.58579808,  0.56877700],
        [-0.31412600,  0.44814876,  0.18390454,  0.78547437,  0.22292751],
        [-0.38082462,  0.10982129, -0.91810233,  0.00000000,  0.00000000],
        [ 0.74762770,  0.62082796, -0.23585052,  0.00000000, -0.00000000]])
>>> print(S)
Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [1.56031096, 1.12956227, 0.27922715, 0.00000000, 0.00000000])
>>> print(V)
Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [[ 0.88568469, -0.29081908,  0.06163676,  0.19597228, -0.29796422],
        [-0.26169364, -0.27616183,  0.43148760, -0.42522796, -0.69874939],
        [ 0.28587685,  0.30695344, -0.47790836, -0.76982533, -0.05501437],
        [-0.23958121, -0.62770647, -0.71141770,  0.11463224, -0.17125926],
        [ 0.08918713, -0.59238761,  0.27478686, -0.41833534,  0.62498824]])