send_uv

paddle.geometric. send_uv ( x, y, src_index, dst_index, message_op='add', name=None ) [source]

Graph Learning message passing api.

This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take x as the source node feature tensor, take y as the destination node feature tensor. Then we use src_index and dst_index to gather the corresponding data, and then compute the edge features in different message_ops like add, sub, mul, div.

Given:

x = [[0, 2, 3],
     [1, 4, 5],
     [2, 6, 7]]

y = [[0, 1, 2],
     [2, 3, 4],
     [4, 5, 6]]

src_index = [0, 1, 2, 0]

dst_index = [1, 2, 1, 0]

message_op = "add"

Then:

out = [[2, 5, 7],
       [5, 9, 11],
       [4, 9, 11],
       [0, 3, 5]]
Parameters
  • x (Tensor) – The source node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.

  • y (Tensor) – The destination node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.

  • src_index (Tensor) – An 1-D tensor, and the available data type is int32, int64.

  • dst_index (Tensor) – An 1-D tensor, and should have the same shape as src_index. The available data type is int32, int64.

  • message_op (str) – Different message ops for x and y, including add, sub, mul and div.

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

  • out (Tensor), the output tensor.

Examples

>>> import paddle

>>> x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
>>> y = paddle.to_tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
>>> indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
>>> src_index = indexes[:, 0]
>>> dst_index = indexes[:, 1]
>>> out = paddle.geometric.send_uv(x, y, src_index, dst_index, message_op="add")
>>> print(out.numpy())
[[ 2. 5. 7.]
 [ 5. 9. 11.]
 [ 4. 9. 11.]
 [ 0. 3. 5.]]