Skip to content

Commit

Permalink
Improve keras.Variable by exposing docstrings and ensuring consiste…
Browse files Browse the repository at this point in the history
…ncy in the codebase (#20544)

* Improve `keras.Variable` by exposing docstrings and ensuring consistency in the codebase

* Fix CI

* Update docstrings
  • Loading branch information
james77777778 authored Nov 26, 2024
1 parent e0369f6 commit 553521e
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 53 deletions.
10 changes: 6 additions & 4 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import Variable
from keras.src.backend.common.variables import get_autocast_scope
from keras.src.backend.common.variables import is_float_dtype
from keras.src.backend.common.variables import is_int_dtype
Expand All @@ -35,25 +36,26 @@
# Import backend functions.
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
else:
raise ValueError(f"Unable to import backend : {backend()}")


BackendVariable = Variable # noqa: F405


@keras_export("keras.Variable")
class Variable(BackendVariable):
class Variable(BackendVariable): # noqa: F811
pass


Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from keras.src.backend.common import backend_utils
from keras.src.backend.common.dtypes import result_type
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import KerasVariable
from keras.src.backend.common.variables import Variable as KerasVariable
from keras.src.backend.common.variables import get_autocast_scope
from keras.src.backend.common.variables import is_float_dtype
from keras.src.backend.common.variables import is_int_dtype
Expand Down
10 changes: 5 additions & 5 deletions keras/src/backend/common/stateless_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class StatelessScope:
The values of variables to be used inside the scope
should be passed via the `state_mapping` argument, a
list of tuples `(k, v)` where `k` is a `KerasVariable`
list of tuples `(k, v)` where `k` is a `Variable`
and `v` is the intended value for this variable
(a backend tensor).
Expand Down Expand Up @@ -39,21 +39,21 @@ def __init__(
initialize_variables=True,
):
from keras.src import backend
from keras.src.backend.common.variables import KerasVariable
from keras.src.backend.common.variables import Variable

self.collect_losses = collect_losses
self.initialize_variables = initialize_variables
self.losses = []
self.state_mapping = {}
state_mapping = state_mapping or {}
for k, v in state_mapping:
if not isinstance(k, KerasVariable):
if not isinstance(k, Variable):
raise ValueError(
"Invalid reference variable in StatelessScope: "
"all keys in argument `mapping` must be KerasVariable "
"all keys in argument `mapping` must be Variable "
f"instances. Received instead: {k}"
)
if isinstance(v, KerasVariable):
if isinstance(v, Variable):
v = backend.cast(v.value, dtype=k.dtype)
else:
v = backend.convert_to_tensor(v, dtype=k.dtype)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/common/stateless_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_invalid_key_in_state_mapping(self):
value1 = ops.ones(shape=(2,))

with self.assertRaisesRegex(
ValueError, "all keys in argument `mapping` must be KerasVariable"
ValueError, "all keys in argument `mapping` must be Variable"
):
StatelessScope(state_mapping=[(invalid_key, value1)])

Expand Down
65 changes: 47 additions & 18 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from keras.src.utils.naming import auto_name


class KerasVariable:
class Variable:
"""Represents a backend-agnostic variable in Keras.
A `Variable` acts as a container for state. It holds a tensor value and can
Expand All @@ -30,17 +30,25 @@ class KerasVariable:
dtype type (`"float32"` if never configured).
trainable: Optional. Boolean indicating if variable is trainable.
Defaults to `True`.
autocast: Optional. Boolean indicating whether the variable supports
autocasting. If `True`, the layer may first convert the variable
to the compute data type when accessed. Defaults to `True`.
aggregation: Optional. String specifying how a distributed variable will
be aggregated. This serves as a semantic annotation, to be taken
into account by downstream backends or users. Defaults to `"mean"`.
name: Optional. A unique name for the variable. Automatically generated
if not set.
Attributes:
name: The name of the variable (string).
path: The path of the variable within the Keras model or layer (string).
dtype: The data type of the variable (string).
shape: The shape of the variable (tuple of integers).
ndim: The number of dimensions of the variable (integer).
dtype: The data type of the variable (string).
trainable: Whether the variable is trainable (boolean).
autocast: Whether the variable supports autocasting (boolean).
aggregation: How a distributed variable will be aggregated (string).
value: The current value of the variable (NumPy array or tensor).
name: The name of the variable (string).
path: The path of the variable within the Keras model or layer (string).
Examples:
Expand Down Expand Up @@ -101,20 +109,19 @@ def __init__(
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
f"Received: aggregation={aggregation}"
)
self.name = name
self._name = name
parent_path = current_path()
if parent_path:
self.path = current_path() + "/" + self.name
self._path = current_path() + "/" + name
else:
self.path = self.name
dtype = standardize_dtype(dtype)
self._dtype = dtype
self._path = name
self._dtype = standardize_dtype(dtype)
self._shape = None
self._initializer = None
self._regularizer = None
self._constraint = None
self._trainable = trainable
self._autocast = autocast
self._trainable = bool(trainable)
self._autocast = bool(autocast)
self._aggregation = aggregation
# `self._overwrite_with_gradient` is an internal property to determine
# whether this variable should be overwritten by the computed gradient.
Expand Down Expand Up @@ -163,7 +170,7 @@ def __init__(
self._initialize_with_initializer(initializer)
else:
self._initialize(initializer)
self._shape = tuple(self._value.shape)
self._shape = self._validate_shape(self._value.shape)
self._ndim = len(self._shape)

def _deferred_initialize(self):
Expand Down Expand Up @@ -201,10 +208,12 @@ def numpy(self):

@property
def aggregation(self):
"""The strategy for aggregating this variable."""
return self._aggregation

@property
def value(self):
"""The current value of the variable (numpy array or backend tensor)."""
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
Expand Down Expand Up @@ -246,30 +255,46 @@ def assign_sub(self, value):

@property
def dtype(self):
"""The data type of the variable."""
autocast_scope = get_autocast_scope()
if (
self._autocast
and autocast_scope is not None
and is_float_dtype(self._dtype)
):
return autocast_scope.dtype
return self._dtype
dtype = autocast_scope.dtype
else:
dtype = self._dtype
return backend.standardize_dtype(dtype)

@property
def shape(self):
"""The shape of the variable."""
return self._shape

@property
def ndim(self):
"""The number of dimensions of the variable."""
return self._ndim

@property
def trainable(self):
"""Whether the variable is trainable."""
return self._trainable

@trainable.setter
def trainable(self, value):
self._trainable = value
self._trainable = bool(value)

@property
def name(self):
"""The name of the variable."""
return self._name

@property
def path(self):
"""The path of the variable within the Keras model or layer."""
return self._path

@property
def overwrite_with_gradient(self):
Expand Down Expand Up @@ -326,9 +351,13 @@ def constraint(self, value):
self._constraint = value

def __repr__(self):
value = None
if hasattr(self, "_value") and self._value is not None:
value = backend.core.convert_to_numpy(self._value)
value_str = f", value={value}" if value is not None else ""
return (
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
f"path={self.path}>"
f"<Variable path={self.path}, shape={self.shape}, "
f"dtype={self.dtype}{value_str}>"
)

def _initialize(self, value):
Expand Down Expand Up @@ -573,7 +602,7 @@ def get_autocast_scope():
class AutocastScope:
"""Context manager that enables the autocasting of float variables.
Under this context manager, float `KerasVariables`s will be cast to `dtype`
Under this context manager, float `Variables`s will be cast to `dtype`
(note that `dtype` must also be float).
"""

Expand Down
34 changes: 22 additions & 12 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from keras.src import initializers
from keras.src.backend.common import dtypes
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import KerasVariable
from keras.src.backend.common.variables import shape_equal
from keras.src.backend.common.variables import standardize_dtype
from keras.src.backend.common.variables import standardize_shape
Expand All @@ -17,7 +16,7 @@


class VariableInitializationTest(test_case.TestCase):
"""Tests for KerasVariable.__init__()"""
"""Tests for Variable.__init__()"""

def test_deferred_initialization(self):
"""Tests deferred initialization of variables."""
Expand Down Expand Up @@ -73,17 +72,16 @@ def test_variable_initialize(self):
self.assertAllClose(v.value, init_value)

def test_variable_without_shape_from_callable_initializer(self):
"""Test that KerasVariable raises error
"""Test that Variable raises error
if shape is not provided for callable initializer."""
with self.assertRaisesRegex(
ValueError, "When creating a Variable from an initializer"
):
KerasVariable(initializer=lambda: np.ones((2, 2)))
backend.Variable(initializer=lambda: np.ones((2, 2)))


class VariablePropertiesTest(test_case.TestCase):
"""Tests for KerasVariable._deferred_initialize
KerasVariable._maybe_autocast"""
"""Tests for Variable._deferred_initialize Variable._maybe_autocast"""

def test_deferred_assignment(self):
"""Tests deferred assignment to variables."""
Expand Down Expand Up @@ -204,10 +202,12 @@ def test_name_validation(self):
with self.assertRaisesRegex(
ValueError, "Argument `name` must be a string"
):
KerasVariable(initializer=initializers.RandomNormal(), name=12345)
backend.Variable(
initializer=initializers.RandomNormal(), name=12345
)

with self.assertRaisesRegex(ValueError, "cannot contain character `/`"):
KerasVariable(
backend.Variable(
initializer=initializers.RandomNormal(), name="invalid/name"
)

Expand Down Expand Up @@ -272,8 +272,7 @@ def test_overwrite_with_gradient_setter(self):


class VariableNumpyValueAndAssignmentTest(test_case.TestCase):
"""tests for KerasVariable.numpy(), KerasVariable.value()
and KerasVariable.assign()"""
"""tests for Variable.numpy(), Variable.value() and Variable.assign()"""

def test_variable_numpy(self):
"""Test retrieving the value of a variable as a numpy array."""
Expand Down Expand Up @@ -373,10 +372,21 @@ def test_variable_repr(self):
"""Test the string representation of a variable."""
v = backend.Variable(initializer=np.array([1, 2, 3]), name="test_var")
expected_repr = (
"<KerasVariable shape=(3,), dtype=float32, path=test_var>"
"<Variable path=test_var, shape=(3,), dtype=float32, "
"value=[1. 2. 3.]>"
)
self.assertEqual(repr(v), expected_repr)

# Test with `backend.StatelessScope()`
with backend.StatelessScope():
v = backend.Variable(
initializer="zeros", shape=(3,), name="test_var"
)
expected_repr = (
"<Variable path=test_var, shape=(3,), dtype=float32>"
)
self.assertEqual(repr(v), expected_repr)

def test_variable_getitem(self):
"""Test getting an item from a variable."""
v = backend.Variable(initializer=np.array([1, 2, 3]))
Expand Down Expand Up @@ -408,7 +418,7 @@ def test_variable_array(self):


class VariableOpsCorrectnessTest(test_case.TestCase):
"""Tests for operations on KerasVariable."""
"""Tests for operations on Variable."""

def test_int(self):
v = backend.Variable(initializer=np.array(-1.1))
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def _purge_model_variables(
During JAX training, since the training function are stateless, we have
to pass in and get the model weights over and over, during which the
copy of the weights that attached to the KerasVariable are still and
copy of the weights that attached to the Variable are still and
occupying extra memory. We remove those variable to save memory (for
better memory utilization) at the beginning of the epoch, and reattach
the value back to variables at the end of the epoch, via
Expand Down
7 changes: 3 additions & 4 deletions keras/src/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import tensorflow as tf

from keras.src import backend
from keras.src.backend.common import KerasVariable
from keras.src.backend.tensorflow.trackable import KerasAutoTrackable
from keras.src.optimizers import base_optimizer

Expand Down Expand Up @@ -46,7 +45,7 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
)

def assign(self, variable, value):
if isinstance(variable, KerasVariable):
if isinstance(variable, backend.Variable):
variable = variable.value
value = tf.cast(value, variable.dtype)
if isinstance(value, tf.IndexedSlices):
Expand All @@ -55,7 +54,7 @@ def assign(self, variable, value):
variable.assign(value)

def assign_add(self, variable, value):
if isinstance(variable, KerasVariable):
if isinstance(variable, backend.Variable):
variable = variable.value
value = tf.cast(value, variable.dtype)
if isinstance(value, tf.IndexedSlices):
Expand All @@ -64,7 +63,7 @@ def assign_add(self, variable, value):
variable.assign_add(value)

def assign_sub(self, variable, value):
if isinstance(variable, KerasVariable):
if isinstance(variable, backend.Variable):
variable = variable.value
value = tf.cast(value, variable.dtype)
if isinstance(value, tf.IndexedSlices):
Expand Down
Loading

0 comments on commit 553521e

Please sign in to comment.