DistributedStrategy

class paddle.distributed.fleet. DistributedStrategy [source]
save_to_prototxt ( output )

save_to_prototxt

Serialize current DistributedStrategy to string and save to output file

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.dgc = True
>>> strategy.recompute = True
>>> strategy.recompute_configs = {"checkpoints": ["x"]}
>>> strategy.save_to_prototxt("dist_strategy.prototxt")
load_from_prototxt ( pb_file )

load_from_prototxt

Load from prototxt file for DistributedStrategy initialization

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.dgc = True
>>> strategy.recompute = True
>>> strategy.recompute_configs = {"checkpoints": ["x"]}
>>> strategy.save_to_prototxt("dist_strategy.prototxt")

>>> strategy.load_from_prototxt("dist_strategy.prototxt")
property execution_strategy

Configure ExecutionStrategy for DistributedStrategy

Examples

>>> import paddle
>>> exe_strategy = paddle.static.ExecutionStrategy()
>>> exe_strategy.num_threads = 10
>>> exe_strategy.num_iteration_per_drop_scope = 10
>>> exe_strategy.num_iteration_per_run = 10

>>> strategy = paddle.distributed.fleet.DistributedStrategy()
>>> strategy.execution_strategy = exe_strategy
property build_strategy

Configure BuildStrategy for DistributedStrategy Note that the properties of BuildStrategy are valid in DistributedStrategy only if the property is non-distributed strategy.

Examples

>>> import paddle
>>> build_strategy = paddle.static.BuildStrategy()
>>> build_strategy.enable_sequential_execution = True
>>> build_strategy.fuse_elewise_add_act_ops = True
>>> build_strategy.fuse_bn_act_ops = True
>>> build_strategy.enable_auto_fusion = True
>>> build_strategy.fuse_relu_depthwise_conv = True
>>> build_strategy.fuse_broadcast_ops = True
>>> build_strategy.fuse_all_optimizer_ops = True
>>> build_strategy.enable_inplace = True

>>> strategy = paddle.distributed.fleet.DistributedStrategy()
>>> strategy.build_strategy = build_strategy
property gradient_scale_configs

Set the strategy of gradient scale

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.gradient_scale_configs = {'scale_strategy': 'avg'}

Note that, strategy must be in ‘avg’, ‘sum’ or ‘customized’

property a_sync

Indicating whether we are using asynchronous stocastic gradient descent updates for training. This property is valid when we are using parameter server training, which is implied by setting approperate RoleMaker Default value: True

Examples

>>> import paddle.distributed.fleet as fleet
>>> role_maker = fleet.PaddleCloudRoleMaker()
>>> fleet.init(role_maker)

>>> strategy = fleet.DistributedStrategy()
>>> strategy.a_sync = True  # by default this is True

>>> # code block for defining loss and local optimizer
>>> # sgd = fleet.distributed_optimizer(optimizer, strategy)
property a_sync_configs

Set a_sync update configurations. In general, asynchronous parameter server training has serveral configurable settings that can be configured through a dict.

Notes:

k_step(int): number of local optimization updates before communication

max_merge_var_num(int): maximum number of merged gradients before communication

send_queue_size(int): a buffer size of worker communication

independent_recv_thread(bool): if we are using independent recv thread for communication

thread_pool_size(int): number of thread pool

send_wait_times(int): waiting time for sending gradients

runtime_split_send_recv(bool): if we are using Tensor split for send and recv during runtime

Examples

>>> import paddle.distributed.fleet as fleet
>>> role_maker = fleet.PaddleCloudRoleMaker()
>>> fleet.init(role_maker)

>>> strategy = fleet.DistributedStrategy()
>>> strategy.a_sync = True  # by default this is True
>>> configs = {"k_steps": 1024, "send_queue_size": 32}
>>> strategy.a_sync_configs = configs

>>> # code block for defining loss and local optimizer
>>> # sgd = fleet.distributed_optimizer(optimizer, strategy)
property adam_d2sum

set adam_d2sum Default value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> role_maker = fleet.PaddleCloudRoleMaker()
>>> fleet.init(role_maker)

>>> strategy = fleet.DistributedStrategy()
>>> strategy.adam_d2sum = True  # by default this is False

>>> # code block for defining loss and local optimizer
>>> # sgd = fleet.distributed_optimizer(optimizer, strategy)
property trainer_desc_configs

Set trainer desc configurations.

Notes:

dump_fields_path(str): the path of dump fields

dump_fields(list(str)): the fields that you want to dump

dump_param(list(str)): the param that you want to dump

stat_var_names(list(str)):

Examples

>>> import paddle.distributed.fleet as fleet
>>> role_maker = fleet.PaddleCloudRoleMaker()
>>> fleet.init(role_maker)

>>> strategy = fleet.DistributedStrategy()
>>> configs = {"dump_fields_path": "./dump_data", "dump_fields": ["xxx", "yyy"]}
>>> strategy.trainer_desc_configs = configs

>>> # code block for defining loss and local optimizer
>>> # sgd = fleet.distributed_optimizer(optimizer, strategy)
property fs_client_param

Set fs client configurations.

Note

uri(str): the uri of fs client

user(str): the user_name of fs client

passwd(str): the passwd of fs client

hadoop_bin(str):

Examples

>>> import paddle.distributed.fleet as fleet
>>> role_maker = fleet.PaddleCloudRoleMaker()
>>> fleet.init(role_maker)
>>> strategy = fleet.DistributedStrategy()
>>> configs = {"uri": "xxx", "user": "xxx", "passwd": "xxx"}
>>> strategy.fs_client_param = configs
>>> # code block for defining loss and local optimizer
>>> # sgd = fleet.distributed_optimizer(optimizer, strategy)
property amp

Indicating whether we are using automatic mixed precision training Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.amp = True # by default this is false
property amp_configs

Set automatic mixed precision training configurations. In general, amp has serveral configurable settings that can be configured through a dict.

Notes:

init_loss_scaling(float): The initial loss scaling factor. Default 32768.

use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. Default True.

incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. Default 1000.

decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. Default 2.

incr_ratio(float): The multiplier to use when increasing the loss scaling. Default 2.0.

decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. Default 0.5.

custom_white_list(list[str]): Users’ custom white list which always execution fp16.

custom_black_list(list[str]): Users’ custom black list which forbidden execution fp16.

custom_black_varnames(list[str]): Users’ custom black varibles’ names.

use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.

use_pure_bf16(bool): Whether to use the pure bf16 training. Default False.

use_fp16_guard(bool): Whether to use fp16_guard when constructing the program. Default True. Only takes effect when use_pure_fp16 is turned on.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.amp = True
>>> strategy.amp_configs = {
...     "init_loss_scaling": 32768,
...     "custom_white_list": ['conv2d']
... }
>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.amp = True
>>> # pure fp16
>>> strategy.amp_configs = {
...     "init_loss_scaling": 32768,
...     "use_pure_fp16": True
... }
property asp

Indicating whether we are using automatic sparsity training Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.asp = True # by default this is false
property sync_nccl_allreduce

Indicating whether we are using synchronized all reduce in each communication thread We note that system overhead is usually lower when sync_nccl_allreduce = True

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.sync_nccl_allreduce = True
property use_hierarchical_allreduce

Indicating whether we are using hierarchical allreduce in collective communication Hierarchical allreduce often does allreduce within a certain node group and then do allreduce among the leaders of each group

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.use_hierarchical_allreduce = True
property hierarchical_allreduce_inter_nranks

Number of ranks for low level node groups in hierarchical allreduce Default value: number of GPU cards on each single GPU machine

Example

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.hierarchical_allreduce_inter_nranks = 8
property sync_batch_norm

Indicating whether we are using sync_batch_norm to do synchronous batch normalization among all training nodes.

Default value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.sync_batch_norm = True
property fuse_all_reduce_ops

Indicating whether we are using fuse_all_reduce_ops for gradient fusion during backward phase of training Default value: True

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.fuse_all_reduce_ops = False
property fuse_grad_size_in_MB

Specifying the size of gradient to fuse in Mega-Bytes

Default value: 32

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.fuse_grad_size_in_MB = 50
property last_comm_group_size_MB

Specifying the size of gradient to fuse in Mega-Bytes when the last group of each batch communicates. Making the last group small is useful to improve performance.

Default value: 1

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.last_comm_group_size_MB = 2
property find_unused_parameters

Indicating whether we are using find_unused_parameters to find unused parameters in DataParallel.

Default value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.find_unused_parameters = True
property nccl_comm_num

Specifying the number of NCCL communicator

Default value: 1

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.nccl_comm_num = 2
property recompute [source]

Indicating whether we are using forward recomputation for memory optimization Default value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.recompute = True
>>> # suppose x and y are names of checkpoint tensors for recomputation
>>> strategy.recompute_configs = {"checkpoints": ["x", "y"]}
property recompute_configs

Set recompute configurations.

Note: checkpoints(list): list of string name of checkpoints. In general, the recompute strategy of current implementation should have some manually assign checkpoints.

enable_offload(bool): enable recompute checkpoints offload feature. this feature will offload the checkpoint to host memory to allow even larger batch size. since the memcpy from host to device takes time, it is a trade off between larger batch size and training speed.

checkpoint_shape(list): list of int that specific the shape of checkpoint. so far recompute-offload requires that all checkpoint to be same shape, and every dimension specific here should be determined (“-1” is not allowed).

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.recompute = True
>>> strategy.recompute_configs = {
...     "checkpoints": ["x", "y"],
...     "enable_offload": True,
...     "checkpoint_shape": [100, 512, 1024]
... }
property sharding

Indicating whether we are using sharding Optimizer for memory optimization. We implement the sharding optimizer following the ZeRO-DP idea from [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054). Model parameters and Optimizer State are sharded into different ranks allowing to fit larger model.

In Hybrid parallelism scenario, we use sharding config as uniform API to set each parallelism.

Default value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.sharding = True
property sharding_configs

Set sharding configurations.

Note:

sharding_segment_strategy(string, optional): strategy used to segment the program(forward & backward operations). two strategise are available: “segment_broadcast_MB” and “segment_anchors”. segment is a concept used in sharding to overlap computation and communication. Default is segment_broadcast_MB.

segment_broadcast_MB(float, optional): segment by the parameters broadcast volume. sharding will introduce parameter broadcast operations into program, and after every segment_broadcast_MB size parameter being broadcasted, the program will be cutted into one segment. This configuration will affect the communication speed in sharding training, and should be an empirical value decided by your model size and network topology. Only enable when sharding_segment_strategy = segment_broadcast_MB. Default is 32.0 .

segment_anchors(list): list of anchors used to segment the program, which allows a finner control of program segmentation. this strategy is experimental by now. Only enable when sharding_segment_strategy = segment_anchors.

sharding_degree(int, optional): specific the number of gpus within each sharding parallelism group; and sharding will be turn off if sharding_degree=1. Default is 8.

gradient_merge_acc_step(int, optional): specific the accumulation steps in gradient merge; and gradient merge will be turn off if gradient_merge_acc_step=1. Default is 1.

optimize_offload(bool, optional): enable the optimizer offload which will offload the moment vars to Host memory in order to saving GPU memory for fitting larger model. the moment var will be prefetch from and offloaded to Host memory during update stage. it is a stragtegy that trades off between training speed and GPU memory, and is recommened to be turn on only when gradient_merge_acc_step large, where the number of time of update stage will be relatively small compared with forward&backward’s. Default is False.

dp_degree(int, optional): specific the number of data parallelism group; when dp_degree >= 2, it will introduce dp_degree ways data parallelism as the outer parallelsim for the inner parallelsim. User is responsible to ensure global_world_size = mp_degree * sharding_degree * pp_degree * dp_degree. Default is 1.

mp_degree(int, optional): [Hybrid parallelism ONLY] specific the number of gpus within each megatron parallelism group; and megatron parallelism will turn be off if mp_degree=1. Default is 1.

pp_degree(int, optional): [Hybrid parallelism ONLY] specific the number of gpus within each pipeline parallelism group; and pipeline parallelism will turn be off if pp_degree=1. Default is 1.

pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False.

optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large.

Examples

>>> # sharding-DP, 2 nodes with 8 gpus per node
>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.sharding = True
>>> strategy.sharding_configs = {
...     "sharding_segment_strategy": "segment_broadcast_MB",
...     "segment_broadcast_MB": 32,
...     "sharding_degree": 8,
...     "dp_degree": 2,
...     "gradient_merge_acc_step": 4,
... }
property without_graph_optimization

Run program using Executor other than ParallelExecutor.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.without_graph_optimization = True
property fuse_grad_merge

Set whether fuse the grad for gradient merge. Note: this flag will only effect the gradient merge under pipeline mode The default value for the fuse_grad_merge is False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.fuse_grad_merge = True
property fuse_grad_size_in_num

This based on raw_program_optimizer program and allreduce the num of the fused op

Examples

>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.fuse_grad_size_in_num = 2
property pipeline

Indicating whether we are using pipeline parallelism for distributed training. Current implementation mainly focus on single GPU machine pipeline parallelism and data parallelism across GPU machine. The pipeline information is indicated through device_guard information in user-defined program.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.pipeline = True
property pipeline_configs

Set pipeline parallelism configurations. In pipeline parallelism, different parts of neural networks are running on different GPUS. There are Tensor queue buffer between each pair of neighborhood GPUS that are responsible for synchronizing hidden Tensor results between GPUs. Pipeline parallelism consists of serveral producer-consumer style hardware pairs, such as GPU-GPU, CPU-GPU, GPU-XPU. The best way to speedup pipeline parallelism is to make the size of Tensor in Tensor queue smaller, so that we will have a faster producer for downstream consumers.

Notes:

Detailed arguments for pipeline_configs

micro_batch_size: the number of small batches in each user defined batch

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.pipeline = True
>>> strategy.pipeline_configs = {"micro_batch_size": 12}
property tensor_parallel

Indicating whether we are using tensor parallel for distributed training.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.tensor_parallel = True
property tensor_parallel_configs

Set tensor_parallel configurations.

Notes:

Detailed arguments for tensor_parallel_configs

tensor_parallel_degree: degree of tensor parallel

tensor_init_seed: parameter initialization random seed

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.tensor_parallel = True
>>> strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
...                                     "tensor_init_seed": 123}
property hybrid_configs

Dynamic graph hybrid parallel strategy configuration. Five-way hybrid parallelism needs to meet the following relationships

total_number_GPUs = dp_degree * mp_degree * pp_degree * sharding_degree * sep_degree

Note:
dp_degree(int): set number of GPUs in a data parallel group. Default -1.

This value should be an integer greater than 0. If it is not set, or set to -1, its value will be inferred based on the total number of cards.

mp_degree(int): set number of GPUs in a model parallel group. Default 1

pp_degree(int): set number of GPUs in a pipeline parallel group. Default 1 sep_degree(int): set number of GPUs in a sep parallel group. Default 1 sharding_degree(int): set number of GPUs in a sharding parallel group. Default 1 order(list(string)): set hybrid parallel dimensions, the order is from outside to inside. Default [‘dp’,’pp’,’sharding’,’sep’, ‘mp’]

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.hybrid_configs = {
...     "dp_degree": 1,
...     "mp_degree": 2,
...     "pp_degree": 1,
...     "order":['dp','pp','sharding', 'sep', 'mp']
... }
property localsgd

False For more details, please refer to Don’t Use Large Mini-Batches, Use Local SGD.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.localsgd = True # by default this is false
Type

Indicating whether we are using Local SGD training. Default Value

property localsgd_configs

Set LocalSGD training configurations. LocalSGD has a configurable setting that can be configured through a dict.

Notes:

k_steps(int) The local steps for training before parameter synchronization. Default 1. begin_step(int) The step of beginning training by localsgd. Default 1.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.localsgd = True
>>> strategy.localsgd_configs = {"k_steps": 4,
...                             "begin_step": 30}
property adaptive_localsgd

False For more details, please refer to Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.adaptive_localsgd = True # by default this is false
Type

Indicating whether we are using Adaptive Local SGD training. Default Value

property adaptive_localsgd_configs

Set AdaptiveLocalSGD training configurations. AdaptiveLocalSGD has a configurable setting that can be configured through a dict.

Notes:
init_k_steps(int) The initial steps for training before adaptive localsgd.

Then, the adaptive localsgd method will modify init_k_steps automatically. Default 1.

begin_step(int) The step of beginning training by adaptive localsgd. Default 1.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.adaptive_localsgd = True
>>> strategy.adaptive_localsgd_configs = {"init_k_steps": 1,
...                                       "begin_step": 30}
property dgc

Indicating whether we are using Deep Gradient Compression training. For more details, please refer to [Deep Gradient Compression](https://arxiv.org/abs/1712.01887).

Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.dgc = True # by default this is false
property dgc_configs

Set Deep Gradient Compression training configurations. In general, dgc has serveral configurable settings that can be configured through a dict.

Notes:

rampup_begin_step(int): The beginning step from which gradient compression is implemented. Default 0.

rampup_step(int): Time steps used in sparsity warm-up periods. Default is 1.

For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. And when reach sparsity array ends, it will use 0.999 then and after.

sparsity(list[float]): Get top important element from gradient tensor, the ratio is (1 - sparsity).

Default is [0.999]. For example, if the sparsity is [0.99, 0.999], the top [1%, 0.1%] important element will be transmitted.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.dgc = True
>>> strategy.dgc_configs = {"rampup_begin_step": 1252}
property fp16_allreduce

Indicating whether we are using fp16 gradient allreduce training Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.fp16_allreduce = True # by default this is false
property gradient_merge

Gradient Merge, also called as Gradient Accumulation, is a strategy for large batch training. With this strategy, model parameter will not be updated until user-defined steps. For each step, the forward network and the backward network will run to calculate the gradient of model parameters. For every k step, the optimization network will run, applying a specific optimization method (such as SGD, Adam) to model parameters.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.gradient_merge = True
>>> strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
property gradient_merge_configs

the key-value configs of distribute_strategy

Note:

k_steps(int): the update period of the parameters.

avg(bool): whether to average the gradients of each mini-batch, the default value is True

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.gradient_merge = True
>>> strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
property lars

Set lars configurations. lars is used to deal with the convergence problems when the global batch size is larger than 8k. For more details, please refer to [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888).

Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.lars = True # by default this is false
property lars_configs

Set Lars training configurations.

Notes: lars_coeff (float): trust ratio in lars formula. lars_weight_decay (float): weight decay coefficient in lars formula. epsilon (float): argument is used to avoid potential devision-by-zero when compute the local lr; exclude_from_weight_decay ([string]): is a list of name strings of layers which will be exclude from weight decay in lars formula.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.lars = True
>>> strategy.lars_configs = {
...             "lars_coeff": 0.01,
...             "lars_weight_decay": 0.0005,
...             "epsilon": 0,
...             "exclude_from_weight_decay": ['batch_norm', '.b_0']
... }
property lamb

Set lamb configurations. lamb is used to deal with the convergence problems for large batch size training, specially for attention-related model like BERT. For more details, please refer to [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962).

Default Value: False

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.lamb = True # by default this is false
property lamb_configs

Set Lars training configurations.

Notes: lamb_weight_decay (float): weight decay coefficient in lamb formula. exclude_from_weight_decay ([string]): is a list of name strings of layers which will be exclude from weight decay in lamb formula.

Examples

>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.lamb = True
>>> strategy.lamb_configs = {
...         'lamb_weight_decay': 0.01,
...         'exclude_from_weight_decay': [],
... }
property elastic [source]

Indicating whether we want to do current distributed training on clusters with elastic resources. Currently, this is configuration is not valid.

property auto [source]

Indicating whether we are using auto-parallel configuration This feature is currently an experimental feature. Currently, auto-parallelism can be used only when a user does not set any other strategy configs except auto. For details, please reference the following code example Default Value: False

Examples

>>> import paddle
>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.auto = True
>>> # if set other strategy at the same time, auto will not apply
>>> # strategy.amp = True

>>> optimizer = paddle.optimizer.SGD(learning_rate=0.01)
>>> optimizer = fleet.distributed_optimizer(optimizer, strategy)
property semi_auto

Indicating whether we are using semi-auto parallel function This feature is currently an experimental feature. Currently, auto-parallelism can be used only when a user does not set any other strategy configs except semi-auto. For details, please reference the following code example Default Value: False

Examples

>>> import paddle
>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.semi_auto = True
>>> # if set other strategy at the same time, auto will not apply
>>> # strategy.amp = True

>>> optimizer = paddle.optimizer.SGD(learning_rate=0.01)
>>> optimizer = fleet.distributed_optimizer(optimizer, strategy)
property auto_search

Indicating whether we are using auto-search parallel function For details, please reference the following code example Default Value: False

Examples

>>> import paddle

>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.auto_search = True
property split_data

Indicating whether we split the data. If True, we split the data. Default Value: True

Examples

>>> import paddle

>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet
>>> strategy = fleet.DistributedStrategy()
>>> strategy.split_data = True
property qat

Indicating whether we are using quantization training Default Value: False

property qat_configs

Set quantization training configurations. In general, qat has serveral configurable settings that can be configured through a dict.

Notes:

channel_wise_abs_max(bool): Whether to use per_channel quantization training. Default is True.

weight_bits(int): quantization bit number for weight. Default is 8.

activation_bits(int): quantization bit number for activation. Default is 8.

not_quant_pattern(list[str]): When the skip pattern is detected in an op’s name scope,

the corresponding op will not be quantized.

algo(str): Other quantization training algorithm.

Exampless:
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.qat = True
>>> strategy.qat_configs = {
...     "channel_wise_abs_max": True,
...     "weight_bits": 8,
...     "activation_bits": 8,
...     "not_quant_pattern": ['skip_quant']
... }
property heter_ccl_mode

Indicating whether we are using heter_ccl_mode for model training. This feature is currently an experimental feature. Currently, heter_ccl_mode can be used only for dataparallel with dygraph mode. Default Value: False

Examples

>>> import paddle
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.heter_ccl_mode = True

>>> # for initialize parallel env, only need to call
>>> paddle.distributed.init_parallel_env()
>>> # then the heterogenous context will be created.
property cudnn_exhaustive_search

Indicating whether to use exhaustive search method to choose convolution algorithms. Exhaustive search attempts all cuDNN algorithms to choose the fastest algorithm. This method is time-consuming, the choosed algorithm will be cached for the given layer specifications. Once the layer specifications (like batch size, feature map size) are changed, it will search again. Default Value: True

Examples

>>> import paddle
>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.cudnn_exhaustive_search = False

>>> optimizer = paddle.optimizer.SGD(learning_rate=0.01)
>>> optimizer = fleet.distributed_optimizer(optimizer, strategy)
property conv_workspace_size_limit

The workspace limit size in MB unit for choosing cuDNN convolution algorithms. The inner function of cuDNN obtain the fastest suited algorithm that fits within this memory limit. Usually, large workspace size may lead to choose faster algorithms, but significant increasing memory workspace. Users need to trade-off between memory and speed. Default Value: 4000

Examples

>>> import paddle
>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.conv_workspace_size_limit = 1024

>>> optimizer = paddle.optimizer.SGD(learning_rate=0.01)
>>> optimizer = fleet.distributed_optimizer(optimizer, strategy)
property cudnn_batchnorm_spatial_persistent

Indicates whether to use the mode CUDNN_BATCHNORM_SPATIAL_PERSISTENT function in batchnorm. This is only useful in cudnn. Default Value: True

Examples

>>> import paddle
>>> paddle.enable_static()
>>> import paddle.distributed.fleet as fleet

>>> strategy = fleet.DistributedStrategy()
>>> strategy.cudnn_batchnorm_spatial_persistent = True

>>> optimizer = paddle.optimizer.SGD(learning_rate=0.01)
>>> optimizer = fleet.distributed_optimizer(optimizer, strategy)

Used in the guide/tutorials