-
Notifications
You must be signed in to change notification settings - Fork 500
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce Nesterov SGD for transformer (where warmup stage is used) (#…
…229) Summary: Pull Request resolved: #229 - Add NAGOptimizer implementation from fairseq https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/optim/nag.py#L43 where specific lr schedule is used during warmup. - This speedup the convergence by a lot for Transformer LM models + better overall performance at the end of training. Reviewed By: vineelpratap Differential Revision: D24700925 fbshipit-source-id: a0d3780583c7c3961b3bbab28ccbb280e2419c50
- Loading branch information
1 parent
3fb7019
commit 36b242e
Showing
5 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "flashlight/fl/optim/NAGOptimizer.h" | ||
|
||
#include <cmath> | ||
|
||
using std::vector; | ||
|
||
namespace fl { | ||
|
||
NAGOptimizer::NAGOptimizer( | ||
const vector<Variable>& parameters, | ||
float learningRate, | ||
float momentum /* = 0 */, | ||
float weightDecay /* = 0 */) | ||
: FirstOrderOptimizer(parameters, learningRate), | ||
mu_(momentum), | ||
wd_(weightDecay), | ||
velocities_(), | ||
oldLr_(learningRate) { | ||
if (momentum <= 0) { | ||
throw std::runtime_error( | ||
"Invalid momentum for NAG optimizer, it should be > 0"); | ||
} | ||
velocities_.reserve(parameters.size()); | ||
for (const auto& parameter : parameters_) { | ||
velocities_.emplace_back( | ||
af::constant(0, parameter.dims(), parameter.type())); | ||
velocities_.back().eval(); | ||
} | ||
} | ||
|
||
void NAGOptimizer::step() { | ||
float correctedLr = lr_ / oldLr_; | ||
|
||
for (size_t i = 0; i < parameters_.size(); i++) { | ||
if (!parameters_[i].isGradAvailable()) { | ||
continue; | ||
} | ||
|
||
af::array& grad = parameters_[i].grad().array(); | ||
af::array& data = parameters_[i].array(); | ||
|
||
if (wd_ != 0) { | ||
// Weight decay term | ||
data = data * (1 - lr_ * wd_); | ||
} | ||
af::array& velocity = velocities_[i]; | ||
// this velocity corresponds to fairseq velocity * -1 | ||
velocity = mu_ * velocity * correctedLr + lr_ * grad; | ||
af::eval(velocity); | ||
grad = grad * lr_ + velocity * mu_; | ||
data = data - grad; | ||
af::eval(data); | ||
} | ||
oldLr_ = lr_; | ||
} | ||
|
||
std::string NAGOptimizer::prettyString() const { | ||
std::ostringstream ss; | ||
ss << "NAG (lr=" << lr_ << " ); (previous lr=" << oldLr_ << ");"; | ||
|
||
if (wd_ != 0) { | ||
ss << " (weight decay=" << wd_ << ");"; | ||
} | ||
ss << " (Nesterov momentum=" << mu_ << ")"; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace fl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "flashlight/fl/optim/Optimizers.h" | ||
|
||
namespace fl { | ||
|
||
/** Nesterov Accelerated Gradient with modification for the changeable lr through | ||
* time. Implements the version from | ||
* https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/optim/nag.py#L43 | ||
*/ | ||
class NAGOptimizer : public FirstOrderOptimizer { | ||
private: | ||
FL_SAVE_LOAD_WITH_BASE( | ||
FirstOrderOptimizer, | ||
mu_, | ||
wd_, | ||
velocities_, | ||
oldLr_) | ||
|
||
NAGOptimizer() = default; // Intentionally private | ||
|
||
float mu_; | ||
float wd_; | ||
std::vector<af::array> velocities_; | ||
float oldLr_; | ||
|
||
public: | ||
/** NAGOptimizer constructor. | ||
* @param parameters The parameters from e.g. `model.parameters()` | ||
* @param learningRate The learning rate. | ||
* @param momentum The momentum. | ||
* @param weightDecay The amount of L2 weight decay to use for all the | ||
* parameters. | ||
*/ | ||
NAGOptimizer( | ||
const std::vector<Variable>& parameters, | ||
float learningRate, | ||
float momentum = 0.99, | ||
float weightDecay = 0); | ||
|
||
void step() override; | ||
|
||
std::string prettyString() const override; | ||
}; | ||
|
||
} // namespace fl | ||
|
||
CEREAL_REGISTER_TYPE(fl::NAGOptimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters