sequence_topk_avg_pooling

paddle.fluid.contrib.layers.nn. sequence_topk_avg_pooling ( input, row, col, topks, channel_num ) [source]

The topks is a list with incremental values in this function. For each topk, it will average the topk features as an output feature for each channel of every input sequence. Both row and col are LodTensor, which provide height and width information for input tensor. If feature size of input sequence is less than topk, it will padding 0 at the back.

If channel_num is 2 and given row LoDTensor and col LoDTensor as follows:
    row.lod = [[5, 4]]
    col.lod = [[6, 7]]

input is a LoDTensor with input.lod[0][i] = channel_num * row.lod[0][i] * col.lod[0][i]
    input.lod = [[60, 56]]  # where 60 = channel_num * 5 * 6
    input.dims = [116, 1]   # where 116 = 60 + 56

If topks is [1, 3, 5], then we get a 1-level LoDTensor:
    out.lod =  [[5, 4]]     # share Lod info with row LodTensor
    out.dims = [9, 6]       # where 6 = len(topks) * channel_num
Parameters
  • input (Variable) – The input should be 2D LodTensor with dims[1] equals 1.

  • row (Variable) – The row should be 1-level LodTensor to provide the height information of the input tensor data.

  • col (Variable) – The col should be 1-level LodTensor to provide the width information of the input tensor data.

  • topks (list) – A list of incremental value to average the topk feature.

  • channel_num (int) – The number of input channel.

Returns

output LodTensor specified by this layer.

Return type

Variable

Examples

import numpy as np
from paddle.fluid import layers
from paddle.fluid import contrib

x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = contrib.sequence_topk_avg_pooling(input=x_lod_tensor,
                                       row=row_lod_tensor,
                                       col=col_lod_tensor,
                                       topks=[1, 3, 5],
                                       channel_num=5)