Skip to content

Commit

Permalink
Introduce Nesterov SGD for transformer (where warmup stage is used) (#…
Browse files Browse the repository at this point in the history
…229)

Summary:
Pull Request resolved: #229

- Add NAGOptimizer implementation from fairseq https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/optim/nag.py#L43 where specific lr schedule is used during warmup.
- This speedup the convergence by a lot for Transformer LM models + better overall performance at the end of training.

Reviewed By: vineelpratap

Differential Revision: D24700925

fbshipit-source-id: a0d3780583c7c3961b3bbab28ccbb280e2419c50
  • Loading branch information
Tatiana Likhomanenko authored and facebook-github-bot committed Nov 4, 2020
1 parent 3fb7019 commit 36b242e
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 0 deletions.
1 change: 1 addition & 0 deletions flashlight/fl/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(
${CMAKE_CURRENT_LIST_DIR}/AdadeltaOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/AdagradOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/AMSgradOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/NAGOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/NovogradOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/RMSPropOptimizer.cpp
${CMAKE_CURRENT_LIST_DIR}/SGDOptimizer.cpp
Expand Down
75 changes: 75 additions & 0 deletions flashlight/fl/optim/NAGOptimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "flashlight/fl/optim/NAGOptimizer.h"

#include <cmath>

using std::vector;

namespace fl {

NAGOptimizer::NAGOptimizer(
const vector<Variable>& parameters,
float learningRate,
float momentum /* = 0 */,
float weightDecay /* = 0 */)
: FirstOrderOptimizer(parameters, learningRate),
mu_(momentum),
wd_(weightDecay),
velocities_(),
oldLr_(learningRate) {
if (momentum <= 0) {
throw std::runtime_error(
"Invalid momentum for NAG optimizer, it should be > 0");
}
velocities_.reserve(parameters.size());
for (const auto& parameter : parameters_) {
velocities_.emplace_back(
af::constant(0, parameter.dims(), parameter.type()));
velocities_.back().eval();
}
}

void NAGOptimizer::step() {
float correctedLr = lr_ / oldLr_;

for (size_t i = 0; i < parameters_.size(); i++) {
if (!parameters_[i].isGradAvailable()) {
continue;
}

af::array& grad = parameters_[i].grad().array();
af::array& data = parameters_[i].array();

if (wd_ != 0) {
// Weight decay term
data = data * (1 - lr_ * wd_);
}
af::array& velocity = velocities_[i];
// this velocity corresponds to fairseq velocity * -1
velocity = mu_ * velocity * correctedLr + lr_ * grad;
af::eval(velocity);
grad = grad * lr_ + velocity * mu_;
data = data - grad;
af::eval(data);
}
oldLr_ = lr_;
}

std::string NAGOptimizer::prettyString() const {
std::ostringstream ss;
ss << "NAG (lr=" << lr_ << " ); (previous lr=" << oldLr_ << ");";

if (wd_ != 0) {
ss << " (weight decay=" << wd_ << ");";
}
ss << " (Nesterov momentum=" << mu_ << ")";
return ss.str();
}

} // namespace fl
55 changes: 55 additions & 0 deletions flashlight/fl/optim/NAGOptimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "flashlight/fl/optim/Optimizers.h"

namespace fl {

/** Nesterov Accelerated Gradient with modification for the changeable lr through
* time. Implements the version from
* https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/optim/nag.py#L43
*/
class NAGOptimizer : public FirstOrderOptimizer {
private:
FL_SAVE_LOAD_WITH_BASE(
FirstOrderOptimizer,
mu_,
wd_,
velocities_,
oldLr_)

NAGOptimizer() = default; // Intentionally private

float mu_;
float wd_;
std::vector<af::array> velocities_;
float oldLr_;

public:
/** NAGOptimizer constructor.
* @param parameters The parameters from e.g. `model.parameters()`
* @param learningRate The learning rate.
* @param momentum The momentum.
* @param weightDecay The amount of L2 weight decay to use for all the
* parameters.
*/
NAGOptimizer(
const std::vector<Variable>& parameters,
float learningRate,
float momentum = 0.99,
float weightDecay = 0);

void step() override;

std::string prettyString() const override;
};

} // namespace fl

CEREAL_REGISTER_TYPE(fl::NAGOptimizer)
1 change: 1 addition & 0 deletions flashlight/fl/optim/optim.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flashlight/fl/optim/AdadeltaOptimizer.h"
#include "flashlight/fl/optim/AdagradOptimizer.h"
#include "flashlight/fl/optim/AdamOptimizer.h"
#include "flashlight/fl/optim/NAGOptimizer.h"
#include "flashlight/fl/optim/NovogradOptimizer.h"
#include "flashlight/fl/optim/Optimizers.h"
#include "flashlight/fl/optim/RMSPropOptimizer.h"
Expand Down
7 changes: 7 additions & 0 deletions flashlight/fl/test/optim/OptimBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,16 @@ double adadelta() {
return optloop(opt, w);
}

double nag() {
auto w = Variable(af::randn(1, 10), true);
auto opt = NAGOptimizer({w}, 1e-3);
return optloop(opt, w);
}

int main() {
af::info();
TIME(sgd);
TIME(nag);
TIME(adam);
TIME(rmsprop);
TIME(adadelta);
Expand Down

0 comments on commit 36b242e

Please sign in to comment.