I'm making a data structure to store disjoint sets of keys that point to values:
importStd.Data.HashMapimportStd.Data.HashSetimportMathlibopenStdvariable{kv:Type*}[instBeq:BEqk][instHash:Hashablek]instance:Hashable(HashSetk)wherehash:=hash∘HashSet.toArrayabbrevDisjointSetMap'(k:Typeu)(v:Typeu')[instBeq:BEqk][instHash:Hashablek]:=HashMap(HashSetk)vstructureDisjointSetMap(k:Type*)[BEqk][Hashablek](v:Type*)where/-- This maps from _sets_ of `k`s to `v`s -/innerMap:DisjointSetMap'kv/-- How to merge values together when key sets are merged -/mergeFn:v→v→(keysNew:HashSetk)→v
However for this data structure to be correct I want to include a proof in the DisjointSetMap structure that:
none of the hashsets are empty
no two hashsets have any overlap. In other words, that for every single key there is at most one key set that contains that key
Here's the full code, including my attempt at a proof so far
importStd.Data.HashMapimportStd.Data.HashSetimportMathlibopenStdvariable{kv:Type*}[instBeq:BEqk][instHash:Hashablek]instance:Hashable(HashSetk)wherehash:=hash∘HashSet.toArraydefStd.HashSet.intersection(a:HashSetk)(b:HashSetk):HashSetk:=a.fold(init:=∅)funsx=>ifb.containsxthens.insertxelsesdefStd.HashSet.intersection_mem_both{k:Typeu}[BEqk][Hashablek][EquivBEqk][LawfulHashablek](a:HashSetk)(b:HashSetk):∀v∈(a.intersectionb),v∈a∧v∈b:=byrw[intersection]suffices∀a:Listk,∀q:HashSetk,∀(v:k),v∈List.foldl(funsx=>ifb.containsx=truethens.insertxelses)qa→v∈q∨a.containsv∧v∈bbysimpa[HashSet.fold_eq_foldl_toList,HashSet.mem_iff_contains]usingthisa.toList∅introaqinductionageneralizingqwith|nil=>simp|conshdtlih=>simpintrovhvobtain(h|h):=ih__hv·splitath·nextht=>rw[HashSet.mem_insert]athobtain(h|h):=h·refineOr.inr⟨Or.inl(BEq.symmh),?_⟩rwa[HashSet.mem_iff_contains,←HashSet.contains_congrh]·exactOr.inlh·exactOr.inlh·exactOr.inr⟨Or.inrh.1,h.2⟩/-- Gets the items in `a` that are not in `b` -/defStd.HashSet.diff(a:HashSetk)(b:HashSetk):HashSetk:=a.filter(not∘b.contains)abbrevDisjointSetMap'(k:Typeu)(v:Typeu')[instBeq:BEqk][instHash:Hashablek]:=HashMap(HashSetk)vinductiveNoEmptyKeys(map:DisjointSetMap'kv):Propwhere|mk(keySet:HashSetk)(keySetInMap:keySet∈map)(keySetNotEmpty:keySet≠∅):NoEmptyKeysmap/-- @TODO: Need help with this 👇 -/inductiveNoKeyOverlap(map:DisjointSetMap'kv):Propwhere|mk(allPairsDisjoint:∀(keySet1keySet2:HashSetk),keySet1∈map→keySet2∈map→keySet1≠∅→keySet2≠∅→keySet1≠keySet2→∀(key:k),key∈keySet1→key∉keySet2→NoKeyOverlapmap):NoKeyOverlapmap/-- The vacuously true case for `NoKeyOverlap` where the map is empty. -/defNoKeyOverlap.empty:NoKeyOverlap(HashMap.empty:DisjointSetMap'kv):=byconstructorintrokeySet1keySet2h1simpath1structureDisjointSetMap(k:Type*)[BEqk][Hashablek](v:Type*)where/-- This maps from _sets_ of `k`s to `v`s -/innerMap:DisjointSetMap'kv/-- A proof that the inner map has no key overlap -/noKeyOverlap:NoKeyOverlapinnerMap/-- How to merge values together when key sets are merged -/mergeFn:v→v→(keysNew:HashSetk)→vnamespaceDisjointSetMapdefgetOuterMapFromInnerMap(innerMap:HashMap(HashSetk)v):HashMapk(HashSetk×v):=innerMap|>.fold(fun(map:HashMapk(HashSetk×v))(keySet:HashSetk)value=>keySet.fold(funmap'key=>map'.insertkey(keySet,value))map)HashMap.empty/-- A map with a separate key for each key in `d.innerMap`'s key sets.A projection from the inner map to a map of individual keys to key sets _and_ values; so we can: a) find items in the map by a single key by doing `d.outerMap[key]` b) which then returns the full set of keys that the single key belongs to, as well as the value stored for that set. The values in this map have type `HashSet k × v`-/defouterMap(d:DisjointSetMapkv):HashMapk(HashSetk×v):=getOuterMapFromInnerMapd.innerMapdefempty(mergeFn:v→v→(keysNew:HashSetk)→v):DisjointSetMapkv:={innerMap:=HashMap.emptynoKeyOverlap:=NoKeyOverlap.mk∅∅(∅∈innerMap)(∅∈innerMap)(bysimp)(bysimp)mergeFn}defaddSet(d:DisjointSetMapkv)(newKeySet:HashSetk)(val:v):DisjointSetMapkv:=letoverlappingSets:=d.innerMap.fold(init:=(newKeySet,[]))funacccurrSetvalue=>letintersection:=acc.1.intersectioncurrSetifintersection==∅then-- The no overlap case, so there's nothing to add hereaccelse-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge laterletunion:=acc.1.unioncurrSet(union,value::acc.2)matchoverlappingSetswith|(_,[])=>-- No overlaps, just insert the new set with its value{dwithinnerMap:=d.innerMap.insertnewKeySetval}|(mergedSet,valuesToMerge)=>-- Merge all overlapping sets and their valuesletmergedValue:=valuesToMerge.foldl(init:=val)(funaccvalue=>d.mergeFnaccvaluemergedSet)-- Remove old setsletnewInnerMap:=d.innerMap.fold(init:=HashMap.empty)funnewInnerMapcurrSet_=>lethasOverlap:=currSet.anynewKeySet.containsifhasOverlapthennewInnerMap.erasecurrSetelsenewInnerMap-- And insert the new snowballed merged set with its combined value{dwithinnerMap:=newInnerMap.insertmergedSetmergedValue}/-- This merges multiple sets without adding a new value. If none of the keys are in the map then this is a no-op because we have no value to set for it! -/defunion(d:DisjointSetMapkv)(keysToMerge:HashSetk):DisjointSetMapkv:=letoverlappingSets:=d.innerMap.fold(init:=(keysToMerge,[]))funacccurrSetvalue=>letintersection:=acc.1.intersectioncurrSetifintersection==∅then-- The no overlap case, so there's nothing to add hereaccelse-- If there is some overlap, we snowball the current set in the accumulated set, and include the value of the newly merged set in the list of values to merge laterletunion:=acc.1.unioncurrSet(union,value::acc.2)matchoverlappingSetswith|(_,[])=>-- No overlaps, none of the keys are in the map so we do nothing because we have no value to set for itd|(mergedSet,firstVal::restValsToMerge)=>-- Merge all overlapping sets and their valuesletmergedValue:=restValsToMerge.foldl(init:=firstVal)(funaccvalue=>d.mergeFnaccvaluemergedSet)-- Remove old setsletnewInnerMap:=d.innerMap.fold(init:=HashMap.empty)funnewInnerMapcurrSet_=>lethasOverlap:=currSet.anykeysToMerge.containsifhasOverlapthennewInnerMap.erasecurrSetelsenewInnerMap-- And insert the new snowballed merged set with its combined value{dwithinnerMap:=newInnerMap.insertmergedSetmergedValue}deffind?(d:DisjointSetMapkv)(key:k):Optionv:=d.outerMap[key]?|>.mapProd.snddeffind(d:DisjointSetMapkv)(key:k)(h:key∈d.outerMap):v:=d.outerMap[key]'h|>.2endDisjointSetMap
but I don't think I'm going down the right path here. Is there a better and ideally simpler way to prove this property?
also, notice that NoKeyOverlap doesn't need to check if the sets are nonempty because NoEmptyKeys already takes care of that (assuming you include NoEmptyKeys in the DisjointSetMap structure)
Good point – because I needed them to be inductives for this other data structure I made, but you're right that since this property isn't recursive it can just be a regular function
@Johannes Tantow right but if you don't require keySet1 and keySet2 to be different sets then you could never construct a valid proof because you could not prove that (keySet1.intersection keySet2).isEmpty for all keySets – namely when keySet1 = keySet2