Zulip Chat Archive

Stream: general

Topic: arrays and performance


Johan Commelin (Apr 09 2020 at 21:43):

I've never worried about performance of Lean code, and I have no intuition for it at all.
Keeley wrote the following code:

meta def map_copy_aux {α : Type u} {β : Type v} {n m : } (f : α  β) :
    array n α  array m β  array m β
| r x y := if h : r < n  r < m then
             let fn : fin n := r, and.elim_left h in
             let fm : fin m := r, and.elim_right h in
             map_copy_aux (r + 1) x $ y.write fm (f $ x.read fn)
           else y

meta def map_copy {α : Type u} {β : Type v} {n m : } (x : array n α) (y : array m β) (f : α  β) :
  array m β :=
map_copy_aux f 0 x y

and I massaged it into

def map_copy_aux (f : α  β) : Π (r : ), r < n  r < m  array n α  array m β  array m β
| 0     hn hm x y := y.write 0, hm (f $ x.read 0, hn)
| (r+1) hn hm x y :=
  (map_copy_aux r (lt.trans r.lt_succ_self hn) (lt.trans r.lt_succ_self hm) x y)
  .write r+1, hm (f $ x.read r+1, hn)

def map_copy : Π {n m : } (x : array n α) (y : array m β) (f : α  β), array m β
| n 0 x y f := y
| 0 m x y f := y
| (n+1) (m+1) x y f := map_copy_aux f (min n m)
  (lt_of_le_of_lt (min_le_left n m) n.lt_succ_self)
  (lt_of_le_of_lt (min_le_right n m) m.lt_succ_self)
  x y

The benefit is that the meta is gone. But did my rewrite damage the performance?

Reid Barton (Apr 09 2020 at 21:50):

Normally this would be bad but I think @Gabriel Ebner said the Lean 3 VM doesn't implement tail recursion, so maybe it doesn't matter.

Johan Commelin (Apr 09 2020 at 21:52):

But I have the same type of recursion as Keeley's code, right?

Reid Barton (Apr 09 2020 at 21:53):

Keeley's looks tail recursive to me, yours isn't because the last thing you do is a write, not a recursive call.

Johan Commelin (Apr 09 2020 at 21:54):

Ooh, you are right

Chris Hughes (Apr 09 2020 at 21:55):

You can make the first one not meta quite easily by saying m - r is decreasing.

Johan Commelin (Apr 09 2020 at 21:56):

Hmm, but I don't know which wf.* I need to invoke for that. The error message said that it couldn't figure out how to prove r + 1 < r... which is clearly not the thing it should be trying.

Chris Hughes (Apr 09 2020 at 22:12):

import data.nat.basic

universes u v

def map_copy_aux {α : Type u} {β : Type v} {n m : } (f : α  β) :
    array n α  array m β  array m β
| r x y := if h : r < n  r < m then
             let fn : fin n := r, and.elim_left h in
             let fm : fin m := r, and.elim_right h in
             have wf : m - (r + 1) < m - r,
               from (nat.sub_lt_sub_left_iff h.2).2 (nat.lt_succ_self _),
             map_copy_aux (r + 1) x $ y.write fm (f $ x.read fn)
           else y
using_well_founded {rel_tac := λ _ _, `[exact ⟨_, measure_wf (λ a, m - a.1)]}

Chris Hughes (Apr 09 2020 at 22:16):

import free version

universes u v

def map_copy_aux {α : Type u} {β : Type v} {n m : } (f : α  β) :
    array n α  array m β  array m β
| r x y := if h : r < n  r < m then
             let fn : fin n := r, and.elim_left h in
             let fm : fin m := r, and.elim_right h in
             have wf : m - (r + 1) < m - r,
               from nat.lt_of_succ_le $ by rw [ nat.succ_sub h.2, nat.succ_sub_succ],
             map_copy_aux (r + 1) x $ y.write fm (f $ x.read fn)
           else y
using_well_founded {rel_tac := λ _ _, `[exact ⟨_, measure_wf (λ a, m - a.1)]}

Mario Carneiro (Apr 09 2020 at 22:26):

Using well founded recursion like this is the best option for VM performance because the VM bytecode compiler ignores the messy wf induction produced for the benefit of lean by the equation compiler, and just blindly uses the equation as it would with the meta def

Mario Carneiro (Apr 09 2020 at 22:27):

that is, Chris's version should produce exactly the same code as Keeley's

Reid Barton (Apr 09 2020 at 22:37):

Thanks, that was going to be my next question.

Johan Commelin (Apr 10 2020 at 05:28):

@Chris Hughes Thanks a lot!

Johan Commelin (Apr 10 2020 at 05:36):

The monadic version is

def mmap_copy_aux (f : α  k β) :   array n α  array m β  k (array m β)
| r x y := do if h : r < n  r < m then do
                let fn : fin n := r, and.elim_left h,
                let fm : fin m := r, and.elim_right h,
                y  y.write fm <$> f (x.read fn),
                have wf : m - (r + 1) < m - r,
                  from nat.lt_of_succ_le $ by rw [ nat.succ_sub h.2, nat.succ_sub_succ],
                mmap_copy_aux (r + 1) x y
              else return y
using_well_founded {rel_tac := λ _ _, `[exact ⟨_, measure_wf (λ a, m - a.1)]}

Can I deduce the "regular" version by just applying the monadic version with k = id? Or would that still take a (minor) performance hit? I would hope that the compiler is smart enough to get rid of the ids...

Reid Barton (Apr 10 2020 at 13:32):

A real compiler should be able to do that but I don't think Lean 3 has a real compiler.
Actually, without that inlining, your function is no longer tail recursive so it might be more than a minor performance hit (though that depends on the relative costs of things, especially f)


Last updated: Dec 20 2023 at 11:08 UTC