Zulip Chat Archive

Stream: general

Topic: casing on a fin type


Jeremy Avigad (Oct 14 2021 at 19:06):

Suppose I define a function like foo, and want to prove bar:

def foo : fin 4  string
  | 0, _ := "dog"
  | 1, _ := "cat"
  | 2, _ := "parakeet"
  | 3, _ := "tasmanian devil"
  | _ + 4, h := by linarith

theorem bar (x : fin 4) : (foo x).length < 16 := sorry

Here is one way to do it:

@[simp] lemma foo0 : foo 0 = "dog" := rfl
@[simp] lemma foo1 : foo 1 = "cat" := rfl
@[simp] lemma foo2 : foo 2 = "parakeet" := rfl
@[simp] lemma foo3 : foo 3 = "tasmanian devil" := rfl
@[simp] lemma l0 : "dog".length = 3 := rfl
@[simp] lemma l1 : "cat".length = 3 := rfl
@[simp] lemma l2 : "parakeet".length = 8 := rfl
@[simp] lemma l3 : "tasmanian devil".length = 15 := rfl

theorem bar (x : fin 4) : (foo x).length < 16 :=
by { fin_cases x; simp }

Is there a way to do it without having to declare all the simp lemma?

I realize that there are two questions here: how to get simp to simplify a definition by cases, and how to get the simplifier to unpack strings. Information on either will be helpful.

Julian Berman (Oct 14 2021 at 19:13):

Is the question specifically about simp? It works for me without the simp lemmas with dec_trivial if that's a suitable answer?

I.e.:

import tactic
def foo : fin 4  string
  | 0, _ := "dog"
  | 1, _ := "cat"
  | 2, _ := "parakeet"
  | 3, _ := "tasmanian devil"
  | _ + 4, h := by linarith

theorem bar (x : fin 4) : (foo x).length < 16 := by fin_cases x; dec_trivial

Rob Lewis (Oct 14 2021 at 19:52):

Slightly rephrased, you don't even need fin_cases.

import tactic

def foo : fin 4  string
  | 0, _ := "dog"
  | 1, _ := "cat"
  | 2, _ := "parakeet"
  | 3, _ := "tasmanian devil"
  | _ + 4, h := by linarith

theorem bar :  (x : fin 4), (foo x).length < 16 := dec_trivial

Jeremy Avigad (Oct 14 2021 at 20:03):

Thanks! This is really helpful. I have a use case where we are performing permutations on concrete matrices and want to verify the results. I worry that things like dec_trivial and reflexivity will hit performance problems -- but maybe not. I'm still be curious to know if the simplifier can be made to do explicit calculations like these (length of a string, value of a function on an element of a fintype). But, again, this is really helpful.

Yakov Pechersky (Oct 14 2021 at 20:06):

We wrote a norm_num plugin to deal with normalizing swaps, norm_swap.

Yakov Pechersky (Oct 14 2021 at 20:08):

From what I remember Mario saying, such explicit calculations shouldn't be the simplifier's role, because often you're expanding and increasing terms, which might not be confluent. Instead, a d sl-based proof constructing algorithm can be employed a la norm_num plugins.

Jeremy Avigad (Oct 15 2021 at 01:34):

For what it's worth, I managed to get simp to do the computation. For strings, it's easy:

namespace string

@[simp] lemma length_str (s : string) (c : char) :
  (str s c).length = s.length + 1 :=
by { cases s, simp [str, length, push] }

@[simp] lemma length_empty : length empty = 0 := rfl

end string

example : string.length "tasmanian devil" < 16 := by norm_num

It's harder to simplify a fin numeral like 3 to ⟨3, _⟩, because the simplifier is wired to go the other way. So you have to be careful to use simp only.

namespace fin

run_cmd mk_simp_attr `fin_num_simps

@[fin_num_simps] lemma zero_eq (n : nat) : (0 : fin n.succ) = fin.mk 0 (nat.zero_lt_succ _) := rfl

@[fin_num_simps] lemma one_eq (n : nat) :
  (1 : fin n.succ) = fin.mk (1 % n.succ) (nat.mod_lt 1 (nat.zero_lt_succ _)) := rfl

@[fin_num_simps] lemma bit0_eq (n m : nat) (h : m < n.succ) :
  bit0 (fin.mk m h) = fin.mk ((m + m) % n.succ) (nat.mod_lt _ (nat.zero_lt_succ _)) := rfl

lemma add_one_eq (n m : nat) (h : m < n.succ) :
  fin.mk m h + 1 = fin.mk ((m + (1 % n.succ)) % n.succ) (nat.mod_lt _ (nat.zero_lt_succ _)) := rfl

@[fin_num_simps] lemma bit1_eq (n m : nat) (h : m < n.succ) :
  bit1 (fin.mk m h) = fin.mk ((m + m + 1) % n.succ) (nat.mod_lt _ (nat.zero_lt_succ _)) :=
by { rw [bit1, bit0_eq, add_one_eq], simp }

end fin

meta def tactic.interactive.fin_num_simp : tactic unit :=
`[ simp only with fin_num_simps , try { norm_num1 }, simp only [fin.mk_eq_subtype_mk] ]

example : foo 3 = "tasmanian devil" :=
by { fin_num_simp, simp only [foo] }

Putting it together, we have

theorem bar (x : fin 4) : (foo x).length < 16:=
by { fin_cases x; fin_num_simp; simp only [foo]; norm_num }

Yakov Pechersky (Oct 15 2021 at 02:07):

Dealing with fins has its own plugin! norm_fin

Yakov Pechersky (Oct 15 2021 at 02:07):

Which will also properly deal with the periodic properties of fin.

Anne Baanen (Oct 15 2021 at 10:22):

For what it's worth, you might also want to check out data.matrix.notation, which allows you to write:

import data.matrix.notation

def foo : fin 4  string := !["dog", "cat", "parakeet", "tasmanian devil"]

-- The following still works, by the way:
theorem bar :  (x : fin 4), (foo x).length < 16 := dec_trivial

Last updated: Dec 20 2023 at 11:08 UTC