Skip to content

Commit

Permalink
Propagate the aggregation property when creating a tf.Variable (#…
Browse files Browse the repository at this point in the history
…20541)

* Fix TF variable aggregation

* Add `none` to aggregation
  • Loading branch information
james77777778 authored Nov 23, 2024
1 parent 5bec656 commit 28d39c0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
4 changes: 2 additions & 2 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def __init__(
"cannot contain character `/`. "
f"Received: name={name}"
)
if aggregation not in ("mean", "sum", "only_first_replica"):
if aggregation not in ("none", "mean", "sum", "only_first_replica"):
raise ValueError(
"Invalid valid for argument `aggregation`. Expected "
"one of {'mean', 'sum', 'only_first_replica'}. "
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
f"Received: aggregation={aggregation}"
)
self.name = name
Expand Down
16 changes: 15 additions & 1 deletion keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def handle(self):

def _initialize(self, value):
self._value = tf.Variable(
value, dtype=self._dtype, trainable=self.trainable, name=self.name
value,
dtype=self._dtype,
trainable=self.trainable,
name=self.name,
aggregation=self._map_aggregation(self.aggregation),
)

def _initialize_with_initializer(self, initializer):
Expand All @@ -45,6 +49,7 @@ def _initialize_with_initializer(self, initializer):
dtype=self._dtype,
trainable=self.trainable,
name=self.name,
aggregation=self._map_aggregation(self.aggregation),
)

def _deferred_initialize(self):
Expand Down Expand Up @@ -113,6 +118,15 @@ def _export_to_saved_model_graph(
def _write_object_proto(self, proto, options):
return self.value._write_object_proto(proto, options)

def _map_aggregation(self, aggregation):
mapping = {
"none": tf.VariableAggregation.NONE,
"sum": tf.VariableAggregation.SUM,
"mean": tf.VariableAggregation.MEAN,
"only_first_replica": tf.VariableAggregation.ONLY_FIRST_REPLICA,
}
return mapping[aggregation]


def convert_to_tensor(x, dtype=None, sparse=None):
if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse:
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/tensorflow/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,16 @@ def test_epoch_iterator(self):
self.assertEqual(y.values[0].shape, [2, 4])
self.assertEqual(sample_weight.values[0].shape, [2])
self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])

def test_variable_aggregation(self):
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])

with strategy.scope():
x = np.random.random((4, 4))
v1 = backend.Variable(x, dtype="float32")
self.assertEqual(v1.aggregation, "mean")
self.assertEqual(v1.value.aggregation, tf.VariableAggregation.MEAN)

v2 = backend.Variable(x, dtype="float32", aggregation="sum")
self.assertEqual(v2.aggregation, "sum")
self.assertEqual(v2.value.aggregation, tf.VariableAggregation.SUM)
2 changes: 2 additions & 0 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def build(self, var_list):
shape=(),
dtype="int",
initializer=initializers.Zeros(),
aggregation="none",
name="step_counter",
)
self.dynamic_scale = self.add_variable(
shape=(),
dtype="float32",
initializer=initializers.Constant(self.initial_scale),
aggregation="none",
name="dynamic_scale",
)
self.inner_optimizer.build(var_list)
Expand Down

0 comments on commit 28d39c0

Please sign in to comment.