Skip to content

Commit

Permalink
Better input validation for InputLayer with input_tensor provided
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 26, 2024
1 parent b6d305f commit 4ca4345
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 41 deletions.
87 changes: 62 additions & 25 deletions keras/src/layers/core/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
):
# TODO: support for ragged.
super().__init__(name=name)

if "input_shape" in kwargs:
warnings.warn(
"Argument `input_shape` is deprecated. Use `shape` instead."
Expand All @@ -30,40 +31,76 @@ def __init__(
if "batch_input_shape" in kwargs:
batch_shape = kwargs.pop("batch_input_shape")

if shape is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `shape` and `batch_shape` at the "
"same time."
)
if batch_size is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `batch_size` and `batch_shape` at the "
"same time."
)
if shape is None and batch_shape is None:
raise ValueError("You must pass a `shape` argument.")
if input_tensor is not None:
if not isinstance(input_tensor, backend.KerasTensor):
raise ValueError(
"Argument `input_tensor` must be a KerasTensor. "
f"Received invalid type: input_tensor={input_tensor} "
f"(of type {type(input_tensor)})"
)
if batch_size is not None:
if (
len(input_tensor.shape) < 1
or input_tensor.shape[0] != batch_size
):
raise ValueError(
"When providing the `input_tensor` argument, you "
"cannot provide an incompatible `batch_size` argument."
)
if shape is not None:
if (
len(shape) != len(input_tensor.shape) - 1
or shape != input_tensor.shape[1:]
):
raise ValueError(
"When providing the `input_tensor` argument, you "
"cannot provide an incompatible `shape` argument."
)
if batch_shape is not None and batch_shape != input_tensor.shape:
raise ValueError(
"When providing the `input_tensor` argument, you "
"cannot provide an incompatible `batch_shape` argument."
)
if dtype is not None and input_tensor.dtype != dtype:
raise ValueError(
"When providing the `input_tensor` argument, you "
"cannot provide an incompatible `dtype` argument."
)
if sparse is not None and input_tensor.sparse != sparse:
raise ValueError(
"When providing the `input_tensor` argument, you "
"cannot provide an incompatible `sparse` argument."
)
batch_shape = input_tensor.shape
dtype = input_tensor.dtype
sparse = input_tensor.sparse
else:
if shape is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `shape` and `batch_shape` at the "
"same time."
)
if batch_size is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `batch_size` and `batch_shape` "
"at the same time."
)
if shape is None and batch_shape is None:
raise ValueError("You must pass a `shape` argument.")

if shape is not None:
shape = backend.standardize_shape(shape)
batch_shape = (batch_size,) + shape

if shape is not None:
shape = backend.standardize_shape(shape)
batch_shape = (batch_size,) + shape
self._batch_shape = backend.standardize_shape(batch_shape)
self._dtype = backend.standardize_dtype(dtype)

self.sparse = bool(sparse)
if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:
raise ValueError(
"`sparse=True` is not supported with backend: "
f"{backend.backend()}"
)

if input_tensor is not None:
if not isinstance(input_tensor, backend.KerasTensor):
raise ValueError(
"Argument `input_tensor` must be a KerasTensor. "
f"Received invalid type: input_tensor={input_tensor} "
f"(of type {type(input_tensor)})"
)
else:
if input_tensor is None:
input_tensor = backend.KerasTensor(
shape=batch_shape, dtype=dtype, sparse=sparse, name=name
)
Expand Down
69 changes: 56 additions & 13 deletions keras/src/layers/core/input_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,20 @@ def test_input_tensor_error(self):
# Testing happy path for layer with input tensor
def testing_input_tensor(self):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
input_tensor = KerasTensor(shape=input_shape, dtype=dtype)

values = InputLayer(
shape=input_shape,
batch_size=batch_size,
layer = InputLayer(
input_tensor=input_tensor,
dtype=dtype,
)

self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output, input_tensor)
self.assertEqual(values.output.ndim, input_tensor.ndim)
self.assertEqual(values.output.dtype, dtype)
self.assertEqual(layer.dtype, dtype)
self.assertEqual(layer.batch_shape, (2, 3))
self.assertEqual(layer.trainable, True)
self.assertIsInstance(layer.output, KerasTensor)
self.assertEqual(layer.output, input_tensor)
self.assertEqual(layer.output.ndim, input_tensor.ndim)
self.assertEqual(layer.output.dtype, dtype)

def test_input_shape_deprecated(self):
input_shape = (2, 3)
Expand Down Expand Up @@ -135,3 +130,51 @@ def test_call_method(self):
def test_numpy_shape(self):
# non-python int type shapes should be ok
InputLayer(shape=(np.int64(32),))

def test_invalid_arg_combinations(self):
input_tensor = KerasTensor(shape=(2, 3), dtype="float32")

with self.assertRaisesRegex(
ValueError, "cannot provide an incompatible `shape`"
):
_ = InputLayer(
shape=(2, 4),
input_tensor=input_tensor,
)
with self.assertRaisesRegex(
ValueError, "cannot provide an incompatible `batch_shape`"
):
_ = InputLayer(
batch_shape=(2, 4),
input_tensor=input_tensor,
)
with self.assertRaisesRegex(
ValueError, "cannot provide an incompatible `batch_size`"
):
_ = InputLayer(
batch_size=5,
input_tensor=input_tensor,
)
with self.assertRaisesRegex(
ValueError, "cannot provide an incompatible `dtype`"
):
_ = InputLayer(
dtype="float16",
input_tensor=input_tensor,
)
with self.assertRaisesRegex(
ValueError, "cannot provide an incompatible `sparse`"
):
_ = InputLayer(
sparse=True,
input_tensor=input_tensor,
)

# This works
_ = InputLayer(
shape=(3,),
batch_size=2,
sparse=False,
dtype="float32",
input_tensor=input_tensor,
)
3 changes: 0 additions & 3 deletions keras/src/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
)
inputs = Input(
tensor=input_tensors,
batch_shape=input_tensors.shape,
dtype=input_tensors.dtype,
name=input_name,
)
new_layers = [inputs] + new_layers
else:
if input_batch_shape is not None:
inputs = Input(
tensor=input_tensors,
batch_shape=input_batch_shape,
dtype=input_dtype,
name=input_name,
Expand Down

0 comments on commit 4ca4345

Please sign in to comment.