squeeze

paddle. squeeze ( x, axis=None, name=None ) [source]

This OP will squeeze the dimension(s) of size 1 of input tensor x’s shape.

Note that the output Tensor will share data with origin Tensor and doesn’t have a Tensor copy in dygraph mode. If you want to use the Tensor copy version, please use Tensor.clone like squeeze_clone_x = x.squeeze().clone().

If axis is provided, it will remove the dimension(s) by given axis that of size 1. If the dimension of given axis is not of size 1, the dimension remain unchanged. If axis is not provided, all dims equal of size 1 will be removed.

Case1:

  Input:
    x.shape = [1, 3, 1, 5]  # If axis is not provided, all dims equal of size 1 will be removed.
    axis = None
  Output:
    out.shape = [3, 5]

Case2:

  Input:
    x.shape = [1, 3, 1, 5]  # If axis is provided, it will remove the dimension(s) by given axis that of size 1.
    axis = 0
  Output:
    out.shape = [3, 1, 5]

Case4:

  Input:
    x.shape = [1, 3, 1, 5]  # If the dimension of one given axis (3) is not of size 1, the dimension remain unchanged.
    axis = [0, 2, 3]
  Output:
    out.shape = [3, 5]

Case4:

  Input:
    x.shape = [1, 3, 1, 5]  # If axis is negative, axis = axis + ndim (number of dimensions in x).
    axis = [-2]
  Output:
    out.shape = [1, 3, 5]
Parameters
  • x (Tensor) – The input Tensor. Supported data type: float32, float64, bool, int8, int32, int64.

  • axis (int|list|tuple, optional) – An integer or list of integers, indicating the dimensions to be squeezed. Default is None. The range of axis is \([-ndim(x), ndim(x))\). If axis is negative, \(axis = axis + ndim(x)\). If axis is None, all the dimensions of x of size 1 will be removed.

  • name (str, optional) – Please refer to Name, Default None.

Returns

Squeezed Tensor with the same data type as input Tensor.

Return type

Tensor

Examples

import paddle

x = paddle.rand([5, 1, 10])
output = paddle.squeeze(x, axis=1)

print(x.shape)  # [5, 1, 10]
print(output.shape)  # [5, 10]

# output shares data with x in dygraph mode
x[0, 0, 0] = 10.
print(output[0, 0]) # [10.]