Zulip Chat Archive
Stream: lean4
Topic: Fast Modexp
Niels Voss (Jun 14 2023 at 22:23):
I spent a few hours writing a verified fast modular exponentiation algorithm. I'm sure someone else has done this before, but I just thought I would share.
import Mathlib.Data.Nat.Parity
import Mathlib.Tactic.Linarith
def fast_modexp_aux (base running modulus : ℕ) : ℕ → ℕ
| 0 => running
| (exponent+1) =>
have : (exponent+1) / 2 < exponent + 1 := by
apply Nat.div_lt_of_lt_mul
linarith
if Even exponent then
fast_modexp_aux (base * base % modulus) (running * base % modulus) modulus ((exponent + 1) / 2)
else
fast_modexp_aux (base * base % modulus) running modulus ((exponent + 1) / 2)
def fast_modexp (base exponent modulus : ℕ) : ℕ :=
match modulus with
| 0 => base ^ exponent
| 1 => 0
| m => fast_modexp_aux (base % m) 1 m exponent
lemma fast_modexp_aux_spec (b r m e : ℕ) (h : r < m)
: fast_modexp_aux b r m e = r * b ^ e % m := by
revert b r m
induction' e using Nat.case_strong_induction_on with e ih
· intros b r m h
unfold fast_modexp_aux
simp [Nat.mod_eq_of_lt h]
· intros b r m h₁
have hlt : (e + 1) / 2 ≤ e := by
apply Nat.le_of_lt_succ
apply Nat.div_lt_of_lt_mul
linarith
unfold fast_modexp_aux
by_cases he : Even e
· rw [if_pos he]
have he' : (e + 1) / 2 * 2 = e := by
simpa using Nat.div_two_mul_two_add_one_of_odd (Even.add_odd he odd_one)
change fast_modexp_aux (b * b % m) (r * b % m) m ((e + 1) / 2) = _
rw [ih ((e + 1) / 2) hlt (b * b % m) (r * b % m) m (Nat.mod_lt _ (Nat.zero_lt_of_lt h₁))]
rw [Nat.mul_mod (r * b % m), Nat.mod_mod, ←Nat.pow_mod]
rw [←sq, ←pow_mul, mul_comm 2]
rw [he']
rw [←Nat.mul_mod, Nat.pow_succ]
ring_nf
· rw [if_neg he]
have he' : (e + 1) / 2 * 2 = e + 1 := by
rw [←Nat.odd_iff_not_even] at he
exact Nat.div_two_mul_two_of_even (Odd.add_odd he odd_one)
change fast_modexp_aux (b * b % m) r m ((e + 1) / 2) = _
rw [ih ((e + 1) / 2) hlt (b * b % m) r m h₁]
rw [Nat.mul_mod r, ←Nat.pow_mod]
rw [←sq, ←pow_mul, mul_comm 2]
rw [he']
exact (Nat.mul_mod _ _ _).symm
theorem fast_modexp_spec (b e m : ℕ) : fast_modexp b e m = (b ^ e) % m := by
rcases m with m | m | m
· change b ^ e = _
simp
· change 0 = _
exact (Nat.mod_one _).symm
· change fast_modexp_aux _ _ _ _ = _
rw [fast_modexp_aux_spec (b % (m + 2)) 1 (m + 2) e (by linarith)]
rw [one_mul, ←Nat.pow_mod]
#eval fast_modexp 3081683040978812497613678567591472736846636481039624945710
9688672230569954477594568460243311380476614093418427219148
3862468819539322552199292737141641904852335676904946310429
-- 1604949739222115423721086562038072677142980288837751825385
It should be able to handle fairly large numbers. I haven't profiled it because I'm not really sure how to. Is there a version of fast modexp already in Mathlib?
James Gallicchio (Jun 14 2023 at 23:03):
You can use timeit
for doing pretty easy timing
Niels Voss (Jun 14 2023 at 23:09):
Is there an example of how to use timeit
? I tried
#eval timeit "fast modexp" (do pure (fast_modexp 3123 384016810 131231))
But this is not timing the computation because it always reports 0.0002ms
James Gallicchio (Jun 14 2023 at 23:11):
I think you need to prevent it from doing constant lifting. There's an option set_option compiler.extract_closed false
that you can try?
James Gallicchio (Jun 14 2023 at 23:12):
alternatively, you can wrap it in a little function that takes the arguments so that the call is no longer a closed expression
Niels Voss (Jun 14 2023 at 23:13):
Both #evals still print 0.0002 ms
set_option compiler.extract_closed false
#eval timeit "fast modexp" (do pure (fast_modexp 3123 1243345244 131231))
#eval timeit "normal" (do pure (3 ^ 100000))
Niels Voss (Jun 14 2023 at 23:16):
I finally got it to work. Of these three commands, the third is the only one which returns a realistic time.
def run_modexp (m b e : ℕ) : IO ℕ := do
pure (fast_modexp m b e)
#eval timeit "Fast modexp" (do pure (fast_modexp 213 1223423422342525235235 123))
#eval timeit "Fast modexp" (do pure (run_modexp 213 1231251251252142123523523643623442342345 123))
#eval timeit "Fast modexp" (run_modexp 213 1231251251252142123523523643623442342345 123
James Gallicchio (Jun 14 2023 at 23:23):
I assume it's much faster than the standard Nat functions? :joy:
Niels Voss (Jun 14 2023 at 23:25):
Well actually only for very large exponents. I might be timing this wrong but I think fast_modexp
seems to take around 0.2 ms no matter how large the exponent is, whereas doing it normally is faster for small values (I think below 10000?). I have no experience benchmarking things so I could very well be wrong
Niels Voss (Jun 14 2023 at 23:35):
image.png
Here's the output of a quick benchmark
def test_cases : List (ℕ × ℕ × ℕ) :=
[
(123, 100, 916),
(123, 1000, 916),
(123, 10000, 916),
(123, 100000, 916),
(123, 1000000, 916)
]
def run_fast_modexp (b e m : ℕ) : IO ℕ := do
pure (fast_modexp b e m)
def run_normal_modexp (b e m : ℕ) : IO ℕ := do
pure (b ^ e % m)
def test : IO Unit := do
for case in test_cases do
let (b, e, m) := case
let fast_result ← timeit s!"fast_modexp {b} {e} {m}" (run_fast_modexp b e m)
for case in test_cases do
let (b, e, m) := case
let normal_result ← timeit s!"normal_modexp {b} {e} {m}" (run_normal_modexp b e m)
#eval test
Niels Voss (Jun 14 2023 at 23:51):
I guess there are also some optimizations I can make. Is dividing by 2 the same speed as bit shifting one time? And is checking whether something is Even
the same speed as checking whether it is divisible by 2?
Trebor Huang (Jun 16 2023 at 05:21):
Why is the first testcase significantly slower? Is some startup cost included there?
Niels Voss (Jun 16 2023 at 06:19):
I have no idea. I reran the test (also printing out what each function returned) and I got very different results.
image.png
Do you know a more reliable way to benchmark? I haven't ever done this before so I'm probably doing something wrong.
Niels Voss (Jun 16 2023 at 06:19):
In case images don't render
fast_modexp 123 100 916 0.169ms
661
fast_modexp 123 1000 916 0.0063ms
165
fast_modexp 123 10000 916 0.0073ms
289
fast_modexp 123 100000 916 0.0083ms
17
fast_modexp 123 1000000 916 0.0097ms
501
normal_modexp 123 100 916 0.0365ms
661
normal_modexp 123 1000 916 0.0256ms
165
normal_modexp 123 10000 916 4.03ms
289
normal_modexp 123 100000 916 294ms
17
normal_modexp 123 1000000 916 24.3s
501
Niels Voss (Jun 16 2023 at 06:24):
I think it is most likely some sort of startup cost because if I shuffle the test cases around whichever one I put first seems to take significantly longer for the fast_modexp
. Although it doesn't matter whether I test fast_modexp
or normal_modexp
first, the first test for fast_modexp
is always the one that takes longer.
Mario Carneiro (Jun 16 2023 at 06:31):
you could put an extra sacrifical test case at the start
Mario Carneiro (Jun 16 2023 at 06:32):
BTW I would be very concerned about the compiler seeing through the test cases here
Mario Carneiro (Jun 16 2023 at 06:33):
one thing you could do is put the test_cases in an IO.Ref and then read from that ref
Niels Voss (Jun 16 2023 at 06:51):
I'm not sure how to use IO.Ref
, as I haven't really done monadic programming in before. Here's my current attempt:
def the_test_cases : List (ℕ × ℕ × ℕ) :=
[
(123, 100, 916),
(123, 100, 916),
(123, 1000, 916),
(123, 10000, 916),
(123, 100000, 916),
(123, 1000000, 916)
]
def test_cases_ref : IO (IO.Ref (List (ℕ × ℕ × ℕ))) :=
IO.mkRef the_test_cases
def run_fast_modexp (b e m : ℕ) : IO ℕ := do
pure (fast_modexp b e m)
def run_normal_modexp (b e m : ℕ) : IO ℕ := do
pure (b ^ e % m)
def test : IO Unit := do
let test_cases ← ST.Ref.get (← test_cases_ref)
for case in test_cases do
let (b, e, m) := case
let fast_result ← timeit s!"fast_modexp {b} {e} {m}" (run_fast_modexp b e m)
println! fast_result
for case in test_cases do
let (b, e, m) := case
let normal_result ← timeit s!"normal_modexp {b} {e} {m}" (run_normal_modexp b e m)
println! normal_result
#eval test
This runs but has the same startup penalty. Should I be reading from the reference every loop? If so, would test_cases_ref
need to be a List
of references rather than a reference to a List
?
Mario Carneiro (Jun 16 2023 at 06:52):
no I think this will work, this is what I meant
Mario Carneiro (Jun 16 2023 at 06:52):
what are the performance numbers for this?
Mario Carneiro (Jun 16 2023 at 06:53):
this was not meant to fix the startup cost, but rather to avoid the compiler getting smart and optimizing the calculation to a constant
Mario Carneiro (Jun 16 2023 at 06:55):
To fix the startup cost, you should try evaluating run_fast_modexp
once before the main loop and not timing it
Mario Carneiro (Jun 16 2023 at 06:56):
that's what I meant by a "sacrificial test case"
Mario Carneiro (Jun 16 2023 at 06:56):
There are probably some values used in the calculation which are being initialized on first use
Niels Voss (Jun 16 2023 at 06:58):
Here are the results, I added a test case with an exponent of 10 at the start.
fast_modexp 123 10 916 0.189ms
469
fast_modexp 123 100 916 0.0062ms
661
fast_modexp 123 1000 916 0.007ms
165
fast_modexp 123 10000 916 0.0086ms
289
fast_modexp 123 100000 916 0.0102ms
17
fast_modexp 123 1000000 916 0.012ms
501
normal_modexp 123 10 916 0.0418ms
469
normal_modexp 123 100 916 0.0032ms
661
normal_modexp 123 1000 916 0.0288ms
165
normal_modexp 123 10000 916 4.09ms
289
normal_modexp 123 100000 916 300ms
17
normal_modexp 123 1000000 916 24.3s
501
Niels Voss (Jun 16 2023 at 07:01):
Here are the results in a slightly more readable format
| Exponent | fast_modexp | normal_modexp |
| 10 | 0.189 ms | 0.0418 ms |
| 100 | 0.0062 ms | 0.0032 ms |
| 1000 | 0.007 ms | 0.0288 ms |
| 10000 | 0.0086 ms | 4.09 ms |
| 100000 | 0.0102 ms | 300 ms |
| 1000000 | 0.012 ms | 24.3 s |
Mario Carneiro (Jun 16 2023 at 07:04):
Exponent | fast_modexp | normal_modexp |
---|---|---|
10 | 0.189 ms | 0.0418 ms |
100 | 0.0062 ms | 0.0032 ms |
1000 | 0.007 ms | 0.0288 ms |
10000 | 0.0086 ms | 4.09 ms |
100000 | 0.0102 ms | 300 ms |
1000000 | 0.012 ms | 24.3 s |
Mario Carneiro (Jun 16 2023 at 07:05):
does it not work to run it beforehand?
Niels Voss (Jun 16 2023 at 07:12):
Actually it does. I inserted println! s!"early run fast_modexp {fast_modexp 4 13 449}"
as the first line in test
and it seemed to mostly avoid the startup penalty, but not completely.
early run fast_modexp 426
fast_modexp 123 10 916 0.0335ms
469
fast_modexp 123 100 916 0.0043ms
661
fast_modexp 123 1000 916 0.005ms
165
fast_modexp 123 10000 916 0.0067ms
289
fast_modexp 123 100000 916 0.0079ms
17
fast_modexp 123 1000000 916 0.0091ms
501
normal_modexp 123 10 916 0.0396ms
469
normal_modexp 123 100 916 0.0017ms
661
normal_modexp 123 1000 916 0.0254ms
165
normal_modexp 123 10000 916 3.91ms
289
normal_modexp 123 100000 916 304ms
17
normal_modexp 123 1000000 916 24.7s
501
Niels Voss (Jun 16 2023 at 07:14):
At this point the measurements for fast_modexp
are on the order of 4 to 30 microseconds so I don't know how accurate my computer actually is with timing.
Mario Carneiro (Jun 16 2023 at 07:14):
how many times are you running it?
Niels Voss (Jun 16 2023 at 07:15):
Just once. Is there a way to have timeit
run multiple times?
Mario Carneiro (Jun 16 2023 at 07:15):
it takes an IO action, you can just stick a loop inside
Niels Voss (Jun 16 2023 at 07:24):
Would something like this work? At first I just tried discarding the value but it seemed to optimize the whole calculation away. I also got rid of the last test case because it took 20 seconds to run with normal modexp.
def dump : IO (IO.Ref ℕ) :=
IO.mkRef 0
def run_fast_modexp (b e m : ℕ) : IO Unit := do
for _ in [1:100] do
ST.Ref.set (← dump) (fast_modexp b e m)
def run_normal_modexp (b e m : ℕ) : IO Unit := do
for _ in [1:100] do
ST.Ref.set (← dump) (b ^ e % m)
Niels Voss (Jun 16 2023 at 07:30):
Dividing each value by 100 should result in the mean time per run
Exponent | fast_modexp (100 runs) | normal_modexp (100 runs) |
---|---|---|
10 | 0.397 ms | 0.154 ms |
100 | 0.429 ms | 0.125 ms |
1000 | 0.574 ms | 2.53 ms |
10000 | 0.774 ms | 423 ms |
100000 | 0.91 ms | 29.7 s |
Mac Malone (Jun 18 2023 at 00:53):
@Niels Voss I would advise putting some additional test cases between 100 and 1000 (e.g., 250, 500, 750, etc.) to see if you cannot more precisely determine the cutoff point. Then, I would suggest making an e.g. smart_modexp
that uses normal_modexp
for the values below and fast_modexp
for values above to get the best result.
Niels Voss (Jun 18 2023 at 01:28):
It seems to be overtaking normal_modexp
for exponents somewhere between 350 and 400. All test were done with the same base and modulus, so I might consider varying those to see how it impacts the results.
Niels Voss (Jun 18 2023 at 02:37):
Would something like this work?
def smart_modexp (b e m : ℕ) : ℕ :=
if b ≤ 300 ∧ e ≤ 300 then b ^ e % m else fast_modexp b e m
theorem smart_modexp_spec (b e m : ℕ) : smart_modexp b e m = b ^ e % m := by
unfold smart_modexp
simp [fast_modexp_spec]
I think 300 is a bit lower than the point at which fast_modexp
becomes faster than b ^ e % m
, but I think the performance cost of running fast_modexp b e m
when b ^ e % m
was a better choice is less significant than the performance cost of running b ^ e % m
when fast_modexp b e m
should have been run.
Also all the performance tests up to this point have been based on the #eval
command in VSCode, but not kernel computation, which I don't know how to benchmark.
Jason Rute (Jun 18 2023 at 09:50):
For benchmarking, see #lean4 > speed tests
Niels Voss (Jun 18 2023 at 20:33):
Isn't this only about Münchausen numbers? I'm not quite sure how to use this to test my code.
My goal was primarily to show that it was possible to write a formally verified modexp algorithm based on the Exponentiation by Squaring method, and I'm mainly benchmarking to make sure it performs faster than the Lean builtin. However I could definitely test this against a C implementation, though I'd imagine it won't perform that well because I was focusing more on verification than optimization.
Niels Voss (Jun 19 2023 at 00:06):
I generalized the fast_modexp
algorithm to work with any monoid, so it could theoretically be used to do things like matrix or polynomial exponentiation. I then used ZMod m
to actually do the modular exponentiation.
import Mathlib
universe u
lemma pow_by_squaring_even {G : Type u} [Monoid G] (b : G) {n : ℕ} (hn : Even n)
: (b ^ 2) ^ (n / 2) = b ^ n := by
rw [←pow_mul, mul_comm 2, Nat.div_two_mul_two_of_even hn]
def pow_by_squaring_aux {G : Type u} [Monoid G] : G → G → ℕ → G
| _, running, 0 => running
| b, running, (e + 1) =>
have : (e + 1) / 2 < e + 1 := by
apply Nat.div_lt_of_lt_mul
linarith
have : e / 2 < e + 1 := by
apply Nat.div_lt_of_lt_mul
linarith
if Even e then
pow_by_squaring_aux (b ^ 2) (running * b) (e / 2)
else
pow_by_squaring_aux (b ^ 2) running ((e + 1) / 2)
def pow_by_squaring {G : Type u} [Monoid G] (b : G) (e : ℕ) :=
pow_by_squaring_aux b 1 e
lemma pow_by_squaring_aux_spec {G : Type u} [Monoid G] (b running : G) (e : ℕ)
: pow_by_squaring_aux b running e = running * (b ^ e) := by
revert b running
induction' e using Nat.case_strong_induction_on with e ih
· intros b running
rw [pow_zero, mul_one]
rfl
· intros b running
unfold pow_by_squaring_aux
by_cases h : Even e
· have hle : e / 2 ≤ e := by
apply Nat.le_of_lt_succ
apply Nat.div_lt_of_lt_mul
linarith
rw [if_pos h]
change pow_by_squaring_aux (b ^ 2) (running * b) (e / 2) = _
rw [ih (e / 2) hle (b ^ 2) (running * b)]
rw [pow_by_squaring_even b h]
rw [pow_succ]
rw [mul_assoc]
· have hle : (e + 1) / 2 ≤ e := by
apply Nat.le_of_lt_succ
apply Nat.div_lt_of_lt_mul
linarith
rw [if_neg h]
change pow_by_squaring_aux (b ^ 2) running ((e + 1) / 2) = _
rw [ih ((e + 1) / 2) hle (b ^ 2) running]
have : Even (e + 1) := by
rw [←Nat.odd_iff_not_even] at h
exact Odd.add_odd h odd_one
rw [pow_by_squaring_even b this]
theorem pow_by_squaring_spec {G : Type u} [Monoid G] (b : G) (e : ℕ)
: pow_by_squaring b e = b ^ e := by
unfold pow_by_squaring
rw [pow_by_squaring_aux_spec, one_mul]
def modexp_by_squaring (b e m : ℕ) : ℕ :=
ZMod.val (pow_by_squaring (b : ZMod m) e)
lemma zmod_pow_val_nat {m : ℕ} (n e : ℕ)
: ZMod.val ((↑n : ZMod m) ^ e) = ZMod.val (n ^ e : ZMod m) := by
induction' e with e ih
· rw [pow_zero, pow_zero]
rw [Nat.cast_one]
· rw [pow_succ, pow_succ]
rw [ZMod.val_mul]
rw [ih]
rw [ZMod.val_nat_cast]
rw [ZMod.val_nat_cast]
rw [ZMod.val_nat_cast]
rw [←Nat.mul_mod]
lemma zmod_val_pow_nat' {m : ℕ} (n e : ℕ)
: ZMod.val ((b : ZMod m) ^ e) = b ^ e % m := by
rw [zmod_pow_val_nat]
exact ZMod.val_nat_cast _
theorem modexp_by_squaring_spec (b e m : ℕ) : modexp_by_squaring b e m = (b ^ e) % m := by
unfold modexp_by_squaring
rw [pow_by_squaring_spec]
exact zmod_val_pow_nat' b _
Niels Voss (Jun 19 2023 at 00:10):
One caveat is that this algorithm is much slower than the fast_modexp
described above. It seems to run about 5 times slower. This isn't awful since it is still able to handle really large numbers and the time complexity (if I implemented it properly) should still be O(log n)
, but it doesn't start performing faster than the normal modexp until an exponent of a little over 1000.
Last updated: Dec 20 2023 at 11:08 UTC