Zulip Chat Archive

Stream: lean4

Topic: Fast Modexp


Niels Voss (Jun 14 2023 at 22:23):

I spent a few hours writing a verified fast modular exponentiation algorithm. I'm sure someone else has done this before, but I just thought I would share.

import Mathlib.Data.Nat.Parity
import Mathlib.Tactic.Linarith

def fast_modexp_aux (base running modulus : ) :   
| 0 => running
| (exponent+1) =>
  have : (exponent+1) / 2 < exponent + 1 := by
    apply Nat.div_lt_of_lt_mul
    linarith
  if Even exponent then
    fast_modexp_aux (base * base % modulus) (running * base % modulus) modulus ((exponent + 1) / 2)
  else
    fast_modexp_aux (base * base % modulus) running modulus ((exponent + 1) / 2)

def fast_modexp (base exponent modulus : ) :  :=
match modulus with
| 0 => base ^ exponent
| 1 => 0
| m => fast_modexp_aux (base % m) 1 m exponent

lemma fast_modexp_aux_spec (b r m e : ) (h : r < m)
  : fast_modexp_aux b r m e = r * b ^ e % m := by
  revert b r m
  induction' e using Nat.case_strong_induction_on with e ih
  · intros b r m h
    unfold fast_modexp_aux
    simp [Nat.mod_eq_of_lt h]
  · intros b r m h₁
    have hlt : (e + 1) / 2  e := by
      apply Nat.le_of_lt_succ
      apply Nat.div_lt_of_lt_mul
      linarith
    unfold fast_modexp_aux
    by_cases he : Even e
    · rw [if_pos he]
      have he' : (e + 1) / 2 * 2 = e := by
        simpa using Nat.div_two_mul_two_add_one_of_odd (Even.add_odd he odd_one)
      change fast_modexp_aux (b * b % m) (r * b % m) m ((e + 1) / 2) = _
      rw [ih ((e + 1) / 2) hlt (b * b % m) (r * b % m) m (Nat.mod_lt _ (Nat.zero_lt_of_lt h₁))]
      rw [Nat.mul_mod (r * b % m), Nat.mod_mod, Nat.pow_mod]
      rw [sq, pow_mul, mul_comm 2]
      rw [he']
      rw [Nat.mul_mod, Nat.pow_succ]
      ring_nf
    · rw [if_neg he]
      have he' : (e + 1) / 2 * 2 = e + 1 := by
        rw [Nat.odd_iff_not_even] at he
        exact Nat.div_two_mul_two_of_even (Odd.add_odd he odd_one)
      change fast_modexp_aux (b * b % m) r m ((e + 1) / 2) = _
      rw [ih ((e + 1) / 2) hlt (b * b % m) r m h₁]
      rw [Nat.mul_mod r, Nat.pow_mod]
      rw [sq, pow_mul, mul_comm 2]
      rw [he']
      exact (Nat.mul_mod _ _ _).symm

theorem fast_modexp_spec (b e m : ) : fast_modexp b e m = (b ^ e) % m := by
  rcases m with m | m | m
  · change b ^ e = _
    simp
  · change 0 = _
    exact (Nat.mod_one _).symm
  · change fast_modexp_aux _ _ _ _ = _
    rw [fast_modexp_aux_spec (b % (m + 2)) 1 (m + 2) e (by linarith)]
    rw [one_mul, Nat.pow_mod]

#eval fast_modexp 3081683040978812497613678567591472736846636481039624945710
  9688672230569954477594568460243311380476614093418427219148
  3862468819539322552199292737141641904852335676904946310429
  -- 1604949739222115423721086562038072677142980288837751825385

It should be able to handle fairly large numbers. I haven't profiled it because I'm not really sure how to. Is there a version of fast modexp already in Mathlib?

James Gallicchio (Jun 14 2023 at 23:03):

You can use timeit for doing pretty easy timing

Niels Voss (Jun 14 2023 at 23:09):

Is there an example of how to use timeit? I tried

#eval timeit "fast modexp" (do pure (fast_modexp 3123 384016810 131231))

But this is not timing the computation because it always reports 0.0002ms

James Gallicchio (Jun 14 2023 at 23:11):

I think you need to prevent it from doing constant lifting. There's an option set_option compiler.extract_closed false that you can try?

James Gallicchio (Jun 14 2023 at 23:12):

alternatively, you can wrap it in a little function that takes the arguments so that the call is no longer a closed expression

Niels Voss (Jun 14 2023 at 23:13):

Both #evals still print 0.0002 ms

set_option compiler.extract_closed false
#eval timeit "fast modexp" (do pure (fast_modexp 3123 1243345244 131231))
#eval timeit "normal" (do pure (3 ^ 100000))

Niels Voss (Jun 14 2023 at 23:16):

I finally got it to work. Of these three commands, the third is the only one which returns a realistic time.

def run_modexp (m b e : ) : IO  := do
  pure (fast_modexp m b e)
#eval timeit "Fast modexp" (do pure (fast_modexp 213 1223423422342525235235 123))
#eval timeit "Fast modexp" (do pure (run_modexp 213 1231251251252142123523523643623442342345 123))
#eval timeit "Fast modexp" (run_modexp 213 1231251251252142123523523643623442342345 123

James Gallicchio (Jun 14 2023 at 23:23):

I assume it's much faster than the standard Nat functions? :joy:

Niels Voss (Jun 14 2023 at 23:25):

Well actually only for very large exponents. I might be timing this wrong but I think fast_modexp seems to take around 0.2 ms no matter how large the exponent is, whereas doing it normally is faster for small values (I think below 10000?). I have no experience benchmarking things so I could very well be wrong

Niels Voss (Jun 14 2023 at 23:35):

image.png
Here's the output of a quick benchmark

def test_cases : List ( ×  × ) :=
[
  (123, 100, 916),
  (123, 1000, 916),
  (123, 10000, 916),
  (123, 100000, 916),
  (123, 1000000, 916)
]

def run_fast_modexp (b e m : ) : IO  := do
  pure (fast_modexp b e m)
def run_normal_modexp (b e m : ) : IO  := do
  pure (b ^ e % m)

def test : IO Unit := do
  for case in test_cases do
    let (b, e, m) := case
    let fast_result  timeit s!"fast_modexp {b} {e} {m}" (run_fast_modexp b e m)
  for case in test_cases do
    let (b, e, m) := case
    let normal_result  timeit s!"normal_modexp {b} {e} {m}" (run_normal_modexp b e m)

#eval test

Niels Voss (Jun 14 2023 at 23:51):

I guess there are also some optimizations I can make. Is dividing by 2 the same speed as bit shifting one time? And is checking whether something is Even the same speed as checking whether it is divisible by 2?

Trebor Huang (Jun 16 2023 at 05:21):

Why is the first testcase significantly slower? Is some startup cost included there?

Niels Voss (Jun 16 2023 at 06:19):

I have no idea. I reran the test (also printing out what each function returned) and I got very different results.
image.png
Do you know a more reliable way to benchmark? I haven't ever done this before so I'm probably doing something wrong.

Niels Voss (Jun 16 2023 at 06:19):

In case images don't render

fast_modexp 123 100 916 0.169ms
661
fast_modexp 123 1000 916 0.0063ms
165
fast_modexp 123 10000 916 0.0073ms
289
fast_modexp 123 100000 916 0.0083ms
17
fast_modexp 123 1000000 916 0.0097ms
501
normal_modexp 123 100 916 0.0365ms
661
normal_modexp 123 1000 916 0.0256ms
165
normal_modexp 123 10000 916 4.03ms
289
normal_modexp 123 100000 916 294ms
17
normal_modexp 123 1000000 916 24.3s
501

Niels Voss (Jun 16 2023 at 06:24):

I think it is most likely some sort of startup cost because if I shuffle the test cases around whichever one I put first seems to take significantly longer for the fast_modexp. Although it doesn't matter whether I test fast_modexp or normal_modexp first, the first test for fast_modexp is always the one that takes longer.

Mario Carneiro (Jun 16 2023 at 06:31):

you could put an extra sacrifical test case at the start

Mario Carneiro (Jun 16 2023 at 06:32):

BTW I would be very concerned about the compiler seeing through the test cases here

Mario Carneiro (Jun 16 2023 at 06:33):

one thing you could do is put the test_cases in an IO.Ref and then read from that ref

Niels Voss (Jun 16 2023 at 06:51):

I'm not sure how to use IO.Ref, as I haven't really done monadic programming in before. Here's my current attempt:

def the_test_cases : List ( ×  × ) :=
[
  (123, 100, 916),
  (123, 100, 916),
  (123, 1000, 916),
  (123, 10000, 916),
  (123, 100000, 916),
  (123, 1000000, 916)
]

def test_cases_ref : IO (IO.Ref (List ( ×  × ))) :=
IO.mkRef the_test_cases

def run_fast_modexp (b e m : ) : IO  := do
  pure (fast_modexp b e m)
def run_normal_modexp (b e m : ) : IO  := do
  pure (b ^ e % m)

def test : IO Unit := do
  let test_cases  ST.Ref.get ( test_cases_ref)
  for case in test_cases do
    let (b, e, m) := case
    let fast_result  timeit s!"fast_modexp {b} {e} {m}" (run_fast_modexp b e m)
    println! fast_result
  for case in test_cases do
    let (b, e, m) := case
    let normal_result  timeit s!"normal_modexp {b} {e} {m}" (run_normal_modexp b e m)
    println! normal_result

#eval test

This runs but has the same startup penalty. Should I be reading from the reference every loop? If so, would test_cases_ref need to be a List of references rather than a reference to a List?

Mario Carneiro (Jun 16 2023 at 06:52):

no I think this will work, this is what I meant

Mario Carneiro (Jun 16 2023 at 06:52):

what are the performance numbers for this?

Mario Carneiro (Jun 16 2023 at 06:53):

this was not meant to fix the startup cost, but rather to avoid the compiler getting smart and optimizing the calculation to a constant

Mario Carneiro (Jun 16 2023 at 06:55):

To fix the startup cost, you should try evaluating run_fast_modexp once before the main loop and not timing it

Mario Carneiro (Jun 16 2023 at 06:56):

that's what I meant by a "sacrificial test case"

Mario Carneiro (Jun 16 2023 at 06:56):

There are probably some values used in the calculation which are being initialized on first use

Niels Voss (Jun 16 2023 at 06:58):

Here are the results, I added a test case with an exponent of 10 at the start.

fast_modexp 123 10 916 0.189ms
469
fast_modexp 123 100 916 0.0062ms
661
fast_modexp 123 1000 916 0.007ms
165
fast_modexp 123 10000 916 0.0086ms
289
fast_modexp 123 100000 916 0.0102ms
17
fast_modexp 123 1000000 916 0.012ms
501
normal_modexp 123 10 916 0.0418ms
469
normal_modexp 123 100 916 0.0032ms
661
normal_modexp 123 1000 916 0.0288ms
165
normal_modexp 123 10000 916 4.09ms
289
normal_modexp 123 100000 916 300ms
17
normal_modexp 123 1000000 916 24.3s
501

Niels Voss (Jun 16 2023 at 07:01):

Here are the results in a slightly more readable format

| Exponent | fast_modexp | normal_modexp |
|       10 | 0.189 ms    | 0.0418 ms     |
|      100 | 0.0062 ms   | 0.0032 ms     |
|     1000 | 0.007 ms    | 0.0288 ms     |
|    10000 | 0.0086 ms   | 4.09 ms       |
|   100000 | 0.0102 ms   | 300 ms        |
|  1000000 | 0.012 ms    | 24.3 s        |

Mario Carneiro (Jun 16 2023 at 07:04):

Exponent fast_modexp normal_modexp
10 0.189 ms 0.0418 ms
100 0.0062 ms 0.0032 ms
1000 0.007 ms 0.0288 ms
10000 0.0086 ms 4.09 ms
100000 0.0102 ms 300 ms
1000000 0.012 ms 24.3 s

Mario Carneiro (Jun 16 2023 at 07:05):

does it not work to run it beforehand?

Niels Voss (Jun 16 2023 at 07:12):

Actually it does. I inserted println! s!"early run fast_modexp {fast_modexp 4 13 449}" as the first line in test and it seemed to mostly avoid the startup penalty, but not completely.

early run fast_modexp 426
fast_modexp 123 10 916 0.0335ms
469
fast_modexp 123 100 916 0.0043ms
661
fast_modexp 123 1000 916 0.005ms
165
fast_modexp 123 10000 916 0.0067ms
289
fast_modexp 123 100000 916 0.0079ms
17
fast_modexp 123 1000000 916 0.0091ms
501
normal_modexp 123 10 916 0.0396ms
469
normal_modexp 123 100 916 0.0017ms
661
normal_modexp 123 1000 916 0.0254ms
165
normal_modexp 123 10000 916 3.91ms
289
normal_modexp 123 100000 916 304ms
17
normal_modexp 123 1000000 916 24.7s
501

Niels Voss (Jun 16 2023 at 07:14):

At this point the measurements for fast_modexp are on the order of 4 to 30 microseconds so I don't know how accurate my computer actually is with timing.

Mario Carneiro (Jun 16 2023 at 07:14):

how many times are you running it?

Niels Voss (Jun 16 2023 at 07:15):

Just once. Is there a way to have timeit run multiple times?

Mario Carneiro (Jun 16 2023 at 07:15):

it takes an IO action, you can just stick a loop inside

Niels Voss (Jun 16 2023 at 07:24):

Would something like this work? At first I just tried discarding the value but it seemed to optimize the whole calculation away. I also got rid of the last test case because it took 20 seconds to run with normal modexp.

def dump : IO (IO.Ref ) :=
IO.mkRef 0

def run_fast_modexp (b e m : ) : IO Unit := do
  for _ in [1:100] do
    ST.Ref.set ( dump) (fast_modexp b e m)

def run_normal_modexp (b e m : ) : IO Unit := do
  for _ in [1:100] do
    ST.Ref.set ( dump) (b ^ e % m)

Niels Voss (Jun 16 2023 at 07:30):

Dividing each value by 100 should result in the mean time per run

Exponent fast_modexp (100 runs) normal_modexp (100 runs)
10 0.397 ms 0.154 ms
100 0.429 ms 0.125 ms
1000 0.574 ms 2.53 ms
10000 0.774 ms 423 ms
100000 0.91 ms 29.7 s

Mac Malone (Jun 18 2023 at 00:53):

@Niels Voss I would advise putting some additional test cases between 100 and 1000 (e.g., 250, 500, 750, etc.) to see if you cannot more precisely determine the cutoff point. Then, I would suggest making an e.g. smart_modexp that uses normal_modexp for the values below and fast_modexp for values above to get the best result.

Niels Voss (Jun 18 2023 at 01:28):

It seems to be overtaking normal_modexp for exponents somewhere between 350 and 400. All test were done with the same base and modulus, so I might consider varying those to see how it impacts the results.

Niels Voss (Jun 18 2023 at 02:37):

Would something like this work?

def smart_modexp (b e m : ) :  :=
if b  300  e  300 then b ^ e % m else fast_modexp b e m

theorem smart_modexp_spec (b e m : ) : smart_modexp b e m = b ^ e % m := by
  unfold smart_modexp
  simp [fast_modexp_spec]

I think 300 is a bit lower than the point at which fast_modexp becomes faster than b ^ e % m, but I think the performance cost of running fast_modexp b e m when b ^ e % m was a better choice is less significant than the performance cost of running b ^ e % m when fast_modexp b e m should have been run.
Also all the performance tests up to this point have been based on the #eval command in VSCode, but not kernel computation, which I don't know how to benchmark.

Jason Rute (Jun 18 2023 at 09:50):

For benchmarking, see #lean4 > speed tests

Niels Voss (Jun 18 2023 at 20:33):

Isn't this only about Münchausen numbers? I'm not quite sure how to use this to test my code.
My goal was primarily to show that it was possible to write a formally verified modexp algorithm based on the Exponentiation by Squaring method, and I'm mainly benchmarking to make sure it performs faster than the Lean builtin. However I could definitely test this against a C implementation, though I'd imagine it won't perform that well because I was focusing more on verification than optimization.

Niels Voss (Jun 19 2023 at 00:06):

I generalized the fast_modexp algorithm to work with any monoid, so it could theoretically be used to do things like matrix or polynomial exponentiation. I then used ZMod m to actually do the modular exponentiation.

import Mathlib

universe u

lemma pow_by_squaring_even {G : Type u} [Monoid G] (b : G) {n : } (hn : Even n)
: (b ^ 2) ^ (n / 2) = b ^ n := by
  rw [pow_mul, mul_comm 2, Nat.div_two_mul_two_of_even hn]

def pow_by_squaring_aux {G : Type u} [Monoid G] : G  G    G
| _, running, 0 => running
| b, running, (e + 1) =>
  have : (e + 1) / 2 < e + 1 := by
    apply Nat.div_lt_of_lt_mul
    linarith
  have : e / 2 < e + 1 := by
    apply Nat.div_lt_of_lt_mul
    linarith
  if Even e then
    pow_by_squaring_aux (b ^ 2) (running * b) (e / 2)
  else
    pow_by_squaring_aux (b ^ 2) running ((e + 1) / 2)

def pow_by_squaring {G : Type u} [Monoid G] (b : G) (e : ) :=
pow_by_squaring_aux b 1 e

lemma pow_by_squaring_aux_spec {G : Type u} [Monoid G] (b running : G) (e : )
: pow_by_squaring_aux b running e = running * (b ^ e) := by
  revert b running
  induction' e using Nat.case_strong_induction_on with e ih
  · intros b running
    rw [pow_zero, mul_one]
    rfl
  · intros b running
    unfold pow_by_squaring_aux
    by_cases h : Even e
    · have hle : e / 2  e := by
        apply Nat.le_of_lt_succ
        apply Nat.div_lt_of_lt_mul
        linarith
      rw [if_pos h]
      change pow_by_squaring_aux (b ^ 2) (running * b) (e / 2) = _
      rw [ih (e / 2) hle (b ^ 2) (running * b)]
      rw [pow_by_squaring_even b h]
      rw [pow_succ]
      rw [mul_assoc]
    · have hle : (e + 1) / 2  e := by
        apply Nat.le_of_lt_succ
        apply Nat.div_lt_of_lt_mul
        linarith
      rw [if_neg h]
      change pow_by_squaring_aux (b ^ 2) running ((e + 1) / 2) = _
      rw [ih ((e + 1) / 2) hle (b ^ 2) running]
      have : Even (e + 1) := by
        rw [Nat.odd_iff_not_even] at h
        exact Odd.add_odd h odd_one
      rw [pow_by_squaring_even b this]

theorem pow_by_squaring_spec {G : Type u} [Monoid G] (b : G) (e : )
: pow_by_squaring b e = b ^ e := by
  unfold pow_by_squaring
  rw [pow_by_squaring_aux_spec, one_mul]

def modexp_by_squaring (b e m : ) :  :=
ZMod.val (pow_by_squaring (b : ZMod m) e)

lemma zmod_pow_val_nat {m : } (n e : )
: ZMod.val ((n : ZMod m) ^ e) = ZMod.val (n ^ e : ZMod m) := by
  induction' e with e ih
  · rw [pow_zero, pow_zero]
    rw [Nat.cast_one]
  · rw [pow_succ, pow_succ]
    rw [ZMod.val_mul]
    rw [ih]
    rw [ZMod.val_nat_cast]
    rw [ZMod.val_nat_cast]
    rw [ZMod.val_nat_cast]
    rw [Nat.mul_mod]

lemma zmod_val_pow_nat' {m : } (n e : )
: ZMod.val ((b : ZMod m) ^ e) = b ^ e % m := by
  rw [zmod_pow_val_nat]
  exact ZMod.val_nat_cast _

theorem modexp_by_squaring_spec (b e m : ) : modexp_by_squaring b e m = (b ^ e) % m := by
  unfold modexp_by_squaring
  rw [pow_by_squaring_spec]
  exact zmod_val_pow_nat' b _

Niels Voss (Jun 19 2023 at 00:10):

One caveat is that this algorithm is much slower than the fast_modexp described above. It seems to run about 5 times slower. This isn't awful since it is still able to handle really large numbers and the time complexity (if I implemented it properly) should still be O(log n), but it doesn't start performing faster than the normal modexp until an exponent of a little over 1000.


Last updated: Dec 20 2023 at 11:08 UTC