Zulip Chat Archive

Stream: lean4

Topic: memoziation of strong recursion on nat


Joachim Breitner (May 10 2023 at 07:24):

I just defined a function Nat → Nat via recursion that calls itself on many smaller values (like fibonacci), and of course it gets very slow very quickly. In Haskell I’d define a lazy list to speed this up using memoize. What’s the usual way to do that in Lean? Is there maybe already an efficient

memo : ( n : Nat, (Fin n  a)  a)  (Nat  a)

somewhere, and if not, what would be the best way to implement it (in a way that I can prove it equal to the naive version)?

Floris van Doorn (May 10 2023 at 09:14):

I think you're looking for docs4#memoFix

Floris van Doorn (May 10 2023 at 09:15):

Oh, but you cannot prove this one equal to the naive version...

Joachim Breitner (May 10 2023 at 11:12):

This is what I came up with so far, which already helps a bit:

def memoVec :  {a}, ( n, Vector a n -> a) -> ( n, Vector a n)
 | _, _, 0 => Vector.nil
 | _, f, succ n =>
   let v := memoVec f n
   Vector.cons (f n v) v

def memo :  {a}, ( n : Nat, (Fin n  a)  a)  (Nat  a) := fun {a} f n =>
  have f' :  n, Vector a n -> a := fun n v =>
    f n (fun i, hi => Vector.get v n - succ i,
      Nat.sub_lt_self (Nat.zero_lt_succ _) (Nat.succ_le_of_lt hi)
    ⟩)
  Vector.get (memoVec f' (succ n)) 0

Is using #eval the fastest way to evaluate code, is or is there something “better”?
Will this execute natural numbers efficiently by default, or is it still using the unary representation?
(My code does arithmetic with Nat and Rat and the numbers get large.)

Jannis Limperg (May 10 2023 at 11:27):

Afaik #eval uses bignums for Nat and is the fastest available interpreter. But if you package your code as a program and compile it, you should get an additional (big?) speedup.

Kyle Miller (May 10 2023 at 12:57):

@Joachim Breitner Maybe this version would interest you. It uses an Array instead, which is a bit more efficient since it's a contiguous block of memory rather than a linked list, and also it makes the indexing calculations simpler since you can efficiently push to the end.

import Lean
import Mathlib.Data.Fin.Basic

/-- Arrays of a given size. -/
def SArray (α : Type _) (n : Nat) := {a : Array α // a.size = n}

namespace SArray

protected def push (a : SArray α n) (x : α) : SArray α (n + 1) :=
  a.1.push x, by rw [Array.size_push, a.2]⟩

protected def get (a : SArray α n) (i : Fin n) : α := a.1.get (a.2.symm  i)

protected def empty : SArray α 0 := Array.empty, rfl

end SArray

def memoVec (f : (n : Nat)  SArray α n  α) : (n : Nat)  SArray α n
  | 0 => .empty
  | n + 1 =>
    let v := memoVec f n
    v.push (f n v)

def memo (f : (n : Nat)  (Fin n  α)  α) (n : Nat) : α :=
  let f' (n : Nat) (v : SArray α n) : α := f n v.get
  (memoVec f' (n + 1)).get n

def fib (n : Nat) : Nat := memo (n := n) fun
  | 0, _ | 1, _ => 1
  | n + 2, f => f n + f (n + 1)

#eval fib 100
-- 573147844013817084101
#eval fib 1000
-- 70330367711422815821...5245323403501

Joachim Breitner (May 10 2023 at 13:40):

Thanks, that would have been my next question; better data structures. Very nice. Is Vector-push better than linear, or still linear?

Kyle Miller (May 10 2023 at 13:42):

Both SArray.push and Vector.push are (amortized) constant time (so long as there's no sharing of the SArray; if there's sharing, it's linear time to make a copy).

Kyle Miller (May 10 2023 at 13:43):

SArray.push is only amortized constant time unless you pre-allocate an array using docs4#Array.mkEmpty to get enough capacity to never need to re-allocate.

Joachim Breitner (May 10 2023 at 13:44):

Thanks!

Kyle Miller (May 10 2023 at 13:46):

If you want to incorporate that, what you could do is switch SArray.empty to

protected def empty (capacity : Nat := 0) : SArray α 0 := Array.mkEmpty capacity, rfl

and then pass in a capacity:

def memoVec (capacity : Nat) (f : (n : Nat)  SArray α n  α) : (n : Nat)  SArray α n
  | 0 => .empty capacity
  | n + 1 =>
    let v := memoVec capacity f n
    v.push (f n v)

def memo (f : (n : Nat)  (Fin n  α)  α) (n : Nat) : α :=
  let f' (n : Nat) (v : SArray α n) : α := f n v.get
  (memoVec (n + 1) f' (n + 1)).get n

Joachim Breitner (May 10 2023 at 13:48):

It may be a bit faster, but something is still slower than it ought to be (I think, maybe I am just expecting too much), if I use the memo like this:

def vf (p : ) (n : ) (r : Fin n -> ) :  :=
  if hn : n = 0 then 0 else
  /- all tails -/
  (1-p)^n * r n-1, by aesop (add safe Nat.sub_lt, safe Nat.zero_lt_of_ne_zero)⟩ +
   j : Fin n,
    /- j < n tails, so n-j heads -/
    Nat.choose n j * p^(n-j) * (1-p)^(j : ) *
    /- Pick best next step -/
    Finset.sup' (Finset.range (n-j))
      (by simp)
      (fun i => (1+i) + r n-(1+i), by aesop (add safe Nat.sub_lt, safe Nat.zero_lt_of_ne_zero)⟩)

def fast_v (n : ) (p : ) :  :=
  memo (vf p) n

#eval fast_v 15 0.5

If I increase the 15, to, say, 20, it quickly gets much slower.

Kyle Miller (May 10 2023 at 13:51):

I wouldn't be surprised if docs4#Nat.choose were slow (take a look at how it's defined)

Kyle Miller (May 10 2023 at 13:52):

It'd be worth replacing Nat.choose n j with 1 to see if things speed up

Henrik Böving (May 10 2023 at 13:52):

We could csimp optimize that I guess?

Joachim Breitner (May 10 2023 at 13:55):

Is csimp a mechanism to tell the system to use a different, faster implementation for a given function?
(If so then you just answered the question I’d have once I have succesfully defined the faster version of the plain recursive v function :-))

Henrik Böving (May 10 2023 at 13:55):

Yes

Joachim Breitner (May 10 2023 at 13:55):

Cool :-)

Eric Wieser (May 10 2023 at 13:55):

Yes; but it requires the signature of the two functions to be identical

Kyle Miller (May 10 2023 at 13:56):

which is fine here (Eric is probably referencing for example that more complicated functions only have more efficient implementations with additional typeclass assumptions)

Joachim Breitner (May 10 2023 at 14:01):

Indeed, these three magic lines greatly speed it up:

def fast_choose n k := Nat.descFactorial n k / Nat.factorial k
@[csimp] lemma choose_eq_fast_choose : Nat.choose = fast_choose :=
  funext (fun _ => funext Nat.choose_eq_descFactorial_div_factorial _)

Kyle Miller (May 10 2023 at 14:04):

A bit better of an algorithm is probably

def Nat.fast_choose (n k : Nat) : Nat := Id.run do
  let mut r : Nat := 1
  for i in [0 : k] do
    r := (r * (n - i)) / (i + 1)
  return r

since it involves smaller intermediate values, though your fast_choose is great since it involves not having to prove anything new :smile:

Kyle Miller (May 10 2023 at 14:04):

I don't actually know if it's better though, since it's possible that doing a single division with large numbers is better than doing k divisions :shrug:

Johan Commelin (May 10 2023 at 14:08):

I don't know a thing about algorithms, but I had naively thought that doing the recursive (but memoized!) version of choose would be faster than all those multiplications and divisions.

Johan Commelin (May 10 2023 at 14:10):

But I guess O(n * k) additions can also become pretty slow pretty fast.

Kyle Miller (May 10 2023 at 14:12):

Another thing to consider is that the more that can be offloaded to individual nat operations, the more is being handled by some optimized C code from the GMP library

Johan Commelin (May 10 2023 at 14:12):

And I guess the same is true for Int operations, but not for Rat, maybe?

Henrik Böving (May 10 2023 at 14:13):

It is not necessarily the additions. Each function call carries an inherent overhead with it unless it is possible to perform tail recursion (which is not the case in the naive version of the Nat.choose function and also not if you memoize it) so calling a function does not only perform an addition it also performs:

  1. Some bookkeeping of CPU state in memory
  2. A lookup in memory for whether an input pair has been memoized yet
  3. the actual computation
  4. More recursive calls
  5. once all those recursive calls have finished we restore the part of the CPU state that we book kept in memory before
  6. now we can actually return

And on the other hand the for loop should I think be optimized into basically precisely a loop via tail recursion elimination and inlining such that it can for the most part omit step 1 2 4 5 and 6 and just compute.

Now I can of course not tell (without further profiling) if this is the reason that it is slower but it certainly plays a part.

Johan Commelin (May 10 2023 at 14:17):

Thanks! That's helpful.

Joachim Breitner (May 10 2023 at 15:31):

My lean has become very rusty… @Kyle Miller , can you maybe give me a hint about how to prove a lemma like this:

/-- Arrays of a given size, H'T Kyle Miller -/
def SArray (α : Type _) (n : Nat) := {a : Array α // a.size = n}

namespace SArray

protected def push {α n} (a : SArray α n) (x : α) : SArray α (n + 1) :=
  a.1.push x, by rw [Array.size_push, a.2]⟩

protected def get {α n} (a : SArray α n) (i : Fin n) : α := a.1.get (a.2.symm  i)

theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < succ n) :
  (a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get i, h else x) := by
  cases' a with v hv
  simp [SArray.get, SArray.push]
  rw [Array.get_push]
  cases (i < n) -- doesn’t work

Or, more generally, if I have a dependent if h : p then … else … in the goal, how do I do case analysis on that?

Mauricio Collares (May 10 2023 at 15:35):

Does by_cases instead of cases work?

Kyle Miller (May 10 2023 at 15:36):

Lean 4's split tactic does case analysis on an if

Kyle Miller (May 10 2023 at 15:39):

That rw I put into get that rewrites i itself makes things a little tricky. Here's an alternative:

protected def get (a : SArray α n) (i : Fin n) : α := a.1.get i, a.2.symm  i.2

theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < n + 1) :
    (a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get i, h else x) := by
  simp [SArray.get, SArray.push, Array.get_push, a.2]

Joachim Breitner (May 10 2023 at 20:04):

Very nice:

@[csimp]
lemma v_fast_v : v = fast_v := by
  apply funext; intro n
  apply funext; intro p
  rw [fast_v, memo_spec, v_fix_vf]

#eval v 50 0.5

(It seems I can’t use ext n p here because ext seems to overshoot and applies Rat.ext or something like that.)

Mario Carneiro (May 10 2023 at 20:08):

you can probably use funext n p

Mario Carneiro (May 10 2023 at 20:26):

also ext has a : 2 argument you can add if it goes too far

Joachim Breitner (May 11 2023 at 05:11):

Ah, that's good to know. I was trying ext 2. I thought I read the tooltip; if it's not there I'll maybe PR a doc improvement.

Joachim Breitner (May 11 2023 at 06:48):

Kyle Miller said:

That rw I put into get that rewrites i itself makes things a little tricky. Here's an alternative:

protected def get (a : SArray α n) (i : Fin n) : α := a.1.get i, a.2.symm  i.2

theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < n + 1) :
    (a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get i, h else x) := by
  simp [SArray.get, SArray.push, Array.get_push, a.2]

This proof is surprisingly brittle. If I remove the unrelated looking import Mathlib.Data.Nat.Factorial.Basic it breaks. Too bad there is no squeeze_simp in lean4 yet; staring at the proof term didn’t yet tell me which lemma is lost.

Maybe it’s not some lemma, but some other simp-affecting setting that disappears when not importing that file?

Heather Macbeth (May 11 2023 at 06:52):

There is simp? as an alternative.

Joachim Breitner (May 11 2023 at 06:55):

Ah, using trial and error with import I learned that
import Std.Data.Array.Lemmas is needed; import Std.Data.Array.Init.Lemmas is not enough.

Thanks, TIL simp?! This shows that the missing lemma is Array.get_eq_getElem from that module. All izz well.

Mario Carneiro (May 11 2023 at 07:21):

also FYI if you hover on Array.get_eq_getElem it should mention that it is defined in Std.Data.Array.Lemmas

Joachim Breitner (May 11 2023 at 07:37):

Right, but (without simp?) I didn’t even know which lemma simp was picking up. So simp? is very useful here, thanks!

Joachim Breitner (May 11 2023 at 07:38):

I’ve made the NatMemo code independent of mathlib (only std4), and with an example it looks like this:
https://gist.github.com/nomeata/b0929f2503fcab4d35717e92b5ba5e58
Is this something worth contributing to std4?

Joachim Breitner (May 11 2023 at 07:55):

Joachim Breitner said:

Indeed, these three magic lines greatly speed it up:

def fast_choose n k := Nat.descFactorial n k / Nat.factorial k
@[csimp] lemma choose_eq_fast_choose : Nat.choose = fast_choose :=
  funext (fun _ => funext Nat.choose_eq_descFactorial_div_factorial _)

PR’ed in https://github.com/leanprover-community/mathlib4/pull/3915; not sure if such csimp tweaks should be in mathlib though.

Joachim Breitner (May 12 2023 at 15:29):

A function that’s defined recursively with strong recursion on nat gets elaborated to a call to WellFounded.fix with suitable parameters, one of which is the F which we can also pass to the memo function above.

So I think it should be possible to automate the process of extracting that “functorial”, defining the _fast variant using memo, and prove the csimp lemma, so in the end instead of

def slow (n : Nat) : Nat :=
  1 + List.foldl (fun a i => a + (if _ : i<n then slow i else 0)) 0 (List.range n)

def fast (n : Nat) : Nat :=
  NatMemo.memo (fun n r =>
    1 + List.foldl (fun a i => a + (if h : i<n then r i h else 0)) 0 (List.range n)
  ) n

@[csimp]
theorem slow_is_fast: slow = fast := by
  apply NatMemo.memo_spec
  intro n
  rw [slow]

one can just write

[@memo_csimp]
def slow (n : Nat) : Nat :=
  1 + List.foldl (fun a i => a + (if _ : i<n then slow i else 0)) 0 (List.range n)

I started coding this, and got the definition’s RHS via getConstInfoDefn and .value of type Expr. But getting hold of the 5th argument by just manual pattern matching with Expr.app is quite tedious. Is there a better way to implement “I expect a term of the following shape”? Like we can do on Syntax very nicely using the `(…)
feature?

Joachim Breitner (May 12 2023 at 15:37):

Ah, Lean.Expr.getAppFn may help

Joachim Breitner (May 12 2023 at 17:12):

It works!

@[memo]
def slow2 (n : Nat) : Nat :=
  1 + List.foldl (fun a i => a + (if _ : i<n then slow2 i else 0)) 0 (List.range n)

does all the things automatically! At least in this very particular case…

Joachim Breitner (May 12 2023 at 17:12):

Or here is another obvious example:

@[memo]
def fib : Nat  Nat
  | 0 | 1 => 1
  | n + 2 => fib n + fib (n + 1)
termination_by fib n => n

#eval fib 100 --fast

Joachim Breitner (May 13 2023 at 06:14):

Is there a dependent version of Array around somewhere? (With value types depending on the indices?)

Joachim Breitner (May 13 2023 at 07:43):

I have put the NatMemo code as a lakeified library on https://github.com/nomeata/lean4-memo-nat
there is example code at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat/Demo.lean
The (probably usable) main memo code is at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat.lean
And the (certainly experimental) attribute is in https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat/Attr.lean

I’d love to get some critical feedback of the attribute code; this is my first and probably very naive excursion into Lean metaprogramming :-)

Joachim Breitner (May 13 2023 at 07:43):

Also, is there a repository or at least list of lean4 libraries somewhere, for better discoverability?
And is there a “best practice” github repo setup that I can cargo cult, with CI and documentation pushing to github pages?

James Gallicchio (May 17 2023 at 17:27):

Joachim Breitner said:

Is there a dependent version of Array around somewhere? (With value types depending on the indices?)

Not that I know of. Back when I was making alternative Array primitives, I had played around with an implementation that was modeled by a dependent function from the indices. But I threw it out because I didn't have a use case and it was a pain to work with :joy:

Joachim Breitner (May 17 2023 at 17:42):

I only have an indirect usecase: I had a non-dependent recursive function, which I sped up using memoization using Array, and the technique applies to any use of WellFounded.fix, and that allows dependent functions, so I had an inclination to generalize my memoization library as well… but since _that_ generalization has no use case yet, it’s not at all pressing.

François G. Dorais (May 20 2023 at 06:12):

@Joachim Breitner A quick hack is to use something like this:

inductive Any.{u} : Type u
| mk {α : Sort u} : α  Any

protected abbrev Any.Sort : Any  Sort _
| @mk α _ => α

protected abbrev Any.val : (a : Any)  a.Sort
| mk x => x

def test : Array Any := #[.mk 1, .mk "Hello", .mk true]

#eval test[0].val -- 1
#eval test[1].val -- "Hello"
#eval test[2].val -- true

Joachim Breitner (May 20 2023 at 17:17):

That seems to lose the type information, though, so probably can’t be (safely) used to memoize a function of dependent type ∀ n, C n

François G. Dorais (May 21 2023 at 15:37):

The type info is still there, as the three evals at the end show. For a specific dependent type, it's simpler to just wrap your function values in a Sigma type.

Joachim Breitner (May 22 2023 at 16:03):

Ah, I think I now get it. I can put values of Any into the Vectors, and separately keep around the information that a[n].Sort = C i for some C : Nat → Sort _!

This indeed allows me to define

namespace DArray

protected def push {n C} (a : DArray n C) (x : C n) : DArray (n + 1) C :=
  a.1.push (Any.mk x), by     

protected def get {n C} (a : DArray n C) (i : Fin n) : C i :=
  a.2 i  (a.1.get i).val

protected theorem get_push {n C} (a : DArray n C) (x : C n) (i : Nat) (hi : i < n + 1) :
    (a.push x).get i, hi =
      if h : i < n
      then a.get i, h
      else (Nat.le_antisymm (Nat.le_of_lt_succ hi) (Nat.le_of_not_lt h)  x : C i) := by 

protected def empty {C} (cap : Nat) : DArray 0 C := SArray.empty cap, λ i => Fin.elim0 i

end DArray

and then generalize the memo code at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat.lean to dependent recursion.

(I didn’t manage to conclude the proof of get_push because the get in the way, unfortunately, so I’ll park it at https://github.com/nomeata/lean4-memo-nat/pull/2 for now, until someone actually needs this.)

Joachim Breitner (Jun 04 2023 at 15:56):

Tried some further. It seems I am hitting universe level issues: Playing around with

inductive Any.{u} : Type u
| mk {α : Sort u} : α  Any

protected abbrev Any.Sort : Any  Sort _
| @mk α _ => α

protected abbrev Any.val : (a : Any)  a.Sort
| mk x => x

structure DArray.{u} (n : Nat) (C : Nat  Type u) : Type (u+1):=
  arr : Array Any.{u+1}
  size_eq : arr.size = n
  types :  (i : Fin arr.size), (arr.get i).Sort = C i

protected def get.{u} {n} {C : Nat  Type u} (a : DArray n C) (i : Fin n) : C i := by
  rcases a with a, rfl, types
  exact (types i).symm  a.get i

I get

type mismatch
  Array.get a i
has type
  Any : Type (u + 1)
but is expected to have type
  Any.Sort (Array.get a i) : Type u

Is this approach, using a subtype, doomed to fail or is there a way?

Joachim Breitner (Jun 04 2023 at 15:59):

(The explicit universe levels in the structure DArray are the inferred ones.)

Joachim Breitner (Jun 04 2023 at 16:08):

Oh, here is some random staring lean down that worked:

protected def get.{u} {n} {C : Nat  Type u} (a : DArray n C) (i : Fin n) : C i := by
  rcases a with a, rfl, types
  have h := types i
  revert h
  generalize Array.get a i = x
  cases x with | _ x =>
  rintro rfl
  exact x

So it works, and can maybe be made pretty (because in the current form proving things about it will be quite ugly)

Mario Carneiro (Jun 04 2023 at 16:08):

it still has a universe bump problem though

Joachim Breitner (Jun 04 2023 at 16:10):

You mean the definition of DArray? Is that an avoidable one?

Mario Carneiro (Jun 04 2023 at 16:16):

here's one way to do it without a universe bump, although it wastes some memory unless lean gets support for Erased:

structure DArray (n : Nat) (C : Nat  Type u) : Type u :=
  arr : Array (Σ i, C i)
  size_eq : arr.size = n
  types :  (i : Fin arr.size), (arr.get i).1 = i

protected def DArray.get {n} {C : Nat  Type u} (a : DArray n C) (i : Fin n) : C i := by
  refine (?b : i.1 = _)  (a.arr.get i, ?a⟩).2
  case a => exact a.2.symm  i.2
  case b => exact (a.3 i, ?a⟩).symm  rfl

Joachim Breitner (Jun 04 2023 at 20:11):

Thanks! I was thinking about that variant earlier, but it’s a bit disappointing with the extra (unerased) field.
But maybe I’ll play some more with this, thanks!

James Gallicchio (Jun 26 2023 at 08:19):

I ran into something I need DArray for, so I put together an implementation to PR to Std: https://github.com/leanprover/std4/pull/166

It relies on a TypeErased type for values whose type is noncomputable. Unsure if the interface is sound, would appreciate a second set of eyes on it.

Both the interface and the implementations are deep in cast hell, so suggestions are highly appreciated, but I think most of the casting cannot be avoided.

James Gallicchio (Jun 26 2023 at 08:21):

(The TypeErased is essentially equivalent to the Any and Σ i, C i types above, but has no runtime overhead because it is implemented with unsafeCasts)

Joachim Breitner (Oct 22 2023 at 14:33):

Small update: I finished support for memoizing functions recursing on nat also when they have a dependent result type, and the @[memo] attribute now also works when the function is elaborated to use Nat.brecOn (not just WellFounded.fix):

Here an example from Demo.lean that calculates Pascal’s triangle, and gets much faster with @[memo]:

@[memo]
def pascal : (i : Nat)  SArray Nat i
  | 0 => SArray.empty
  | n + 1 => SArray.ofFn (n + 1) fun i =>
      pad_left 1 (pascal n) i + pad_right (pascal n) 1 i

(It would also get much faster if I had used a let there, but that’s not the point :-)).

I am using the variant of dependent arrays using the Any type that Kyle suggested, and am living with the universe bump (DArray); would be interesting to know if this can be avoided (with the same efficient representation).

On top of that, I also have dependent arrays with the length as an index (DSArray), which incidentally is isomorphic to Nat.below. I need this conversion when replacing Nat.brecEq with my memo combinator, although I fear that this is rather expensive. Maybe I should just use the Nat.below data structure directly in these cases.

I doubt that this very useful, but it is a useful learning exercise for me…


Last updated: Dec 20 2023 at 11:08 UTC