affine_grid

paddle.nn.functional. affine_grid ( theta, out_shape, align_corners=True, name=None ) [source]

It generates a grid of (x,y) or (x,y,z) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map.

Parameters
  • theta (Tensor) – The data type can be float32 or float64.

  • out_shape (Tensor | list | tuple) – Type can be a 1-D Tensor, list, or tuple. It is used to represent the shape of the output in an affine transformation, in the format [N, C, H, W] or [N, C, D, H, W]. When the format is [N, C, H, W], it represents the batch size, number of channels, height and width. When the format is [N, C, D, H, W], it represents the batch size, number of channels, depth, height and width. The data type must be int32.

  • align_corners (bool, optional) – if True, aligns the centers of the 4 (4D) or 8 (5D) corner pixels of the input and output tensors, and preserves the value of the corner pixels. Default: True

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensor, A Tensor with shape [batch_size, H, W, 2] or [batch, D, H, W, 3] while (‘D’)’H’, ‘W’ are the (depth)height, width of feature map in affine transformation. The data type is the same as theta.

Examples

>>> import paddle
>>> import paddle.nn.functional as F
>>> # theta.shape = [1, 2, 3]
>>> theta = paddle.to_tensor([[[-0.7, -0.4, 0.3],
...                            [ 0.6,  0.5, 1.5]]], dtype="float32")
>>> y_t = F.affine_grid(
...     theta,
...     [1, 2, 3, 3],
...     align_corners=False
... )
>>> print(y_t)
Tensor(shape=[1, 3, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[ 1.03333330,  0.76666665],
   [ 0.56666672,  1.16666663],
   [ 0.10000002,  1.56666672]],
  [[ 0.76666665,  1.09999990],
   [ 0.30000001,  1.50000000],
   [-0.16666666,  1.90000010]],
  [[ 0.50000000,  1.43333328],
   [ 0.03333333,  1.83333337],
   [-0.43333334,  2.23333335]]]])