Zulip Chat Archive

Stream: lean4

Topic: Help generalizing an efficient loop combinator


Sebastian Graf (Oct 28 2025 at 09:43):

Here's a function for implementing for ... in ... do notation, an alternative to ForIn.forIn:

@[specialize 3 4] def List.foldrNonTR (f : α  β  β) (init : β) : (l : List α)  β
  | []     => init
  | a :: l => f a (foldrNonTR f init l)

@[always_inline, inline]
def List.forBreak_ {α : Type u} {m : Type w  Type x} [Monad m] (xs : List α) (s : σ) (body : α  OptionT (StateT σ (ExceptT ρ m)) PUnit) (kreturn : ρ  m γ) (kbreak : σ  m γ) : m γ :=
  List.foldrNonTR
    (fun a acc s => do
      let e  body a s
      match e with
      | .error r => kreturn r
      | .ok (.some _, s) => acc s
      | .ok (none, s) => kbreak s)
    kbreak
    xs
    s

set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 25
    def List.foldrNonTR._at_._example.spec_0 x.1 _y.2 : Nat :=
      fun _f.3 x : Nat :=
        let _x.4 := 13;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        let x := Nat.add x _x.4;
        return x;
      cases x.1 : Nat
      | List.nil =>
        let _x.5 := _f.3 _y.2;
        return _x.5
      | List.cons head.6 tail.7 =>
        let _x.8 := 0;
        let _x.9 := instDecidableEqNat _y.2 _x.8;
        cases _x.9 : Nat
        | Decidable.isFalse x.10 =>
          let _x.11 := 10;
          let _x.12 := Nat.decLt _x.11 _y.2;
          cases _x.12 : Nat
          | Decidable.isFalse x.13 =>
            let _x.14 := Nat.add _y.2 head.6;
            let _x.15 := List.foldrNonTR._at_._example.spec_0 tail.7 _x.14;
            return _x.15
          | Decidable.isTrue x.16 =>
            let _x.17 := _f.3 _y.2;
            return _x.17
        | Decidable.isTrue x.18 =>
          return _y.2
[Compiler.saveBase] size: 9
    def _example : Nat :=
      let x := 42;
      let _x.1 := 1;
      let _x.2 := 2;
      let _x.3 := 3;
      let _x.4 := @List.nil _;
      let _x.5 := @List.cons _ _x.3 _x.4;
      let _x.6 := @List.cons _ _x.2 _x.5;
      let _x.7 := @List.cons _ _x.1 _x.6;
      let _x.8 := List.foldrNonTR._at_._example.spec_0 _x.7 x;
      return _x.8
-/
#guard_msgs in
example := Id.run do
  let x := 42;
  List.forBreak_ (m:=Id) (ρ := Nat) [1, 2, 3] x (fun i => do
      let x  get
      if x = 0 then throw (m := ExceptT Nat Id) x   -- return
      else if x > 10 then failure (f := OptionT _)  -- break
      else set (x + i) >>= fun _ => pure ())        -- continue
      pure fun x => do
  let x := x + 13;
  let x := x + 13;
  let x := x + 13;
  let x := x + 13;
  let x := x + 13;
  let x := x + 13;
  let x := x + 13;
  return x

As you can see, it generates optimal code for the given loop (well, modulo detecting a join point, https://github.com/leanprover/lean4/issues/10995), more so than the ForIn.forIn application that for ... in ... do elaborates to today.

However, I would like to further generalize forBreak_. I would like to get rid of the explicit StateT σ (ExceptT ρ m) in favour of just n and a constraint MonadControlT m n (or something comparable), while retaining the same specialized and optimized code. I made some progress in realizing that one can define

  k : stM m (StateT σ (EarlyReturnT ρ m)) PUnit  m γ
    | .error r => kreturn r
    | .ok (⟨⟩, s) => kbreak s

and so users could instead pass k : stM m n PUnit → m γ instead of 2 separate continuations (although deciding that it's safe to inline/peel the match from such a k already overwhelms the optimizer).
I have no idea what kind of type class would be necessary to generalize the first argument to List.foldrNonTR. (I'm doubtful that MonadControlT m n is sufficient.)

So, in short: Can you help me generalize forBreak_ such that it works for any body : α -> OptionT n PUnit, such that n corresponds to m via some type class and generates optimal code for the given example?

Also, I like that forBreak_ can be defined in terms of foldr; it means that I can generalize (along a different dimension) from List to any Foldable type class, and users just need to instantiate this type class with their notion of foldr to make use of for ... in ... do syntax. This type class has a very simple theory, in contrast to something like ForIn.forIn. So it would be great if the solution would still be formulated in terms of List.foldrNonTR (List.foldr questionably csimps to Array.foldr in order to be tail-recursive at all costs).

Jovan Gerbscheid (Oct 28 2025 at 12:08):

I didn't realize that List.foldrNonTR could actually become tail recursive depending on what it's specialized to. That's neat!

Sebastian Graf (Oct 28 2025 at 12:10):

Yeah, I think rewriting List.foldr to List.foldrTR/Array.foldr by default and unconditionally is actually bad design... It can specialize to a left-fold if properly eta-expanded afterwards. However, this currently needs the explicit @[specialize 3 4] pragma in order to specialize for init:

@[specialize]
def Foldable.foldl' [Foldable ρ α] (f : β  α  β) (init : β) (xs : ρ) : β :=
  foldr (fun a k b => k (f b a)) id xs init

set_option trace.Compiler.saveBase true in
example := Foldable.foldl' (fun a b => a * b) 1 [1, 2, 3]

(See also the definition in Haskell.)

I'm planning on using the following setup, where Foldable.foldrTR defines a reusable notion of "tail-recursive foldr":

class Foldable (ρ : Type u) (α : outParam (Type v)) extends Membership α ρ where
  foldr {β : Type w} : (α  β  β)  β  ρ  β
  foldrMem {β : Type w} : (xs : ρ)  ((a : α)  a  xs  β  β)  β  β
  foldl {β : Type w} : (β  α  β)  β  ρ  β
  foldlMem {β : Type w} : (xs : ρ)  (β  (a : α)  a  xs  β)  β  β
  length : ρ  Nat

@[specialize 3 4] def List.foldrNonTR (f : α  β  β) (init : β) : (l : List α)  β
  | []     => init
  | a :: l => f a (foldrNonTR f init l)

instance : Foldable (List α) α where
  foldr := List.foldrNonTR
  foldl := List.foldl
  foldlMem xs f z := List.foldl (fun b a, h => f b a h) z xs.attach
  foldrMem xs f z := List.foldr (fun a, h b => f a h b) z xs.attach
  length := List.length

instance : Foldable (Array α) α where
  foldr := Array.foldr
  foldl := Array.foldl
  foldlMem xs f z := Array.foldl (fun b a, h => f b a h) z xs.attach
  foldrMem xs f z := Array.foldr (fun a, h b => f a h b) z xs.attach
  length := Array.size

@[specialize]
def Foldable.toList [Foldable ρ α] : ρ  List α :=
  foldr (fun a acc => a :: acc) []

class LawfulFoldable (ρ : Type u) (α : outParam (Type v)) [Foldable ρ α] : Prop where
  -- Unsure whether the following law follows by parametricity.
  foldr_eq_foldr_toList (xs : ρ) (k : α  β  β) (z : β) :
    Foldable.foldr k z xs = List.foldr k z (Foldable.toList xs)

@[specialize]
def Foldable.toArray [Foldable ρ α] (xs : ρ) : Array α :=
  foldr (fun a k arr => k (arr.push a)) id xs (Array.mkEmpty (Foldable.length xs))

@[specialize 4 5]
def Foldable.foldrTR [Foldable ρ α] (f : α  β  β) (init : β) (xs : ρ) : β :=
  xs |> Foldable.toArray |>.foldr f init

-- def warmup := Foldable.toArray [3]
-- set_option trace.Compiler.saveBase true in
-- def blah := Foldable.foldrTR (fun a b => a * b) 0 [1, 2, 3]
-- set_option trace.Compiler.saveBase true in
-- example := List.foldr (fun a b => a * b) 0 [1, 2, 3]

@[inline]
def Foldable.forBreak_ {ρ : Type u} {α : Type v} [Foldable ρ α] {m : Type w  Type x} [Monad m] {σ ε γ} (xs : ρ) (s : σ) (body : α  BreakT (StateT σ (EarlyReturnT ε m)) PUnit) (kreturn : ε  m γ) (kbreak : σ  m γ) : m γ :=
  Foldable.foldr
    (fun a acc s => do
      let e  body a s
      match e with
      | .error r => kreturn r
      | .ok (.some _, s) => acc s
      | .ok (none, s) => kbreak s)
    kbreak
    xs
    s

Jovan Gerbscheid (Oct 28 2025 at 12:49):

I tried to see if I could create a foldl by making it into a foldr and then back into a foldl again, but sadly the compiler seems to not specialize lambdas well enough.

@[inline]
def Foldable.foldl' [Foldable ρ α] (f : β  α  β) (init : β) (xs : ρ) : β :=
  foldl (fun k a b => k (fun c => b (f c a))) id xs id init

set_option trace.Compiler.saveBase true in
example := Foldable.foldl' (fun a b => a * b) 1 [1, 2, 3]

Jovan Gerbscheid (Oct 28 2025 at 13:02):

What is your motivation for generalizing forBreak_ to any body : α -> OptionT n PUnit? I'd think that it would be sufficient to implement this for just one such monad n?

Sebastian Graf (Oct 28 2025 at 13:26):

sadly the compiler seems to not specialize lambdas well enough.

Indeed, that's lean4#10924. I have a patch that's sort of halfway there (lean4#10987), but I deprioritized after I was able to define forBreak_ as above.

What is your motivation for generalizing forBreak_ to any body : α -> OptionT n PUnit? I'd think that it would be sufficient to implement this for just one such monad n?

One of the stretch goals for my work on the do elaborator this Q is to implement return/continue/break-at-label. That would require an arbitrary nesting of alternating StateT σ/ExceptT ε to instantiate n at. I suppose I'll revisit once I got the rest working.

Jovan Gerbscheid (Oct 28 2025 at 13:41):

Would this allow to break out of nested for loops? And to continue in an outer for loop?

Sebastian Graf (Oct 28 2025 at 13:42):

Exactly

Robin Arnez (Nov 02 2025 at 09:38):

Right I also had a similar idea which I called "cfoldl" (continuation fold left) which is like "foldr" but evaluated left-to-right:

-- @[implemented_by Array.cfoldlUnsafe]
def Array.cfoldl {α : Type u} {β : Sort v} (f : α  β  β) (init : β) (as : Array α)
    (start : Nat := 0) (stop : Nat := as.size) : β :=
  let fold (stop : Nat) (h : stop  as.size) :=
    let rec loop (i : Nat) (j : Nat) (b : β) : β :=
      if hlt : j < stop then
        match i with
        | 0    => b
        | i'+1 =>
          have : j < as.size := Nat.lt_of_lt_of_le hlt h
          f as[j] (loop i' (j+1) b)
      else
        b
    loop (stop - start) start init
  if h : stop  as.size then
    fold stop h
  else
    fold as.size (Nat.le_refl _)

@[inline]
def Array.foldl {α : Type u} {β : Sort v} (f : β  α  β) (init : β) (as : Array α)
    (start : Nat := 0) (stop : Nat := as.size) : β :=
  as.cfoldl (fun a cont x => cont (f x a)) (fun x => x) start stop init

@[inline]
def Array.foldlM {α : Type u} {β : Type v} {m : Type v  Type w} [Monad m]
    (f : β  α  m β) (init : β) (as : Array α)
    (start : Nat := 0) (stop : Nat := as.size) : m β :=
  as.cfoldl (fun a cont x => f x a >>= cont) pure start stop init

Robin Arnez (Nov 02 2025 at 09:39):

Maybe that name could be confusing though lol

Robin Arnez (Nov 02 2025 at 10:02):

Otherwise I don't think you need anything else but the cfoldl / foldrNonTR combinator:

@[specialize 3 4] def List.cfoldl (f : α  β  β) (init : β) : (l : List α)  β
  | []     => init
  | a :: l => f a (cfoldl f init l)

example := Id.run do
  let x := 42;
  let brk (x : Nat) : Id Nat := do
    let x := x + 13
    let x := x + 13
    let x := x + 13
    let x := x + 13
    let x := x + 13
    let x := x + 13
    let x := x + 13
    return x
  [1, 2, 3].cfoldl (fun i cont x => do
      if x = 0 then return x   -- return
      else if x > 10 then brk x  -- break
      else cont (x + i))        -- continue
    (init := brk) (x := x)

Sebastian Graf (Nov 03 2025 at 14:45):

One reason that a single, specializable forBreak_ implementation would be great is that you could write a single specification lemma such as forIn_list for it and then instantiate it at m := StateT letMuts (ExceptT earlyReturn) m and others as needed. So I guess it's a question of (1) what type class on m do we need to enable that, and (2) does the resulting definition still specialize well.


Last updated: Dec 20 2025 at 21:32 UTC