Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add layerwise learning rate for adamw #35569

Merged
merged 6 commits into from
Sep 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/operators/optimizers/adam_op.cc
Original file line number Diff line number Diff line change
@@ -236,6 +236,10 @@ class AdamWOpMaker : public AdamOpMaker {
public:
void Make() {
AdamOpMaker::Make();
AddAttr<float>("lr_ratio",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add this argument to adam? is that adamw and adam share the same .cc file ?

in this case, adamw should have its own .cc file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdamWOpMaker inherits AdamOpMaker, and they use the same InferShape function of AdamOp.
In this case, 'lr_ratio' has no effect on Adam.

"(float, default 1.0) "
"layerwise learning rate decay")
.SetDefault(1.0f);
AddAttr<float>("coeff",
"(float, default 0.01) "
"coeff of the weight decay")
68 changes: 34 additions & 34 deletions paddle/fluid/operators/optimizers/adamw_op.cu
Original file line number Diff line number Diff line change
@@ -20,17 +20,17 @@ namespace operators {

template <typename T, typename MT>
__global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT beta1_pow_, MT beta2_pow_, const MT* moment1,
MT* moment1_out, const MT* moment2,
MT* moment2_out, const MT* lr_, const T* grad,
const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT lr_ratio, MT beta1_pow_, MT beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out,
const MT* lr_, const T* grad, const T* param,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

@@ -43,9 +43,9 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
@@ -57,18 +57,16 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
}

template <typename T, typename MT>
__global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
const MT* beta1_pow_, const MT* beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out,
const MT* lr_, const T* grad, const T* param,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_;
__global__ void AdamWKernelMEM(
MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_,
const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2,
MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

@@ -81,9 +79,9 @@ __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
@@ -103,16 +101,16 @@ __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_,

template <typename T, typename MT>
__global__ void SparseAdamWCUDAKernelREG(
MT beta1, MT beta2, MT epsilon, MT coeff, const MT beta1_pow,
MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow,
const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_,
MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_,
T* param_out_, const MT* master_param, MT* master_param_out,
const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode,
int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_;
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

@@ -130,9 +128,9 @@ __global__ void SparseAdamWCUDAKernelREG(
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

// Write back to global memory
mom1_out_[id] = mom1;
@@ -165,7 +163,9 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float coeff = ctx.Attr<float>("coeff");

MPDType coeff = static_cast<MPDType>(ctx.Attr<float>("coeff"));
MPDType lr_ratio = static_cast<MPDType>(ctx.Attr<float>("lr_ratio"));

auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
@@ -301,7 +301,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
@@ -318,7 +318,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
@@ -377,7 +377,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {

SparseAdamWCUDAKernelREG<
T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
@@ -395,7 +395,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
}
} else {
SparseAdamWFunctor<T, GPUAdamW, MPDType> functor(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
25 changes: 15 additions & 10 deletions paddle/fluid/operators/optimizers/adamw_op.h
Original file line number Diff line number Diff line change
@@ -32,12 +32,13 @@ template <typename T>
class AdamWFunctor<T, CPUAdamW> {
private:
const T coeff_;
const T lr_ratio_;
const T* lr_;
T* param_;

public:
AdamWFunctor(const T coeff, const T* lr, T* param)
: coeff_(coeff), lr_(lr), param_(param) {}
AdamWFunctor(const T coeff, const T lr_ratio, const T* lr, T* param)
: coeff_(coeff), lr_ratio_(lr_ratio), lr_(lr), param_(param) {}

inline HOSTDEVICE void operator()(size_t numel) const {
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{
@@ -46,7 +47,7 @@ class AdamWFunctor<T, CPUAdamW> {
T lr = *lr_;

// Calculation
param = param * (1 - lr * coeff_);
param -= lr * lr_ratio_ * coeff_ * param;
}
};

@@ -60,6 +61,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
MT beta2_;
MT epsilon_;
MT coeff_;
MT lr_ratio_;

const MT* beta1_pow_;
const MT* beta2_pow_;
@@ -80,7 +82,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
bool lazy_mode_;

public:
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff,
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio,
const MT* beta1_pow, const MT* beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr,
const T* grad, const T* param, T* param_out,
@@ -91,6 +93,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
beta2_(beta2),
epsilon_(epsilon),
coeff_(coeff),
lr_ratio_(lr_ratio),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
@@ -112,21 +115,21 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
// The following code is the same as dense
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT lr = *lr_;
MT lr = *lr_ * lr_ratio_;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);

// Calculation
MT wd = static_cast<MT>(1.0) - coeff_ * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff_ * p;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));

// Write back to global memory
moment1_out_[i] = mom1;
@@ -187,6 +190,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
}

T coeff = static_cast<T>(ctx.Attr<float>("coeff"));
T lr_ratio = static_cast<T>(ctx.Attr<float>("lr_ratio"));
auto* lr = ctx.Input<LoDTensor>("LearningRate");

LoDTensor* param;
@@ -198,7 +202,8 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param"));
}

AdamWFunctor<T, CPUAdamW> functor(coeff, lr->data<T>(), param->data<T>());
AdamWFunctor<T, CPUAdamW> functor(coeff, lr_ratio, lr->data<T>(),
param->data<T>());
functor(param->numel());

AdamOpKernel<DeviceContext, T>::Compute(ctx);
87 changes: 87 additions & 0 deletions python/paddle/fluid/tests/unittests/test_adamw_op.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
import paddle
import numpy as np
import paddle.fluid as fluid
from functools import partial


class TestAdamWOp(unittest.TestCase):
@@ -148,5 +149,91 @@ def test_adamw_op_dygraph(self):
adam.clear_gradients()


def simple_lr_setting(param, decay_rate, n_layers):
if "fc_0" in param.name or "linear_1" in param.name:
depth = int(param.name.split("_")[2]) + 1
elif "fc_1" in param.name or "linear_2" in param.name:
depth = int(param.name.split("_")[2]) + 2
else:
depth = 0

return decay_rate**(n_layers + 2 - depth)


class TestAdamWOpLayerwiseLR(TestAdamWOp):
def test_adamw_op_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear1 = paddle.nn.Linear(13, 8)
linear2 = paddle.nn.Linear(8, 5)

simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)

adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=[{
'params': linear1.parameters()
}, {
'params': linear2.parameters(),
}],
apply_decay_param_fun=lambda name: True,
weight_decay=0.01,
lr_ratio=simple_lr_fun)

for _ in range(2):
a1 = linear1(a)
out = linear2(a1)
out.backward()
adam.step()
adam.clear_gradients()

def test_adamw_op(self):
paddle.enable_static()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
train_prog = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
x = fluid.data(name='x', shape=[None, 10], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')

fc1 = fluid.layers.fc(input=x, size=32, act=None)
prediction = fluid.layers.fc(input=fc1, size=1, act=None)
cost = fluid.layers.square_error_cost(input=prediction, label=y)
avg_cost = fluid.layers.mean(cost)

simple_lr_fun = partial(
simple_lr_setting, decay_rate=0.8, n_layers=2)

beta1 = fluid.layers.create_global_var(
shape=[1], value=0.85, dtype='float32', persistable=True)
beta2 = fluid.layers.create_global_var(
shape=[1], value=0.95, dtype='float32', persistable=True)
betas = [beta1, beta2]
opt = paddle.optimizer.AdamW(
learning_rate=1e-5,
beta1=beta1,
beta2=beta2,
weight_decay=0.01,
epsilon=1e-8,
lr_ratio=simple_lr_fun)
opt.minimize(avg_cost)

exe = fluid.Executor(place)
exe.run(startup)
for _ in range(2):
inputs = np.random.random(size=[8, 10]).astype('float32')
outputs = np.random.random(size=[8, 1]).astype('float32')
rets = exe.run(train_prog,
feed={"x": inputs,
"y": outputs},
fetch_list=[avg_cost])
assert rets[0] is not None

paddle.disable_static()


if __name__ == "__main__":
unittest.main()
19 changes: 18 additions & 1 deletion python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
from ..fluid import framework
from ..fluid.framework import Variable
from ..fluid.dygraph import base as imperative_base
from collections import Callable
import paddle

_C_ops = core.ops
@@ -63,6 +64,10 @@ class AdamW(Adam):
epsilon (float, optional): A small float value for numerical stability.
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
lr_ratio (function|None, optional): If it is not None,
the learning rate will be updated with layerwise learning rate ratio.
Otherwise, the learning rate is the original.
Default: None.
apply_decay_param_fun (function|None, optional): If it is not None,
only tensors that makes apply_decay_param_fun(Tensor.name)==True
will be updated with weight decay. It only works when we want to specify tensors.
@@ -140,6 +145,7 @@ def __init__(self,
epsilon=1e-8,
parameters=None,
weight_decay=0.01,
lr_ratio=None,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
@@ -163,6 +169,12 @@ def __init__(self,
self._apply_decay_param_fun = apply_decay_param_fun
self._coeff = coeff
self._lr_to_coeff = dict()
if lr_ratio is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add explanation for the new lr_ration argument, which should follow the explanation for "apply_decay_param_fun"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

assert isinstance(lr_ratio, Callable)
if core.is_compiled_with_xpu() or core.is_compiled_with_npu():
raise NotImplementedError(
"'lr_ratio' is unimplemented in XPU and NPU")
self._lr_ratio = lr_ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should think about how many kernel will be affected by "lr_ratio".
if you only want the lr_ratio the affect gpu and cpu kernel, you should raise an Unimplement Error for xpu and npu here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


super(AdamW, self).__init__(
learning_rate=learning_rate,
@@ -278,6 +290,8 @@ def _append_optimize_op(self, block, param_and_grad):

# create the adamw optimize op
if framework.in_dygraph_mode():
lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio(
param_and_grad[0])

_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
@@ -288,7 +302,8 @@ def _append_optimize_op(self, block, param_and_grad):
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff)
1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff,
"lr_ratio", lr_ratio_)

return None

@@ -321,6 +336,8 @@ def _append_optimize_op(self, block, param_and_grad):
"multi_precision": find_master,
"with_decay": with_decay,
"coeff": self._coeff,
"lr_ratio": 1.
if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
}

if isinstance(self._beta1, Variable):