ChannelShuffle

class paddle.nn. ChannelShuffle ( groups, data_format='NCHW', name=None ) [source]

Can divide channels in a tensor of shape [N, C, H, W] or [N, H, W, C] into g groups, getting a tensor with the shape of [N, g, C/g, H, W] or [N, H, W, g, C/g], and transposes them as [N, C/g, g, H, W] or [N, H, W, g, C/g], then rearranges them to original tensor shape. This operation can improve the interaction between channels, using features efficiently. Please refer to the paper: ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices . by Zhang et. al (2017) for more details.

Parameters
  • groups (int) – Number of groups to divide channels in.

  • data_format (str, optional) – The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width].

  • name (str, optional) – Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to Name.

Shape:
  • x: 4-D tensor with shape of [N, C, H, W] or [N, H, W, C].

  • out: 4-D tensor with shape and dtype same as x.

Examples

>>> import paddle
>>> import paddle.nn as nn
>>> x = paddle.arange(0, 0.6, 0.1, 'float32')
>>> x = paddle.reshape(x, [1, 6, 1, 1])
>>> print(x)
Tensor(shape=[1, 6, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[0.        ]],
  [[0.10000000]],
  [[0.20000000]],
  [[0.30000001]],
  [[0.40000001]],
  [[0.50000000]]]])
>>> channel_shuffle = nn.ChannelShuffle(3)
>>> y = channel_shuffle(x)
>>> print(y)
Tensor(shape=[1, 6, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[0.        ]],
  [[0.20000000]],
  [[0.40000001]],
  [[0.10000000]],
  [[0.30000001]],
  [[0.50000000]]]])
forward ( x )

forward

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

extra_repr ( )

extra_repr

Extra representation of this layer, you can have custom implementation of your own layer.