shard_tensor

paddle.distributed. shard_tensor ( data, mesh, placements, dtype=None, place=None, stop_gradient=None ) [source]

Creates a distributed Tensor (i.e., Tensor with distributed attributes or DistTensor for short) from the input data, which can be a scalar, tuple, list, numpy.ndarray, or paddle.Tensor.

If the data is already a Tensor, it will be transformed into a distributed Tensor.

Parameters
  • data (scalar|tuple|list|ndarray|Tensor) – Initial data for the tensor. Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor.

  • mesh (paddle.distributed.ProcessMesh) – The ProcessMesh object describes the Cartesian topology of the used processes.

  • placements (list[paddle.distributed.Placement]) – the placements describe how to place the tensor on ProcessMesh, it can be Shard, Replicate and Partial.

  • dtype (str|np.dtype, optional) – The desired data type of returned tensor. It Can be ‘bool’ , ‘float16’ , ‘float32’ , ‘float64’ , ‘int8’ , ‘int16’ , ‘int32’ , ‘int64’ , ‘uint8’, ‘complex64’ , ‘complex128’. Default: None. If None, the the dtype is infered from data except for python float number, in which case the dtype is infered from get_default_type .

  • place (CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional) – The place to allocate Tensor. Can be CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If place is string, It can be cpu, gpu:x and gpu_pinned, where x is the index of the GPUs.

  • stop_gradient (bool, optional) – Whether to block the gradient propagation of Autograd. If stop_gradient is None, set the returned Tensor’s stop_gradient identical as the data.stop_gradient when data has stop_gradient attribute and True otherwise. Default: None.

Returns

A Tensor constructed from data with distributed attributes.

Return type

Tensor

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])

>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
...                       [5,6,7]])

>>> 
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Shard(0), dist.Shard(1)])

>>> print(d_tensor)