Zulip Chat Archive

Stream: new members

Topic: reasoning with the do syntax


Yannick Seurin (Dec 05 2024 at 17:00):

I'd like to prove that sampling uniformly at random from some Fintype α and applying a bijection f : α → β yields the uniform distribution over β. Working with the monadic structure of PMFs, I can define the process in two different ways: unif_then_bij (wihtout the do syntax) and unif_then_bij' (with the do syntax) and they are definitionally equivalent.

Now, I was able to prove the result for the definition without the do syntax, but I can't figure out how to work with the alternative definition. Of course, I could just start the second proof by replacing unif_then_bij' by unif_then_bij and proceed with the first proof, but I'd like to work with more involved processes where this might become cumbersome. Any idea how I can get away with the do stuff in the second proof?

import Mathlib.Probability.Distributions.Uniform
import Mathlib.Probability.ProbabilityMassFunction.Basic
import Mathlib.Tactic

variable (α β : Type) [Fintype α] [Nonempty α] [Fintype β] [Nonempty β]
         (f : α  β)

noncomputable def unif_then_bij : PMF β :=
    (PMF.uniformOfFintype α).bind (fun x => PMF.pure (f x))

noncomputable def unif_then_bij' : PMF β := do
  let x  PMF.uniformOfFintype α
  PMF.pure (f x)

omit [Fintype β] [Nonempty β] in
theorem equivalent f : unif_then_bij α β f = unif_then_bij' α β f := rfl

theorem pmf_eq_uniform_of_bij (bij : Function.Bijective f) :
    PMF.uniformOfFintype β = unif_then_bij α β f := by
  let B := (Fintype.card β : ENNReal)
  ext b
  simp only [unif_then_bij, PMF.uniformOfFintype_apply, PMF.bind_apply, PMF.pure_apply, mul_ite, mul_one, mul_zero]
  rw [Fintype.card_of_bijective bij]
  rcases Function.bijective_iff_has_inverse.mp bij with finv, invli, invri
  let a := finv b
  classical
  let Sa : Finset α := {x | x = a}
  let Sna : Finset α := {x | x  a}
  have hu : Finset.univ = Sa  Sna := by
    ext x
    simp
    simp [Sa, Sna]
    apply Classical.em
  have hd : Disjoint Sa Sna := by apply Finset.disjoint_left.mpr; simp [Sa, Sna]
  rw [tsum_fintype fun y  if b = f y then B⁻¹ else 0]
  rw [hu, Finset.sum_union hd]
  have snazero:  x  Sna, (fun y  if b = f y then B⁻¹ else 0) x = 0 := by
    intro x xsna
    dsimp
    rw [if_neg]
    apply Finset.mem_filter.mp at xsna
    intro bfx
    have : x = a :=
      calc
        x = finv (f x) := by exact (invli x).symm
        _ = finv b := by rw [ bfx]
        _ = a := rfl
    tauto
  have sacardb :  x  Sa, (fun y  if b = f y then B⁻¹ else 0) x = B⁻¹ := by
    intro x xsa
    dsimp
    rw [if_pos]
    apply Finset.mem_filter.mp at xsa
    calc
      b = f (finv b) := by exact (invri b).symm
      _ = f a := by rfl
      _ = f x := by rw [xsa.2]
  rw [Finset.sum_eq_zero snazero, add_zero, Finset.sum_congr rfl sacardb, Finset.sum_const]
  have cardsa : Finset.card Sa = 1 := by
    apply (Fintype.exists_unique_iff_card_one _).mp
    exact exists_unique_eq
  rw [cardsa]
  exact Eq.symm (one_nsmul B⁻¹)

  theorem pmf_eq_uniform_of_bij' (bij : Function.Bijective f) :
    PMF.uniformOfFintype β = unif_then_bij' α β f := by
  let B := (Fintype.card β : ENNReal)
  ext b
  simp only [unif_then_bij', PMF.uniformOfFintype_apply, PMF.bind_apply, PMF.pure_apply, mul_ite, mul_one, mul_zero]
  -- PMF.bind_apply and PMF.pure_apply won't do anything here
  sorry

Last updated: May 02 2025 at 03:31 UTC