Zulip Chat Archive

Stream: lean4

Topic: Strange UInt32 behavior on coercion to Nat


Amitayush Thakur (Mar 07 2025 at 08:51):

import Mathlib
import Mathlib.Data.UInt


instance : Coe UInt32 Nat where
  coe u := UInt32.toNat u

def temp := ((UInt32.ofNat 4294967295) + (UInt32.ofNat 1))
#eval                        temp
-- Prints 0 (expected)
#eval       ((UInt32.ofNat 4294967295) + (UInt32.ofNat 1))
-- Prints 0 (expected)
#eval                      (temp : Nat)
-- Prints 0 (expected)
#eval       (((UInt32.ofNat 4294967295) + (UInt32.ofNat 1)) : Nat)
-- Prints 4294967296 (NOT expected!!)

theorem add_uint32_coe_bound (a b : UInt32)
  :((a + b) : Nat) = (a : Nat) + (b : Nat) := by
  simp

I'm having difficulty understanding why the behavior is different when we create a temporary definition.

Amitayush Thakur (Mar 07 2025 at 08:53):

IMO the theorem should not compile. But maybe I'm understanding something differently.

Markus Himmel (Mar 07 2025 at 09:05):

The behavior you're observing is due to how the elaborator inserts coercions. The key point is this: coercions are one of the last things the elaborator tries to make something typecheck. The notation (term : type) does not necessarily correspond with insertion of a coercion, but rather just specifies the expected type. So when the elaborator sees (((UInt32.ofNat 4294967295) + (UInt32.ofNat 1)) : Nat), it will try to make sense of the term in a way where the resulting type is Nat. So what it does is look for meanings of + that return a Nat, and it finds the addition on natural numbers, so now it wants the arguments to + to be natural numbers. Then it will try to make sense of (UInt32.ofNat 4294967295) in a way where the resulting type is a natural number, and only now does it turn to the coercion that you have defined, so it turns the term into (UInt32.ofNat 4294967295).toNat, and the entire term comes out to (UInt32.ofNat 4294967295).toNat + (UInt32.ofNat 1).toNat.

This also explains why making the temporary definition makes a difference. Here, the type of temp is fixed to UInt32, and so the only thing the elaborator can do is create temp.toNat.

Markus Himmel (Mar 07 2025 at 09:11):

You can prevent the expected type from affecting elaboration using the syntax (term :). So

#eval       (((UInt32.ofNat 4294967295) + (UInt32.ofNat 1) :) : Nat)

will print 0.

Amitayush Thakur (Mar 07 2025 at 09:17):

Thank you so much for the help.
I actually wanted to state something like:

theorem add_uint32_coe_bound (a b : UInt32)
  (h_a_b: (a : Nat) + (b : Nat) < 4294967296)
  : ((a + b) : Nat) = (a : Nat) + (b : Nat) :=

But I observed that I was able to prove even without h_a_b assumption.
However, when I change to:

 ((a + b:) : Nat) = (a : Nat) + (b : Nat)

then the simp proof no longer works, I had to do something like:

theorem add_uint32_coe_bound (a b : UInt32)
  (h_a_b: (a : Nat) + (b : Nat) < 4294967296)
  :((a + b:) : Nat) = (a : Nat) + (b : Nat) := by
  simp
  assumption

Do you see any potential problem using this further to prove more things about programs that use UInt32 instead of Nat?

Markus Himmel (Mar 07 2025 at 09:24):

I'm not sure how much good the coercion is doing you here. I don't know what your actual application looks like, but at least the lemma actually looks clearer to me without the coercion:

theorem add_uint32_coe_bound (a b : UInt32)
    (h_a_b : a.toNat + b.toNat < 4294967296) :
    (a + b).toNat = a.toNat + b.toNat := by
  simp
  assumption

Amitayush Thakur (Mar 07 2025 at 09:26):

I guess you are right,
I wanted to prove something like:

theorem add_uint32_coe_bound (a b : UInt32)
  (h_a_b: (a : Nat) + (b : Nat) < 4294967296)
  :((a + b:) : Nat) = (a : Nat) + (b : Nat) := by
  simp
  assumption

def double_uint32 (a: UInt32): UInt32
:= a + a

theorem double_uint32_coe_bound
(n: Nat)
(h_n: n < 2147483648)
: (double_uint32 (UInt32.ofNat n)).toNat = 2*n := by
  rw [double_uint32]
  rw [add_uint32_coe_bound (UInt32.ofNat n) (UInt32.ofNat n) _]
  simp
  rw [Nat.mod_eq_of_lt]
  linarith
  linarith
  simp
  rw [Nat.mod_eq_of_lt]
  linarith
  linarith

Amitayush Thakur (Mar 07 2025 at 09:30):

But you are right, I tested the same thing can be done without coercion, so maybe defining it is not that useful.

theorem add_uint32_coe_bound1 (a b : UInt32)
    (h_a_b : a.toNat + b.toNat < 4294967296) :
    (a + b).toNat = a.toNat + b.toNat := by
  simp
  assumption

def double_uint32 (a: UInt32): UInt32
:= a + a

theorem double_uint32_coe_bound
(n: Nat)
(h_n: n < 2147483648)
: (double_uint32 (UInt32.ofNat n)).toNat = 2*n := by
  rw [double_uint32]
  rw [add_uint32_coe_bound1 (UInt32.ofNat n) (UInt32.ofNat n) _]
  simp
  rw [Nat.mod_eq_of_lt]
  linarith
  linarith
  simp
  rw [Nat.mod_eq_of_lt]
  linarith
  linarith

This works perfectly. However, I wanted nice syntactical sugar I guess instead of writing .toNat everytime. I didn't know the behavior change with just some : added

Robin Arnez (Mar 07 2025 at 12:19):

I might also note this also works:

def doubleUInt32 (a : UInt32) : UInt32 := a + a

theorem double_uint32_coe_bound (n : Nat) (h_n : n < 2147483648) : (doubleUInt32 (UInt32.ofNat n)).toNat = 2*n := by
  unfold doubleUInt32
  simp
  omega

Last updated: May 02 2025 at 03:31 UTC