Zulip Chat Archive

Stream: new members

Topic: Proof review: CSES Counting Bits


Huỳnh Trần Khanh (Feb 09 2021 at 03:47):

This thread is a sequel to this thread. The main lemma I proved is dp_eq_total_popcount.

import data.nat.basic
import data.list.range
import tactic.linarith
import data.nat.modeq
import tactic.slim_check
import data.nat.parity
open list

def dp_cardinality : list bool  bool  
| [] tt := 1
| [] ff := 0
| (ff::the_rest) ff := dp_cardinality the_rest ff
| (tt::the_rest) ff := dp_cardinality the_rest tt + dp_cardinality the_rest ff
| (_::the_rest) tt := 2 * dp_cardinality the_rest tt

def dp_popcount : list bool  bool  
| [] _ := 0
| (ff::the_rest) ff := dp_popcount the_rest ff
| (tt::the_rest) ff := dp_popcount the_rest tt + (dp_popcount the_rest ff + dp_cardinality the_rest ff)
| (_::the_rest) tt := dp_popcount the_rest tt + (dp_popcount the_rest tt + dp_cardinality the_rest tt)

lemma to_binary_terminates (n : ) (h : 0 < n) : n / 2 < n := begin
  have := lt_or_eq_of_le (nat.div_le_self n 2),
  cases this,
  {
    assumption,
  },
  {
    exfalso,
    have := (nat.div_eq_self).mp this,
    cases this,
    {
      linarith,
    },
    {
      linarith,
    }
  }
end

def to_binary :   list bool
| 0 := []
| (n + 1) := have (n + 1) / 2 < (n + 1) := to_binary_terminates (n + 1) (nat.succ_pos n), if (n + 1) % 2 = 0 then to_binary ((n + 1) / 2) ++ [ff] else to_binary ((n + 1) / 2) ++ [tt]

lemma last_bit_one (n : ) : (2 * n + 1) % 2 = 1 := begin
  have initial := (nat.modeq.mod_modeq (2 * n) 2).symm,
  rw nat.mul_mod_right 2 n at initial,
  have := nat.modeq.modeq_add initial (nat.modeq.refl 1),
  norm_num at this,
  rw nat.modeq at this,
  norm_num at this,
  assumption,
end

lemma last_bit_zero (n : ) : (2 * n + 1 + 1) % 2 = 0 := begin
  have initial := (nat.modeq.mod_modeq (2 * n + 1) 2).symm,
  rw last_bit_one n at initial,
  have := nat.modeq.modeq_add initial (nat.modeq.refl 1),
  rw nat.modeq at this,
  have eq_zero : (1 + 1) % 2 = 0 := rfl,
  rw eq_zero at this,
  assumption,
end

lemma append_true :  n, to_binary (2 * n + 1) = to_binary n ++ [tt] := begin
  intro n,
  rw to_binary,
  split_ifs,
  {
    have := last_bit_one n,
    linarith,
  },
  {
    have : (2 * n + 1) / 2 = n := begin
      have := nat.div_add_mod (2 * n + 1) 2,
      rw last_bit_one at this,
      have := nat.succ.inj this,
      have := nat.eq_of_mul_eq_mul_left zero_lt_two this,
      assumption,
    end,

    rw this,
  }
end

lemma append_false :  n, 0 < n  to_binary (2 * n) = to_binary n ++ [ff] := begin
  intro n,
  intro guard,
  cases n,
  {
    linarith,
  },
  {
    have : 2 * (n + 1) = (2 * n + 1) + 1 := by linarith,
    rw this,
    rw to_binary,
    split_ifs,
    {
      have : (2 * n + 1 + 1) / 2 = n + 1 := begin
        have := nat.div_add_mod (2 * n + 1 + 1) 2,
        rw last_bit_zero at this,
        norm_num at this,
        have distribute : 2 * n + 1 + 1 = 2 * (n + 1) := by ring,
        rw distribute at this,
        rw distribute,
        have := nat.eq_of_mul_eq_mul_left zero_lt_two this,
        assumption,
      end,

      rw this,
    },
    {
      have := last_bit_zero n,
      exfalso,
      exact h this,
    }
  }
end

def dp (n : ) := dp_popcount (to_binary n) ff

def popcount : list bool  
| [] := 0
| (tt::the_rest) := 1 + popcount the_rest
| (ff::the_rest) := popcount the_rest

def total_popcount :   
| 0 := 0
| (n + 1) := total_popcount n + popcount (to_binary n)

lemma cardinality_ignores_string :  string, dp_cardinality (string ++ [ff]) tt = dp_cardinality (string ++ [tt]) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string]
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string]

lemma cardinality_ignores_string' :  string, dp_cardinality (string ++ [ff]) tt = dp_cardinality (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string']
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string']

lemma cardinality_ignores_string'' :  string, dp_cardinality (string ++ [tt]) tt = dp_cardinality (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_cardinality, cardinality_ignores_string'']
| (ff::rest) := by simp [dp_cardinality, cardinality_ignores_string'']

lemma popcount_ignores_string :  string, dp_popcount (string ++ [tt]) tt = dp_popcount (string ++ [ff]) tt
| [] := rfl
| (tt::rest) := by simp [dp_popcount, popcount_ignores_string, cardinality_ignores_string]
| (ff::rest) := by simp [dp_popcount, popcount_ignores_string, cardinality_ignores_string]

lemma popcount_ignores_string' :  string, dp_popcount (string ++ [ff]) tt = dp_popcount (tt::string) tt
| [] := rfl
| (tt::rest) := by simp [dp_popcount, popcount_ignores_string', cardinality_ignores_string', cardinality_ignores_string, cardinality_ignores_string'', dp_cardinality]
| (ff::rest) := by simp [dp_popcount, popcount_ignores_string', cardinality_ignores_string', cardinality_ignores_string, cardinality_ignores_string'', dp_cardinality]

lemma double_cardinality :  string, dp_cardinality (string ++ [ff]) ff = 2 * dp_cardinality string ff
| [] := rfl
| (tt::rest) := begin
  rw dp_cardinality,
  have : tt::rest ++ [ff] = tt::(rest ++ [ff]) := cons_append tt rest [ff],
  rw this,
  rw dp_cardinality,
  rw double_cardinality rest,
  rw cardinality_ignores_string',
  rw dp_cardinality,
  ring,
end
| (ff::rest) := begin
  rw dp_cardinality,
  have : ff::rest ++ [ff] = ff::(rest ++ [ff]) := cons_append ff rest [ff],
  rw this,
  rw dp_cardinality,
  rw double_cardinality rest,
end

lemma dp_cardinality_odd :  string, dp_cardinality (string ++ [tt]) ff = dp_cardinality (string ++ [ff]) ff + 1
| [] := rfl
| (tt::rest) := begin
  simp [dp_cardinality, dp_cardinality_odd, cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
  ring,
end
| (ff::rest) := begin
  simp [dp_cardinality, dp_cardinality_odd, cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
end

(cont'd)

Huỳnh Trần Khanh (Feb 09 2021 at 03:47):

lemma not_even_is_odd (n : ) (h : ¬(n % 2 = 0)) : n % 2 = 1 := begin
  rwa [ nat.odd_iff, nat.odd_iff_not_even, nat.even_iff],
end

mutual lemma is_bijective_even, is_bijective_odd
with is_bijective_even :  n, dp_cardinality (to_binary (2 * n)) ff = 2 * n
| n := begin
  exact if is_zero: n = 0 then begin simp [is_zero, dp_cardinality, to_binary], end else begin
    rw append_false n (pos_iff_ne_zero.mpr is_zero),
    rw double_cardinality,
    have := nat.div_add_mod n 2,
    exact if is_even: n % 2 = 0 then begin
      rw is_even at this,
      norm_num at this,
      rw this.symm,
      norm_num,
      exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_even (n / 2),
    end
    else
    begin
      rw not_even_is_odd n is_even at this,
      rw this.symm,
      norm_num,
      exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_odd (n / 2),
    end
  end,
end
with is_bijective_odd :  n, dp_cardinality (to_binary (2 * n + 1)) ff = 2 * n + 1
| n := begin
  rw append_true,
  rw dp_cardinality_odd,
  norm_num,
  have := nat.div_add_mod n 2,
  exact if is_zero: n = 0 then begin
    simp [is_zero, dp_cardinality, to_binary],
  end else if is_even: n % 2 = 0 then begin
    rw is_even at this,
    norm_num at this,
    rw this.symm,
    rw double_cardinality,
    norm_num,
    exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_even (n / 2),
  end
  else
  begin
    rw not_even_is_odd n is_even at this,
    rw this.symm,
    rw double_cardinality,
    norm_num,
    exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), is_bijective_odd (n / 2),
  end
end

lemma is_bijective:  n, dp_cardinality (to_binary n) ff = n := begin
  intro n,
  have := nat.div_add_mod n 2,
  exact if is_even: n % 2 = 0 then begin
    rw is_even at this,
    norm_num at this,
    rw this.symm,
    exact is_bijective_even (n / 2),
  end else begin
    have is_odd := not_even_is_odd n is_even,
    rw is_odd at this,
    rw this.symm,
    exact is_bijective_odd (n / 2),
  end
end

lemma helper :  string, dp_popcount (string ++ [ff]) ff = dp_cardinality string ff + 2 * dp_popcount string ff
| [] := rfl
| (tt::rest) := begin
  have : tt::rest ++ [ff] = tt::(rest ++ [ff]) := cons_append tt rest [ff],
  rw this,
  rw dp_popcount,
  rw helper rest,
  rw double_cardinality,
  rw dp_cardinality,
  rw dp_popcount,
  rw popcount_ignores_string',
  rw dp_popcount,
  ring,
end
| (ff::rest) := begin
  have : ff::rest ++ [ff] = ff::(rest ++ [ff]) := cons_append ff rest [ff],
  rw this,
  rw dp_popcount,
  rw helper rest,
  rw dp_cardinality,
  rw dp_popcount,
end

lemma even_dp_induction (n : ) : dp (2 * n) = n + 2 * dp n := if is_zero: n = 0 then begin
  rw is_zero,
  refl,
end else begin
  rw dp,
  rw append_false n (pos_iff_ne_zero.mpr is_zero),
  rw dp,
  rw helper,
  rw is_bijective,
end

lemma popcount_even :  string, popcount (string ++ [ff]) = popcount string
| [] := rfl
| (tt::rest) := by simp [popcount, popcount_even]
| (ff::rest) := by simp [popcount, popcount_even]

lemma popcount_odd :  string, popcount (string ++ [tt]) = popcount string + 1
| [] := rfl
| (tt::rest) := by simp [popcount, popcount_odd, add_assoc]
| (ff::rest) := by simp [popcount, popcount_odd, add_assoc]

lemma even_popcount_induction (n : ) : total_popcount (2 * n) = n + 2 * total_popcount n := begin
  induction n with n induction_hypothesis,
  {
    refl,
  },
  {
    rw total_popcount,
    have : 2 * n.succ = 2 * n + 1 + 1 := by ring,
    rw this,
    rw total_popcount,
    rw total_popcount,
    rw append_true,
    rw induction_hypothesis,
    exact if is_zero: n = 0 then begin
      rw is_zero,
      refl,
    end else begin
      rw append_false n (pos_iff_ne_zero.mpr is_zero),
      suffices : n + 2 * total_popcount n + popcount (to_binary n) + (popcount (to_binary n) + 1) = n.succ + 2 * (total_popcount n + popcount (to_binary n)), from begin
        simp [popcount_even, popcount_odd],
        assumption,
      end,
      rw nat.succ_eq_add_one,
      ring,
    end,
  }
end

lemma dp_popcount_odd_add_one:  string, dp_cardinality (string ++ [tt]) ff = dp_cardinality (string ++ [ff]) ff + 1
| [] := rfl
| (tt::rest) := begin
  suffices : dp_cardinality (rest ++ [tt]) tt + (dp_cardinality (rest ++ [ff]) ff + 1) = dp_cardinality (rest ++ [ff]) tt + dp_cardinality (rest ++ [ff]) ff + 1, from begin
    simp [dp_popcount_odd_add_one, dp_cardinality],
    assumption,
  end,
  rw cardinality_ignores_string,
  simp [add_assoc],
end
| (ff::rest) := begin
  suffices : dp_cardinality (rest ++ [tt]) tt + (dp_cardinality (rest ++ [ff]) ff + 1) = dp_cardinality (rest ++ [ff]) tt + dp_cardinality (rest ++ [ff]) ff + 1, from begin
    simp [dp_popcount_odd_add_one, dp_cardinality],
  end,
  rw cardinality_ignores_string,
  simp [add_assoc],
end

lemma dp_popcount_identity :  string, dp_popcount (string ++ [tt]) ff = dp_popcount (string ++ [ff]) ff + popcount (string ++ [ff])
| [] := rfl
| (tt::rest) := begin
  suffices : dp_popcount rest tt + (dp_popcount rest tt + dp_cardinality rest tt) + (dp_popcount (rest ++ [ff]) ff + popcount (rest ++ [ff]) + dp_cardinality (rest ++ [tt]) ff) = dp_popcount rest tt + (dp_popcount rest tt + dp_cardinality rest tt) + (dp_popcount (rest ++ [ff]) ff + dp_cardinality (rest ++ [ff]) ff) + (1 + popcount (rest ++ [ff])), from begin
    simp [dp_popcount, dp_popcount_identity, dp_cardinality, popcount, popcount_ignores_string, popcount_ignores_string', cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string''],
    assumption,
  end,
  ring,
  norm_num,
  rw dp_popcount_odd_add_one,
end
| (ff::rest) := by simp [dp_popcount, dp_popcount_identity, dp_cardinality, popcount, popcount_ignores_string, popcount_ignores_string', cardinality_ignores_string, cardinality_ignores_string', cardinality_ignores_string'']

lemma odd_dp_induction (n : ) : dp (2 * n + 1) = dp (2 * n) + popcount (to_binary (2 * n)) := begin
  rw dp,
  rw append_true,
  exact if is_zero: n = 0 then begin
    rw is_zero, refl,
  end else begin
    rw dp,
    rw append_false n (pos_iff_ne_zero.mpr is_zero),
    rw dp_popcount_identity,
  end,
end

lemma dp_eq_total_popcount :  n, dp n = total_popcount n
| n := if is_even: n % 2 = 0 then begin
    have := nat.div_add_mod n 2,
    rw is_even at this,
    norm_num at this,
    rw this.symm,
    rw even_dp_induction (n / 2),
    rw even_popcount_induction (n / 2),
    norm_num,
    exact if is_zero: n = 0 then begin
      rw is_zero,
      refl,
    end
    else begin
      exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), dp_eq_total_popcount (n / 2),
    end
  end else begin
    have is_odd := not_even_is_odd n is_even,
    have := nat.div_add_mod n 2,
    rw is_odd at this,
    rw this.symm,
    rw odd_dp_induction (n / 2),
    rw total_popcount,
    norm_num,

    rw even_dp_induction (n / 2),
    rw even_popcount_induction (n / 2),
    norm_num,
    exact if is_zero: n = 0 then begin
      rw is_zero,
      refl,
    end
    else begin
      exact have n / 2 < n := to_binary_terminates n (pos_iff_ne_zero.mpr is_zero), dp_eq_total_popcount (n / 2),
    end
  end

Mario Carneiro (Feb 09 2021 at 03:55):

You can use a gist if you want to share code above zulip's file size limit

Huỳnh Trần Khanh (Feb 09 2021 at 05:35):

I'm reluctant to do that—link rot is my perpetual enemy. Thanks for the suggestion though.

Mario Carneiro (Feb 09 2021 at 05:46):

gists don't bitrot though? They aren't very discoverable but they should last as long as github does

Kevin Buzzard (Feb 09 2021 at 07:42):

My experience is that rw A, rw B, rw C,... can be decidedly slower than rw [A, B, C, ...]

Kevin Buzzard (Feb 09 2021 at 09:01):

I have to spend today preparing for class so can't give detailed feedback but just to say that this looks pretty good. to_binary_terminates should be called nat.div_two_lt_of_pos (to match div_two_lt_of_pos) and should be PR'ed (I was surprised it wasn't there, but library_search didn't find it). It's a bit disappointing that omega doesn't solve this (omega takes a while to run, but less time to write!). Here's a shorter proof:

lemma nat.div_two_lt_of_pos (n : ) (h : 0 < n) : n / 2 < n :=
by rwa [nat.div_lt_iff_lt_mul' (zero_lt_two), mul_two, lt_add_iff_pos_left]

which also compiles in about 15ms on my machine -- your proof takes nearly 200ms, probably because of the appeals to linarith. In fact let me just run through the nat stuff you did:

lemma last_bit_one (n : ) : (2 * n + 1) % 2 = 1 := begin
  rw [add_comm, nat.add_mul_mod_self_left],
  refl,
end

and it should probably be called something like two_mul_add_one_mod_two or something :-) last_bit_zero shouldn't really be stated like that, 2*n+1+1 should be replaced by 2*n+2 (they are defeq):

lemma two_mul_add_two_mod_two (n : ) : (2 * n + 2) % 2 = 0 :=
begin
  rw [add_comm, nat.add_mul_mod_self_left, nat.mod_self],
end

lemma last_bit_zero (n : ) : (2 * n + 1 + 1) % 2 = 0 :=
two_mul_add_two_mod_two n

In general lemmas should be stated in their "simplest form". Here's a sublemma you needed (which you should really factor out -- two short proofs is better than one big one, this is a style tip):

lemma two_mul_add_one_div_two {n : } : (2 * n + 1) / 2 = n :=
begin
  rw [add_comm, nat.add_mul_div_left _ _ (zero_lt_two)],
  exact zero_add n,
end

Note now I abuse definitional equality.

Mario Carneiro (Feb 09 2021 at 09:50):

This theorem looks a lot like docs#nat.div2_bit


Last updated: Dec 20 2023 at 11:08 UTC