Zulip Chat Archive

Stream: general

Topic: Strassen's algorithm with int and znum


Seul Baek (Feb 11 2019 at 04:37):

I've implemented Strassen's algorithm in Lean (https://github.com/skbaek/strassen). The hope was that it will help with multiplication of large matrices, but its behavior is a bit strange.

1. When multiplying znum matrices I get a nice U-curve for (execution time / crossover point), where the optimum crossover point (in numbers of rows/columns) seems to be in the 32-128 range. But with int matrices the algorithm shows only negligible improvements over naive multiplication. Since Strassen's algorithm cuts down on multiplications at the expense of more additions/subtractions of matrices, I think this means that the cost of addition/subtraction (relative to that of multiplication) in int is higher than that in znum. Am I correct about this?

2. What's really puzzling is that multiplication of int matrices is faster than that of znum matrices, by about a factor of 5. At first I thought this is because matrix entries are not large enough, but the difference persists even when matrices are multiplied by constant factors of 1000-100000. The only time when znum outperforms int is when it involves a lot of additions of the form z+z that just require a bit shift. I thought znum was the better datatype for arithmetic with large numbers — I wonder if I'm doing something wrong here?

Simon Hudon (Feb 11 2019 at 05:15):

Are you doing the computation in kernel reductions or in the VM?

Mario Carneiro (Feb 11 2019 at 05:16):

vector.halve can be constructed in one pass

Mario Carneiro (Feb 11 2019 at 05:17):

As a kernel data structure, I would use balanced binary trees instead of vectors of length 2^k

Seul Baek (Feb 11 2019 at 05:17):

@Simon Hudon In the kernel

Seul Baek (Feb 11 2019 at 05:19):

@Mario Carneiro That's a good point, I should have used trees

Mario Carneiro (Feb 11 2019 at 05:20):

def big_vector (α : Type) :   Type
| 0 := α
| (n+1) := big_vector n × big_vector n

Mario Carneiro (Feb 11 2019 at 05:21):

I'm a bit worried about the kernel having to compute with the proofs

Mario Carneiro (Feb 11 2019 at 05:21):

since vectors have a proof component

Seul Baek (Feb 11 2019 at 05:22):

Although it can't be trees all the way down, since we have to resort to regular matrix multiplication at some crossover point > 1 for efficiency

Mario Carneiro (Feb 11 2019 at 05:22):

You can also implement regular matrix multiplication on these kind of matrices

Mario Carneiro (Feb 11 2019 at 05:23):

although perhaps the size being a power of 2 is a bit restrictive

Mario Carneiro (Feb 11 2019 at 05:23):

but I guess strassen has that built in so it should be fine

Mario Carneiro (Feb 11 2019 at 05:24):

You have more tupling and untupling than you need

Mario Carneiro (Feb 11 2019 at 05:25):

like there is no reason quadruple and double_cols should be uncurried

Seul Baek (Feb 11 2019 at 05:26):

The only reason they are uncurried is that it makes statements of correctness lemmas simpler, should I want to verify them

Seul Baek (Feb 11 2019 at 05:27):

Do the (un)tuplings significantly slow things down? In that case, I guess they should be curried

Mario Carneiro (Feb 11 2019 at 05:28):

Not sure by how much, but I expect so. I don't think I've identified any big problems yet, just linear slowdowns

Mario Carneiro (Feb 11 2019 at 05:31):

all your tests use #eval?

Seul Baek (Feb 11 2019 at 05:32):

@Mario Carneiro Yes, I used #eval

Mario Carneiro (Feb 11 2019 at 05:33):

well, it's no surprise int is fast then

Seul Baek (Feb 11 2019 at 05:38):

Oops... you're right, I forgot #eval is VM by default

Mario Carneiro (Feb 11 2019 at 06:17):

Here's an implementation using quad-trees. How does it compare to your numbers?

structure quad (α : Type*) := (a b c d : α)

def smatrix (α : Type*) :   Type*
| 0 := α
| (n+1) := quad (smatrix n)

namespace smatrix

def const {α} (c : α) :  k, smatrix α k
| 0 := c
| (k+1) := const k, const k, const k, const k

def map {α β} (f : α  β) :  {k}, smatrix α k  smatrix β k
| 0 a := f a
| (k+1) A₁₁, A₁₂, A₂₁, A₂₂ := map A₁₁, map A₁₂, map A₂₁, map A₂₂

def map₂ {α β γ} (f : α  β  γ) :  {k}, smatrix α k  smatrix β k  smatrix γ k
| 0     A B := f A B
| (k+1) A₁₁, A₁₂, A₂₁, A₂₂ B₁₁, B₁₂, B₂₁, B₂₂ :=
  map₂ A₁₁ B₁₁, map₂ A₁₂ B₁₂, map₂ A₂₁ B₂₁, map₂ A₂₂ B₂₂

instance {α} [has_add α] {k} : has_add (smatrix α k) := map₂ (+)
instance {α} [has_neg α] {k} : has_neg (smatrix α k) := map has_neg.neg
instance {α} [has_sub α] {k} : has_sub (smatrix α k) := map₂ has_sub.sub

def mul {α} [has_add α] [has_mul α] :
   {k : nat}, smatrix α k  smatrix α k  smatrix α k
| 0     A B := @has_mul.mul α _ A B
| (k+1) A₁₁, A₁₂, A₂₁, A₂₂ B₁₁, B₁₂, B₂₁, B₂₂ :=
  mul A₁₁ B₁₁ + mul A₁₂ B₂₁, mul A₁₁ B₁₂ + mul A₁₂ B₂₂,
   mul A₂₁ B₁₁ + mul A₂₂ B₂₁, mul A₂₁ B₁₂ + mul A₂₂ B₂₂

instance {α} [has_add α] [has_mul α] {k} : has_mul (smatrix α k) := mul

def strassen {α} [ring α] (t : nat) :  {k : nat}, smatrix α k  smatrix α k  smatrix α k
| 0 A B := A * B
| (k+1) A@⟨A₁₁, A₁₂, A₂₁, A₂₂ B@⟨B₁₁, B₁₂, B₂₁, B₂₂ :=
  if k < t
  then @has_mul.mul (smatrix α (k+1)) _ A B
  else
  let S₁ := A₂₁ + A₂₂, S₂ := S₁ - A₁₁, S₃ := A₁₁ - A₂₁, S₄ := A₁₂ - S₂ in
  let T₁ := B₁₂ - B₁₁, T₂ := B₂₂ - T₁, T₃ := B₂₂ - B₁₂, T₄ := T₂ - B₂₁ in
  let P₁ := strassen A₁₁ B₁₁,
      P₂ := strassen A₁₂ B₂₁,
      P₃ := strassen S₄ B₂₂,
      P₄ := strassen A₂₂ T₄,
      P₅ := strassen S₁ T₁,
      P₆ := strassen S₂ T₂,
      P₇ := strassen S₃ T₃ in
  let U₁ := P₁ + P₂,
      U₂ := P₁ + P₆,
      U₃ := U₂ + P₇,
      U₄ := U₂ + P₅,
      U₅ := U₄ + P₃,
      U₆ := U₃ - P₄,
      U₇ := U₃ + P₅ in
  U₁, U₅, U₆, U₇

end smatrix

def arith_prog_mat_core :  k : nat,   smatrix int k
| 0     n := (n:)
| (k+1) n := let t := 2^k in
  arith_prog_mat_core k n,
   arith_prog_mat_core k (n+t),
   arith_prog_mat_core k (n+t),
   arith_prog_mat_core k (n+2*t)

def arith_prog_mat (m : nat) : smatrix int m :=
arith_prog_mat_core m 0

def size : nat := 6

def test_mat := smatrix.map (λ x, (100000 : int) * x) (arith_prog_mat size)

open smatrix
def test  := test_mat * test_mat
def test2 := strassen 2 test_mat test_mat
def test3 := strassen 3 test_mat test_mat
def test4 := strassen 4 test_mat test_mat
def test5 := strassen 5 test_mat test_mat
def test6 := strassen 6 test_mat test_mat
def test7 := strassen 7 test_mat test_mat

set_option profiler true

#eval test2
#eval test3
#eval test4
#eval test5
#eval test6
#eval test7
#eval test

Mario Carneiro (Feb 11 2019 at 06:18):

@Seul Baek

Seul Baek (Feb 11 2019 at 06:40):

On my computer I'm seeing improvements by a factor of 2-3. It's surprising how small things add up

Seul Baek (Feb 11 2019 at 06:50):

Is there a way to measure the execution time for #reduce?

Mario Carneiro (Feb 11 2019 at 06:53):

sadly no. It's on my todo list (see: https://leanprover.zulipchat.com/#narrow/stream/113488-general/topic/lean.20community.20fork/near/157996991)

Mario Carneiro (Feb 11 2019 at 06:53):

does timeit work?

Seul Baek (Feb 11 2019 at 07:06):

It reports a time, but I think it's giving the VM evaluation time (way too fast to be the kernel)


Last updated: Dec 20 2023 at 11:08 UTC