Zulip Chat Archive
Stream: lean4
Topic: CPS and tail recursion
Horațiu Cheval (Dec 27 2022 at 16:47):
I tried to write a tail recursive implementation of the plain depth
function below using continuation passing style.
inductive Exp where
| num : Nat → Exp
| plus : Exp → Exp → Exp
def Exp.depth (e : Exp) : Nat :=
match e with
| .num _ => 0
| .plus e₁ e₂ => (max e₁.depth e₂.depth) + 1
def Exp.depthCPS (e : Exp) : Nat :=
let rec depthCore (e : Exp) (k : Nat → Nat) : Nat :=
match e with
| .num _ => k 0
| .plus e₁ e₂ => depthCore e₁
<| fun d₁ => depthCore e₂
<| fun d₂ => k (Nat.max d₁ d₂) + 1
depthCore e id
-- 0 + (1 + (2 + ( + ... + 1000000)...)
def largeTest : Exp := List.range 1000000 |>.mapTR (.num .) |>.foldl Exp.plus (.num 0)
def main : IO Unit := do
IO.println largeTest.depthCPS
However, testing it on a large input still produces a stack overflow. Is it not possible in Lean to use continuation passing style in order to prevent stack overflows? I should mention that I only learnt a few days ago about this, so I'm not sure that this is the right way to do it, but to convince myself I wrote an analogous implementation in OCaml and it succesfully computed depthCPS largeTest
.
Mario Carneiro (Dec 27 2022 at 16:51):
I don't think using CPS alone will prevent stack overflows. lean still has to call the function stack you have built and that will result in a call stack as large as the nesting depth of the closure
Mario Carneiro (Dec 27 2022 at 16:52):
especially this continuation here -> fun d₂ => k (Nat.max d₁ d₂) + 1
- that's not even tail recursive so it defintely has to push a stack frame to call k
Mario Carneiro (Dec 27 2022 at 16:55):
in fact, looking at it some more I think that was just a typo
Horațiu Cheval (Dec 27 2022 at 16:55):
What do you mean for the continuation to be tail recursive? I thought the question of tail recursivity was about depthCore
here, which is tail recursive, right?
Mario Carneiro (Dec 27 2022 at 16:56):
and you meant fun d₂ => k (Nat.max d₁ d₂ + 1)
Horațiu Cheval (Dec 27 2022 at 16:56):
Yes, that's what I meant
Horațiu Cheval (Dec 27 2022 at 16:57):
Though I stil get a stack overflow
Mario Carneiro (Dec 27 2022 at 17:04):
That lambda gets called at some point, once you have decomposed the whole expression, and you are relying on it to tail-call k
, replacing its stack frame with k
's, and that's not what happens because lean doesn't do this kind of TCO
Mario Carneiro (Dec 27 2022 at 17:05):
in the example you end up building a closure which has a reference to another closure and so on a million deep, and since each closure calls the next one before it exits, you end up with a million stack frames
Mario Carneiro (Dec 27 2022 at 17:06):
currently lean can only do TCO to a recursive invocation of the same function or another function in the same mutual block, not to variables like k
Horațiu Cheval (Dec 27 2022 at 17:08):
I see now, thanks!
Mario Carneiro (Dec 27 2022 at 17:16):
Here's a version that works. This is the same CPS style translation I use to do programming language semantics
inductive Exp where
| num : Nat → Exp
| plus : Exp → Exp → Exp
inductive Cont where
| ret : Cont
| plus₁ : Exp → Cont → Cont
| plus₂ : Nat → Cont → Cont
def Exp.depth (e : Exp) : Nat :=
match e with
| .num _ => 0
| .plus e₁ e₂ => max e₁.depth e₂.depth + 1
partial def Exp.depthCPS (e : Exp) : Nat :=
let rec
start : Exp → Cont → Nat
| .num _, k => ret k 0
| .plus e₁ e₂, k => start e₁ (.plus₁ e₂ k),
ret : Cont → Nat → Nat
| .ret, d => d
| .plus₁ e₂ k, d₁ => start e₂ (.plus₂ d₁ k)
| .plus₂ d₁ k, d₂ => ret k (max d₁ d₂ + 1)
start e .ret
-- 0 + (1 + (2 + ( + ... + 1000000)...)
def largeTest : Exp := List.range 1000000 |>.mapTR (.num .) |>.foldl Exp.plus (.num 0)
def main : IO Unit := do
IO.println largeTest.depthCPS
Mario Carneiro (Dec 27 2022 at 17:16):
I think it might be possible to do this with closures too but I'm not sure
Mario Carneiro (Dec 27 2022 at 17:17):
(The partial def
is only because I'm lazy, it is possible to prove this terminates)
Mario Carneiro (Dec 27 2022 at 17:24):
The Cont
type corresponds very closely to the closures that are used in your example:
.ret
isid
,.plus₁ e₂ k
isfun d₁ => depthCore e₂ fun d₂ => k (Nat.max d₁ d₂ + 1)
, and.plus₂ d₁ k
isfun d₂ => k (Nat.max d₁ d₂ + 1)
Horațiu Cheval (Dec 27 2022 at 17:50):
Thank you Mario, this is a very nice solution! I'll try to adapt it to my non-minimzed function
Sebastian Ullrich (Dec 27 2022 at 19:11):
TCO through dynamic calls is dependent on https://github.com/leanprover/lean4/pull/1805, plus probably some special code generation on the caller's side
Last updated: Dec 20 2023 at 11:08 UTC