dtensor_from_fn

class paddle.distributed. dtensor_from_fn ( fn, mesh, placements, *args, **kwargs ) [源代码]

通过一个 paddle API(一般是 Tensor 创建类的 API )结合分布式属性 placements 创建一个带分布式属性的 Tensor。

参数

  • fn - paddle 公开的可创建 Tensor 的 API。例如: emptyoneszeros 等 paddle API。

  • mesh (paddle.distributed.ProcessMesh) - 表示进程拓扑信息的 ProcessMesh 对象。

  • placements (list(Placement)) - 分布式 Tensor 的切分表示列表,描述 Tensor 在 mesh 上如何切分。

  • *args - fn 函数的输入参数( Tuple 形式)

  • **kwargs - fn 函数的输入参数( Dict 形式)

返回

带有分布式信息的 Tensor

代码示例

>>> import paddle
>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1])
>>> print(d_tensor)