triu

paddle. triu ( x, diagonal=0, name=None ) [source]

Return the upper triangular part of a matrix (2-D tensor) or batch of matrices x, the other elements of the result tensor are set to 0. The upper triangular part of the matrix is defined as the elements on and above the diagonal.

Parameters
  • x (Tensor) – The input x which is a Tensor. Support data types: float64, float32, int32, int64, complex64, complex128.

  • diagonal (int, optional) – The diagonal to consider, default value is 0. If diagonal = 0, all elements on and above the main diagonal are retained. A positive value excludes just as many diagonals above the main diagonal, and similarly a negative value includes just as many diagonals below the main diagonal. The main diagonal are the set of indices \(\{(i, i)\}\) for \(i \in [0, \min\{d_{1}, d_{2}\} - 1]\) where \(d_{1}, d_{2}\) are the dimensions of the matrix.

  • name (str, optional) – For details, please refer to Name. Generally, no setting is required. Default: None.

Returns

Results of upper triangular operation by the specified diagonal of input tensor x, it’s data type is the same as x’s Tensor.

Return type

Tensor

Examples

>>> import paddle

>>> x = paddle.arange(1, 13, dtype="int64").reshape([3,-1])
>>> print(x)
Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1 , 2 , 3 , 4 ],
 [5 , 6 , 7 , 8 ],
 [9 , 10, 11, 12]])

>>> # example 1, default diagonal
>>> triu1 = paddle.tensor.triu(x)
>>> print(triu1)
Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1 , 2 , 3 , 4 ],
 [0 , 6 , 7 , 8 ],
 [0 , 0 , 11, 12]])

>>> # example 2, positive diagonal value
>>> triu2 = paddle.tensor.triu(x, diagonal=2)
>>> print(triu2)
Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[0, 0, 3, 4],
 [0, 0, 0, 8],
 [0, 0, 0, 0]])

>>> # example 3, negative diagonal value
>>> triu3 = paddle.tensor.triu(x, diagonal=-1)
>>> print(triu3)
Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1 , 2 , 3 , 4 ],
 [5 , 6 , 7 , 8 ],
 [0 , 10, 11, 12]])