diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 6ee32b19932..450163a0e24 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -172,6 +172,7 @@ from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag +from keras.src.ops.numpy import inner from keras.src.ops.numpy import isclose from keras.src.ops.numpy import isfinite from keras.src.ops.numpy import isinf diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index dae8b8a297c..3e1381f8229 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -75,6 +75,7 @@ from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag +from keras.src.ops.numpy import inner from keras.src.ops.numpy import isclose from keras.src.ops.numpy import isfinite from keras.src.ops.numpy import isinf diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 6ee32b19932..450163a0e24 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -172,6 +172,7 @@ from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag +from keras.src.ops.numpy import inner from keras.src.ops.numpy import isclose from keras.src.ops.numpy import isfinite from keras.src.ops.numpy import isinf diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index dae8b8a297c..3e1381f8229 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -75,6 +75,7 @@ from keras.src.ops.numpy import hstack from keras.src.ops.numpy import identity from keras.src.ops.numpy import imag +from keras.src.ops.numpy import inner from keras.src.ops.numpy import isclose from keras.src.ops.numpy import isfinite from keras.src.ops.numpy import isinf diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index cc7d8d565c3..d0f8e1ee59a 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -1102,6 +1102,12 @@ def vdot(x1, x2): return jnp.vdot(x1, x2) +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.inner(x1, x2) + + def vstack(xs): return jnp.vstack(xs) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index 11edd49b7a7..98b9f99f5c8 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1006,6 +1006,15 @@ def vdot(x1, x2): return np.vdot(x1, x2) +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.inner(x1, x2) + + def vstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 5634e39f6a7..b90eb8b2d50 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2285,6 +2285,24 @@ def vdot(x1, x2): return tf.cast(dot(x1, x2), result_dtype) +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = dtypes.result_type(result_dtype, float) + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + x = tf.cond( + tf.math.logical_or( + tf.math.equal(tf.rank(x1), 0), + tf.math.equal(tf.rank(x2), 0), + ), + lambda: x1 * x2, + lambda: tf.tensordot(x1, x2, axes=[[-1], [-1]]), + ) + return tf.cast(x, result_dtype) + + def vstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 6a524a93c81..58ad1ffb6ff 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1487,6 +1487,20 @@ def vdot(x1, x2): return cast(torch.vdot(x1, x2), result_dtype) +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = dtypes.result_type(result_dtype, float) + + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.inner(x1, x2), result_dtype) + + def vstack(xs): xs = [convert_to_tensor(x) for x in xs] return torch.vstack(xs) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index e5ec6b7b1ac..98c0e6aa7af 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -5678,6 +5678,45 @@ def vdot(x1, x2): return backend.numpy.vdot(x1, x2) +class Inner(Operation): + def call(self, x1, x2): + return backend.numpy.inner(x1, x2) + + def compute_output_spec(self, x1, x2): + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor([], dtype=dtype) + + +@keras_export(["keras.ops.inner", "keras.ops.numpy.inner"]) +def inner(x1, x2): + """Return the inner product of two tensors. + + Ordinary inner product of vectors for 1-D tensors + (without complex conjugation), in higher dimensions + a sum product over the last axes. + + Multidimensional arrays are treated as vectors by flattening + all but their last axes. The resulting dot product is performed + over their last axes. + + Args: + x1: First input tensor. + x2: Second input tensor. The last dimension of `x1` and `x2` + must match. + + Returns: + Output tensor. The shape of the output is determined by + broadcasting the shapes of `x1` and `x2` after removing + their last axes. + """ + if any_symbolic_tensors((x1, x2)): + return Inner().symbolic_call(x1, x2) + return backend.numpy.inner(x1, x2) + + @keras_export(["keras.ops.vectorize", "keras.ops.numpy.vectorize"]) def vectorize(pyfunc, *, excluded=None, signature=None): """Turn a function into a vectorized function. diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 819ef39bca0..23337744d4e 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -299,6 +299,11 @@ def test_vdot(self): y = KerasTensor((None, 3, 3)) self.assertEqual(knp.vdot(x, y).shape, ()) + def test_inner(self): + x = KerasTensor((None,)) + y = KerasTensor((3,)) + self.assertEqual(knp.inner(x, y).shape, ()) + def test_where(self): condition = KerasTensor((2, None, 1)) x = KerasTensor((None, 1)) @@ -875,6 +880,11 @@ def test_vdot(self): y = KerasTensor((2, 3)) self.assertEqual(knp.vdot(x, y).shape, ()) + def test_inner(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.inner(x, y).shape, ()) + def test_where(self): condition = KerasTensor((2, 3)) x = KerasTensor((2, 3)) @@ -2975,6 +2985,12 @@ def test_vdot(self): self.assertAllClose(knp.vdot(x, y), np.vdot(x, y)) self.assertAllClose(knp.Vdot()(x, y), np.vdot(x, y)) + def test_inner(self): + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + self.assertAllClose(knp.inner(x, y), np.inner(x, y)) + self.assertAllClose(knp.Inner()(x, y), np.inner(x, y)) + def test_where(self): x = np.array([1, 2, 3]) y = np.array([4, 5, 6]) @@ -8249,6 +8265,26 @@ def test_vdot(self, dtypes): ) self.assertEqual(knp.Vdot().symbolic_call(x1, x2).dtype, expected_dtype) + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_inner(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.inner(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.inner(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Inner().symbolic_call(x1, x2).dtype, expected_dtype + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) )