Zulip Chat Archive
Stream: general
Topic: Strassen's algorithm with int and znum
Seul Baek (Feb 11 2019 at 04:37):
I've implemented Strassen's algorithm in Lean (https://github.com/skbaek/strassen). The hope was that it will help with multiplication of large matrices, but its behavior is a bit strange.
1. When multiplying znum
matrices I get a nice U-curve for (execution time / crossover point), where the optimum crossover point (in numbers of rows/columns) seems to be in the 32-128 range. But with int
matrices the algorithm shows only negligible improvements over naive multiplication. Since Strassen's algorithm cuts down on multiplications at the expense of more additions/subtractions of matrices, I think this means that the cost of addition/subtraction (relative to that of multiplication) in int
is higher than that in znum
. Am I correct about this?
2. What's really puzzling is that multiplication of int
matrices is faster than that of znum
matrices, by about a factor of 5. At first I thought this is because matrix entries are not large enough, but the difference persists even when matrices are multiplied by constant factors of 1000-100000. The only time when znum
outperforms int
is when it involves a lot of additions of the form z+z
that just require a bit shift. I thought znum
was the better datatype for arithmetic with large numbers — I wonder if I'm doing something wrong here?
Simon Hudon (Feb 11 2019 at 05:15):
Are you doing the computation in kernel reductions or in the VM?
Mario Carneiro (Feb 11 2019 at 05:16):
vector.halve
can be constructed in one pass
Mario Carneiro (Feb 11 2019 at 05:17):
As a kernel data structure, I would use balanced binary trees instead of vectors of length 2^k
Seul Baek (Feb 11 2019 at 05:17):
@Simon Hudon In the kernel
Seul Baek (Feb 11 2019 at 05:19):
@Mario Carneiro That's a good point, I should have used trees
Mario Carneiro (Feb 11 2019 at 05:20):
def big_vector (α : Type) : ℕ → Type | 0 := α | (n+1) := big_vector n × big_vector n
Mario Carneiro (Feb 11 2019 at 05:21):
I'm a bit worried about the kernel having to compute with the proofs
Mario Carneiro (Feb 11 2019 at 05:21):
since vectors have a proof component
Seul Baek (Feb 11 2019 at 05:22):
Although it can't be trees all the way down, since we have to resort to regular matrix multiplication at some crossover point > 1 for efficiency
Mario Carneiro (Feb 11 2019 at 05:22):
You can also implement regular matrix multiplication on these kind of matrices
Mario Carneiro (Feb 11 2019 at 05:23):
although perhaps the size being a power of 2 is a bit restrictive
Mario Carneiro (Feb 11 2019 at 05:23):
but I guess strassen
has that built in so it should be fine
Mario Carneiro (Feb 11 2019 at 05:24):
You have more tupling and untupling than you need
Mario Carneiro (Feb 11 2019 at 05:25):
like there is no reason quadruple
and double_cols
should be uncurried
Seul Baek (Feb 11 2019 at 05:26):
The only reason they are uncurried is that it makes statements of correctness lemmas simpler, should I want to verify them
Seul Baek (Feb 11 2019 at 05:27):
Do the (un)tuplings significantly slow things down? In that case, I guess they should be curried
Mario Carneiro (Feb 11 2019 at 05:28):
Not sure by how much, but I expect so. I don't think I've identified any big problems yet, just linear slowdowns
Mario Carneiro (Feb 11 2019 at 05:31):
all your tests use #eval
?
Seul Baek (Feb 11 2019 at 05:32):
@Mario Carneiro Yes, I used #eval
Mario Carneiro (Feb 11 2019 at 05:33):
well, it's no surprise int
is fast then
Seul Baek (Feb 11 2019 at 05:38):
Oops... you're right, I forgot #eval
is VM by default
Mario Carneiro (Feb 11 2019 at 06:17):
Here's an implementation using quad-trees. How does it compare to your numbers?
structure quad (α : Type*) := (a b c d : α) def smatrix (α : Type*) : ℕ → Type* | 0 := α | (n+1) := quad (smatrix n) namespace smatrix def const {α} (c : α) : ∀ k, smatrix α k | 0 := c | (k+1) := ⟨const k, const k, const k, const k⟩ def map {α β} (f : α → β) : ∀ {k}, smatrix α k → smatrix β k | 0 a := f a | (k+1) ⟨A₁₁, A₁₂, A₂₁, A₂₂⟩ := ⟨map A₁₁, map A₁₂, map A₂₁, map A₂₂⟩ def map₂ {α β γ} (f : α → β → γ) : ∀ {k}, smatrix α k → smatrix β k → smatrix γ k | 0 A B := f A B | (k+1) ⟨A₁₁, A₁₂, A₂₁, A₂₂⟩ ⟨B₁₁, B₁₂, B₂₁, B₂₂⟩ := ⟨map₂ A₁₁ B₁₁, map₂ A₁₂ B₁₂, map₂ A₂₁ B₂₁, map₂ A₂₂ B₂₂⟩ instance {α} [has_add α] {k} : has_add (smatrix α k) := ⟨map₂ (+)⟩ instance {α} [has_neg α] {k} : has_neg (smatrix α k) := ⟨map has_neg.neg⟩ instance {α} [has_sub α] {k} : has_sub (smatrix α k) := ⟨map₂ has_sub.sub⟩ def mul {α} [has_add α] [has_mul α] : ∀ {k : nat}, smatrix α k → smatrix α k → smatrix α k | 0 A B := @has_mul.mul α _ A B | (k+1) ⟨A₁₁, A₁₂, A₂₁, A₂₂⟩ ⟨B₁₁, B₁₂, B₂₁, B₂₂⟩ := ⟨mul A₁₁ B₁₁ + mul A₁₂ B₂₁, mul A₁₁ B₁₂ + mul A₁₂ B₂₂, mul A₂₁ B₁₁ + mul A₂₂ B₂₁, mul A₂₁ B₁₂ + mul A₂₂ B₂₂⟩ instance {α} [has_add α] [has_mul α] {k} : has_mul (smatrix α k) := ⟨mul⟩ def strassen {α} [ring α] (t : nat) : ∀ {k : nat}, smatrix α k → smatrix α k → smatrix α k | 0 A B := A * B | (k+1) A@⟨A₁₁, A₁₂, A₂₁, A₂₂⟩ B@⟨B₁₁, B₁₂, B₂₁, B₂₂⟩ := if k < t then @has_mul.mul (smatrix α (k+1)) _ A B else let S₁ := A₂₁ + A₂₂, S₂ := S₁ - A₁₁, S₃ := A₁₁ - A₂₁, S₄ := A₁₂ - S₂ in let T₁ := B₁₂ - B₁₁, T₂ := B₂₂ - T₁, T₃ := B₂₂ - B₁₂, T₄ := T₂ - B₂₁ in let P₁ := strassen A₁₁ B₁₁, P₂ := strassen A₁₂ B₂₁, P₃ := strassen S₄ B₂₂, P₄ := strassen A₂₂ T₄, P₅ := strassen S₁ T₁, P₆ := strassen S₂ T₂, P₇ := strassen S₃ T₃ in let U₁ := P₁ + P₂, U₂ := P₁ + P₆, U₃ := U₂ + P₇, U₄ := U₂ + P₅, U₅ := U₄ + P₃, U₆ := U₃ - P₄, U₇ := U₃ + P₅ in ⟨U₁, U₅, U₆, U₇⟩ end smatrix def arith_prog_mat_core : ∀ k : nat, ℕ → smatrix int k | 0 n := (n:ℤ) | (k+1) n := let t := 2^k in ⟨arith_prog_mat_core k n, arith_prog_mat_core k (n+t), arith_prog_mat_core k (n+t), arith_prog_mat_core k (n+2*t)⟩ def arith_prog_mat (m : nat) : smatrix int m := arith_prog_mat_core m 0 def size : nat := 6 def test_mat := smatrix.map (λ x, (100000 : int) * x) (arith_prog_mat size) open smatrix def test := test_mat * test_mat def test2 := strassen 2 test_mat test_mat def test3 := strassen 3 test_mat test_mat def test4 := strassen 4 test_mat test_mat def test5 := strassen 5 test_mat test_mat def test6 := strassen 6 test_mat test_mat def test7 := strassen 7 test_mat test_mat set_option profiler true #eval test2 #eval test3 #eval test4 #eval test5 #eval test6 #eval test7 #eval test
Mario Carneiro (Feb 11 2019 at 06:18):
@Seul Baek
Seul Baek (Feb 11 2019 at 06:40):
On my computer I'm seeing improvements by a factor of 2-3. It's surprising how small things add up
Seul Baek (Feb 11 2019 at 06:50):
Is there a way to measure the execution time for #reduce
?
Mario Carneiro (Feb 11 2019 at 06:53):
sadly no. It's on my todo list (see: https://leanprover.zulipchat.com/#narrow/stream/113488-general/topic/lean.20community.20fork/near/157996991)
Mario Carneiro (Feb 11 2019 at 06:53):
does timeit
work?
Seul Baek (Feb 11 2019 at 07:06):
It reports a time, but I think it's giving the VM evaluation time (way too fast to be the kernel)
Last updated: Dec 20 2023 at 11:08 UTC