Skip to content

Commit

Permalink
feat(Data/Nat): faster computation of Nat.log (#17325)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Wieser <[email protected]>
  • Loading branch information
b-mehta and eric-wieser committed Jan 21, 2025
1 parent c01816f commit 59b91d3
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
113 changes: 113 additions & 0 deletions Mathlib/Data/Nat/Log.lean
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,117 @@ theorem log_le_clog (b n : ℕ) : log b n ≤ clog b n := by
exact (Nat.pow_le_pow_iff_right hb).1
((pow_log_le_self b n.succ_ne_zero).trans <| le_pow_clog hb _)

/-! ### Computating the logarithm efficiently -/
section computation

private lemma logC_aux {m b : ℕ} (hb : 1 < b) (hbm : b ≤ m) : m / (b * b) < m / b := by
have hb' : 0 < b := zero_lt_of_lt hb
rw [div_lt_iff_lt_mul (Nat.mul_pos hb' hb'), ← Nat.mul_assoc, ← div_lt_iff_lt_mul hb']
exact (Nat.lt_mul_iff_one_lt_right (Nat.div_pos hbm hb')).2 hb

-- This option is necessary because of lean4#2920
set_option linter.unusedVariables false in
/--
An alternate definition for `Nat.log` which computes more efficiently. For mathematical purposes,
use `Nat.log` instead, and see `Nat.log_eq_logC`.
Note a tail-recursive version of `Nat.log` is also possible:
```
def logTR (b n : ℕ) : ℕ :=
let rec go : ℕ → ℕ → ℕ | n, acc => if h : b ≤ n ∧ 1 < b then go (n / b) (acc + 1) else acc
decreasing_by
have : n / b < n := Nat.div_lt_self (by omega) h.2
decreasing_trivial
go n 0
```
but performs worse for large numbers than `Nat.logC`:
```
#eval Nat.logTR 2 (2 ^ 1000000)
#eval Nat.logC 2 (2 ^ 1000000)
```
The definition `Nat.logC` is not tail-recursive, however, but the stack limit will only be reached
if the output size is around 2^10000, meaning the input will be around 2^(2^10000), which will
take far too long to compute in the first place.
Adapted from https://downloads.haskell.org/~ghc/9.0.1/docs/html/libraries/ghc-bignum-1.0/GHC-Num-BigNat.html#v:bigNatLogBase-35-
-/
@[pp_nodot] def logC (b m : ℕ) : ℕ :=
if h : 1 < b then let (_, e) := step b h; e else 0 where
/--
An auxiliary definition for `Nat.logC`, where the base of the logarithm is _squared_ in each
loop. This allows significantly faster computation of the logarithm: it takes logarithmic time
in the size of the output, rather than linear time.
-/
step (pw : ℕ) (hpw : 1 < pw) : ℕ × ℕ :=
if h : m < pw then (m, 0)
else
let (q, e) := step (pw * pw) (Nat.mul_lt_mul_of_lt_of_lt hpw hpw)
if q < pw then (q, 2 * e) else (q / pw, 2 * e + 1)
termination_by m / pw
decreasing_by
have : m / (pw * pw) < m / pw := logC_aux hpw (le_of_not_lt h)
decreasing_trivial

private lemma logC_step {m pw q e : ℕ} (hpw : 1 < pw) (hqe : logC.step m pw hpw = (q, e)) :
pw ^ e * q ≤ m ∧ q < pw ∧ (m < pw ^ e * (q + 1)) ∧ (0 < m → 0 < q) := by
induction pw, hpw using logC.step.induct m generalizing q e with
| case1 pw hpw hmpw =>
rw [logC.step, dif_pos hmpw] at hqe
cases hqe
simpa
| case2 pw hpw hmpw q' e' hqe' hqpw ih =>
simp only [logC.step, dif_neg hmpw, hqe', if_pos hqpw] at hqe
cases hqe
rw [Nat.pow_mul, Nat.pow_two]
exact ⟨(ih hqe').1, hqpw, (ih hqe').2.2
| case3 pw hpw hmpw q' e' hqe' hqpw ih =>
simp only [Nat.logC.step, dif_neg hmpw, hqe', if_neg hqpw] at hqe
cases hqe
rw [Nat.pow_succ, Nat.mul_assoc, Nat.pow_mul, Nat.pow_two, Nat.mul_assoc]
refine ⟨(ih hqe').1.trans' (Nat.mul_le_mul_left _ (Nat.mul_div_le _ _)),
Nat.div_lt_of_lt_mul (ih hqe').2.1, (ih hqe').2.2.1.trans_le ?_,
fun _ => Nat.div_pos (le_of_not_lt hqpw) (by omega)⟩
exact Nat.mul_le_mul_left _ (Nat.lt_mul_div_succ _ (zero_lt_of_lt hpw))

private lemma logC_spec {b m : ℕ} (hb : 1 < b) (hm : 0 < m) :
b ^ logC b m ≤ m ∧ m < b ^ (logC b m + 1) := by
rw [logC, dif_pos hb]
split
next q e heq =>
obtain ⟨h₁, h₂, h₃, h₄⟩ := logC_step hb heq
exact ⟨h₁.trans' (Nat.le_mul_of_pos_right _ (h₄ hm)), h₃.trans_le (Nat.mul_le_mul_left _ h₂)⟩

private lemma logC_of_left_le_one {b m : ℕ} (hb : b ≤ 1) : logC b m = 0 := by
rw [logC, dif_neg hb.not_lt]

private lemma logC_zero {b : ℕ} :
logC b 0 = 0 := by
rcases le_or_lt b 1 with hb | hb
case inl => exact logC_of_left_le_one hb
case inr =>
rw [logC, dif_pos hb]
split
next q e heq =>
rw [logC.step, dif_pos (zero_lt_of_lt hb)] at heq
rw [(Prod.mk.inj heq).2]

/--
The result of `Nat.log` agrees with the result of `Nat.logC`. The latter will be computed more
efficiently, but the former is easier to prove things about and has more lemmas.
This lemma is tagged @[csimp] so that the code generated for `Nat.log` uses `Nat.logC` instead.
-/
@[csimp] theorem log_eq_logC : log = logC := by
ext b m
rcases le_or_lt b 1 with hb | hb
case inl => rw [logC_of_left_le_one hb, Nat.log_of_left_le_one hb]
case inr =>
rcases eq_or_ne m 0 with rfl | hm
case inl => rw [Nat.log_zero_right, logC_zero]
case inr =>
rw [Nat.log_eq_iff (Or.inr ⟨hb, hm⟩)]
exact logC_spec hb (zero_lt_of_ne_zero hm)

end computation

end Nat
3 changes: 3 additions & 0 deletions test/Nat/log.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import Mathlib.Data.Nat.Log

#eval Nat.log 2 (2 ^ 10000000)

0 comments on commit 59b91d3

Please sign in to comment.