load_state_dict

paddle.distributed. load_state_dict ( state_dict, path, process_group=None, coordinator_rank=0 ) None [source]

Load the state_dict inplace from a checkpoint path.

Parameters
  • state_dict (Dict[str, paddle.Tensor]) – The state_dict to load. It will be modified inplace after loading.

  • path (str) – The directory to load checkpoint files.

  • process_group (paddle.distributed.collective.Group) – ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.

  • coordinator_rank (int) – The rank used to coordinate the checkpoint. Rank0 is used by default.

Example

>>> 
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> w1 = paddle.arange(32).reshape([4, 8])
>>> mesh = dist.ProcessMesh([0, 1])
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
>>> state_dict = {"w1": sharded_w1}
>>> dist.save_state_dict(state_dict, ckpt_path)
>>> w1_to_load = paddle.zeros_like(w1)
>>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()])
>>> state_dict_to_load = {"w1": sharded_w1_to_load}
>>> dist.load_state_dict(state_dict_to_load, ckpt_path)
>>> print(f"state_dict_to_load:{state_dict_to_load}")
state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor=
[[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
 [8 , 9 , 10, 11, 12, 13, 14, 15],
 [16, 17, 18, 19, 20, 21, 22, 23],
 [24, 25, 26, 27, 28, 29, 30, 31]])}
>>>