AdamW

class paddle.optimizer. AdamW ( learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, parameters=None, weight_decay=0.01, lr_ratio=None, apply_decay_param_fun=None, grad_clip=None, lazy_mode=False, multi_precision=False, name=None ) [source]

The AdamW optimizer is implemented based on the AdamW Optimization in paper DECOUPLED WEIGHT DECAY REGULARIZATION. it can resolves the problem of L2 regularization failure in the Adam optimizer.

\[ \begin{align}\begin{aligned}t & = t + 1\\moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad\\moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad\\learning\_rate & = learning\_rate * \frac{\sqrt{1 - {\beta}_2^t}}{1 - {beta}_1^t}\\param\_out & = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param)\end{aligned}\end{align} \]
Parameters
  • learning_rate (float|LRScheduler, optional) – The learning rate used to update Parameter. It can be a float value or a LRScheduler. The default value is 0.001.

  • parameters (list|tuple, optional) – List/Tuple of Tensor names to update to minimize loss. This parameter is required in dygraph mode. And you can specify different options for different parameter groups such as the learning rate, weight decay, etc, then the parameters are list of dict. Note that the learning_rate in parameter groups represents the scale of base learning_rate. The default value is None in static graph mode, at this time all parameters will be updated.

  • beta1 (float|Tensor, optional) – The exponential decay rate for the 1st moment estimates. It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 0.9.

  • beta2 (float|Tensor, optional) – The exponential decay rate for the 2nd moment estimates. It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 0.999.

  • epsilon (float, optional) – A small float value for numerical stability. The default value is 1e-08.

  • weight_decay (float|Tensor, optional) – The weight decay coefficient, it can be float or Tensor. The default value is 0.01.

  • lr_ratio (function|None, optional) – If it is not None, the learning rate will be updated with layer-wise learning rate ratio. Otherwise, the learning rate is the original. Default: None.

  • apply_decay_param_fun (function|None, optional) – If it is not None, only tensors that makes apply_decay_param_fun(Tensor.name)==True will be updated with weight decay. It only works when we want to specify tensors. Default: None.

  • grad_clip (GradientClipBase, optional) – Gradient clipping strategy, it’s an instance of some derived class of GradientClipBase . There are three clipping strategies ( ClipGradByGlobalNorm , ClipGradByNorm , ClipGradByValue ). Default None, meaning there is no gradient clipping.

  • lazy_mode (bool, optional) – The official Adam algorithm has two moving-average accumulators. The accumulators are updated at every step. Every element of the two moving-average is updated in both dense mode and sparse mode. If the size of parameter is very large, then the update may be very slow. The lazy mode only update the element that has gradient in current mini-batch, so it will be much more faster. But this mode has different semantics with the original Adam algorithm and may lead to different result. The default value is False.

  • multi_precision (bool, optional) – Whether to use multi-precision during weight updating. Default is false.

  • name (str, optional) – Normally there is no need for user to set this property. For more information, please refer to Name. The default value is None.

Notes

Currently, AdamW doesn’t support sparse parameter optimization.

Examples

>>> import paddle

>>> linear = paddle.nn.Linear(10, 10)
>>> inp = paddle.rand([10,10], dtype="float32")
>>> out = linear(inp)
>>> loss = paddle.mean(out)

>>> beta1 = paddle.to_tensor([0.9], dtype="float32")
>>> beta2 = paddle.to_tensor([0.99], dtype="float32")

>>> opt = paddle.optimizer.AdamW(learning_rate=0.1,
...         parameters=linear.parameters(),
...         beta1=beta1,
...         beta2=beta2,
...         weight_decay=0.01
... )
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()


>>> # Note that the learning_rate of linear_2 is 0.01.
>>> linear_1 = paddle.nn.Linear(10, 10)
>>> linear_2 = paddle.nn.Linear(10, 10)
>>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
>>> out = linear_1(inp)
>>> out = linear_2(out)
>>> loss = paddle.mean(out)
>>> opt = paddle.optimizer.AdamW(
...     learning_rate=0.1,
...     parameters=[{
...         'params': linear_1.parameters()
...     }, {
...         'params': linear_2.parameters(),
...         'weight_decay': 0.001,
...         'learning_rate': 0.1,
...         'beta1': 0.8
...     }],
...     weight_decay=0.01,
...     beta1=0.9
... )
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
append_regularization_ops ( parameters_and_grads, regularization=None )

append_regularization_ops

Create and add backward regularization Operators

Creates and adds backward regularization operators in the BlockDesc. This will add gradients of the regularizer function to the gradients of the parameters and return these modified gradients. This is the same as implementing weight decay in optimizers for regularization.

Parameters
  • parameters_and_grads – A list of (parameters, gradients) pairs that need to be regularized.

  • regularization – A global regularizer. If the parameter is not set. It will be applied with regularizer.

Returns

list of (parameters, gradients) pair with the regularized gradient

Return type

list[(Variable, Variable)]

Raises

Exception – Unknown regularization type

clear_grad ( set_to_zero=True )

clear_grad

Clear the gradients of all optimized parameters for model.

If not, new gradient will accumulat on previous gradient.

There are two method to clear grad: set_to_zero or delete grad.

Parameters

set_to_zero (bool, optional) – If set grads to zero or not, default is True.

Returns

None

Examples

>>> import paddle

>>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
...                             parameters = linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()
get_lr ( )

get_lr

Get current learning rate of optimizer. If ‘LRScheduler’ is not used, the return value is all the same. If ‘LRScheduler’ is used, the return value is the current scheduled learing rete.

Returns

The current learning rate of optimizer.

Return type

float

Examples

>>> # train on default dynamic graph mode
>>> import paddle
>>> import numpy as np
>>> emb = paddle.nn.Embedding(10, 3)

>>> ## example1: LRScheduler is not used, return the same value is all the same
>>> adam = paddle.optimizer.Adam(0.01, parameters = emb.parameters())
>>> for batch in range(10):
...     input = paddle.randint(low=0, high=5, shape=[5])
...     out = emb(input)
...     out.backward()
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.01
...     adam.step()
Learning rate of step0: 0.01
Learning rate of step1: 0.01
Learning rate of step2: 0.01
Learning rate of step3: 0.01
Learning rate of step4: 0.01
Learning rate of step5: 0.01
Learning rate of step6: 0.01
Learning rate of step7: 0.01
Learning rate of step8: 0.01
Learning rate of step9: 0.01

>>> ## example2: StepDecay is used, return the scheduled learning rate
>>> scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
>>> adam = paddle.optimizer.Adam(scheduler, parameters = emb.parameters())
>>> for batch in range(10):
...     input = paddle.randint(low=0, high=5, shape=[5])
...     out = emb(input)
...     out.backward()
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05...
...     adam.step()
...     scheduler.step()
Learning rate of step0: 0.5
Learning rate of step1: 0.5
Learning rate of step2: 0.05
Learning rate of step3: 0.05
Learning rate of step4: 0.005000000000000001
Learning rate of step5: 0.005000000000000001
Learning rate of step6: 0.0005000000000000001
Learning rate of step7: 0.0005000000000000001
Learning rate of step8: 5.000000000000001e-05
Learning rate of step9: 5.000000000000001e-05

>>> # train on static graph mode
>>> paddle.enable_static()
>>> main_prog = paddle.static.Program()
>>> start_prog = paddle.static.Program()
>>> with paddle.static.program_guard(main_prog, start_prog):
...     x = paddle.static.data(name='x', shape=[None, 10])
...     z = paddle.static.nn.fc(x, 100)
...     loss = paddle.mean(z)
...     scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
...     adam = paddle.optimizer.Adam(learning_rate=scheduler)
...     adam.minimize(loss)

>>> exe = paddle.static.Executor()
>>> exe.run(start_prog)
>>> for batch in range(10):
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05->0.005...
...     out = exe.run(main_prog, feed={'x': np.random.randn(3, 10).astype('float32')})
...     scheduler.step()
Learning rate of step0: 0.5
Learning rate of step1: 0.5
Learning rate of step2: 0.05
Learning rate of step3: 0.05
Learning rate of step4: 0.005000000000000001
Learning rate of step5: 0.005000000000000001
Learning rate of step6: 0.0005000000000000001
Learning rate of step7: 0.0005000000000000001
Learning rate of step8: 5.000000000000001e-05
Learning rate of step9: 5.000000000000001e-05
minimize ( loss, startup_program=None, parameters=None, no_grad_set=None )

minimize

Add operations to minimize loss by updating parameters.

Parameters
  • loss (Tensor) – A Tensor containing the value to minimize.

  • startup_program (Program, optional) – Program for initializing parameters in parameters. The default value is None, at this time default_startup_program will be used.

  • parameters (list, optional) – List of Tensor or Tensor.name to update to minimize loss. The default value is None, at this time all parameters will be updated.

  • no_grad_set (set, optional) – Set of Tensor or Tensor.name that don’t need to be updated. The default value is None.

Returns

tuple (optimize_ops, params_grads), A list of operators appended by minimize and a list of (param, grad) tensor pairs, param is Parameter, grad is the gradient value corresponding to the parameter. In static graph mode, the returned tuple can be passed to fetch_list in Executor.run() to indicate program pruning. If so, the program will be pruned by feed and fetch_list before run, see details in Executor.

Return type

tuple

Examples

>>> import paddle
>>> linear = paddle.nn.Linear(10, 10)
>>> input = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
>>> out = linear(input)
>>> loss = paddle.mean(out)

>>> beta1 = paddle.to_tensor([0.9], dtype="float32")
>>> beta2 = paddle.to_tensor([0.99], dtype="float32")

>>> adam = paddle.optimizer.Adam(learning_rate=0.1,
...         parameters=linear.parameters(),
...         weight_decay=0.01)
>>> loss.backward()
>>> adam.minimize(loss)
>>> adam.clear_grad()
set_lr ( value )

set_lr

Api_attr

imperative

Set the value of the learning rate manually in the optimizer. If the optimizer use LRScheduler, this API cannot be invoked, because it will lead to conflict.

Parameters

value (float) – the value of learning rate

Returns

None

Examples

>>> import paddle
>>> linear = paddle.nn.Linear(10, 10)

>>> adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())

>>> # set learning rate manually by python float value
>>> lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
>>> for i in range(5):
...     adam.set_lr(lr_list[i])
...     lr = adam.get_lr()
...     print("current lr is {}".format(lr))
current lr is 0.2
current lr is 0.3
current lr is 0.4
current lr is 0.5
current lr is 0.6
set_lr_scheduler ( scheduler )

set_lr_scheduler

Api_attr

imperative

Set the LRScheduler of the learning rate manually in the optimizer. If the optimizer already used LRScheduler previously, this API will set it be the new one.

Parameters

scheduler (LRScheduler) – the LRScheduler of learning rate

Returns

None

Examples

>>> import paddle
>>> linear = paddle.nn.Linear(10, 10)

>>> adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())

>>> # set learning rate manually by class LRScheduler
>>> scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2,4,6], gamma=0.8)
>>> adam.set_lr_scheduler(scheduler)
>>> lr = adam.get_lr()
>>> print("current lr is {}".format(lr))
current lr is 0.5

>>> # set learning rate manually by another LRScheduler
>>> scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=5, gamma=0.6)
>>> adam.set_lr_scheduler(scheduler)
>>> lr = adam.get_lr()
>>> print("current lr is {}".format(lr))
current lr is 0.1
set_state_dict ( state_dict )

set_state_dict

Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be changed.

Parameters

state_dict (dict) – Dict contains all the Tensor needed by optimizer

Returns

None

Examples

>>> import paddle

>>> emb = paddle.nn.Embedding(10, 10)

>>> layer_state_dict = emb.state_dict()
>>> paddle.save(layer_state_dict, "emb.pdparams")

>>> scheduler = paddle.optimizer.lr.NoamDecay(
...     d_model=0.01, warmup_steps=100, verbose=True)
>>> adam = paddle.optimizer.Adam(
...     learning_rate=scheduler,
...     parameters=emb.parameters())
>>> opt_state_dict = adam.state_dict()
>>> paddle.save(opt_state_dict, "adam.pdopt")

>>> opti_state_dict = paddle.load("adam.pdopt")
>>> adam.set_state_dict(opti_state_dict)
state_dict ( )

state_dict

Get state dict information from optimizer. It contain all the tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be include in state dict. If the optimizer never be called(minimize function), the state_dict is empty.

Parameters

None

Returns

dict contains all the Tensor used by optimizer

Return type

state_dict(dict)

Examples

>>> import paddle
>>> emb = paddle.nn.Embedding(10, 10)

>>> adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters())
>>> state_dict = adam.state_dict()
step ( )

step

Execute the optimizer and update parameters once.

Returns

None

Examples

>>> import paddle

>>> a = paddle.rand([2,13], dtype="float32")
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> opt = paddle.optimizer.AdamW(learning_rate = 0.01,
...                             parameters = linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> opt.step()
>>> opt.clear_grad()