TreeConv(feature_size, output_size, num_filters=1, max_depth=2, act='tanh', param_attr=None, bias_attr=None, name=None, dtype='float32')
This interface is used to construct a callable object of the
TreeConvclass. For more details, refer to code examples. Tree-Based Convolution is a kind of convolution based on tree structure. Tree-Based Convolution is a part of Tree-Based Convolution Neural Network(TBCNN), which is used to classify tree structures, such as Abstract Syntax Tree. Tree-Based Convolution proposed a kind of data structure called continuous binary tree, which regards multiway tree as binary tree. The paper of Tree-Based Convolution Operator is here: tree-based convolution .
feature_size (int) – last dimension of nodes_vector.
output_size (int) – output feature width.
num_filters (int, optional) – number of filters, Default: 1.
max_depth (int, optional) – max depth of filters, Default: 2.
act (str, optional) – activation function, Default: tanh.
param_attr (ParamAttr, optional) – the parameter attribute for the filters, Default: None.
bias_attr (ParamAttr, optional) – the parameter attribute for the bias of this layer, Default: None.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .
dtype (str, optional) – Data type, it can be “float32” or “float64”. Default: “float32”.
weight (Parameter): the learnable weights of filters of this layer.
bias (Parameter or None): the learnable bias of this layer.
import paddle.fluid as fluid import numpy with fluid.dygraph.guard(): nodes_vector = numpy.random.random((1, 10, 5)).astype('float32') edge_set = numpy.random.random((1, 9, 2)).astype('int32') treeConv = fluid.dygraph.nn.TreeConv( feature_size=5, output_size=6, num_filters=1, max_depth=2) ret = treeConv(fluid.dygraph.base.to_variable(nodes_vector), fluid.dygraph.base.to_variable(edge_set))
Defines the computation performed at every call. Should be overridden by all subclasses.
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments