Rprop

class paddle.optimizer. Rprop ( learning_rate=0.001, learning_rate_range=(1e-5, 50), parameters=None, etas=(0.5, 1.2), grad_clip=None, name=None ) [源代码]

注解

此优化器仅适用于 full-batch 训练。

Rprop算法的优化器。有关详细信息,请参阅:

A direct adaptive method for faster backpropagation learning : The RPROP algorithm

\[\begin{split}\begin{aligned} &\hspace{0mm} For\ all\ weights\ and\ biases\{ \\ &\hspace{5mm} \textbf{if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)> 0)\ \textbf{then} \: \{ \\ &\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{minimum}(learning\_rate_{ij}(t-1)*\eta^{+},learning\_rate_{max}) \\ &\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\ &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\ &\hspace{5mm} \} \\ &\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)< 0)\ \textbf{then} \: \{ \\ &\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{maximum}(learning\_rate_{ij}(t-1)*\eta^{-},learning\_rate_{min}) \\ &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t) \\ &\hspace{10mm} \frac{\partial E}{\partial w_{ij}}(t)=0 \\ &\hspace{5mm} \} \\ &\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)= 0)\ \textbf{then} \: \{ \\ &\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\ &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\ &\hspace{5mm} \} \\ &\hspace{0mm} \} \\ \end{aligned}\end{split}\]

参数

  • learning_rate (float|_LRScheduleri,可选) - 初始学习率,用于参数更新的计算。可以是一个浮点型值或者一个_LRScheduler 类。默认值为 0.001。

  • learning_rate_range (tuple,可选) - 学习率的范围。学习率不能小于元组的第一个元素;学习率不能大于元组的第二个元素。默认值为 (1e-5, 50)。

  • parameters (list,可选) - 指定优化器需要优化的参数。在动态图模式下必须提供该参数;在静态图模式下默认值为 None,这时所有的参数都将被优化。

  • etas (tuple,可选) - 用于更新学习率的元组。元组的第一个元素是乘法递减因子;元组的第二个元素是乘法增加因子。默认值为 (0.5, 1.2)。

  • grad_clip (GradientClipBase,可选) – 梯度裁剪的策略,支持三种裁剪策略:paddle.nn.ClipGradByGlobalNormpaddle.nn.ClipGradByNormpaddle.nn.ClipGradByValue 。 默认值为 None,此时将不进行梯度裁剪。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

代码示例

>>> import paddle

>>> inp = paddle.uniform(min=-0.1, max=0.1, shape=[1, 100], dtype='float32')
>>> linear = paddle.nn.Linear(100, 10)
>>> inp = paddle.to_tensor(inp)
>>> out = linear(inp)
>>> loss = paddle.mean(out)
>>> rprop = paddle.optimizer.Rprop(learning_rate=0.001, learning_rate_range=(0.0001,0.1), parameters=linear.parameters(), etas=(0.5,1.2))
>>> out.backward()
>>> rprop.step()
>>> rprop.clear_grad()

方法

step()

注解

该 API 只在 Dygraph 模式下生效。

执行一次优化器并进行参数更新。

返回

无。

代码示例

>>> import paddle

>>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
...                         parameters = linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()

minimize(loss, startup_program=None, parameters=None, no_grad_set=None)

为网络添加反向计算过程,并根据反向计算所得的梯度,更新 parameters 中的 Parameters,最小化网络损失值 loss。

参数

  • loss (Tensor) - 需要最小化的损失值变量

  • startup_program (Program,可选) - 用于初始化 parameters 中参数的 Program,默认值为 None,此时将使用 default_startup_program

  • parameters (list,可选) - 待更新的 Parameter 或者 Parameter.name 组成的列表,默认值为 None,此时将更新所有的 Parameter。

  • no_grad_set (set,可选) - 不需要更新的 Parameter 或者 Parameter.name 组成的集合,默认值为 None。

返回

tuple(optimize_ops, params_grads),其中 optimize_ops 为参数优化 OP 列表;param_grads 为由(param, param_grad)组成的列表,其中 param 和 param_grad 分别为参数和参数的梯度。在静态图模式下,该返回值可以加入到 Executor.run() 接口的 fetch_list 参数中,若加入,则会重写 use_prune 参数为 True,并根据 feedfetch_list 进行剪枝,详见 Executor 的文档。

代码示例

>>> import paddle
>>> linear = paddle.nn.Linear(10, 10)
>>> input = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
>>> out = linear(input)
>>> loss = paddle.mean(out)

>>> beta1 = paddle.to_tensor([0.9], dtype="float32")
>>> beta2 = paddle.to_tensor([0.99], dtype="float32")

>>> adam = paddle.optimizer.Adam(learning_rate=0.1,
...         parameters=linear.parameters(),
...         weight_decay=0.01)
>>> loss.backward()
>>> adam.minimize(loss)
>>> adam.clear_grad()

clear_grad()

注解

该 API 只在 Dygraph 模式下生效。

清除需要优化的参数的梯度。

代码示例

>>> import paddle

>>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
...                             parameters = linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()

get_lr()

注解

该 API 只在 Dygraph 模式下生效。

获取当前步骤的学习率。当不使用_LRScheduler 时,每次调用的返回值都相同,否则返回当前步骤的学习率。

返回

float,当前步骤的学习率。

代码示例

>>> # train on default dynamic graph mode
>>> import paddle
>>> import numpy as np
>>> emb = paddle.nn.Embedding(10, 3)

>>> ## example1: LRScheduler is not used, return the same value is all the same
>>> adam = paddle.optimizer.Adam(0.01, parameters = emb.parameters())
>>> for batch in range(10):
...     input = paddle.randint(low=0, high=5, shape=[5])
...     out = emb(input)
...     out.backward()
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.01
...     adam.step()
Learning rate of step0: 0.01
Learning rate of step1: 0.01
Learning rate of step2: 0.01
Learning rate of step3: 0.01
Learning rate of step4: 0.01
Learning rate of step5: 0.01
Learning rate of step6: 0.01
Learning rate of step7: 0.01
Learning rate of step8: 0.01
Learning rate of step9: 0.01

>>> ## example2: StepDecay is used, return the scheduled learning rate
>>> scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
>>> adam = paddle.optimizer.Adam(scheduler, parameters = emb.parameters())
>>> for batch in range(10):
...     input = paddle.randint(low=0, high=5, shape=[5])
...     out = emb(input)
...     out.backward()
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05...
...     adam.step()
...     scheduler.step()
Learning rate of step0: 0.5
Learning rate of step1: 0.5
Learning rate of step2: 0.05
Learning rate of step3: 0.05
Learning rate of step4: 0.005000000000000001
Learning rate of step5: 0.005000000000000001
Learning rate of step6: 0.0005000000000000001
Learning rate of step7: 0.0005000000000000001
Learning rate of step8: 5.000000000000001e-05
Learning rate of step9: 5.000000000000001e-05

>>> # train on static graph mode
>>> paddle.enable_static()
>>> main_prog = paddle.static.Program()
>>> start_prog = paddle.static.Program()
>>> with paddle.static.program_guard(main_prog, start_prog):
...     x = paddle.static.data(name='x', shape=[None, 10])
...     z = paddle.static.nn.fc(x, 100)
...     loss = paddle.mean(z)
...     scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
...     adam = paddle.optimizer.Adam(learning_rate=scheduler)
...     adam.minimize(loss)

>>> exe = paddle.static.Executor()
>>> exe.run(start_prog)
>>> for batch in range(10):
...     print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05->0.005...
...     out = exe.run(main_prog, feed={'x': np.random.randn(3, 10).astype('float32')})
...     scheduler.step()
Learning rate of step0: 0.5
Learning rate of step1: 0.5
Learning rate of step2: 0.05
Learning rate of step3: 0.05
Learning rate of step4: 0.005000000000000001
Learning rate of step5: 0.005000000000000001
Learning rate of step6: 0.0005000000000000001
Learning rate of step7: 0.0005000000000000001
Learning rate of step8: 5.000000000000001e-05
Learning rate of step9: 5.000000000000001e-05