v5.4: 100x – 10'000x More Accurate Cosine Distance
The cosine similarity is the most common and straightforward metric used in machine learning and information retrieval. Interestingly, there are multiple ways to shoot yourself in the foot when computing it. The cosine similarity is the inverse of the cosine distance, which is the cosine of the angle between two vectors.
In NumPy terms, SimSIMD implementation is similar to:
import numpy as np
def cos_numpy(a: np.ndarray, b: np.ndarray) -> float:
ab, a2, b2 = np.dot(a, b), np.dot(a, a), np.dot(b, b) # Fused in SimSIMD
if a2 == 0 and b2 == 0: result = 0 # Same in SciPy
elif ab == 0: result = 1 # Division by zero error in SciPy
else: result = 1 - ab / (sqrt(a2) * sqrt(b2)) # Bigger rounding error in SciPy
return result
In SciPy, however, the cosine distance is computed as 1 - ab / np.sqrt(a2 * b2)
. It handles the edge case of a zero and non-zero argument pair differently, resulting in a division by zero error. It's not only less efficient, but also less accurate, given how the reciprocal square roots are computed. The C standard library provides the sqrt
function, which is generally very accurate, but slow. The rsqrt
in-hardware implementations are faster, but have different accuracy characteristics.
- SSE
rsqrtps
and AVXvrsqrtps
:$1.5 \times 2^{-12}$ maximal error. - AVX-512
vrsqrt14pd
instruction:$2^{-14}$ maximal error. - NEON
frsqrte
instruction has no clear error bounds.
To overcome the limitations of the rsqrt
instruction, SimSIMD uses the Newton-Raphson iteration to refine the initial estimate for high-precision floating-point numbers. It can be defined as:
On 1536-dimensional inputs on Intel Sapphire Rapids CPU a single such iteration can result in a 2-3 orders of magnitude relative error reduction:
Datatype | NumPy Error | SimSIMD w/out Iteration | SimSIMD |
---|---|---|---|
bfloat16 |
1.89e-08 ± 1.59e-08 | 3.07e-07 ± 3.09e-07 | 3.53e-09 ± 2.70e-09 |
float16 |
1.67e-02 ± 1.44e-02 | 2.68e-05 ± 1.95e-05 | 2.02e-05 ± 1.39e-05 |
float32 |
2.21e-08 ± 1.65e-08 | 3.47e-07 ± 3.49e-07 | 3.77e-09 ± 2.84e-09 |
float64 |
0.00e+00 ± 0.00e+00 | 3.80e-07 ± 4.50e-07 | 1.35e-11 ± 1.85e-11 |
On Arm:
Datatype | NumPy Error | SimSIMD w/out Iteration | SimSIMD |
---|---|---|---|
bfloat16 |
1.55e-09 ± 1.27e-09 | 2.79e-05 ± 3.60e-05 | 2.09e-08 ± 1.50e-08 |
float16 |
1.05e-05 ± 9.99e-06 | 4.97e-05 ± 4.33e-05 | 4.81e-05 ± 3.38e-05 |
float32 |
2.37e-09 ± 1.88e-09 | 1.79e-05 ± 1.69e-05 | 9.02e-09 ± 7.16e-09 |
float64 |
0.00e+00 ± 0.00e+00 | 2.54e-05 ± 2.32e-05 | 2.23e-13 ± 4.67e-13 |
Benchmarks
x86: Intel Sapphire Rapids
Baseline
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| | Metric | NDim | DType | Baseline Error | SimSIMD Error | Accurate Duration | Baseline Duration | SimSIMD Duration | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0 | cosine | 11 | bfloat16 | 9.86e-09 ± 1.58e-08 | 3.35e-04 ± 4.42e-04 | 2.16e+04 ± 1.18e+03 | 2.42e+04 ± 2.79e+03 | 2.51e+03 ± 4.17e+02 | 9.90x ± 2.24x |
| 1 | cosine | 11 | float16 | 1.46e-04 ± 1.83e-04 | 5.09e-04 ± 7.05e-04 | 2.16e+04 ± 1.27e+03 | 2.53e+04 ± 2.54e+03 | 1.15e+03 ± 9.31e+01 | 22.17x ± 1.76x |
| 2 | cosine | 11 | float32 | 2.13e-08 ± 2.20e-08 | 2.69e-04 ± 4.08e-04 | 2.14e+04 ± 1.51e+03 | 2.37e+04 ± 3.52e+03 | 1.96e+03 ± 6.73e+03 | 23.21x ± 3.90x |
| 3 | cosine | 11 | float64 | 0.00e+00 ± 0.00e+00 | 4.51e-04 ± 5.78e-04 | 2.57e+04 ± 1.16e+04 | 1.57e+04 ± 1.55e+03 | 1.51e+03 ± 9.03e+02 | 11.57x ± 2.21x |
| 4 | cosine | 11 | int8 | 0.00e+00 ± 0.00e+00 | 4.56e-04 ± 5.30e-04 | 1.59e+04 ± 6.32e+02 | 1.60e+04 ± 5.11e+02 | 1.72e+03 ± 6.12e+02 | 9.89x ± 1.86x |
| 5 | cosine | 97 | bfloat16 | 6.71e-09 ± 7.90e-09 | 1.31e-04 ± 1.47e-04 | 2.14e+04 ± 9.71e+02 | 2.36e+04 ± 4.33e+02 | 2.47e+03 ± 3.95e+02 | 9.82x ± 1.71x |
| 6 | cosine | 97 | float16 | 3.00e-05 ± 2.42e-05 | 1.00e-04 ± 7.79e-05 | 2.15e+04 ± 1.70e+03 | 2.70e+04 ± 2.02e+03 | 1.18e+03 ± 8.51e+01 | 22.89x ± 2.06x |
| 7 | cosine | 97 | float32 | 6.84e-09 ± 5.72e-09 | 1.13e-04 ± 1.19e-04 | 2.19e+04 ± 1.84e+03 | 2.33e+04 ± 1.91e+03 | 1.04e+03 ± 9.38e+01 | 22.44x ± 2.38x |
| 8 | cosine | 97 | float64 | 0.00e+00 ± 0.00e+00 | 9.69e-05 ± 1.54e-04 | 2.13e+04 ± 2.00e+03 | 1.54e+04 ± 1.39e+03 | 1.30e+03 ± 1.20e+02 | 11.92x ± 1.47x |
| 9 | cosine | 97 | int8 | 0.00e+00 ± 0.00e+00 | 1.14e-04 ± 1.33e-04 | 1.56e+04 ± 4.34e+02 | 1.60e+04 ± 3.64e+02 | 1.57e+03 ± 2.48e+02 | 10.43x ± 1.55x |
| 10 | cosine | 1536 | bfloat16 | 1.55e-09 ± 1.27e-09 | 2.79e-05 ± 3.60e-05 | 2.78e+04 ± 1.54e+03 | 2.73e+04 ± 4.66e+02 | 2.83e+03 ± 3.41e+02 | 9.82x ± 1.25x |
| 11 | cosine | 1536 | float16 | 1.05e-05 ± 9.99e-06 | 4.97e-05 ± 4.33e-05 | 2.56e+04 ± 2.02e+03 | 5.44e+04 ± 1.77e+03 | 1.48e+03 ± 1.78e+02 | 37.23x ± 4.42x |
| 12 | cosine | 1536 | float32 | 2.37e-09 ± 1.88e-09 | 1.79e-05 ± 1.69e-05 | 2.49e+04 ± 1.29e+03 | 2.63e+04 ± 5.41e+03 | 1.56e+03 ± 3.41e+02 | 17.46x ± 3.77x |
| 13 | cosine | 1536 | float64 | 0.00e+00 ± 0.00e+00 | 2.54e-05 ± 2.32e-05 | 2.51e+04 ± 2.21e+03 | 1.87e+04 ± 2.87e+02 | 2.39e+03 ± 6.24e+02 | 8.25x ± 1.68x |
| 14 | cosine | 1536 | int8 | 0.00e+00 ± 0.00e+00 | 3.06e-05 ± 3.12e-05 | 1.91e+04 ± 1.14e+03 | 2.18e+04 ± 1.17e+03 | 1.72e+03 ± 2.66e+02 | 13.00x ± 2.13x |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
With 1 Iteration
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| | Metric | NDim | DType | Baseline Error | SimSIMD Error | Accurate Duration | Baseline Duration | SimSIMD Duration | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0 | cosine | 11 | bfloat16 | 3.04e-08 ± 2.53e-08 | 3.63e-09 ± 6.75e-09 | 1.24e+04 ± 8.90e+02 | 7.19e+03 ± 4.48e+02 | 2.75e+03 ± 7.66e+02 | 2.71x ± 0.37x |
| 1 | cosine | 11 | float16 | 2.61e-04 ± 2.45e-04 | 2.12e-04 ± 3.90e-04 | 1.24e+04 ± 9.59e+02 | 9.28e+03 ± 1.72e+03 | 1.27e+03 ± 5.28e+02 | 7.65x ± 1.19x |
| 2 | cosine | 11 | float32 | 2.91e-08 ± 1.81e-08 | 1.20e-08 ± 1.26e-08 | 1.36e+04 ± 4.00e+03 | 8.32e+03 ± 2.87e+03 | 1.09e+03 ± 1.63e+02 | 7.55x ± 1.54x |
| 3 | cosine | 11 | float64 | 0.00e+00 ± 0.00e+00 | 3.35e-10 ± 4.33e-10 | 1.35e+04 ± 7.24e+03 | 6.02e+03 ± 8.72e+02 | 1.58e+03 ± 1.18e+03 | 4.45x ± 1.08x |
| 4 | cosine | 11 | int8 | 0.00e+00 ± 0.00e+00 | 2.81e-03 ± 1.80e-02 | 9.17e+03 ± 4.17e+03 | 8.02e+03 ± 1.51e+03 | 1.76e+03 ± 2.00e+02 | 4.56x ± 0.81x |
| 5 | cosine | 97 | bfloat16 | 2.02e-08 ± 1.25e-08 | 3.44e-09 ± 4.63e-09 | 1.34e+04 ± 3.38e+03 | 7.79e+03 ± 2.84e+03 | 2.55e+03 ± 1.09e+02 | 3.05x ± 1.07x |
| 6 | cosine | 97 | float16 | 1.97e-04 ± 1.18e-04 | 5.37e-05 ± 4.52e-05 | 1.26e+04 ± 1.11e+03 | 1.06e+04 ± 2.95e+03 | 1.19e+03 ± 1.46e+02 | 8.90x ± 1.93x |
| 7 | cosine | 97 | float32 | 2.39e-08 ± 1.36e-08 | 5.66e-09 ± 4.83e-09 | 1.31e+04 ± 3.11e+03 | 7.78e+03 ± 1.25e+03 | 1.26e+03 ± 8.08e+02 | 6.92x ± 1.63x |
| 8 | cosine | 97 | float64 | 0.00e+00 ± 0.00e+00 | 6.84e-11 ± 1.10e-10 | 1.25e+04 ± 1.21e+03 | 6.51e+03 ± 1.69e+03 | 1.37e+03 ± 3.63e+02 | 4.89x ± 1.28x |
| 9 | cosine | 97 | int8 | 0.00e+00 ± 0.00e+00 | 1.93e-03 ± 4.20e-03 | 8.37e+03 ± 1.87e+03 | 7.89e+03 ± 7.80e+02 | 2.02e+03 ± 1.66e+03 | 4.34x ± 0.69x |
| 10 | cosine | 1536 | bfloat16 | 2.25e-08 ± 1.61e-08 | 3.53e-09 ± 2.70e-09 | 1.52e+04 ± 2.81e+03 | 8.28e+03 ± 5.26e+02 | 3.07e+03 ± 1.32e+02 | 2.70x ± 0.20x |
| 11 | cosine | 1536 | float16 | 2.00e-02 ± 1.76e-02 | 2.02e-05 ± 1.39e-05 | 1.43e+04 ± 2.37e+03 | 2.74e+04 ± 3.46e+03 | 1.38e+03 ± 1.25e+02 | 19.98x ± 2.35x |
| 12 | cosine | 1536 | float32 | 2.24e-08 ± 1.40e-08 | 3.77e-09 ± 2.84e-09 | 1.36e+04 ± 2.64e+03 | 8.64e+03 ± 8.04e+02 | 1.23e+03 ± 8.10e+01 | 7.06x ± 0.72x |
| 13 | cosine | 1536 | float64 | 0.00e+00 ± 0.00e+00 | 1.35e-11 ± 1.85e-11 | 1.34e+04 ± 1.27e+03 | 7.31e+03 ± 8.12e+02 | 1.98e+03 ± 2.02e+03 | 4.49x ± 1.01x |
| 14 | cosine | 1536 | int8 | 0.00e+00 ± 0.00e+00 | 4.20e-04 ± 4.88e-04 | 9.47e+03 ± 2.09e+03 | 1.01e+04 ± 1.11e+03 | 1.95e+03 ± 1.04e+02 | 5.19x ± 0.56x |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
Arm: AWS Graviton 3
Baseline
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| | Metric | NDim | DType | Baseline Error | SimSIMD Error | Accurate Duration | Baseline Duration | SimSIMD Duration | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0 | cosine | 11 | bfloat16 | 1.15e-08 ± 1.76e-08 | 3.27e-08 ± 2.04e-08 | 2.15e+04 ± 9.10e+02 | 2.34e+04 ± 7.89e+02 | 2.68e+03 ± 2.09e+03 | 9.91x ± 2.11x |
| 1 | cosine | 11 | float16 | 1.36e-04 ± 2.03e-04 | 1.30e-04 ± 1.46e-04 | 2.12e+04 ± 1.69e+03 | 2.57e+04 ± 3.00e+03 | 9.60e+02 ± 7.11e+01 | 26.76x ± 2.32x |
| 2 | cosine | 11 | float32 | 1.87e-08 ± 1.99e-08 | 3.84e-04 ± 4.15e-04 | 2.08e+04 ± 1.68e+03 | 2.35e+04 ± 3.43e+03 | 8.79e+02 ± 8.79e+01 | 26.85x ± 2.98x |
| 3 | cosine | 11 | float64 | 0.00e+00 ± 0.00e+00 | 6.10e-04 ± 1.27e-03 | 2.50e+04 ± 1.20e+04 | 1.55e+04 ± 1.45e+03 | 1.24e+03 ± 7.82e+02 | 13.99x ± 2.64x |
| 4 | cosine | 11 | int8 | 0.00e+00 ± 0.00e+00 | 2.37e-08 ± 1.57e-08 | 1.59e+04 ± 7.21e+02 | 1.64e+04 ± 3.14e+03 | 1.48e+03 ± 2.97e+02 | 11.38x ± 2.44x |
| 5 | cosine | 97 | bfloat16 | 5.98e-09 ± 6.36e-09 | 2.19e-08 ± 1.39e-08 | 2.14e+04 ± 7.54e+02 | 2.35e+04 ± 1.15e+03 | 2.31e+03 ± 3.70e+02 | 10.45x ± 2.11x |
| 6 | cosine | 97 | float16 | 3.40e-05 ± 2.87e-05 | 5.63e-05 ± 4.57e-05 | 2.13e+04 ± 2.05e+03 | 2.67e+04 ± 1.57e+03 | 9.43e+02 ± 6.95e+01 | 28.48x ± 2.48x |
| 7 | cosine | 97 | float32 | 9.55e-09 ± 7.13e-09 | 9.71e-05 ± 1.50e-04 | 2.06e+04 ± 1.38e+03 | 2.27e+04 ± 1.03e+03 | 8.77e+02 ± 6.84e+01 | 26.02x ± 2.05x |
| 8 | cosine | 97 | float64 | 0.00e+00 ± 0.00e+00 | 1.31e-04 ± 1.89e-04 | 2.07e+04 ± 2.22e+03 | 1.53e+04 ± 1.25e+03 | 1.06e+03 ± 1.08e+02 | 14.52x ± 1.51x |
| 9 | cosine | 97 | int8 | 0.00e+00 ± 0.00e+00 | 2.06e-08 ± 1.53e-08 | 1.59e+04 ± 2.18e+03 | 1.58e+04 ± 1.79e+02 | 1.37e+03 ± 2.12e+02 | 11.81x ± 1.86x |
| 10 | cosine | 1536 | bfloat16 | 1.76e-09 ± 1.55e-09 | 2.07e-08 ± 1.44e-08 | 2.84e+04 ± 1.25e+03 | 2.77e+04 ± 7.20e+02 | 3.20e+03 ± 3.70e+02 | 8.80x ± 1.12x |
| 11 | cosine | 1536 | float16 | 8.31e-06 ± 7.39e-06 | 4.23e-05 ± 3.41e-05 | 2.50e+04 ± 1.64e+03 | 5.42e+04 ± 2.20e+03 | 1.22e+03 ± 1.41e+02 | 44.85x ± 4.00x |
| 12 | cosine | 1536 | float32 | 2.64e-09 ± 1.97e-09 | 2.61e-05 ± 3.13e-05 | 2.44e+04 ± 3.00e+03 | 2.57e+04 ± 1.63e+03 | 1.25e+03 ± 1.69e+02 | 20.80x ± 1.94x |
| 13 | cosine | 1536 | float64 | 0.00e+00 ± 0.00e+00 | 1.59e-05 ± 1.54e-05 | 2.47e+04 ± 1.90e+03 | 1.90e+04 ± 1.83e+03 | 1.90e+03 ± 4.16e+02 | 10.32x ± 1.99x |
| 14 | cosine | 1536 | int8 | 0.00e+00 ± 0.00e+00 | 2.04e-08 ± 1.39e-08 | 1.90e+04 ± 1.38e+03 | 2.15e+04 ± 3.39e+02 | 1.48e+03 ± 2.51e+02 | 14.91x ± 2.54x |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
With 2 Iterations
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| | Metric | NDim | DType | Baseline Error | SimSIMD Error | Accurate Duration | Baseline Duration | SimSIMD Duration | SimSIMD Speedup |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+
| 0 | cosine | 11 | bfloat16 | 1.54e-08 ± 2.76e-08 | 2.94e-08 ± 2.62e-08 | 2.09e+04 ± 1.18e+03 | 2.37e+04 ± 2.20e+03 | 2.16e+03 ± 4.40e+02 | 11.41x ± 2.71x |
| 1 | cosine | 11 | float16 | 1.32e-04 ± 1.43e-04 | 1.96e-04 ± 2.77e-04 | 2.19e+04 ± 1.90e+03 | 2.59e+04 ± 4.38e+03 | 9.59e+02 ± 9.05e+01 | 27.03x ± 3.00x |
| 2 | cosine | 11 | float32 | 3.44e-08 ± 4.95e-08 | 2.11e-08 ± 2.49e-08 | 2.11e+04 ± 1.36e+03 | 2.37e+04 ± 4.08e+03 | 8.57e+02 ± 8.07e+01 | 27.65x ± 3.71x |
| 3 | cosine | 11 | float64 | 0.00e+00 ± 0.00e+00 | 8.65e-12 ± 1.39e-11 | 2.52e+04 ± 1.22e+04 | 1.56e+04 ± 1.36e+03 | 1.32e+03 ± 7.76e+02 | 13.13x ± 2.61x |
| 4 | cosine | 11 | int8 | 0.00e+00 ± 0.00e+00 | 3.03e-08 ± 3.66e-08 | 1.61e+04 ± 1.03e+03 | 1.60e+04 ± 5.69e+02 | 1.58e+03 ± 3.06e+02 | 10.39x ± 1.62x |
| 5 | cosine | 97 | bfloat16 | 5.22e-09 ± 4.67e-09 | 2.43e-08 ± 1.48e-08 | 2.12e+04 ± 8.81e+02 | 2.38e+04 ± 1.98e+03 | 2.13e+03 ± 4.24e+02 | 11.58x ± 2.17x |
| 6 | cosine | 97 | float16 | 3.17e-05 ± 3.81e-05 | 6.11e-05 ± 5.12e-05 | 2.15e+04 ± 1.56e+03 | 2.70e+04 ± 2.32e+03 | 9.84e+02 ± 9.83e+01 | 27.66x ± 3.59x |
| 7 | cosine | 97 | float32 | 7.65e-09 ± 6.03e-09 | 8.76e-09 ± 5.92e-09 | 2.14e+04 ± 1.90e+03 | 2.31e+04 ± 1.93e+03 | 9.10e+02 ± 8.64e+01 | 25.54x ± 3.07x |
| 8 | cosine | 97 | float64 | 0.00e+00 ± 0.00e+00 | 1.48e-12 ± 2.76e-12 | 2.11e+04 ± 1.81e+03 | 1.53e+04 ± 6.54e+02 | 1.15e+03 ± 1.13e+02 | 13.34x ± 1.24x |
| 9 | cosine | 97 | int8 | 0.00e+00 ± 0.00e+00 | 2.29e-08 ± 1.49e-08 | 1.60e+04 ± 2.33e+03 | 1.61e+04 ± 2.06e+03 | 1.41e+03 ± 2.06e+02 | 11.64x ± 1.95x |
| 10 | cosine | 1536 | bfloat16 | 2.04e-09 ± 1.61e-09 | 2.09e-08 ± 1.50e-08 | 2.84e+04 ± 1.13e+03 | 2.81e+04 ± 1.77e+03 | 2.98e+03 ± 4.43e+02 | 9.62x ± 1.47x |
| 11 | cosine | 1536 | float16 | 8.23e-06 ± 8.19e-06 | 4.81e-05 ± 3.38e-05 | 2.57e+04 ± 2.31e+03 | 5.45e+04 ± 2.38e+03 | 1.23e+03 ± 1.69e+02 | 44.93x ± 5.18x |
| 12 | cosine | 1536 | float32 | 2.41e-09 ± 1.51e-09 | 9.02e-09 ± 7.16e-09 | 2.53e+04 ± 3.03e+03 | 2.59e+04 ± 1.42e+03 | 1.46e+03 ± 2.97e+02 | 18.45x ± 3.57x |
| 13 | cosine | 1536 | float64 | 0.00e+00 ± 0.00e+00 | 2.23e-13 ± 4.67e-13 | 2.57e+04 ± 4.05e+03 | 1.87e+04 ± 9.64e+02 | 2.27e+03 ± 6.14e+02 | 8.75x ± 1.99x |
| 14 | cosine | 1536 | int8 | 0.00e+00 ± 0.00e+00 | 2.32e-08 ± 1.39e-08 | 1.92e+04 ± 2.52e+03 | 2.17e+04 ± 4.07e+02 | 1.48e+03 ± 2.82e+02 | 15.13x ± 2.60x |
+----+--------+------+----------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------+