Zulip Chat Archive

Stream: lean4

Topic: Unification with let bindings


Tomas Skrivan (May 13 2022 at 15:48):

To do automatic differentiation you need to smartly differentiate through let bindings to get efficient code. Turning of zeta reduction in the simplifier can get me partially there but not fully.

The crucial simplification rule is:

theorem D_let (g : α  β) (f : α  β  γ)
  : D (λ x =>
         let tmp := g x
         f x tmp)
    = λ x dx =>
         let tmp  := g x
         let dtmp := D g x dx
         D (f x) tmp dtmp + D f x dx tmp

Is seems that when the simplifier looks for the pattern it removed let binding on the left hand side i.e. applies zeta during unification. It does not matter whether you write D (fun x => let tmp = g x; f x tmp) or D (fun x => f x (g x)) on the left hand side. Is there a way to turn zeta off when simp searches for patterns to match on?

An example, the expression

D (λ x =>
         let gx := (g x)
         f gx gx)

gets simplified with simp (config := {zeta := false}) to:

fun x dx =>
    let tmp := g x;
    let dtmp := D g x dx;
    D (f (g x)) tmp dtmp +
      (let tmp := g x;
        let dtmp := D g x dx;
        D f tmp dtmp)
        tmp

You can see that g x and D g x dx is computed twice.

I would like to get the following result:

fun x dx =>
  let gx := (g x)
  let dgx := D g x dx
  D (f gx) gx dgx + D f gx dgx gx

Here is mwe, look at the last example where I apply D_let manually to the lhs and automatically to the rhs.

constant D (f : α  β) : α  α  β := λ x _ => f x

variable {α β β₁ β₂ γ δ : Type}
variable [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited β₁] [Inhabited β₂]
variable [Add α] [Add β] [Add γ] [Add δ] [Add β₁] [Add β₂]

instance {α β : Type} [Add β] : Add (α  β) := λ f g x => f x + g x
instance {α β : Type} [Add α] [Add β] : Add (α × β) := λ (a,b) (a',b') => (a+a', b+b')⟩

 -- default plays role of a zero
@[simp]
axiom add_default {α} [Inhabited α] [Add α] (x : α) : x + default = x
@[simp]
axiom default_add {α} [Inhabited α] [Add α] (x : α) : default + x = x
-- derivative in zero direction is zero
@[simp]
axiom D_default (f : α  β) (x) : D f x default = default

@[simp]
theorem default_eval {α} [Inhabited β] (x : α) : (default : α  β) x = default :=
by
  unfold default; unfold instInhabitedForAll_1; unfold default; simp; done

-- Basic combinators
@[simp]
axiom D_I
  : D (λ x : α => x) = λ x dx => dx
@[simp]
axiom D_K (y : β)
  : D (λ x : α => y) = λ x dx => default
@[simp]
axiom D_S (f : α  β  γ) (g : α  β)
  : D (λ x => f x (g x)) = λ x dx => D (f x) (g x) (D g x dx) + D f x dx (g x)

@[simp]
axiom D_add_1 [Add α] : D (λ x y : α => x + y) = λ x dx y => dx

@[simp]
axiom D_add_2 [Add α] (x : α) : D (λ y : α => x + y) = λ y dy => dy

@[simp high]
theorem D_let (g : α  β) (f : α  β  γ)
  : D (λ x => f x (g x))
    = λ x dx =>
         let tmp  := g x
         let dtmp := D g x dx
         D (f x) tmp dtmp + D f x dx tmp
  :=
by
  funext x dx
  simp

example (g : α  α) (f : α  α  γ)
  : D (λ x =>
         let gx := (g x)
         f gx gx)
    =
    D (λ x =>
         let gx := (g x)
         f gx gx)
  :=
by
  conv =>
    lhs
    rw[D_let g (λ x y => f y y)]
    simp (config := {zeta := false}) -- this simplifes the derivative of constant function `D (λ x => f y y)`
    trace_state

-- | fun x dx =>
--     let tmp := g x;
--     let dtmp := D g x dx;
--     let tmp_1 := tmp;
--     let dtmp_1 := dtmp;
--     D (f tmp) tmp_1 dtmp_1 + D (fun x => f x) tmp dtmp tmp_1

  conv =>
    rhs
    simp (config := {zeta := false})
    trace_state

-- | fun x dx =>
--     let tmp := g x;
--     let dtmp := D (fun x => g x) x dx;
--     D (f (g x)) tmp dtmp +
--       (let tmp := g x;
--         let dtmp := D (fun x => g x) x dx;
--         D f tmp dtmp)
--         tmp

Tomas Skrivan (May 13 2022 at 16:13):

Would my requirement mean to modify the isDefEq at the beginning of tryTheoremCore?

private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) (discharge? : Expr  SimpM (Option Expr)) : SimpM (Option Result) := do
  let rec go (e : Expr) : SimpM (Option Result) := do
    if ( isDefEq lhs e) then
...

However isDefEq is calling whnfCore

Which looks like:

partial def whnfCore (e : Expr) : MetaM Expr :=
  whnfEasyCases e fun e => do
    trace[Meta.whnf] e
    match e with
    | Expr.const ..  => pure e
    | Expr.letE _ _ v b _ => whnfCore $ b.instantiate1 v
....

so it always reduces let binding. Well it is called whnf after all. So there is no option to turn off zeta.

Tomas Skrivan (May 13 2022 at 16:20):

I will probably write my own tactic that handles let bindings properly as changing these internals is a bad idea.


Last updated: Dec 20 2023 at 11:08 UTC