Zulip Chat Archive

Stream: general

Topic: using_well_founded for a type mimicking programming structs


Kris Brown (Jul 15 2020 at 21:09):

I'm modeling programming datatypes, including a type with fields that contain other datatypes.

mutual inductive PrimType, Struct
  with PrimType : Type
   | pInt : nat   PrimType
   | pStr : string   PrimType
   | pBool: bool  PrimType
   | pStruct {s: string}: Struct s  PrimType

  with Struct : string  Type
   | mk (name: string) (fields : list (string × PrimType))
          : Struct name
   | nil (name: string) : Struct name
open PrimType open Struct

I'm trying to prove decidable equality on this type (that function is in a #new_members thread), but doing that straightforwardly doesn't lead to something that Lean knows will terminate. Based on Floris' suggestion, I pattern matched off of src#lists.equiv.decidable and add to the bottom

using_well_founded {
  rel_tac := λ _ _, `[exact ⟨_, measure_wf _⟩], -- this last hole needs to be filled in with something specific to PrimType/Struct
  dec_tac := `[assumption] }

The type that I need to implement (it seems) is psum (Σ' (a : PrimType), PrimType) (Σ' {s : string} (a : Struct s), Struct s) → ℕ

I'm still trying to understand what this psum really means. I started writing out an implementation but keep hitting roadblocks.

def meas : psum (Σ' (a : PrimType), PrimType) (Σ' {s : string} (a : Struct s), Struct s)  
| (psum.inl pStruct (mk _ x), pStruct (mk _ y)) := 1 + 0 -- something with x and y
| (psum.inl pStruct (mk _ x), _⟩) := 1 + 0 -- something with x
| (psum.inl ⟨_, pStruct (mk _ x)) := 1 + 0 -- something with x
| (psum.inl ⟨_,_⟩) := 0
| (psum.inr ⟨_,⟨_,mk _ x⟩⟩) := 1 + 0 -- something with x
| (psum.inr ⟨_,⟨_,nil _⟩⟩) := 0

A simpler recursive function to prove termination for might look like this

mutual def sizeofp, sizeofs
with sizeofp: PrimType  
 | (pStruct s) := 1 + sizeofs s
 | _ := 0
with sizeofs: Π{s: string}, Struct s  
 | _ (nil _) := 0
 | _ (mk _ []) := 0
 | _ (mk s ((_, h)::t)) := sizeofp h + sizeofs (mk s t)
using_well_founded {
  rel_tac := λ _ _, `[exact ⟨_, measure_wf meas],
  dec_tac := `[assumption] } -- fails to prove recursive application is decreasing

Does anyone with experience with using_well_founded have suggestions for how to tackle this?

Kris Brown (Jul 17 2020 at 00:30):

It turns out this problem could be solved without resorting to using_well_founded - it just requires adding the right proofs immediately above the recursive calls. There is one bizarre thing where I need to prove a+b<c+d and Lean seems to nondeterministically need some permutation of that (possibly a,b swapped or c,d swapped). Sometimes the checker would approve with just one, but often it fails unless all four permutations are explicitly declared.

Posting my solution here in case anyone else comes across a similar problem

import tactic.linarith

mutual inductive PrimType, Struct
  with PrimType : Type
   | pInt : nat   PrimType
   | pStr : string   PrimType
   | pBool: bool  PrimType
   | pStruct {s: string}: Struct s  PrimType

  with Struct : string  Type
   | mk (name: string) (fields : list (string × PrimType))
          : Struct name
   | nil (name: string) : Struct name

open PrimType open Struct
-------------------------------------------------------------

/-
The head of a list of pairs is smaller than the whole list
-/
lemma headsize {α β: Type} [has_sizeof α] [has_sizeof β]
                (a: α) (b: β) (t: list (α×β)):
                    sizeof b < sizeof(list.cons (a,b) t) := calc
  sizeof b <= sizeof a + sizeof b                   : by linarith
   ...     <  1 + prod.sizeof (a,b) + list.sizeof t : by {unfold prod.sizeof,
                                                          linarith}
   ...     =  1 + sizeof (a,b) + list.sizeof t      : by {unfold sizeof, refl}
   ...     =  list.sizeof(list.cons (a,b) t)        : by unfold list.sizeof

/-
The tail of a list of pairs is smaller than the whole list
-/
lemma tailsize {α β: Type} [has_sizeof α] [has_sizeof β]
                (ab: α×β)  (t: list (α×β)):
                    sizeof t < sizeof(list.cons (ab) t) := calc
  sizeof t < 1 + sizeof (ab) + sizeof t      : by {linarith}
   ...     = 1 + sizeof (ab) + list.sizeof t : by {unfold sizeof, refl}
   ...     =  list.sizeof(list.cons (ab) t)  : by unfold list.sizeof

/-
Sometimes we cannot use prod.decidable_eq, because that requires an instance of
[has_decidable_eq] for α and β, which is what we're trying to prove! Need a
version that works for specific elements, used in a proof by induction.
-/
def prod_eq_intro: Π{α β: Type} (a1 a2: α) (b1 b2: β)
                (x: decidable (a1=a2)) (y: decidable (b1=b2)),
                    decidable ((a1,b1)=(a2,b2)) := by begin
    intros,
    cases x with xfalse xtrue,
        {simp [xfalse], exact decidable.false},
    {cases y with yfalse ytrue,
        {simp [yfalse], exact decidable.false},
        {simp [xtrue, ytrue], exact decidable.true}},
end

-------------------------------------------------------------
/-
One caveat of this equality is that the NULL values for some given struct
are considered equal.
-/
mutual def P_eq, P_list_eq, S_eq
with P_eq : decidable_eq PrimType
 | (pInt _)    (pBool _)    := by {simp only [], exact decidable.false}
 | (pInt _)    (pStr _)     := by {simp only [], exact decidable.false}
 | (pInt _)    (pStruct _)  := by {simp only [], exact decidable.false}
 | (pBool _)   (pInt _)     := by {simp only [], exact decidable.false}
 | (pBool _)   (pStr _)     := by {simp only [], exact decidable.false}
 | (pBool _)   (pStruct _)  := by {simp only [], exact decidable.false}
 | (pStr _)    (pInt _)     := by {simp only [], exact decidable.false}
 | (pStr _)    (pBool _)    := by {simp only [], exact decidable.false}
 | (pStr _)     (pStruct _) := by {simp only [], exact decidable.false}
 | (pStruct _)  (pInt _)    := by {simp only [], exact decidable.false}
 | (pStruct _)  (pStr _)    := by {simp only [], exact decidable.false}
 | (pStruct _)  (pBool _)   := by {simp only [], exact decidable.false}
 | (pInt x)     (pInt y)    := by {simp only [], exact nat.decidable_eq x y}
 | (pBool x)    (pBool y)   := by {simp only [], exact bool.decidable_eq x y}
 | (pStr x)     (pStr y)    := by {simp only [],
                                   exact string.has_decidable_eq x y}

 | (@pStruct m (mk a b)) (@pStruct n (nil x)) := by {
        simp, have p: decidable (a=x), by exact string.has_decidable_eq a x,
        cases p with pf pt,
        -- two different Struct classes! definitely not equal.
        {simp [pf], exact decidable.false},
        -- Even if same class, it's different constructors; still not equal
        {simp [pt], rw pt, rw [heq_iff_eq], simp, exact decidable.false}}

 | (@pStruct _ (nil x)) (@pStruct m (mk a b)) := by {
        simp, have p: decidable (x=a), by exact string.has_decidable_eq x a,
        cases p with pf pt, -- (same as above block, with a flip)
        {simp [pf], exact decidable.false},
        {simp [pt], rw pt, rw [heq_iff_eq], simp, exact decidable.false}}

 | (@pStruct n (nil x)) (@pStruct m (nil y)) := by {simp only,
        have p: decidable (x=y), by exact string.has_decidable_eq x y,
        cases p, {simp [p], exact decidable.false}, -- different classes.
        {rw p, rw [heq_iff_eq], simp, exact decidable.true}} -- both NULL

 | (@pStruct n (mk a x)) (@pStruct m (mk b y)) := by {
        simp, have p: decidable (a=b), by exact string.has_decidable_eq a b,
        cases p with pf pt,
        {simp [pf], exact decidable.false}, -- different classes; not equal
        {rw pt, simp, exact P_list_eq x y}} -- equal if arg lists are equal

with P_list_eq : decidable_eq (list (string × PrimType))
 | a b := by begin
  cases eqa : a with a_hd a_tl,
  {cases b,
    {simp, exact decidable.true}, -- both lists nil, they're equal
    {simp, exact decidable.false}}, -- false if only one list is nil
  {cases eqb : b with b_hd b_tl,
    {simp, exact decidable.false}, simp, -- false if only one list is nil
    cases a_hd with a_hds a_hdp, -- split the head into its parts
    cases b_hd with b_hds b_hdp, -- split the head into its parts

    -- Lemmas useful for proving recursive calls terminate
    have sizea : a_hdp.sizeof < a.sizeof :=
        by {rw eqa, exact headsize a_hds a_hdp a_tl},
    have sizea': a_tl.sizeof < a.sizeof :=
        by {rw eqa, exact tailsize (a_hds, a_hdp) a_tl},
    have sizeb : b_hdp.sizeof < b.sizeof :=
        by {rw eqb, exact headsize b_hds b_hdp b_tl},
    have sizeb': b_tl.sizeof < b.sizeof :=
        by {rw eqb, exact tailsize (b_hds, b_hdp) b_tl},

    have res1: decidable (a_hdp = b_hdp) :=
        -- bizarre... but all four permutations are explicitly needed
        have H1: a_hdp.sizeof + b_hdp.sizeof < 1 + (b.sizeof + a.sizeof),
                by linarith,
        have H2: a_hdp.sizeof + b_hdp.sizeof < 1 + (a.sizeof + b.sizeof),
                by linarith [H1],
        have H3: b_hdp.sizeof + a_hdp.sizeof < 1 + (b.sizeof + a.sizeof),
                by linarith [H1],
        have H4: b_hdp.sizeof + a_hdp.sizeof < 1 + (a.sizeof + b.sizeof),
                by linarith [H1],
        P_eq a_hdp b_hdp, -- the head is decidably equal (recursive call)

    -- Above case was just for the primtype half; extend to the whole head
    have res1': decidable ((a_hds, a_hdp) = (b_hds, b_hdp)) :=
            prod_eq_intro a_hds b_hds a_hdp b_hdp
                (string.has_decidable_eq a_hds b_hds) res1,

    have res2: decidable (a_tl = b_tl) :=
        have H1: a_tl.sizeof + b_tl.sizeof < b.sizeof + a.sizeof,
                by linarith,
        have H2: b_tl.sizeof + a_tl.sizeof < b.sizeof + a.sizeof,
                by linarith [H1],
        have H3: a_tl.sizeof + b_tl.sizeof < a.sizeof + b.sizeof,
                by linarith [H1],
        have H4: b_tl.sizeof + a_tl.sizeof < a.sizeof + b.sizeof,
                by linarith [H1],
        P_list_eq a_tl b_tl, -- the tail is decidably equal (recursive call)

    -- head and tail are both decidably equal, so we're done
    exact @and.decidable ((a_hds, a_hdp) = (b_hds, b_hdp)) (a_tl = b_tl)
                         res1' res2
    }
end

with S_eq : Π(s: string), decidable_eq (Struct s)
 | _ (mk _ _)      (nil _)       := by {simp only [], exact decidable.false}
 | _ (nil _)       (mk _ _)      := by {simp only [], exact decidable.false}
 | _ (nil _)       (nil _)       := by {simp only [], exact decidable.true}
 | z (Struct.mk m x) (Struct.mk n y) := by {simp,
    cases (P_list_eq x y) with Hf Ht, -- recursively check if lists are equal
        {simp [Hf], exact decidable.false},
        {rw Ht, simp, exact decidable.true}}

instance : Π{s: string}, decidable_eq (Struct s) := by {intros s, exact S_eq s}

def f {s: string} (x: Struct s) (y: Struct s) : bool := x = y
#eval f (mk "a" []) (nil "a") --ff

Last updated: Dec 20 2023 at 11:08 UTC