Pytorch中常用的优化器

1、SGD

  • 特点

    • 随机梯度下降(SGD):是一种基本的优化器,通过计算梯度并沿着梯度反方向更新参数,以最小化损失函数。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  • 参数含义:

    • lr:学习率(learning rate),即参数更新的步长。
    • momentum:动量(momentum),用于加速收敛和平滑梯度更新的方向。

2、Momentum

  • 特点

    • 动量优化器(Momentum):通过累积之前的梯度信息,来决定参数更新的方向和大小,从而加速收敛。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  • 参数含义

    • lr:学习率。
    • momentum:动量。

3、AdaGrad

  • 特点

    • AdaGrad:根据每个参数的历史梯度信息来更新参数,对于经常出现的参数梯度较大的参数,会降低学习率,从而使得模型更加稳定。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.Adagrad(model.parameters(), lr=0.01)
  • 参数含义

    • lr:学习率。
    • lr_decay:学习率衰减。
    • eps:数值稳定性。

4、RMSprop

  • 特点

    • RMSprop:通过对梯度平方的指数加权平均来更新参数,从而对不同的参数赋予不同的学习率,使其更加稳定。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08)
  • 参数含义

    • lr:学习率。
    • alpha:衰减因子,用于控制历史梯度的权重。
    • eps:数值稳定性。

5、Adam

  • 特点

    • Adam:结合动量优化器和 RMSprop 的优点,通过计算梯度的均值和方差来更新参数,从而使得学习率更加自适应。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08)
  • 参数含义

    • lr:学习率。
    • betas:用于计算梯度的一阶矩估计和二阶矩估计的指数衰减率。
    • eps:数值稳定性。

6、Adadelta

  • 特点

    • Adadelta:通过对梯度平方的指数加权平均和参数更新的平方的指数加权平均来更新参数,从而避免了学习率的选择和调整。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06)
  • 参数含义

    • rho:衰减因子,用于控制历史梯度的权重。
    • eps:数值稳定性。

7、Nadam

  • 特点

    • Nadam:结合了 Adam 和 Nesterov 动量的优点,具有更好的性能和收敛速度。
  • 代码实现

    1
    2
    3
    4
    import torch
    import torch.optim as optim

    optimizer = optim.Nadam(model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08)
  • 参数含义

    • lr:学习率。
    • betas:用于计算梯度的一阶矩估计和二阶矩估计的指数衰减率。
    • eps:数值稳定性。