Zulip Chat Archive
Stream: new members
Topic: help on proof involving Vector.get/set
Angelina Chen (Jul 30 2025 at 02:26):
I'm trying to prove this simple lemma about our cache coherence protocol that says that when a shim processes a write instruction to an address, it increments the timestamp for that address. This fact should be trivial based on the way that shimWrite has been written (see below, after type and step function definitions), but since this is the first proof I’ve been trying to write for this project, I’m having some trouble getting lean to simplify all the vector operations from my functions.
For context, here are the structures/definitions that I'm working with:
structure ShimElemState : Type where
data : Data
ts : Timestamp
-- [etc. remaining fields omitted]
def ShimCache (c : SystemConfig) : Type := Vector ShimElemState c.addrCount
structure Shim (c : SystemConfig) : Type where
state : ShimCache c
-- [etc.]
def ShimType (c : SystemConfig) : Type := Vector (Shim c) c.threads
Here is the step function that I use in the lemma:
inductive increment_step {c : SystemConfig} : IncState c → IncState c → Prop where
| ProcessInstr : forall (s s' : IncState c) (shim : ShimId c),
canIssueInstr shim s →
s' = getAndIssueInstr shim s →
increment_step s s'
-- | [etc]
Below is the ShimWrite function that processes the write instruction. There is a conditional handling instructions with SC strength. Regardless of which branch is taken, timestamp is always incremented, though I believe we still have to case on this in the proof.
def shimWrite {c : SystemConfig} (shim : ShimId c) (addr : Addr c) (data : Data) (stren : OpStrength) (shimVec : ShimType c) (e : Execution c) (net : NETOrdered c) : (ShimType c) × (Execution c) × (NETOrdered c) :=
let newTs : Timestamp := ((shimVec.get shim).state.get addr).ts + 1
let (shimVec', e') := popInstr shim shimVec e
let shimVec'' := shimWriteCache shim CacheState.Valid data newTs addr shimVec'
let shimNode : Node c := Fin.castSucc shim
let CCNode : Node c := Fin.mk c.threads (Nat.lt_succ_self c.threads)
let net' := send MType.WRITE shimNode CCNode data addr newTs stren net
if stren = OpStrength.SC then
let modShim := {shimVec''.get shim with pendingWSC := true}
let shimVec''' := shimVec''.set shim modShim
(shimVec''', e', net')
else
(shimVec'', e', net')
Here is what I have for the lemma and proof so far:
def increments_ts {c : SystemConfig} (s s' : IncState c) (shim : ShimId c) : Prop :=
((s'.shimVec.get shim).state.get (getInstr shim s.shimVec s.execution).2.addr).ts =
((s.shimVec.get shim).state.get (getInstr shim s.shimVec s.execution).2.addr).ts +1
lemma local_W_inc_ts {c : SystemConfig} :
forall (s s' : IncState c) (shim : ShimId c),
canIssueInstr shim s →
s' = getAndIssueInstr shim s →
(getInstr shim s.shimVec s.execution).2.access = PermissionType.store →
increment_step s s' →
increments_ts s s' shim
:= by
intro s s' shim h_issue h_s' h_W h_step
unfold increments_ts
rw [h_s']
unfold getAndIssueInstr
simp [h_W]
unfold shimWrite
simp
split
case isTrue =>
simp
sorry
case isFalse =>
simp
unfold shimWriteCache
-- simp [List.Vector.get_set_same]
simp [Vector.getElem_set_self shim.isLt]
I'm working on the isFalse (not SC) case first, and after the last simp tactic, my infoview shows the following remaining goal, which is essentially an unfolded version of the increments_ts definition used in the lemma:
⊢ (Vector.get
-- below is `s'.shimVec` from the increments_ts definition. shim is updated by shimWrite
((Vector.set (popInstr shim s.shimVec (getInstr shim s.shimVec s.execution).1).1 ↑shim
{
state := [state omitted]
ts :=
(Vector.get (Vector.get s.shimVec shim).state (getInstr shim s.shimVec s.execution).2.addr).ts + 1,
[etc. other fields])
.get
shim).state
(getInstr shim s.shimVec s.execution).2.addr).ts =
(Vector.get (Vector.get s.shimVec shim).state (getInstr shim s.shimVec s.execution).2.addr).ts + 1
I'm confused as to why simp is unable to simplify the get and set to the same index. I thought that the following theorem should be applicable
@[simp] theorem Vector.getElem_set_self {α : Type u_1} {n : Nat} {xs : Vector α n} {i : Nat} {x : α} (hi : i < n) :
(xs.set i x hi)[i] = x
where i = shim, and hi comes from the fact that shim < n = c.threads because shim : ShimId c = Fin c.threads. However, when I try to do simp [Vector.getElem_set_self shim.isLt], it says that the argument is unused. I was thinking that perhaps it's because the goal uses Vector.get, whereas the theorem is using bracket indexing, but I'm not sure.
I also originally tried simp [List.Vector.get_set_same], but that also didn't work I assume because of the mismatch between ↑shim : Nat in the Vector.set and shim : ShimId c in the Vector.get, though I’m not sure how to work around this issue since that difference is also present in how Vector.set/get are defined.
I've been stuck on how to simplify down the goal in this proof for a while, so If someone could please give some advice as to how I should proceed in this proof, I'd really appreciate it. Thanks!
Edit: I changed ShimType and ShimCache to be List.Vectors instead of Vector, and now it seeems like those mismatches have been resolved and simp solves the goal.
Last updated: Dec 20 2025 at 21:32 UTC