maxout

paddle.nn.functional. maxout ( x, groups, axis=1, name=None ) [source]

maxout activation.

Assumed the input shape is (N, Ci, H, W). The output shape is (N, Co, H, W). Then Co = Ci/groups and the operator formula is as follows:

\[\begin{split}\begin{array}{l} &out_{si+j} = \max_{k} x_{gsi + sk + j} \\ &g = groups \\ &s = \frac{input.size}{num\_channels} \\ &0 \le i < \frac{num\_channels}{groups} \\ &0 \le j < s \\ &0 \le k < groups \end{array}\end{split}\]
Parameters
  • x (Tensor) – The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type of input is float16, float32 or float64.

  • groups (int) – The groups number of maxout. groups specifies the index of channel dimension where maxout will be performed. This must be a factor of number of features.

  • axis (int, optional) – The axis along which to perform maxout calculations. It should be 1 when data format is NCHW, be -1 or 3 when data format is NHWC. If axis < 0, it works the same way as \(axis + D\) , where D is the dimensions of x . axis only supports 1, 3 or -1. Default is 1.

  • name (str, optional) – For details, please refer to Name. Generally, no setting is required. Default: None.

Returns

A Tensor with the same data type as x .

Examples

>>> import paddle
>>> import paddle.nn.functional as F

>>> paddle.seed(2023)
>>> x = paddle.rand([1, 2, 3, 4])
>>> print(x)
Tensor(shape=[1, 2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[0.86583614, 0.52014720, 0.25960937, 0.90525323],
   [0.42400089, 0.40641287, 0.97020894, 0.74437362],
   [0.51785129, 0.73292869, 0.97786582, 0.04315904]],
  [[0.42639419, 0.71958369, 0.20811461, 0.19731510],
   [0.38424349, 0.14603184, 0.22713774, 0.44607511],
   [0.21657862, 0.67685395, 0.46460176, 0.92382854]]]])
>>> out = F.maxout(x, groups=2)
>>> print(out)
Tensor(shape=[1, 1, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[0.86583614, 0.71958369, 0.25960937, 0.90525323],
   [0.42400089, 0.40641287, 0.97020894, 0.74437362],
   [0.51785129, 0.73292869, 0.97786582, 0.92382854]]]])