Zulip Chat Archive
Stream: new members
Topic: reasoning with the do syntax
Yannick Seurin (Dec 05 2024 at 17:00):
I'd like to prove that sampling uniformly at random from some Fintype α
and applying a bijection f : α → β
yields the uniform distribution over β
. Working with the monadic structure of PMFs, I can define the process in two different ways: unif_then_bij
(wihtout the do
syntax) and unif_then_bij'
(with the do
syntax) and they are definitionally equivalent.
Now, I was able to prove the result for the definition without the do
syntax, but I can't figure out how to work with the alternative definition. Of course, I could just start the second proof by replacing unif_then_bij'
by unif_then_bij
and proceed with the first proof, but I'd like to work with more involved processes where this might become cumbersome. Any idea how I can get away with the do
stuff in the second proof?
import Mathlib.Probability.Distributions.Uniform
import Mathlib.Probability.ProbabilityMassFunction.Basic
import Mathlib.Tactic
variable (α β : Type) [Fintype α] [Nonempty α] [Fintype β] [Nonempty β]
(f : α → β)
noncomputable def unif_then_bij : PMF β :=
(PMF.uniformOfFintype α).bind (fun x => PMF.pure (f x))
noncomputable def unif_then_bij' : PMF β := do
let x ← PMF.uniformOfFintype α
PMF.pure (f x)
omit [Fintype β] [Nonempty β] in
theorem equivalent f : unif_then_bij α β f = unif_then_bij' α β f := rfl
theorem pmf_eq_uniform_of_bij (bij : Function.Bijective f) :
PMF.uniformOfFintype β = unif_then_bij α β f := by
let B := (Fintype.card β : ENNReal)
ext b
simp only [unif_then_bij, PMF.uniformOfFintype_apply, PMF.bind_apply, PMF.pure_apply, mul_ite, mul_one, mul_zero]
rw [Fintype.card_of_bijective bij]
rcases Function.bijective_iff_has_inverse.mp bij with ⟨finv, invli, invri⟩
let a := finv b
classical
let Sa : Finset α := {x | x = a}
let Sna : Finset α := {x | x ≠ a}
have hu : Finset.univ = Sa ∪ Sna := by
ext x
simp
simp [Sa, Sna]
apply Classical.em
have hd : Disjoint Sa Sna := by apply Finset.disjoint_left.mpr; simp [Sa, Sna]
rw [tsum_fintype fun y ↦ if b = f y then B⁻¹ else 0]
rw [hu, Finset.sum_union hd]
have snazero: ∀ x ∈ Sna, (fun y ↦ if b = f y then B⁻¹ else 0) x = 0 := by
intro x xsna
dsimp
rw [if_neg]
apply Finset.mem_filter.mp at xsna
intro bfx
have : x = a :=
calc
x = finv (f x) := by exact (invli x).symm
_ = finv b := by rw [← bfx]
_ = a := rfl
tauto
have sacardb : ∀ x ∈ Sa, (fun y ↦ if b = f y then B⁻¹ else 0) x = B⁻¹ := by
intro x xsa
dsimp
rw [if_pos]
apply Finset.mem_filter.mp at xsa
calc
b = f (finv b) := by exact (invri b).symm
_ = f a := by rfl
_ = f x := by rw [xsa.2]
rw [Finset.sum_eq_zero snazero, add_zero, Finset.sum_congr rfl sacardb, Finset.sum_const]
have cardsa : Finset.card Sa = 1 := by
apply (Fintype.exists_unique_iff_card_one _).mp
exact exists_unique_eq
rw [cardsa]
exact Eq.symm (one_nsmul B⁻¹)
theorem pmf_eq_uniform_of_bij' (bij : Function.Bijective f) :
PMF.uniformOfFintype β = unif_then_bij' α β f := by
let B := (Fintype.card β : ENNReal)
ext b
simp only [unif_then_bij', PMF.uniformOfFintype_apply, PMF.bind_apply, PMF.pure_apply, mul_ite, mul_one, mul_zero]
-- PMF.bind_apply and PMF.pure_apply won't do anything here
sorry
Last updated: May 02 2025 at 03:31 UTC