Zulip Chat Archive
Stream: new members
Topic: Proof review: CSES Counting Bits
Huỳnh Trần Khanh (Feb 09 2021 at 03:47):
This thread is a sequel to this thread. The main lemma I proved is dp_eq_total_popcount
.
import data.nat.basic
import data.list.range
import tactic.linarith
import data.nat.modeq
import tactic.slim_check
import data.nat.parity
open list
def dp_cardinality : list bool → bool → ℕ
| [] tt := 1
| [] ff := 0
| (ff::the_rest) ff := dp_cardinality the_rest ff
| (tt::the_rest) ff := dp_cardinality the_rest tt + dp_cardinality the_rest ff
| (_::the_rest) tt := 2 * dp_cardinality the_rest tt
def dp_popcount : list bool → bool → ℕ
| [] _ := 0
| (ff::the_rest) ff := dp_popcount the_rest ff
| (tt::the_rest) ff := dp_popcount the_rest tt + (dp_popcount the_rest ff + dp_cardinality the_rest ff)
| (_::the_rest) tt := dp_popcount the_rest tt + (dp_popcount the_rest tt + dp_cardinality the_rest tt)
lemma to_binary_terminates (n : ℕ) (h : 0 < n) : n / 2 < n := begin
have := lt_or_eq_of_le (nat.div_le_self n 2),
cases this,
{
assumption,
},
{
exfalso,
have := (nat.div_eq_self).mp this,
cases this,
{
linarith,
},
{
linarith,
}
}
end
def to_binary : ℕ → list bool
| 0 := []
| (n + 1) := have (n + 1) / 2 < (n + 1) := to_binary_terminates (n + 1) (nat.succ_pos n), if (n + 1) % 2 = 0 then to_binary ((n + 1) / 2) ++ [ff] else to_binary ((n + 1) / 2) ++ [tt]
lemma last_bit_one (n : ℕ) : (2 * n + 1) % 2 = 1 := begin
have initial := (nat.modeq.mod_modeq (2 * n) 2).symm,
rw nat.mul_mod_right 2 n at initial,
have := nat.modeq.modeq_add initial (nat.modeq.refl 1),
norm_num at this,
rw nat.modeq at this,
norm_num at this,
assumption,
end
lemma last_bit_zero (n : ℕ) : (2 * n + 1 + 1) % 2 = 0 := begin
have initial := (nat.modeq.mod_modeq (2 * n + 1) 2).symm,
rw last_bit_one n at initial,
have := nat.modeq.modeq_add initial (nat.modeq.refl 1),
rw nat.modeq at this,
have eq_zero : (1 + 1) % 2 = 0 := rfl,
rw eq_zero at this,
assumption,
end
lemma append_true : ∀ n, to_binary (2 * n + 1) = to_binary n ++ [tt] := begin
intro n,
rw to_binary,
split_ifs,
{
have := last_bit_one n,
linarith,
},
{
have : (2 * n + 1) / 2 = n := begin
have := nat.div_add_mod (2 * n + 1) 2,
rw last_bit_one at this,
have := nat.succ.inj this,
have := nat.eq_of_mul_eq_mul_left zero_lt_two this,
assumption,
end,
rw this,
}
end
lemma append_false : ∀ n, 0 < n → to_binary (2 * n) = to_binary n ++ [ff] := begin
intro n,
intro guard,
cases n,
{
linarith,
},
{
have : 2 * (n + 1) = (2 * n + 1) + 1 := by linarith,
rw this,
rw to_binary,
split_ifs,
{
have : (2 * n + 1 + 1) / 2 = n + 1 := begin
have := nat.div_add_mod (2 * n + 1 + 1) 2,
rw last_bit_zero at this,
norm_num at this,
have distribute : 2 * n + 1 + 1 = 2 * (n + 1) := by ring,
rw distribute at this,
rw distribute,
have := nat.eq_of_mul_eq_mul_left zero_lt_two this,
assumption,
end,
rw this,
},
{
have := last_bit_zero n,
exfalso,
exact h this,
}
}
end
def dp (n : ℕ) := dp_popcount (to_binary n) ff
def popcount : list bool → ℕ
| [] := 0
| (tt::the_rest) := 1 + popcount the_rest
| (ff::the_rest) := popcount the_rest
def total_popcount : ℕ → ℕ
| 0 := 0
| (n + 1) := total_popcount n + popcount (to_binary n)
lemma cardinality_ignores_string : ∀ string, dp_cardinality (string ++ [ff]) tt = dp_cardinality (string ++ [tt]) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string]
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string]
lemma cardinality_ignores_string' : ∀ string, dp_cardinality (string ++ [ff]) tt = dp_cardinality (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string']
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string']
lemma cardinality_ignores_string'' : ∀ string, dp_cardinality (string ++ [tt]) tt = dp_cardinality (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string'']
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string'']
lemma popcount_ignores_string : ∀ string, dp_popcount (string ++ [tt]) tt = dp_popcount (string ++ [ff]) tt
| [] := rfl
| (tt::rest) := by simp [dp_popcount, popcount_ignores_string, cardinality_ignores_string]
| (ff::rest) := by simp [dp_popcount, popcount_ignores_string, cardinality_ignores_string]
lemma popcount_ignores_string' : ∀ string, dp_popcount (string ++ [ff]) tt = dp_popcount (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_popcount, popcount_ignores_string', cardinality_ignores_string', cardinality_ignores_string, cardinality_ignores_string'', dp_cardinality]
| (ff::rest) := by simp [dp_popcount, popcount_ignores_string', cardinality_ignores_string', cardinality_ignores_string, cardinality_ignores_string'', dp_cardinality]
lemma double_cardinality : ∀ string, dp_cardinality (string ++ [ff]) ff = 2 * dp_cardinality string ff
| [] := rfl
| (tt::rest) := begin
rw dp_cardinality,
have : tt::rest ++ [ff] = tt::(rest ++ [ff]) := cons_append tt rest [ff],
rw this,
rw dp_cardinality,
rw double_cardinality rest,
rw cardinality_ignores_string',
rw dp_cardinality,
ring,
end
| (ff::rest) := begin
rw dp_cardinality,
have : ff::rest ++ [ff] = ff::(rest ++ [ff]) := cons_append ff rest [ff],
rw this,
rw dp_cardinality,
rw double_cardinality rest,
end
lemma dp_cardinality_odd : ∀ string, dp_cardinality (string ++ [tt]) ff = dp_cardinality (string ++ [ff]) ff + 1
| [] := rfl
| (tt::rest) := begin
simp [dp_cardinality, dp_cardinality_odd, cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
ring,
end
| (ff::rest) := begin
simp [dp_cardinality, dp_cardinality_odd, cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
end
(cont'd)
Huỳnh Trần Khanh (Feb 09 2021 at 03:47):
lemma not_even_is_odd (n : ℕ) (h : ¬(n % 2 = 0)) : n % 2 = 1 := begin
rwa [← nat.odd_iff, nat.odd_iff_not_even, nat.even_iff],
end
mutual lemma is_bijective_even, is_bijective_odd
with is_bijective_even : ∀ n, dp_cardinality (to_binary (2 * n)) ff = 2 * n
| n := begin
exact if is_zero: n = 0 then begin simp [is_zero, dp_cardinality, to_binary], end else begin
rw append_false n (pos_iff_ne_zero.mpr is_zero),
rw double_cardinality,
have := nat.div_add_mod n 2,
exact if is_even: n % 2 = 0 then begin
rw is_even at this,
norm_num at this,
rw this.symm,
norm_num,
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_even (n / 2),
end
else
begin
rw not_even_is_odd n is_even at this,
rw this.symm,
norm_num,
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_odd (n / 2),
end
end,
end
with is_bijective_odd : ∀ n, dp_cardinality (to_binary (2 * n + 1)) ff = 2 * n + 1
| n := begin
rw append_true,
rw dp_cardinality_odd,
norm_num,
have := nat.div_add_mod n 2,
exact if is_zero: n = 0 then begin
simp [is_zero, dp_cardinality, to_binary],
end else if is_even: n % 2 = 0 then begin
rw is_even at this,
norm_num at this,
rw this.symm,
rw double_cardinality,
norm_num,
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_even (n / 2),
end
else
begin
rw not_even_is_odd n is_even at this,
rw this.symm,
rw double_cardinality,
norm_num,
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_odd (n / 2),
end
end
lemma is_bijective: ∀ n, dp_cardinality (to_binary n) ff = n := begin
intro n,
have := nat.div_add_mod n 2,
exact if is_even: n % 2 = 0 then begin
rw is_even at this,
norm_num at this,
rw this.symm,
exact is_bijective_even (n / 2),
end else begin
have is_odd := not_even_is_odd n is_even,
rw is_odd at this,
rw this.symm,
exact is_bijective_odd (n / 2),
end
end
lemma helper : ∀ string, dp_popcount (string ++ [ff]) ff = dp_cardinality string ff + 2 * dp_popcount string ff
| [] := rfl
| (tt::rest) := begin
have : tt::rest ++ [ff] = tt::(rest ++ [ff]) := cons_append tt rest [ff],
rw this,
rw dp_popcount,
rw helper rest,
rw double_cardinality,
rw dp_cardinality,
rw dp_popcount,
rw popcount_ignores_string',
rw dp_popcount,
ring,
end
| (ff::rest) := begin
have : ff::rest ++ [ff] = ff::(rest ++ [ff]) := cons_append ff rest [ff],
rw this,
rw dp_popcount,
rw helper rest,
rw dp_cardinality,
rw dp_popcount,
end
lemma even_dp_induction (n : ℕ) : dp (2 * n) = n + 2 * dp n := if is_zero: n = 0 then begin
rw is_zero,
refl,
end else begin
rw dp,
rw append_false n (pos_iff_ne_zero.mpr is_zero),
rw dp,
rw helper,
rw is_bijective,
end
lemma popcount_even : ∀ string, popcount (string ++ [ff]) = popcount string
| [] := rfl
| (tt::rest) := by simp [popcount, popcount_even]
| (ff::rest) := by simp [popcount, popcount_even]
lemma popcount_odd : ∀ string, popcount (string ++ [tt]) = popcount string + 1
| [] := rfl
| (tt::rest) := by simp [popcount, popcount_odd, add_assoc]
| (ff::rest) := by simp [popcount, popcount_odd, add_assoc]
lemma even_popcount_induction (n : ℕ) : total_popcount (2 * n) = n + 2 * total_popcount n := begin
induction n with n induction_hypothesis,
{
refl,
},
{
rw total_popcount,
have : 2 * n.succ = 2 * n + 1 + 1 := by ring,
rw this,
rw total_popcount,
rw total_popcount,
rw append_true,
rw induction_hypothesis,
exact if is_zero: n = 0 then begin
rw is_zero,
refl,
end else begin
rw append_false n (pos_iff_ne_zero.mpr is_zero),
suffices : n + 2 * total_popcount n + popcount (to_binary n) + (popcount (to_binary n) + 1) = n.succ + 2 * (total_popcount n + popcount (to_binary n)), from begin
simp [popcount_even, popcount_odd],
assumption,
end,
rw nat.succ_eq_add_one,
ring,
end,
}
end
lemma dp_popcount_odd_add_one: ∀ string, dp_cardinality (string ++ [tt]) ff = dp_cardinality (string ++ [ff]) ff + 1
| [] := rfl
| (tt::rest) := begin
suffices : dp_cardinality (rest ++ [tt]) tt + (dp_cardinality (rest ++ [ff]) ff + 1) = dp_cardinality (rest ++ [ff]) tt + dp_cardinality (rest ++ [ff]) ff + 1, from begin
simp [dp_popcount_odd_add_one, dp_cardinality],
assumption,
end,
rw cardinality_ignores_string,
simp [add_assoc],
end
| (ff::rest) := begin
suffices : dp_cardinality (rest ++ [tt]) tt + (dp_cardinality (rest ++ [ff]) ff + 1) = dp_cardinality (rest ++ [ff]) tt + dp_cardinality (rest ++ [ff]) ff + 1, from begin
simp [dp_popcount_odd_add_one, dp_cardinality],
end,
rw cardinality_ignores_string,
simp [add_assoc],
end
lemma dp_popcount_identity : ∀ string, dp_popcount (string ++ [tt]) ff = dp_popcount (string ++ [ff]) ff + popcount (string ++ [ff])
| [] := rfl
| (tt::rest) := begin
suffices : dp_popcount rest tt + (dp_popcount rest tt + dp_cardinality rest tt) + (dp_popcount (rest ++ [ff]) ff + popcount (rest ++ [ff]) + dp_cardinality (rest ++ [tt]) ff) = dp_popcount rest tt + (dp_popcount rest tt + dp_cardinality rest tt) + (dp_popcount (rest ++ [ff]) ff + dp_cardinality (rest ++ [ff]) ff) + (1 + popcount (rest ++ [ff])), from begin
simp [dp_popcount, dp_popcount_identity, dp_cardinality, popcount, popcount_ignores_string, popcount_ignores_string', cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
assumption,
end,
ring,
norm_num,
rw dp_popcount_odd_add_one,
end
| (ff::rest) := by simp [dp_popcount, dp_popcount_identity, dp_cardinality, popcount, popcount_ignores_string, popcount_ignores_string', cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string'']
lemma odd_dp_induction (n : ℕ) : dp (2 * n + 1) = dp (2 * n) + popcount (to_binary (2 * n)) := begin
rw dp,
rw append_true,
exact if is_zero: n = 0 then begin
rw is_zero, refl,
end else begin
rw dp,
rw append_false n (pos_iff_ne_zero.mpr is_zero),
rw dp_popcount_identity,
end,
end
lemma dp_eq_total_popcount : ∀ n, dp n = total_popcount n
| n := if is_even: n % 2 = 0 then begin
have := nat.div_add_mod n 2,
rw is_even at this,
norm_num at this,
rw this.symm,
rw even_dp_induction (n / 2),
rw even_popcount_induction (n / 2),
norm_num,
exact if is_zero: n = 0 then begin
rw is_zero,
refl,
end
else begin
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), dp_eq_total_popcount (n / 2),
end
end else begin
have is_odd := not_even_is_odd n is_even,
have := nat.div_add_mod n 2,
rw is_odd at this,
rw this.symm,
rw odd_dp_induction (n / 2),
rw total_popcount,
norm_num,
rw even_dp_induction (n / 2),
rw even_popcount_induction (n / 2),
norm_num,
exact if is_zero: n = 0 then begin
rw is_zero,
refl,
end
else begin
exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), dp_eq_total_popcount (n / 2),
end
end
Mario Carneiro (Feb 09 2021 at 03:55):
You can use a gist if you want to share code above zulip's file size limit
Huỳnh Trần Khanh (Feb 09 2021 at 05:35):
I'm reluctant to do that—link rot is my perpetual enemy. Thanks for the suggestion though.
Mario Carneiro (Feb 09 2021 at 05:46):
gists don't bitrot though? They aren't very discoverable but they should last as long as github does
Kevin Buzzard (Feb 09 2021 at 07:42):
My experience is that rw A, rw B, rw C,...
can be decidedly slower than rw [A, B, C, ...]
Kevin Buzzard (Feb 09 2021 at 09:01):
I have to spend today preparing for class so can't give detailed feedback but just to say that this looks pretty good. to_binary_terminates
should be called nat.div_two_lt_of_pos
(to match div_two_lt_of_pos
) and should be PR'ed (I was surprised it wasn't there, but library_search
didn't find it). It's a bit disappointing that omega
doesn't solve this (omega
takes a while to run, but less time to write!). Here's a shorter proof:
lemma nat.div_two_lt_of_pos (n : ℕ) (h : 0 < n) : n / 2 < n :=
by rwa [nat.div_lt_iff_lt_mul' (zero_lt_two), mul_two, lt_add_iff_pos_left]
which also compiles in about 15ms on my machine -- your proof takes nearly 200ms, probably because of the appeals to linarith
. In fact let me just run through the nat
stuff you did:
lemma last_bit_one (n : ℕ) : (2 * n + 1) % 2 = 1 := begin
rw [add_comm, nat.add_mul_mod_self_left],
refl,
end
and it should probably be called something like two_mul_add_one_mod_two or something :-) last_bit_zero
shouldn't really be stated like that, 2*n+1+1
should be replaced by 2*n+2
(they are defeq):
lemma two_mul_add_two_mod_two (n : ℕ) : (2 * n + 2) % 2 = 0 :=
begin
rw [add_comm, nat.add_mul_mod_self_left, nat.mod_self],
end
lemma last_bit_zero (n : ℕ) : (2 * n + 1 + 1) % 2 = 0 :=
two_mul_add_two_mod_two n
In general lemmas should be stated in their "simplest form". Here's a sublemma you needed (which you should really factor out -- two short proofs is better than one big one, this is a style tip):
lemma two_mul_add_one_div_two {n : ℕ} : (2 * n + 1) / 2 = n :=
begin
rw [add_comm, nat.add_mul_div_left _ _ (zero_lt_two)],
exact zero_add n,
end
Note now I abuse definitional equality.
Mario Carneiro (Feb 09 2021 at 09:50):
This theorem looks a lot like docs#nat.div2_bit
Last updated: Dec 20 2023 at 11:08 UTC