Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Data/Nat): faster computation of Nat.log #17325

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
113 changes: 113 additions & 0 deletions Mathlib/Data/Nat/Log.lean
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,117 @@ theorem log_le_clog (b n : ℕ) : log b n ≤ clog b n := by
exact (Nat.pow_le_pow_iff_right hb).1
((pow_log_le_self b n.succ_ne_zero).trans <| le_pow_clog hb _)

/-! ### Computating the logarithm efficiently -/
section computation

private lemma logC_aux {m b : ℕ} (hb : 1 < b) (hbm : b ≤ m) : m / (b * b) < m / b := by
have hb' : 0 < b := zero_lt_of_lt hb
rw [div_lt_iff_lt_mul (Nat.mul_pos hb' hb'), ← Nat.mul_assoc, ← div_lt_iff_lt_mul hb']
exact (Nat.lt_mul_iff_one_lt_right (Nat.div_pos hbm hb')).2 hb

-- This option is necessary because of lean4#2920
set_option linter.unusedVariables false in
/--
An alternate definition for `Nat.log` which computes more efficiently. For mathematical purposes,
use `Nat.log` instead, and see `Nat.log_eq_logC`.

Note a tail-recursive version of `Nat.log` is also possible:
```
def logTR (b n : ℕ) : ℕ :=
let rec go : ℕ → ℕ → ℕ | n, acc => if h : b ≤ n ∧ 1 < b then go (n / b) (acc + 1) else acc
decreasing_by
have : n / b < n := Nat.div_lt_self (by omega) h.2
decreasing_trivial
go n 0
```
but performs worse for large numbers than `Nat.logC`:
```
#eval Nat.logTR 2 (2 ^ 1000000)
#eval Nat.logC 2 (2 ^ 1000000)
Comment on lines +358 to +359
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kim-em, what's the SOTA for committing benchmarks / performance tests like these? It would be great to have a test file that computes a big logarithm, and even better if we can assert it is "quick".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I read it again, the point of this was that Nat.logTR times out on this computation whereas Nat.logC doesn't, and Nat.log hits stack limit at something like Nat.log 2 (2 ^ 11000).
I wouldn't mind removing these lines either, it's largely to save a future person's time if they think "I wonder if tail-recursion would be better than this weird doubling algorithm"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment really was "let's add a test somewhere of something that didn't work before your change"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to just create tests/nat/log.lean, if these are meant to be regression tests.

If they are documentation, they should be code blocks inside a module doc next to the definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference here. What's the distinction between having a code block as a module doc as opposed to a code block in the definition docstring?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in particular that these aren't meant as regression tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to just create tests/nat/log.lean, if these are meant to be regression tests.

We should add this file, with an invocation that took >20s before and now takes <1s.

```

The definition `Nat.logC` is not tail-recursive, however, but the stack limit will only be reached
if the output size is around 2^10000, meaning the input will be around 2^(2^10000), which will
take far too long to compute in the first place.

Adapted from https://downloads.haskell.org/~ghc/9.0.1/docs/html/libraries/ghc-bignum-1.0/GHC-Num-BigNat.html#v:bigNatLogBase-35-
-/
@[pp_nodot] def logC (b m : ℕ) : ℕ :=
if h : 1 < b then let (_, e) := step b h; e else 0 where
Copy link
Member

@urkud urkud Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional (I didn't test if it results in the same IR):

Suggested change
if h : 1 < b then let (_, e) := step b h; e else 0 where
if h : 1 < b then (step b h).2 else 0 where

/--
An auxiliary definition for `Nat.logC`, where the base of the logarithm is _squared_ in each
loop. This allows significantly faster computation of the logarithm: it takes logarithmic time
in the size of the output, rather than linear time.
-/
step (pw : ℕ) (hpw : 1 < pw) : ℕ × ℕ :=
if h : m < pw then (m, 0)
else
let (q, e) := step (pw * pw) (Nat.mul_lt_mul_of_lt_of_lt hpw hpw)
if q < pw then (q, 2 * e) else (q / pw, 2 * e + 1)
termination_by m / pw
decreasing_by
have : m / (pw * pw) < m / pw := logC_aux hpw (le_of_not_lt h)
decreasing_trivial

private lemma logC_step {m pw q e : ℕ} (hpw : 1 < pw) (hqe : logC.step m pw hpw = (q, e)) :
pw ^ e * q ≤ m ∧ q < pw ∧ (m < pw ^ e * (q + 1)) ∧ (0 < m → 0 < q) := by
induction pw, hpw using logC.step.induct m generalizing q e with
| case1 pw hpw hmpw =>
rw [logC.step, dif_pos hmpw] at hqe
cases hqe
simpa
| case2 pw hpw hmpw q' e' hqe' hqpw ih =>
simp only [logC.step, dif_neg hmpw, hqe', if_pos hqpw] at hqe
cases hqe
rw [Nat.pow_mul, Nat.pow_two]
exact ⟨(ih hqe').1, hqpw, (ih hqe').2.2⟩
| case3 pw hpw hmpw q' e' hqe' hqpw ih =>
simp only [Nat.logC.step, dif_neg hmpw, hqe', if_neg hqpw] at hqe
cases hqe
rw [Nat.pow_succ, Nat.mul_assoc, Nat.pow_mul, Nat.pow_two, Nat.mul_assoc]
refine ⟨(ih hqe').1.trans' (Nat.mul_le_mul_left _ (Nat.mul_div_le _ _)),
Nat.div_lt_of_lt_mul (ih hqe').2.1, (ih hqe').2.2.1.trans_le ?_,
fun _ => Nat.div_pos (le_of_not_lt hqpw) (by omega)⟩
exact Nat.mul_le_mul_left _ (Nat.lt_mul_div_succ _ (zero_lt_of_lt hpw))

private lemma logC_spec {b m : ℕ} (hb : 1 < b) (hm : 0 < m) :
b ^ logC b m ≤ m ∧ m < b ^ (logC b m + 1) := by
rw [logC, dif_pos hb]
split
next q e heq =>
obtain ⟨h₁, h₂, h₃, h₄⟩ := logC_step hb heq
exact ⟨h₁.trans' (Nat.le_mul_of_pos_right _ (h₄ hm)), h₃.trans_le (Nat.mul_le_mul_left _ h₂)⟩

private lemma logC_of_left_le_one {b m : ℕ} (hb : b ≤ 1) : logC b m = 0 := by
rw [logC, dif_neg hb.not_lt]

private lemma logC_zero {b : ℕ} :
logC b 0 = 0 := by
rcases le_or_lt b 1 with hb | hb
case inl => exact logC_small_base hb
b-mehta marked this conversation as resolved.
Show resolved Hide resolved
case inr =>
rw [logC, dif_pos hb]
split
next q e heq =>
rw [logC.step, dif_pos (zero_lt_of_lt hb)] at heq
rw [(Prod.mk.inj heq).2]

/--
The result of `Nat.logC` agrees with the result of `Nat.log`. The former will be computed more
efficiently, but the latter is easier to prove things about and has more lemmas.
b-mehta marked this conversation as resolved.
Show resolved Hide resolved
This lemma is tagged @[csimp] so that the code generated for `Nat.log` uses `Nat.logC` instead.
-/
@[csimp] theorem log_eq_logC : log = logC := by
ext b m
rcases le_or_lt b 1 with hb | hb
case inl => rw [logC_small_base hb, Nat.log_of_left_le_one hb]
b-mehta marked this conversation as resolved.
Show resolved Hide resolved
case inr =>
rcases eq_or_ne m 0 with rfl | hm
case inl => rw [Nat.log_zero_right, logC_zero]
case inr =>
rw [Nat.log_eq_iff (Or.inr ⟨hb, hm⟩)]
exact logC_spec hb (zero_lt_of_ne_zero hm)

end computation

end Nat
Loading