Skip to content

v5.4: 100x – 10'000x More Accurate Cosine Distance

Compare
Choose a tag to compare
@ashvardanian ashvardanian released this 18 Sep 01:45
· 311 commits to main since this release

The cosine similarity is the most common and straightforward metric used in machine learning and information retrieval. Interestingly, there are multiple ways to shoot yourself in the foot when computing it. The cosine similarity is the inverse of the cosine distance, which is the cosine of the angle between two vectors.

$$\text{CosineSimilarity}(a, b) = \frac{a \cdot b}{\|a\| \cdot \|b\|}$$ $$\text{CosineDistance}(a, b) = 1 - \frac{a \cdot b}{\|a\| \cdot \|b\|}$$

In NumPy terms, SimSIMD implementation is similar to:

import numpy as np

def cos_numpy(a: np.ndarray, b: np.ndarray) -> float:
    ab, a2, b2 = np.dot(a, b), np.dot(a, a), np.dot(b, b) # Fused in SimSIMD
    if a2 == 0 and b2 == 0: result = 0                    # Same in SciPy
    elif ab == 0: result = 1                              # Division by zero error in SciPy
    else: result = 1 - ab / (sqrt(a2) * sqrt(b2))         # Bigger rounding error in SciPy
    return result

In SciPy, however, the cosine distance is computed as 1 - ab / np.sqrt(a2 * b2). It handles the edge case of a zero and non-zero argument pair differently, resulting in a division by zero error. It's not only less efficient, but also less accurate, given how the reciprocal square roots are computed. The C standard library provides the sqrt function, which is generally very accurate, but slow. The rsqrt in-hardware implementations are faster, but have different accuracy characteristics.

  • SSE rsqrtps and AVX vrsqrtps: $1.5 \times 2^{-12}$ maximal error.
  • AVX-512 vrsqrt14pd instruction: $2^{-14}$ maximal error.
  • NEON frsqrte instruction has no clear error bounds.

To overcome the limitations of the rsqrt instruction, SimSIMD uses the Newton-Raphson iteration to refine the initial estimate for high-precision floating-point numbers. It can be defined as:

$$x_{n+1} = x_n \cdot (3 - x_n \cdot x_n) / 2$$

On 1536-dimensional inputs on Intel Sapphire Rapids CPU a single such iteration can result in a 2-3 orders of magnitude relative error reduction:

Datatype NumPy Error SimSIMD w/out Iteration SimSIMD
bfloat16 1.89e-08 ± 1.59e-08 3.07e-07 ± 3.09e-07 3.53e-09 ± 2.70e-09
float16 1.67e-02 ± 1.44e-02 2.68e-05 ± 1.95e-05 2.02e-05 ± 1.39e-05
float32 2.21e-08 ± 1.65e-08 3.47e-07 ± 3.49e-07 3.77e-09 ± 2.84e-09
float64 0.00e+00 ± 0.00e+00 3.80e-07 ± 4.50e-07 1.35e-11 ± 1.85e-11

On Arm:

Datatype NumPy Error SimSIMD w/out Iteration SimSIMD
bfloat16 1.55e-09 ± 1.27e-09 2.79e-05 ± 3.60e-05 2.09e-08 ± 1.50e-08
float16 1.05e-05 ± 9.99e-06 4.97e-05 ± 4.33e-05 4.81e-05 ± 3.38e-05
float32 2.37e-09 ± 1.88e-09 1.79e-05 ± 1.69e-05 9.02e-09 ± 7.16e-09
float64 0.00e+00 ± 0.00e+00 2.54e-05 ± 2.32e-05 2.23e-13 ± 4.67e-13

Benchmarks

x86: Intel Sapphire Rapids

Baseline

+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
|    | Metric | NDim |  DType   |   Baseline Error    |    SimSIMD Error    |  Accurate Duration  |  Baseline Duration  |  SimSIMD Duration   | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0  | cosine |  11  | bfloat16 | 9.86e-09 ± 1.58e-08 | 3.35e-04 ± 4.42e-04 | 2.16e+04 ± 1.18e+03 | 2.42e+04 ± 2.79e+03 | 2.51e+03 ± 4.17e+02 |  9.90x ± 2.24x  |
| 1  | cosine |  11  | float16  | 1.46e-04 ± 1.83e-04 | 5.09e-04 ± 7.05e-04 | 2.16e+04 ± 1.27e+03 | 2.53e+04 ± 2.54e+03 | 1.15e+03 ± 9.31e+01 | 22.17x ± 1.76x  |
| 2  | cosine |  11  | float32  | 2.13e-08 ± 2.20e-08 | 2.69e-04 ± 4.08e-04 | 2.14e+04 ± 1.51e+03 | 2.37e+04 ± 3.52e+03 | 1.96e+03 ± 6.73e+03 | 23.21x ± 3.90x  |
| 3  | cosine |  11  | float64  | 0.00e+00 ± 0.00e+00 | 4.51e-04 ± 5.78e-04 | 2.57e+04 ± 1.16e+04 | 1.57e+04 ± 1.55e+03 | 1.51e+03 ± 9.03e+02 | 11.57x ± 2.21x  |
| 4  | cosine |  11  |   int8   | 0.00e+00 ± 0.00e+00 | 4.56e-04 ± 5.30e-04 | 1.59e+04 ± 6.32e+02 | 1.60e+04 ± 5.11e+02 | 1.72e+03 ± 6.12e+02 |  9.89x ± 1.86x  |
| 5  | cosine |  97  | bfloat16 | 6.71e-09 ± 7.90e-09 | 1.31e-04 ± 1.47e-04 | 2.14e+04 ± 9.71e+02 | 2.36e+04 ± 4.33e+02 | 2.47e+03 ± 3.95e+02 |  9.82x ± 1.71x  |
| 6  | cosine |  97  | float16  | 3.00e-05 ± 2.42e-05 | 1.00e-04 ± 7.79e-05 | 2.15e+04 ± 1.70e+03 | 2.70e+04 ± 2.02e+03 | 1.18e+03 ± 8.51e+01 | 22.89x ± 2.06x  |
| 7  | cosine |  97  | float32  | 6.84e-09 ± 5.72e-09 | 1.13e-04 ± 1.19e-04 | 2.19e+04 ± 1.84e+03 | 2.33e+04 ± 1.91e+03 | 1.04e+03 ± 9.38e+01 | 22.44x ± 2.38x  |
| 8  | cosine |  97  | float64  | 0.00e+00 ± 0.00e+00 | 9.69e-05 ± 1.54e-04 | 2.13e+04 ± 2.00e+03 | 1.54e+04 ± 1.39e+03 | 1.30e+03 ± 1.20e+02 | 11.92x ± 1.47x  |
| 9  | cosine |  97  |   int8   | 0.00e+00 ± 0.00e+00 | 1.14e-04 ± 1.33e-04 | 1.56e+04 ± 4.34e+02 | 1.60e+04 ± 3.64e+02 | 1.57e+03 ± 2.48e+02 | 10.43x ± 1.55x  |
| 10 | cosine | 1536 | bfloat16 | 1.55e-09 ± 1.27e-09 | 2.79e-05 ± 3.60e-05 | 2.78e+04 ± 1.54e+03 | 2.73e+04 ± 4.66e+02 | 2.83e+03 ± 3.41e+02 |  9.82x ± 1.25x  |
| 11 | cosine | 1536 | float16  | 1.05e-05 ± 9.99e-06 | 4.97e-05 ± 4.33e-05 | 2.56e+04 ± 2.02e+03 | 5.44e+04 ± 1.77e+03 | 1.48e+03 ± 1.78e+02 | 37.23x ± 4.42x  |
| 12 | cosine | 1536 | float32  | 2.37e-09 ± 1.88e-09 | 1.79e-05 ± 1.69e-05 | 2.49e+04 ± 1.29e+03 | 2.63e+04 ± 5.41e+03 | 1.56e+03 ± 3.41e+02 | 17.46x ± 3.77x  |
| 13 | cosine | 1536 | float64  | 0.00e+00 ± 0.00e+00 | 2.54e-05 ± 2.32e-05 | 2.51e+04 ± 2.21e+03 | 1.87e+04 ± 2.87e+02 | 2.39e+03 ± 6.24e+02 |  8.25x ± 1.68x  |
| 14 | cosine | 1536 |   int8   | 0.00e+00 ± 0.00e+00 | 3.06e-05 ± 3.12e-05 | 1.91e+04 ± 1.14e+03 | 2.18e+04 ± 1.17e+03 | 1.72e+03 ± 2.66e+02 | 13.00x ± 2.13x  |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+

With 1 Iteration

+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
|    | Metric | NDim |  DType   |   Baseline Error    |    SimSIMD Error    |  Accurate Duration  |  Baseline Duration  |  SimSIMD Duration   | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0  | cosine |  11  | bfloat16 | 3.04e-08 ± 2.53e-08 | 3.63e-09 ± 6.75e-09 | 1.24e+04 ± 8.90e+02 | 7.19e+03 ± 4.48e+02 | 2.75e+03 ± 7.66e+02 |  2.71x ± 0.37x  |
| 1  | cosine |  11  | float16  | 2.61e-04 ± 2.45e-04 | 2.12e-04 ± 3.90e-04 | 1.24e+04 ± 9.59e+02 | 9.28e+03 ± 1.72e+03 | 1.27e+03 ± 5.28e+02 |  7.65x ± 1.19x  |
| 2  | cosine |  11  | float32  | 2.91e-08 ± 1.81e-08 | 1.20e-08 ± 1.26e-08 | 1.36e+04 ± 4.00e+03 | 8.32e+03 ± 2.87e+03 | 1.09e+03 ± 1.63e+02 |  7.55x ± 1.54x  |
| 3  | cosine |  11  | float64  | 0.00e+00 ± 0.00e+00 | 3.35e-10 ± 4.33e-10 | 1.35e+04 ± 7.24e+03 | 6.02e+03 ± 8.72e+02 | 1.58e+03 ± 1.18e+03 |  4.45x ± 1.08x  |
| 4  | cosine |  11  |   int8   | 0.00e+00 ± 0.00e+00 | 2.81e-03 ± 1.80e-02 | 9.17e+03 ± 4.17e+03 | 8.02e+03 ± 1.51e+03 | 1.76e+03 ± 2.00e+02 |  4.56x ± 0.81x  |
| 5  | cosine |  97  | bfloat16 | 2.02e-08 ± 1.25e-08 | 3.44e-09 ± 4.63e-09 | 1.34e+04 ± 3.38e+03 | 7.79e+03 ± 2.84e+03 | 2.55e+03 ± 1.09e+02 |  3.05x ± 1.07x  |
| 6  | cosine |  97  | float16  | 1.97e-04 ± 1.18e-04 | 5.37e-05 ± 4.52e-05 | 1.26e+04 ± 1.11e+03 | 1.06e+04 ± 2.95e+03 | 1.19e+03 ± 1.46e+02 |  8.90x ± 1.93x  |
| 7  | cosine |  97  | float32  | 2.39e-08 ± 1.36e-08 | 5.66e-09 ± 4.83e-09 | 1.31e+04 ± 3.11e+03 | 7.78e+03 ± 1.25e+03 | 1.26e+03 ± 8.08e+02 |  6.92x ± 1.63x  |
| 8  | cosine |  97  | float64  | 0.00e+00 ± 0.00e+00 | 6.84e-11 ± 1.10e-10 | 1.25e+04 ± 1.21e+03 | 6.51e+03 ± 1.69e+03 | 1.37e+03 ± 3.63e+02 |  4.89x ± 1.28x  |
| 9  | cosine |  97  |   int8   | 0.00e+00 ± 0.00e+00 | 1.93e-03 ± 4.20e-03 | 8.37e+03 ± 1.87e+03 | 7.89e+03 ± 7.80e+02 | 2.02e+03 ± 1.66e+03 |  4.34x ± 0.69x  |
| 10 | cosine | 1536 | bfloat16 | 2.25e-08 ± 1.61e-08 | 3.53e-09 ± 2.70e-09 | 1.52e+04 ± 2.81e+03 | 8.28e+03 ± 5.26e+02 | 3.07e+03 ± 1.32e+02 |  2.70x ± 0.20x  |
| 11 | cosine | 1536 | float16  | 2.00e-02 ± 1.76e-02 | 2.02e-05 ± 1.39e-05 | 1.43e+04 ± 2.37e+03 | 2.74e+04 ± 3.46e+03 | 1.38e+03 ± 1.25e+02 | 19.98x ± 2.35x  |
| 12 | cosine | 1536 | float32  | 2.24e-08 ± 1.40e-08 | 3.77e-09 ± 2.84e-09 | 1.36e+04 ± 2.64e+03 | 8.64e+03 ± 8.04e+02 | 1.23e+03 ± 8.10e+01 |  7.06x ± 0.72x  |
| 13 | cosine | 1536 | float64  | 0.00e+00 ± 0.00e+00 | 1.35e-11 ± 1.85e-11 | 1.34e+04 ± 1.27e+03 | 7.31e+03 ± 8.12e+02 | 1.98e+03 ± 2.02e+03 |  4.49x ± 1.01x  |
| 14 | cosine | 1536 |   int8   | 0.00e+00 ± 0.00e+00 | 4.20e-04 ± 4.88e-04 | 9.47e+03 ± 2.09e+03 | 1.01e+04 ± 1.11e+03 | 1.95e+03 ± 1.04e+02 |  5.19x ± 0.56x  |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+

Arm: AWS Graviton 3

Baseline

+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
|    | Metric | NDim |  DType   |   Baseline Error    |    SimSIMD Error    |  Accurate Duration  |  Baseline Duration  |  SimSIMD Duration   | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0  | cosine |  11  | bfloat16 | 1.15e-08 ± 1.76e-08 | 3.27e-08 ± 2.04e-08 | 2.15e+04 ± 9.10e+02 | 2.34e+04 ± 7.89e+02 | 2.68e+03 ± 2.09e+03 |  9.91x ± 2.11x  |
| 1  | cosine |  11  | float16  | 1.36e-04 ± 2.03e-04 | 1.30e-04 ± 1.46e-04 | 2.12e+04 ± 1.69e+03 | 2.57e+04 ± 3.00e+03 | 9.60e+02 ± 7.11e+01 | 26.76x ± 2.32x  |
| 2  | cosine |  11  | float32  | 1.87e-08 ± 1.99e-08 | 3.84e-04 ± 4.15e-04 | 2.08e+04 ± 1.68e+03 | 2.35e+04 ± 3.43e+03 | 8.79e+02 ± 8.79e+01 | 26.85x ± 2.98x  |
| 3  | cosine |  11  | float64  | 0.00e+00 ± 0.00e+00 | 6.10e-04 ± 1.27e-03 | 2.50e+04 ± 1.20e+04 | 1.55e+04 ± 1.45e+03 | 1.24e+03 ± 7.82e+02 | 13.99x ± 2.64x  |
| 4  | cosine |  11  |   int8   | 0.00e+00 ± 0.00e+00 | 2.37e-08 ± 1.57e-08 | 1.59e+04 ± 7.21e+02 | 1.64e+04 ± 3.14e+03 | 1.48e+03 ± 2.97e+02 | 11.38x ± 2.44x  |
| 5  | cosine |  97  | bfloat16 | 5.98e-09 ± 6.36e-09 | 2.19e-08 ± 1.39e-08 | 2.14e+04 ± 7.54e+02 | 2.35e+04 ± 1.15e+03 | 2.31e+03 ± 3.70e+02 | 10.45x ± 2.11x  |
| 6  | cosine |  97  | float16  | 3.40e-05 ± 2.87e-05 | 5.63e-05 ± 4.57e-05 | 2.13e+04 ± 2.05e+03 | 2.67e+04 ± 1.57e+03 | 9.43e+02 ± 6.95e+01 | 28.48x ± 2.48x  |
| 7  | cosine |  97  | float32  | 9.55e-09 ± 7.13e-09 | 9.71e-05 ± 1.50e-04 | 2.06e+04 ± 1.38e+03 | 2.27e+04 ± 1.03e+03 | 8.77e+02 ± 6.84e+01 | 26.02x ± 2.05x  |
| 8  | cosine |  97  | float64  | 0.00e+00 ± 0.00e+00 | 1.31e-04 ± 1.89e-04 | 2.07e+04 ± 2.22e+03 | 1.53e+04 ± 1.25e+03 | 1.06e+03 ± 1.08e+02 | 14.52x ± 1.51x  |
| 9  | cosine |  97  |   int8   | 0.00e+00 ± 0.00e+00 | 2.06e-08 ± 1.53e-08 | 1.59e+04 ± 2.18e+03 | 1.58e+04 ± 1.79e+02 | 1.37e+03 ± 2.12e+02 | 11.81x ± 1.86x  |
| 10 | cosine | 1536 | bfloat16 | 1.76e-09 ± 1.55e-09 | 2.07e-08 ± 1.44e-08 | 2.84e+04 ± 1.25e+03 | 2.77e+04 ± 7.20e+02 | 3.20e+03 ± 3.70e+02 |  8.80x ± 1.12x  |
| 11 | cosine | 1536 | float16  | 8.31e-06 ± 7.39e-06 | 4.23e-05 ± 3.41e-05 | 2.50e+04 ± 1.64e+03 | 5.42e+04 ± 2.20e+03 | 1.22e+03 ± 1.41e+02 | 44.85x ± 4.00x  |
| 12 | cosine | 1536 | float32  | 2.64e-09 ± 1.97e-09 | 2.61e-05 ± 3.13e-05 | 2.44e+04 ± 3.00e+03 | 2.57e+04 ± 1.63e+03 | 1.25e+03 ± 1.69e+02 | 20.80x ± 1.94x  |
| 13 | cosine | 1536 | float64  | 0.00e+00 ± 0.00e+00 | 1.59e-05 ± 1.54e-05 | 2.47e+04 ± 1.90e+03 | 1.90e+04 ± 1.83e+03 | 1.90e+03 ± 4.16e+02 | 10.32x ± 1.99x  |
| 14 | cosine | 1536 |   int8   | 0.00e+00 ± 0.00e+00 | 2.04e-08 ± 1.39e-08 | 1.90e+04 ± 1.38e+03 | 2.15e+04 ± 3.39e+02 | 1.48e+03 ± 2.51e+02 | 14.91x ± 2.54x  |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+

With 2 Iterations

+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
|    | Metric | NDim |  DType   |   Baseline Error    |    SimSIMD Error    |  Accurate Duration  |  Baseline Duration  |  SimSIMD Duration   | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0  | cosine |  11  | bfloat16 | 1.54e-08 ± 2.76e-08 | 2.94e-08 ± 2.62e-08 | 2.09e+04 ± 1.18e+03 | 2.37e+04 ± 2.20e+03 | 2.16e+03 ± 4.40e+02 | 11.41x ± 2.71x  |
| 1  | cosine |  11  | float16  | 1.32e-04 ± 1.43e-04 | 1.96e-04 ± 2.77e-04 | 2.19e+04 ± 1.90e+03 | 2.59e+04 ± 4.38e+03 | 9.59e+02 ± 9.05e+01 | 27.03x ± 3.00x  |
| 2  | cosine |  11  | float32  | 3.44e-08 ± 4.95e-08 | 2.11e-08 ± 2.49e-08 | 2.11e+04 ± 1.36e+03 | 2.37e+04 ± 4.08e+03 | 8.57e+02 ± 8.07e+01 | 27.65x ± 3.71x  |
| 3  | cosine |  11  | float64  | 0.00e+00 ± 0.00e+00 | 8.65e-12 ± 1.39e-11 | 2.52e+04 ± 1.22e+04 | 1.56e+04 ± 1.36e+03 | 1.32e+03 ± 7.76e+02 | 13.13x ± 2.61x  |
| 4  | cosine |  11  |   int8   | 0.00e+00 ± 0.00e+00 | 3.03e-08 ± 3.66e-08 | 1.61e+04 ± 1.03e+03 | 1.60e+04 ± 5.69e+02 | 1.58e+03 ± 3.06e+02 | 10.39x ± 1.62x  |
| 5  | cosine |  97  | bfloat16 | 5.22e-09 ± 4.67e-09 | 2.43e-08 ± 1.48e-08 | 2.12e+04 ± 8.81e+02 | 2.38e+04 ± 1.98e+03 | 2.13e+03 ± 4.24e+02 | 11.58x ± 2.17x  |
| 6  | cosine |  97  | float16  | 3.17e-05 ± 3.81e-05 | 6.11e-05 ± 5.12e-05 | 2.15e+04 ± 1.56e+03 | 2.70e+04 ± 2.32e+03 | 9.84e+02 ± 9.83e+01 | 27.66x ± 3.59x  |
| 7  | cosine |  97  | float32  | 7.65e-09 ± 6.03e-09 | 8.76e-09 ± 5.92e-09 | 2.14e+04 ± 1.90e+03 | 2.31e+04 ± 1.93e+03 | 9.10e+02 ± 8.64e+01 | 25.54x ± 3.07x  |
| 8  | cosine |  97  | float64  | 0.00e+00 ± 0.00e+00 | 1.48e-12 ± 2.76e-12 | 2.11e+04 ± 1.81e+03 | 1.53e+04 ± 6.54e+02 | 1.15e+03 ± 1.13e+02 | 13.34x ± 1.24x  |
| 9  | cosine |  97  |   int8   | 0.00e+00 ± 0.00e+00 | 2.29e-08 ± 1.49e-08 | 1.60e+04 ± 2.33e+03 | 1.61e+04 ± 2.06e+03 | 1.41e+03 ± 2.06e+02 | 11.64x ± 1.95x  |
| 10 | cosine | 1536 | bfloat16 | 2.04e-09 ± 1.61e-09 | 2.09e-08 ± 1.50e-08 | 2.84e+04 ± 1.13e+03 | 2.81e+04 ± 1.77e+03 | 2.98e+03 ± 4.43e+02 |  9.62x ± 1.47x  |
| 11 | cosine | 1536 | float16  | 8.23e-06 ± 8.19e-06 | 4.81e-05 ± 3.38e-05 | 2.57e+04 ± 2.31e+03 | 5.45e+04 ± 2.38e+03 | 1.23e+03 ± 1.69e+02 | 44.93x ± 5.18x  |
| 12 | cosine | 1536 | float32  | 2.41e-09 ± 1.51e-09 | 9.02e-09 ± 7.16e-09 | 2.53e+04 ± 3.03e+03 | 2.59e+04 ± 1.42e+03 | 1.46e+03 ± 2.97e+02 | 18.45x ± 3.57x  |
| 13 | cosine | 1536 | float64  | 0.00e+00 ± 0.00e+00 | 2.23e-13 ± 4.67e-13 | 2.57e+04 ± 4.05e+03 | 1.87e+04 ± 9.64e+02 | 2.27e+03 ± 6.14e+02 |  8.75x ± 1.99x  |
| 14 | cosine | 1536 |   int8   | 0.00e+00 ± 0.00e+00 | 2.32e-08 ± 1.39e-08 | 1.92e+04 ± 2.52e+03 | 2.17e+04 ± 4.07e+02 | 1.48e+03 ± 2.82e+02 | 15.13x ± 2.60x  |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+