forked from EleutherAI/gpt-neo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizers.py
176 lines (143 loc) · 6.48 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
def clip_by_global_norm(grads, clip_norm):
"""Clip the grads by global norm."""
global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
clipped_grads = [None if t is None else t * multiplier for t in grads]
return clipped_grads, global_norm
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
"""Creates and returns an optimizer training op."""
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
if inp_var_grads is None:
var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
else:
var_grads = inp_var_grads
# Cast to full precision
var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
# decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
end_step = params.get("lr_decay_end", params["train_steps"])
if params["lr_decay"] == "linear":
learning_rate = tf.train.polynomial_decay(
learning_rate,
global_step,
end_step,
end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
power=1.0,
cycle=False)
elif params["lr_decay"] == "cosine":
learning_rate = tf.train.cosine_decay(
learning_rate,
global_step,
end_step,
alpha=0.1 # Alpha is min lr value as a fraction of init lr.
)
if params["warmup_steps"] > 0:
global_steps_int = tf.cast(global_step, tf.int32)
warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
dtype = variable_dtype.slice_dtype
global_steps_float = tf.cast(global_steps_int, dtype)
warmup_steps_float = tf.cast(warmup_steps_int, dtype)
warmup_percent_done = global_steps_float / warmup_steps_float
warmup_learning_rate = learning_rate * warmup_percent_done
is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
learning_rate = ((1.0 - is_warmup) * learning_rate +
is_warmup * warmup_learning_rate)
learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
mtf.scalar_summary("lr", learning_rate)
if params["opt_name"].lower() == "adam":
optimizer = AdamWeightDecayOptimizer(
learning_rate=learning_rate,
weight_decay_rate=params["weight_decay"],
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"],
exclude_from_weight_decay=["norm", "bias"],
variable_dtype=variable_dtype
)
else:
optimizer = mtf.optimize.AdafactorOptimizer(
learning_rate=params["lr"],
decay_rate=params["weight_decay"],
beta1=params["beta1"],
epsilon1=params["ada_epsilon1"],
epsilon2=params["ada_epsilon2"]
)
if params["gradient_clipping"] is not None:
(var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
return learning_rate, update_ops, var_grads_fp
class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
def __init__(self,
learning_rate,
weight_decay_rate=0.0,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=None,
variable_dtype=None):
"""Constructs a AdamWeightDecayOptimizer."""
self.learning_rate = learning_rate
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.exclude_from_weight_decay = exclude_from_weight_decay
self.variable_dtype = variable_dtype
def apply_grad(self, grad, var):
"""See base class."""
if grad is None:
tf.logging.warning("Gradient is None for variable %s" % var.name)
return []
grad = mtf.to_float(grad)
assignments = []
m = mtf.get_variable(
var.mesh, var.name + "/adam_m", var.shape,
initializer=tf.zeros_initializer(),
# master_dtype=self.variable_dtype.master_dtype,
# slice_dtype=self.variable_dtype.slice_dtype,
# activation_dtype=self.variable_dtype.activation_dtype,
trainable=False)
v = mtf.get_variable(
var.mesh, var.name + "/adam_v", var.shape,
initializer=tf.zeros_initializer(),
# master_dtype=self.variable_dtype.master_dtype,
# slice_dtype=self.variable_dtype.slice_dtype,
# activation_dtype=self.variable_dtype.activation_dtype,
trainable=False)
# Standard Adam update.
next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
update = next_m / (mtf.sqrt(next_v) + self.epsilon)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if self._do_use_weight_decay(var.name):
update += mtf.to_float(var.value) * self.weight_decay_rate
update_with_lr = self.learning_rate * update
var_update = mtf.assign_sub(var, update_with_lr)
assignments.extend(
[var_update,
mtf.assign(m, next_m),
mtf.assign(v, next_v)])
return assignments
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay_rate:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True