Zulip Chat Archive

Stream: lean4 dev

Topic: kabstract that works under binders


Jovan Gerbscheid (Jan 26 2024 at 11:26):

Would you want to have a version of kabstract in lean that can also match with bound variables? This would be especially useful for rewriting under binders. I've made a function like this, so I thought I should share. I was not quite sure what the return type should be, because abstracting a term with bound variables work nicely.

What I did was to still abstract the instances of the pattern as usual, and to specify a SubExpr.Pos at which all used bound variables have been bound, and a list of FVarDecl for the bound variables introduced at this Pos. So to do a substitution, I edit the subexpression at this Pos by first instantiating it with these free variables, and then instantiating with the pattern that I want (which is allowed to contain these free variables), and then abstract these free variables again. Here's my code:

/-- The result of `patternAbstract`. -/
structure PatternAbstractResult where
  /-- The position closest to the root such that all bound variables appearing in the
  instantiated pattern are bound at this position. -/
  pos : SubExpr.Pos := .root
  /-- The original expression with instances of the pattern abstracted. -/
  expr : Expr
  /-- The declarations of the free variables introduced by viewing position `pos` in `expr`. -/
  fvarDecls : List LocalDecl := []

/-- replace the  `LocalContext` of each mvar with the current `LocalContext`. -/
def updateMVarLCtxs (mvarIds : Array MVarId) : MetaM Unit := do
  let lctx  getLCtx
  let mctx  getMCtx
  let updateDecls decls mvarId := decls.insert mvarId { decls.find! mvarId with lctx }
  let decls := mvarIds.foldl (init := mctx.decls) updateDecls
  setMCtx { mctx with decls }

private abbrev M := ReaderT (List FVarId) StateRefT PatternProgress StateRefT Nat MetaM

/--
Find all occurence of a pattern, abstracting the locations of this pattern,
also allowing for bound variables. The bound variables are replaced by free variables
which are recorded in the field `.fvarDecls`.
These are exactly the variables introduced in the returned outer expression.
-/
partial def PatternAbstract (e : Expr) (p : AbstractMVarsResult) (occs : Occurrences := .all) : MetaM (Option (PatternAbstractResult × Expr)) := do
  let e  instantiateMVars e
  withNewMCtxDepth do
  withReducible do
  let (mvars, _, p)  openAbstractMVarsResult p
  let mvarIds := mvars.map Expr.mvarId!
  if p.isFVar && occs == Occurrences.all then
    return some ({ expr := e.abstract #[p] }, p)
  else
    let pHeadIdx := p.toHeadIndex
    let pNumArgs := p.headNumArgs
    let rec visit (pos : SubExpr.Pos) (e : Expr) : M Expr := do

      let introFVar (pos : SubExpr.Pos) (n : Name) (d b : Expr) : M Expr :=
        withLocalDeclD n d fun fvar =>
        withReader (fvar.fvarId! :: ·) do
          if ( get) matches PatternProgress.noMatch then
            updateMVarLCtxs mvarIds
            let e  visit pos (b.instantiate1 fvar)
            match  get with
            | .noMatch =>
              return b
            | .someMatch pattern =>
              if pattern.containsFVar fvar.fvarId! then
                let fvarDecls  liftM $ ( read).mapM FVarId.getDecl
                set (PatternProgress.finished pattern pos fvarDecls)
              return e.abstract #[fvar]
            | .finished .. =>
              return e.abstract #[fvar]

          else
            let e  visit pos (b.instantiate1 fvar)
            return e.abstract #[fvar]

      let visitChildren : Unit  M Expr := fun _ => do
        match e with
        | .app f a         => return e.updateApp! ( visit pos.pushAppFn f) ( visit pos.pushAppArg a)
        | .mdata _ b       => return e.updateMData! ( visit pos b)
        | .proj _ _ b      => return e.updateProj! ( visit pos.pushProj b)
        | .letE n t v b _  => return e.updateLet! ( visit pos.pushLetVarType t) ( visit pos.pushLetValue v) ( introFVar pos.pushLetBody n t b)
        | .lam n d b _     => return e.updateLambdaE! ( visit pos.pushBindingDomain d) ( introFVar pos.pushBindingBody n d b)
        | .forallE n d b _ => return e.updateForallE! ( visit pos.pushBindingDomain d) ( introFVar pos.pushBindingBody n d b)
        | e                => return e

      let progress  get
      if progress matches .finished .. then
        return e
      else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then
        visitChildren ()
      else
        let mctx  getMCtx
        if  isDefEq e p then
          if progress matches .noMatch then
            set (PatternProgress.someMatch ( instantiateMVars p))
          let i  getThe Nat
          set (i+1)
          if occs.contains i then
            return .bvar ( read).length
          else
            setMCtx mctx
            visitChildren ()
        else
          visitChildren ()
    let (expr, progress)  visit SubExpr.Pos.root e |>.run [] |>.run .noMatch |>.run' 0
    match progress with
    | .finished pattern pos fvarDecls =>
      return some ({ expr, pos, fvarDecls }, pattern)
    | .someMatch pattern =>
      return some ({ expr }, pattern)
    | .noMatch => return none

And the code for instantiating:

/- This section follows the definition of `Lean.Meta.replaceSubexpr` -/
variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M] [MonadError M]

/-- Given a constructor index for Expr, runs `g` on the value of that subexpression and replaces it.
Mdata is ignored. An index of 3 is interpreted as the type of the expression. An index of 3 will throw since we can't replace types.

See also `Lean.Meta.transform`, `Lean.Meta.traverseChildren`. -/
private def lensCoordRaw (g : Expr  M Expr) : Nat  Expr  M Expr
  | 0, e@(Expr.app f a)         => return e.updateApp! ( g f) a
  | 1, e@(Expr.app f a)         => return e.updateApp! f ( g a)
  | 0, e@(Expr.lam _ y b _)     => return e.updateLambdaE! ( g y) b
  | 1, e@(Expr.lam _ y b _)     => return e.updateLambdaE! y ( g b)
  | 0, e@(Expr.forallE _ y b _) => return e.updateForallE! ( g y) b
  | 1, e@(Expr.forallE _ y b _) => return e.updateForallE! y ( g b)
  | 0, e@(Expr.letE _ y a b _)  => return e.updateLet! ( g y) a b
  | 1, e@(Expr.letE _ y a b _)  => return e.updateLet! y ( g a) b
  | 2, e@(Expr.letE _ y a b _)  => return e.updateLet! y a ( g b)
  | 0, e@(Expr.proj _ _ b)      => e.updateProj! <$> g b
  | n, e@(Expr.mdata _ a)       => e.updateMData! <$> lensCoordRaw g n a
  | 3, _                        => throwError "Lensing on types is not supported"
  | c, e                        => throwError "Invalid coordinate {c} for {e}"

private def lensRawAux (g : Expr  M Expr) : List Nat  Expr  M Expr
  | []        , e => g e
  | head::tail, e => lensCoordRaw (lensRawAux g tail) head e

/-- Run the given `replace` function to replace the expression at the subexpression position.
If the subexpression is invalid or points to a type then this will throw. -/
def replaceSubexprRaw (p : SubExpr.Pos) (root : Expr) (replace : (subexpr : Expr)  M Expr) : M Expr :=
  lensRawAux replace p.toArray.toList root

/-- instantiate the `PatternAbstractResult` with `e`. -/
def PatternAbstractResult.instantiate (p : PatternAbstractResult) (e : Expr) : M Expr :=
  replaceSubexprRaw p.pos p.expr fun subexpr =>
  let fvars := p.fvarDecls.toArray.reverse.map (.fvar ·.fvarId)
  return ((subexpr.instantiateRev fvars).instantiate1 e).abstract fvars

Last updated: May 02 2025 at 03:31 UTC