Zulip Chat Archive
Stream: lean4 dev
Topic: kabstract that works under binders
Jovan Gerbscheid (Jan 26 2024 at 11:26):
Would you want to have a version of kabstract in lean that can also match with bound variables? This would be especially useful for rewriting under binders. I've made a function like this, so I thought I should share. I was not quite sure what the return type should be, because abstracting a term with bound variables work nicely.
What I did was to still abstract the instances of the pattern as usual, and to specify a SubExpr.Pos
at which all used bound variables have been bound, and a list of FVarDecl
for the bound variables introduced at this Pos
. So to do a substitution, I edit the subexpression at this Pos
by first instantiating it with these free variables, and then instantiating with the pattern that I want (which is allowed to contain these free variables), and then abstract these free variables again. Here's my code:
/-- The result of `patternAbstract`. -/
structure PatternAbstractResult where
/-- The position closest to the root such that all bound variables appearing in the
instantiated pattern are bound at this position. -/
pos : SubExpr.Pos := .root
/-- The original expression with instances of the pattern abstracted. -/
expr : Expr
/-- The declarations of the free variables introduced by viewing position `pos` in `expr`. -/
fvarDecls : List LocalDecl := []
/-- replace the `LocalContext` of each mvar with the current `LocalContext`. -/
def updateMVarLCtxs (mvarIds : Array MVarId) : MetaM Unit := do
let lctx ← getLCtx
let mctx ← getMCtx
let updateDecls decls mvarId := decls.insert mvarId { decls.find! mvarId with lctx }
let decls := mvarIds.foldl (init := mctx.decls) updateDecls
setMCtx { mctx with decls }
private abbrev M := ReaderT (List FVarId) StateRefT PatternProgress StateRefT Nat MetaM
/--
Find all occurence of a pattern, abstracting the locations of this pattern,
also allowing for bound variables. The bound variables are replaced by free variables
which are recorded in the field `.fvarDecls`.
These are exactly the variables introduced in the returned outer expression.
-/
partial def PatternAbstract (e : Expr) (p : AbstractMVarsResult) (occs : Occurrences := .all) : MetaM (Option (PatternAbstractResult × Expr)) := do
let e ← instantiateMVars e
withNewMCtxDepth do
withReducible do
let (mvars, _, p) ← openAbstractMVarsResult p
let mvarIds := mvars.map Expr.mvarId!
if p.isFVar && occs == Occurrences.all then
return some ({ expr := e.abstract #[p] }, p)
else
let pHeadIdx := p.toHeadIndex
let pNumArgs := p.headNumArgs
let rec visit (pos : SubExpr.Pos) (e : Expr) : M Expr := do
let introFVar (pos : SubExpr.Pos) (n : Name) (d b : Expr) : M Expr :=
withLocalDeclD n d fun fvar =>
withReader (fvar.fvarId! :: ·) do
if (← get) matches PatternProgress.noMatch then
updateMVarLCtxs mvarIds
let e ← visit pos (b.instantiate1 fvar)
match ← get with
| .noMatch =>
return b
| .someMatch pattern =>
if pattern.containsFVar fvar.fvarId! then
let fvarDecls ← liftM $ (← read).mapM FVarId.getDecl
set (PatternProgress.finished pattern pos fvarDecls)
return e.abstract #[fvar]
| .finished .. =>
return e.abstract #[fvar]
else
let e ← visit pos (b.instantiate1 fvar)
return e.abstract #[fvar]
let visitChildren : Unit → M Expr := fun _ => do
match e with
| .app f a => return e.updateApp! (← visit pos.pushAppFn f) (← visit pos.pushAppArg a)
| .mdata _ b => return e.updateMData! (← visit pos b)
| .proj _ _ b => return e.updateProj! (← visit pos.pushProj b)
| .letE n t v b _ => return e.updateLet! (← visit pos.pushLetVarType t) (← visit pos.pushLetValue v) (← introFVar pos.pushLetBody n t b)
| .lam n d b _ => return e.updateLambdaE! (← visit pos.pushBindingDomain d) (← introFVar pos.pushBindingBody n d b)
| .forallE n d b _ => return e.updateForallE! (← visit pos.pushBindingDomain d) (← introFVar pos.pushBindingBody n d b)
| e => return e
let progress ← get
if progress matches .finished .. then
return e
else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then
visitChildren ()
else
let mctx ← getMCtx
if ← isDefEq e p then
if progress matches .noMatch then
set (PatternProgress.someMatch (← instantiateMVars p))
let i ← getThe Nat
set (i+1)
if occs.contains i then
return .bvar (← read).length
else
setMCtx mctx
visitChildren ()
else
visitChildren ()
let (expr, progress) ← visit SubExpr.Pos.root e |>.run [] |>.run .noMatch |>.run' 0
match progress with
| .finished pattern pos fvarDecls =>
return some ({ expr, pos, fvarDecls }, pattern)
| .someMatch pattern =>
return some ({ expr }, pattern)
| .noMatch => return none
And the code for instantiating:
/- This section follows the definition of `Lean.Meta.replaceSubexpr` -/
variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M] [MonadError M]
/-- Given a constructor index for Expr, runs `g` on the value of that subexpression and replaces it.
Mdata is ignored. An index of 3 is interpreted as the type of the expression. An index of 3 will throw since we can't replace types.
See also `Lean.Meta.transform`, `Lean.Meta.traverseChildren`. -/
private def lensCoordRaw (g : Expr → M Expr) : Nat → Expr → M Expr
| 0, e@(Expr.app f a) => return e.updateApp! (← g f) a
| 1, e@(Expr.app f a) => return e.updateApp! f (← g a)
| 0, e@(Expr.lam _ y b _) => return e.updateLambdaE! (← g y) b
| 1, e@(Expr.lam _ y b _) => return e.updateLambdaE! y (← g b)
| 0, e@(Expr.forallE _ y b _) => return e.updateForallE! (← g y) b
| 1, e@(Expr.forallE _ y b _) => return e.updateForallE! y (← g b)
| 0, e@(Expr.letE _ y a b _) => return e.updateLet! (← g y) a b
| 1, e@(Expr.letE _ y a b _) => return e.updateLet! y (← g a) b
| 2, e@(Expr.letE _ y a b _) => return e.updateLet! y a (← g b)
| 0, e@(Expr.proj _ _ b) => e.updateProj! <$> g b
| n, e@(Expr.mdata _ a) => e.updateMData! <$> lensCoordRaw g n a
| 3, _ => throwError "Lensing on types is not supported"
| c, e => throwError "Invalid coordinate {c} for {e}"
private def lensRawAux (g : Expr → M Expr) : List Nat → Expr → M Expr
| [] , e => g e
| head::tail, e => lensCoordRaw (lensRawAux g tail) head e
/-- Run the given `replace` function to replace the expression at the subexpression position.
If the subexpression is invalid or points to a type then this will throw. -/
def replaceSubexprRaw (p : SubExpr.Pos) (root : Expr) (replace : (subexpr : Expr) → M Expr) : M Expr :=
lensRawAux replace p.toArray.toList root
/-- instantiate the `PatternAbstractResult` with `e`. -/
def PatternAbstractResult.instantiate (p : PatternAbstractResult) (e : Expr) : M Expr :=
replaceSubexprRaw p.pos p.expr fun subexpr =>
let fvars := p.fvarDecls.toArray.reverse.map (.fvar ·.fvarId)
return ((subexpr.instantiateRev fvars).instantiate1 e).abstract fvars
Last updated: May 02 2025 at 03:31 UTC