Zulip Chat Archive
Stream: lean4
Topic: memoziation of strong recursion on nat
Joachim Breitner (May 10 2023 at 07:24):
I just defined a function Nat → Nat
via recursion that calls itself on many smaller values (like fibonacci), and of course it gets very slow very quickly. In Haskell I’d define a lazy list to speed this up using memoize. What’s the usual way to do that in Lean? Is there maybe already an efficient
memo : (∀ n : Nat, (Fin n → a) → a) → (Nat → a)
somewhere, and if not, what would be the best way to implement it (in a way that I can prove it equal to the naive version)?
Floris van Doorn (May 10 2023 at 09:14):
I think you're looking for docs4#memoFix
Floris van Doorn (May 10 2023 at 09:15):
Oh, but you cannot prove this one equal to the naive version...
Joachim Breitner (May 10 2023 at 11:12):
This is what I came up with so far, which already helps a bit:
def memoVec : ∀ {a}, (∀ n, Vector a n -> a) -> (∀ n, Vector a n)
| _, _, 0 => Vector.nil
| _, f, succ n =>
let v := memoVec f n
Vector.cons (f n v) v
def memo : ∀ {a}, (∀ n : Nat, (Fin n → a) → a) → (Nat → a) := fun {a} f n =>
have f' : ∀ n, Vector a n -> a := fun n v =>
f n (fun ⟨i, hi⟩ => Vector.get v ⟨n - succ i,
Nat.sub_lt_self (Nat.zero_lt_succ _) (Nat.succ_le_of_lt hi)
⟩)
Vector.get (memoVec f' (succ n)) 0
Is using #eval
the fastest way to evaluate code, is or is there something “better”?
Will this execute natural numbers efficiently by default, or is it still using the unary representation?
(My code does arithmetic with Nat
and Rat
and the numbers get large.)
Jannis Limperg (May 10 2023 at 11:27):
Afaik #eval
uses bignums for Nat
and is the fastest available interpreter. But if you package your code as a program and compile it, you should get an additional (big?) speedup.
Kyle Miller (May 10 2023 at 12:57):
@Joachim Breitner Maybe this version would interest you. It uses an Array
instead, which is a bit more efficient since it's a contiguous block of memory rather than a linked list, and also it makes the indexing calculations simpler since you can efficiently push to the end.
import Lean
import Mathlib.Data.Fin.Basic
/-- Arrays of a given size. -/
def SArray (α : Type _) (n : Nat) := {a : Array α // a.size = n}
namespace SArray
protected def push (a : SArray α n) (x : α) : SArray α (n + 1) :=
⟨a.1.push x, by rw [Array.size_push, a.2]⟩
protected def get (a : SArray α n) (i : Fin n) : α := a.1.get (a.2.symm ▸ i)
protected def empty : SArray α 0 := ⟨Array.empty, rfl⟩
end SArray
def memoVec (f : (n : Nat) → SArray α n → α) : (n : Nat) → SArray α n
| 0 => .empty
| n + 1 =>
let v := memoVec f n
v.push (f n v)
def memo (f : (n : Nat) → (Fin n → α) → α) (n : Nat) : α :=
let f' (n : Nat) (v : SArray α n) : α := f n v.get
(memoVec f' (n + 1)).get n
def fib (n : Nat) : Nat := memo (n := n) fun
| 0, _ | 1, _ => 1
| n + 2, f => f n + f (n + 1)
#eval fib 100
-- 573147844013817084101
#eval fib 1000
-- 70330367711422815821...5245323403501
Joachim Breitner (May 10 2023 at 13:40):
Thanks, that would have been my next question; better data structures. Very nice. Is Vector-push better than linear, or still linear?
Kyle Miller (May 10 2023 at 13:42):
Both SArray.push
and Vector.push
are (amortized) constant time (so long as there's no sharing of the SArray
; if there's sharing, it's linear time to make a copy).
Kyle Miller (May 10 2023 at 13:43):
SArray.push
is only amortized constant time unless you pre-allocate an array using docs4#Array.mkEmpty to get enough capacity to never need to re-allocate.
Joachim Breitner (May 10 2023 at 13:44):
Thanks!
Kyle Miller (May 10 2023 at 13:46):
If you want to incorporate that, what you could do is switch SArray.empty
to
protected def empty (capacity : Nat := 0) : SArray α 0 := ⟨Array.mkEmpty capacity, rfl⟩
and then pass in a capacity:
def memoVec (capacity : Nat) (f : (n : Nat) → SArray α n → α) : (n : Nat) → SArray α n
| 0 => .empty capacity
| n + 1 =>
let v := memoVec capacity f n
v.push (f n v)
def memo (f : (n : Nat) → (Fin n → α) → α) (n : Nat) : α :=
let f' (n : Nat) (v : SArray α n) : α := f n v.get
(memoVec (n + 1) f' (n + 1)).get n
Joachim Breitner (May 10 2023 at 13:48):
It may be a bit faster, but something is still slower than it ought to be (I think, maybe I am just expecting too much), if I use the memo
like this:
def vf (p : ℚ) (n : ℕ) (r : Fin n -> ℚ) : ℚ :=
if hn : n = 0 then 0 else
/- all tails -/
(1-p)^n * r ⟨n-1, by aesop (add safe Nat.sub_lt, safe Nat.zero_lt_of_ne_zero)⟩ +
∑ j : Fin n,
/- j < n tails, so n-j heads -/
Nat.choose n j * p^(n-j) * (1-p)^(j : ℕ) *
/- Pick best next step -/
Finset.sup' (Finset.range (n-j))
(by simp)
(fun i => (1+i) + r ⟨n-(1+i), by aesop (add safe Nat.sub_lt, safe Nat.zero_lt_of_ne_zero)⟩)
def fast_v (n : ℕ) (p : ℚ) : ℚ :=
memo (vf p) n
#eval fast_v 15 0.5
If I increase the 15, to, say, 20, it quickly gets much slower.
Kyle Miller (May 10 2023 at 13:51):
I wouldn't be surprised if docs4#Nat.choose were slow (take a look at how it's defined)
Kyle Miller (May 10 2023 at 13:52):
It'd be worth replacing Nat.choose n j
with 1
to see if things speed up
Henrik Böving (May 10 2023 at 13:52):
We could csimp optimize that I guess?
Joachim Breitner (May 10 2023 at 13:55):
Is csimp
a mechanism to tell the system to use a different, faster implementation for a given function?
(If so then you just answered the question I’d have once I have succesfully defined the faster version of the plain recursive v
function :-))
Henrik Böving (May 10 2023 at 13:55):
Yes
Joachim Breitner (May 10 2023 at 13:55):
Cool :-)
Eric Wieser (May 10 2023 at 13:55):
Yes; but it requires the signature of the two functions to be identical
Kyle Miller (May 10 2023 at 13:56):
which is fine here (Eric is probably referencing for example that more complicated functions only have more efficient implementations with additional typeclass assumptions)
Joachim Breitner (May 10 2023 at 14:01):
Indeed, these three magic lines greatly speed it up:
def fast_choose n k := Nat.descFactorial n k / Nat.factorial k
@[csimp] lemma choose_eq_fast_choose : Nat.choose = fast_choose :=
funext (fun _ => funext Nat.choose_eq_descFactorial_div_factorial _)
Kyle Miller (May 10 2023 at 14:04):
A bit better of an algorithm is probably
def Nat.fast_choose (n k : Nat) : Nat := Id.run do
let mut r : Nat := 1
for i in [0 : k] do
r := (r * (n - i)) / (i + 1)
return r
since it involves smaller intermediate values, though your fast_choose
is great since it involves not having to prove anything new :smile:
Kyle Miller (May 10 2023 at 14:04):
I don't actually know if it's better though, since it's possible that doing a single division with large numbers is better than doing k
divisions :shrug:
Johan Commelin (May 10 2023 at 14:08):
I don't know a thing about algorithms, but I had naively thought that doing the recursive (but memoized!) version of choose
would be faster than all those multiplications and divisions.
Johan Commelin (May 10 2023 at 14:10):
But I guess O(n * k)
additions can also become pretty slow pretty fast.
Kyle Miller (May 10 2023 at 14:12):
Another thing to consider is that the more that can be offloaded to individual nat operations, the more is being handled by some optimized C code from the GMP library
Johan Commelin (May 10 2023 at 14:12):
And I guess the same is true for Int
operations, but not for Rat
, maybe?
Henrik Böving (May 10 2023 at 14:13):
It is not necessarily the additions. Each function call carries an inherent overhead with it unless it is possible to perform tail recursion (which is not the case in the naive version of the Nat.choose function and also not if you memoize it) so calling a function does not only perform an addition it also performs:
- Some bookkeeping of CPU state in memory
- A lookup in memory for whether an input pair has been memoized yet
- the actual computation
- More recursive calls
- once all those recursive calls have finished we restore the part of the CPU state that we book kept in memory before
- now we can actually return
And on the other hand the for loop should I think be optimized into basically precisely a loop via tail recursion elimination and inlining such that it can for the most part omit step 1 2 4 5 and 6 and just compute.
Now I can of course not tell (without further profiling) if this is the reason that it is slower but it certainly plays a part.
Johan Commelin (May 10 2023 at 14:17):
Thanks! That's helpful.
Joachim Breitner (May 10 2023 at 15:31):
My lean has become very rusty… @Kyle Miller , can you maybe give me a hint about how to prove a lemma like this:
/-- Arrays of a given size, H'T Kyle Miller -/
def SArray (α : Type _) (n : Nat) := {a : Array α // a.size = n}
namespace SArray
protected def push {α n} (a : SArray α n) (x : α) : SArray α (n + 1) :=
⟨a.1.push x, by rw [Array.size_push, a.2]⟩
protected def get {α n} (a : SArray α n) (i : Fin n) : α := a.1.get (a.2.symm ▸ i)
theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < succ n) :
(a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get ⟨i, h⟩ else x) := by
cases' a with v hv
simp [SArray.get, SArray.push]
rw [Array.get_push]
cases (i < n) -- doesn’t work
Or, more generally, if I have a dependent if h : p then … else …
in the goal, how do I do case analysis on that?
Mauricio Collares (May 10 2023 at 15:35):
Does by_cases
instead of cases
work?
Kyle Miller (May 10 2023 at 15:36):
Lean 4's split
tactic does case analysis on an if
Kyle Miller (May 10 2023 at 15:39):
That rw I put into get
that rewrites i
itself makes things a little tricky. Here's an alternative:
protected def get (a : SArray α n) (i : Fin n) : α := a.1.get ⟨i, a.2.symm ▸ i.2⟩
theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < n + 1) :
(a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get ⟨i, h⟩ else x) := by
simp [SArray.get, SArray.push, Array.get_push, a.2]
Joachim Breitner (May 10 2023 at 20:04):
Very nice:
@[csimp]
lemma v_fast_v : v = fast_v := by
apply funext; intro n
apply funext; intro p
rw [fast_v, memo_spec, v_fix_vf]
#eval v 50 0.5
(It seems I can’t use ext n p
here because ext
seems to overshoot and applies Rat.ext
or something like that.)
Mario Carneiro (May 10 2023 at 20:08):
you can probably use funext n p
Mario Carneiro (May 10 2023 at 20:26):
also ext
has a : 2
argument you can add if it goes too far
Joachim Breitner (May 11 2023 at 05:11):
Ah, that's good to know. I was trying ext 2
. I thought I read the tooltip; if it's not there I'll maybe PR a doc improvement.
Joachim Breitner (May 11 2023 at 06:48):
Kyle Miller said:
That rw I put into
get
that rewritesi
itself makes things a little tricky. Here's an alternative:protected def get (a : SArray α n) (i : Fin n) : α := a.1.get ⟨i, a.2.symm ▸ i.2⟩ theorem get_push {α n} (a : SArray α n) (x : α) (i : Nat) (hi : i < n + 1) : (a.push x).get (⟨i, hi⟩) = (if h : i < n then a.get ⟨i, h⟩ else x) := by simp [SArray.get, SArray.push, Array.get_push, a.2]
This proof is surprisingly brittle. If I remove the unrelated looking import Mathlib.Data.Nat.Factorial.Basic
it breaks. Too bad there is no squeeze_simp
in lean4 yet; staring at the proof term didn’t yet tell me which lemma is lost.
Maybe it’s not some lemma, but some other simp-affecting setting that disappears when not importing that file?
Heather Macbeth (May 11 2023 at 06:52):
There is simp?
as an alternative.
Joachim Breitner (May 11 2023 at 06:55):
Ah, using trial and error with import
I learned that
import Std.Data.Array.Lemmas
is needed; import Std.Data.Array.Init.Lemmas
is not enough.
Thanks, TIL simp?
! This shows that the missing lemma is Array.get_eq_getElem
from that module. All izz well.
Mario Carneiro (May 11 2023 at 07:21):
also FYI if you hover on Array.get_eq_getElem
it should mention that it is defined in Std.Data.Array.Lemmas
Joachim Breitner (May 11 2023 at 07:37):
Right, but (without simp?
) I didn’t even know which lemma simp
was picking up. So simp?
is very useful here, thanks!
Joachim Breitner (May 11 2023 at 07:38):
I’ve made the NatMemo code independent of mathlib (only std4), and with an example it looks like this:
https://gist.github.com/nomeata/b0929f2503fcab4d35717e92b5ba5e58
Is this something worth contributing to std4?
Joachim Breitner (May 11 2023 at 07:55):
Joachim Breitner said:
Indeed, these three magic lines greatly speed it up:
def fast_choose n k := Nat.descFactorial n k / Nat.factorial k @[csimp] lemma choose_eq_fast_choose : Nat.choose = fast_choose := funext (fun _ => funext Nat.choose_eq_descFactorial_div_factorial _)
PR’ed in https://github.com/leanprover-community/mathlib4/pull/3915; not sure if such csimp
tweaks should be in mathlib though.
Joachim Breitner (May 12 2023 at 15:29):
A function that’s defined recursively with strong recursion on nat
gets elaborated to a call to WellFounded.fix
with suitable parameters, one of which is the F
which we can also pass to the memo
function above.
So I think it should be possible to automate the process of extracting that “functorial”, defining the _fast
variant using memo, and prove the csimp
lemma, so in the end instead of
def slow (n : Nat) : Nat :=
1 + List.foldl (fun a i => a + (if _ : i<n then slow i else 0)) 0 (List.range n)
def fast (n : Nat) : Nat :=
NatMemo.memo (fun n r =>
1 + List.foldl (fun a i => a + (if h : i<n then r i h else 0)) 0 (List.range n)
) n
@[csimp]
theorem slow_is_fast: slow = fast := by
apply NatMemo.memo_spec
intro n
rw [slow]
one can just write
[@memo_csimp]
def slow (n : Nat) : Nat :=
1 + List.foldl (fun a i => a + (if _ : i<n then slow i else 0)) 0 (List.range n)
I started coding this, and got the definition’s RHS via getConstInfoDefn
and .value
of type Expr
. But getting hold of the 5th argument by just manual pattern matching with Expr.app
is quite tedious. Is there a better way to implement “I expect a term of the following shape”? Like we can do on Syntax
very nicely using the `(…)
feature?
Joachim Breitner (May 12 2023 at 15:37):
Ah, Lean.Expr.getAppFn
may help
Joachim Breitner (May 12 2023 at 17:12):
It works!
@[memo]
def slow2 (n : Nat) : Nat :=
1 + List.foldl (fun a i => a + (if _ : i<n then slow2 i else 0)) 0 (List.range n)
does all the things automatically! At least in this very particular case…
Joachim Breitner (May 12 2023 at 17:12):
Or here is another obvious example:
@[memo]
def fib : Nat → Nat
| 0 | 1 => 1
| n + 2 => fib n + fib (n + 1)
termination_by fib n => n
#eval fib 100 --fast
Joachim Breitner (May 13 2023 at 06:14):
Is there a dependent version of Array
around somewhere? (With value types depending on the indices?)
Joachim Breitner (May 13 2023 at 07:43):
I have put the NatMemo
code as a lakeified library on https://github.com/nomeata/lean4-memo-nat
there is example code at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat/Demo.lean
The (probably usable) main memo code is at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat.lean
And the (certainly experimental) attribute is in https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat/Attr.lean
I’d love to get some critical feedback of the attribute code; this is my first and probably very naive excursion into Lean metaprogramming :-)
Joachim Breitner (May 13 2023 at 07:43):
Also, is there a repository or at least list of lean4 libraries somewhere, for better discoverability?
And is there a “best practice” github repo setup that I can cargo cult, with CI and documentation pushing to github pages?
James Gallicchio (May 17 2023 at 17:27):
Joachim Breitner said:
Is there a dependent version of
Array
around somewhere? (With value types depending on the indices?)
Not that I know of. Back when I was making alternative Array
primitives, I had played around with an implementation that was modeled by a dependent function from the indices. But I threw it out because I didn't have a use case and it was a pain to work with :joy:
Joachim Breitner (May 17 2023 at 17:42):
I only have an indirect usecase: I had a non-dependent recursive function, which I sped up using memoization using Array
, and the technique applies to any use of WellFounded.fix
, and that allows dependent functions, so I had an inclination to generalize my memoization library as well… but since _that_ generalization has no use case yet, it’s not at all pressing.
François G. Dorais (May 20 2023 at 06:12):
@Joachim Breitner A quick hack is to use something like this:
inductive Any.{u} : Type u
| mk {α : Sort u} : α → Any
protected abbrev Any.Sort : Any → Sort _
| @mk α _ => α
protected abbrev Any.val : (a : Any) → a.Sort
| mk x => x
def test : Array Any := #[.mk 1, .mk "Hello", .mk true]
#eval test[0].val -- 1
#eval test[1].val -- "Hello"
#eval test[2].val -- true
Joachim Breitner (May 20 2023 at 17:17):
That seems to lose the type information, though, so probably can’t be (safely) used to memoize a function of dependent type ∀ n, C n
François G. Dorais (May 21 2023 at 15:37):
The type info is still there, as the three evals at the end show. For a specific dependent type, it's simpler to just wrap your function values in a Sigma type.
Joachim Breitner (May 22 2023 at 16:03):
Ah, I think I now get it. I can put values of Any
into the Vectors, and separately keep around the information that a[n].Sort = C i
for some C : Nat → Sort _
!
This indeed allows me to define
namespace DArray
protected def push {n C} (a : DArray n C) (x : C n) : DArray (n + 1) C :=
⟨a.1.push (Any.mk x), by … ⟩
protected def get {n C} (a : DArray n C) (i : Fin n) : C i :=
a.2 i ▸ (a.1.get i).val
protected theorem get_push {n C} (a : DArray n C) (x : C n) (i : Nat) (hi : i < n + 1) :
(a.push x).get ⟨i, hi⟩ =
if h : i < n
then a.get ⟨i, h⟩
else (Nat.le_antisymm (Nat.le_of_lt_succ hi) (Nat.le_of_not_lt h) ▸ x : C i) := by …
protected def empty {C} (cap : Nat) : DArray 0 C := ⟨SArray.empty cap, λ i => Fin.elim0 i⟩
end DArray
and then generalize the memo code at https://github.com/nomeata/lean4-memo-nat/blob/master/MemoNat.lean to dependent recursion.
(I didn’t manage to conclude the proof of get_push
because the ▸
get in the way, unfortunately, so I’ll park it at https://github.com/nomeata/lean4-memo-nat/pull/2 for now, until someone actually needs this.)
Joachim Breitner (Jun 04 2023 at 15:56):
Tried some further. It seems I am hitting universe level issues: Playing around with
inductive Any.{u} : Type u
| mk {α : Sort u} : α → Any
protected abbrev Any.Sort : Any → Sort _
| @mk α _ => α
protected abbrev Any.val : (a : Any) → a.Sort
| mk x => x
structure DArray.{u} (n : Nat) (C : Nat → Type u) : Type (u+1):=
arr : Array Any.{u+1}
size_eq : arr.size = n
types : ∀ (i : Fin arr.size), (arr.get i).Sort = C i
protected def get.{u} {n} {C : Nat → Type u} (a : DArray n C) (i : Fin n) : C i := by
rcases a with ⟨a, rfl, types⟩
exact (types i).symm ▸ a.get i
I get
type mismatch
Array.get a i
has type
Any : Type (u + 1)
but is expected to have type
Any.Sort (Array.get a i) : Type u
Is this approach, using a subtype, doomed to fail or is there a way?
Joachim Breitner (Jun 04 2023 at 15:59):
(The explicit universe levels in the structure DArray
are the inferred ones.)
Joachim Breitner (Jun 04 2023 at 16:08):
Oh, here is some random staring lean down that worked:
protected def get.{u} {n} {C : Nat → Type u} (a : DArray n C) (i : Fin n) : C i := by
rcases a with ⟨a, rfl, types⟩
have h := types i
revert h
generalize Array.get a i = x
cases x with | _ x =>
rintro rfl
exact x
So it works, and can maybe be made pretty (because in the current form proving things about it will be quite ugly)
Mario Carneiro (Jun 04 2023 at 16:08):
it still has a universe bump problem though
Joachim Breitner (Jun 04 2023 at 16:10):
You mean the definition of DArray
? Is that an avoidable one?
Mario Carneiro (Jun 04 2023 at 16:16):
here's one way to do it without a universe bump, although it wastes some memory unless lean gets support for Erased
:
structure DArray (n : Nat) (C : Nat → Type u) : Type u :=
arr : Array (Σ i, C i)
size_eq : arr.size = n
types : ∀ (i : Fin arr.size), (arr.get i).1 = i
protected def DArray.get {n} {C : Nat → Type u} (a : DArray n C) (i : Fin n) : C i := by
refine (?b : i.1 = _) ▸ (a.arr.get ⟨i, ?a⟩).2
case a => exact a.2.symm ▸ i.2
case b => exact (a.3 ⟨i, ?a⟩).symm ▸ rfl
Joachim Breitner (Jun 04 2023 at 20:11):
Thanks! I was thinking about that variant earlier, but it’s a bit disappointing with the extra (unerased) field.
But maybe I’ll play some more with this, thanks!
James Gallicchio (Jun 26 2023 at 08:19):
I ran into something I need DArray for, so I put together an implementation to PR to Std: https://github.com/leanprover/std4/pull/166
It relies on a TypeErased
type for values whose type is noncomputable. Unsure if the interface is sound, would appreciate a second set of eyes on it.
Both the interface and the implementations are deep in cast hell, so suggestions are highly appreciated, but I think most of the casting cannot be avoided.
James Gallicchio (Jun 26 2023 at 08:21):
(The TypeErased
is essentially equivalent to the Any
and Σ i, C i
types above, but has no runtime overhead because it is implemented with unsafeCasts)
Joachim Breitner (Oct 22 2023 at 14:33):
Small update: I finished support for memoizing functions recursing on nat also when they have a dependent result type, and the @[memo]
attribute now also works when the function is elaborated to use Nat.brecOn
(not just WellFounded.fix
):
Here an example from Demo.lean
that calculates Pascal’s triangle, and gets much faster with @[memo]
:
@[memo]
def pascal : (i : Nat) → SArray Nat i
| 0 => SArray.empty
| n + 1 => SArray.ofFn (n + 1) fun i =>
pad_left 1 (pascal n) i + pad_right (pascal n) 1 i
(It would also get much faster if I had used a let
there, but that’s not the point :-)).
I am using the variant of dependent arrays using the Any
type that Kyle suggested, and am living with the universe bump (DArray
); would be interesting to know if this can be avoided (with the same efficient representation).
On top of that, I also have dependent arrays with the length as an index (DSArray
), which incidentally is isomorphic to Nat.below
. I need this conversion when replacing Nat.brecEq
with my memo combinator, although I fear that this is rather expensive. Maybe I should just use the Nat.below
data structure directly in these cases.
I doubt that this very useful, but it is a useful learning exercise for me…
Last updated: Dec 20 2023 at 11:08 UTC