dtensor_from_fn

paddle.distributed. dtensor_from_fn ( fn, mesh, placements, *args, **kwargs ) [source]

Construct a Distributed Tensor from a function of arguments.

Parameters
  • fn (callable) – A callable function that takes arguments of Distributed Tensor and returns tensor.

  • mesh (paddle.distributed.ProcessMesh) – The ProcessMesh object describes the Cartesian topology of the used processes.

  • placements (list[paddle.distributed.Placement]) – the placements describe how to place the tensor on ProcessMesh, it can be Shard, Replicate and Partial.

  • *args (tuple) – A tuple of arguments to be passed to the fn function.

  • **kwargs (dict) – A dict of arguments to be passed to the fn function.

Retruns:

Tensor: A Tensor constructed from fn with distributed attributes.

Examples

>>> import paddle
>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1])
>>> print(d_tensor)