Zulip Chat Archive

Stream: lean4

Topic: CPS and tail recursion


Horațiu Cheval (Dec 27 2022 at 16:47):

I tried to write a tail recursive implementation of the plain depth function below using continuation passing style.

inductive Exp where
| num : Nat  Exp
| plus : Exp  Exp  Exp

def Exp.depth (e : Exp) : Nat :=
  match e with
  | .num _ => 0
  | .plus e₁ e₂ => (max e₁.depth e₂.depth) + 1

def Exp.depthCPS (e : Exp) : Nat :=
  let rec depthCore (e : Exp) (k : Nat  Nat) : Nat :=
    match e with
    | .num _ => k 0
    | .plus e₁ e₂ => depthCore e₁
      <| fun d₁ => depthCore e₂
      <| fun d₂ => k (Nat.max d₁ d₂) + 1
  depthCore e id

-- 0 + (1 + (2 + ( + ... + 1000000)...)
def largeTest : Exp := List.range 1000000 |>.mapTR (.num .) |>.foldl Exp.plus (.num 0)

def main : IO Unit := do
  IO.println largeTest.depthCPS

However, testing it on a large input still produces a stack overflow. Is it not possible in Lean to use continuation passing style in order to prevent stack overflows? I should mention that I only learnt a few days ago about this, so I'm not sure that this is the right way to do it, but to convince myself I wrote an analogous implementation in OCaml and it succesfully computed depthCPS largeTest.

Mario Carneiro (Dec 27 2022 at 16:51):

I don't think using CPS alone will prevent stack overflows. lean still has to call the function stack you have built and that will result in a call stack as large as the nesting depth of the closure

Mario Carneiro (Dec 27 2022 at 16:52):

especially this continuation here -> fun d₂ => k (Nat.max d₁ d₂) + 1 - that's not even tail recursive so it defintely has to push a stack frame to call k

Mario Carneiro (Dec 27 2022 at 16:55):

in fact, looking at it some more I think that was just a typo

Horațiu Cheval (Dec 27 2022 at 16:55):

What do you mean for the continuation to be tail recursive? I thought the question of tail recursivity was about depthCore here, which is tail recursive, right?

Mario Carneiro (Dec 27 2022 at 16:56):

and you meant fun d₂ => k (Nat.max d₁ d₂ + 1)

Horațiu Cheval (Dec 27 2022 at 16:56):

Yes, that's what I meant

Horațiu Cheval (Dec 27 2022 at 16:57):

Though I stil get a stack overflow

Mario Carneiro (Dec 27 2022 at 17:04):

That lambda gets called at some point, once you have decomposed the whole expression, and you are relying on it to tail-call k, replacing its stack frame with k's, and that's not what happens because lean doesn't do this kind of TCO

Mario Carneiro (Dec 27 2022 at 17:05):

in the example you end up building a closure which has a reference to another closure and so on a million deep, and since each closure calls the next one before it exits, you end up with a million stack frames

Mario Carneiro (Dec 27 2022 at 17:06):

currently lean can only do TCO to a recursive invocation of the same function or another function in the same mutual block, not to variables like k

Horațiu Cheval (Dec 27 2022 at 17:08):

I see now, thanks!

Mario Carneiro (Dec 27 2022 at 17:16):

Here's a version that works. This is the same CPS style translation I use to do programming language semantics

inductive Exp where
| num : Nat  Exp
| plus : Exp  Exp  Exp

inductive Cont where
| ret : Cont
| plus₁ : Exp  Cont  Cont
| plus₂ : Nat  Cont  Cont

def Exp.depth (e : Exp) : Nat :=
  match e with
  | .num _ => 0
  | .plus e₁ e₂ => max e₁.depth e₂.depth + 1

partial def Exp.depthCPS (e : Exp) : Nat :=
  let rec
    start : Exp  Cont  Nat
    | .num _, k => ret k 0
    | .plus e₁ e₂, k => start e₁ (.plus₁ e₂ k),
    ret : Cont  Nat  Nat
    | .ret, d => d
    | .plus₁ e₂ k, d₁ => start e₂ (.plus₂ d₁ k)
    | .plus₂ d₁ k, d₂ => ret k (max d₁ d₂ + 1)
  start e .ret

-- 0 + (1 + (2 + ( + ... + 1000000)...)
def largeTest : Exp := List.range 1000000 |>.mapTR (.num .) |>.foldl Exp.plus (.num 0)

def main : IO Unit := do
  IO.println largeTest.depthCPS

Mario Carneiro (Dec 27 2022 at 17:16):

I think it might be possible to do this with closures too but I'm not sure

Mario Carneiro (Dec 27 2022 at 17:17):

(The partial def is only because I'm lazy, it is possible to prove this terminates)

Mario Carneiro (Dec 27 2022 at 17:24):

The Cont type corresponds very closely to the closures that are used in your example:

  • .ret is id,
  • .plus₁ e₂ k is fun d₁ => depthCore e₂ fun d₂ => k (Nat.max d₁ d₂ + 1), and
  • .plus₂ d₁ k is fun d₂ => k (Nat.max d₁ d₂ + 1)

Horațiu Cheval (Dec 27 2022 at 17:50):

Thank you Mario, this is a very nice solution! I'll try to adapt it to my non-minimzed function

Sebastian Ullrich (Dec 27 2022 at 19:11):

TCO through dynamic calls is dependent on https://github.com/leanprover/lean4/pull/1805, plus probably some special code generation on the caller's side


Last updated: Dec 20 2023 at 11:08 UTC