[ 输入参数用法不一致 ]torch.distributed.all_gather_object

torch.distributed.all_gather_object

torch.distributed.all_gather_object(object_list, obj, group=None)

paddle.distributed.all_gather_object

paddle.distributed.all_gather_object(object_list, obj, group=None)

功能一致,参数几乎完全一致。但object_list的初始化方式不同。具体如下:

参数映射

PyTorch PaddlePaddle 备注
object_list object_list 表示用于保存聚合结果的列表。PyTorch 需初始化成与 group 等长的列表, Paddle 需初始化为空列表,需要转写。
obj obj 表示待聚合的对象。
group group 表示执行该操作的进程组实例。

转写示例

object_list:保存聚合结果列表

# PyTorch 写法
import torch.distributed as dist
object_list = [{}, {}] # NOTE: world size is 2
if dist.get_rank() == 0:
    obj = {"foo": [1, 2, 3]}
else:
    obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)

# Paddle 写法
import paddle.distributed as dist
object_list = [] # No need to pre-allocate
if dist.get_rank() == 0:
    obj = {"foo": [1, 2, 3]}
else:
    obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)