Zulip Chat Archive
Stream: lean4 dev
Topic: kabstract that works under binders
Jovan Gerbscheid (Jan 26 2024 at 11:26):
Would you want to have a version of kabstract in lean that can also match with bound variables? This would be especially useful for rewriting under binders. I've made a function like this, so I thought I should share. I was not quite sure what the return type should be, because abstracting a term with bound variables work nicely.
What I did was to still abstract the instances of the pattern as usual, and to specify a SubExpr.Pos at which all used bound variables have been bound, and a list of FVarDecl for the bound variables introduced at this Pos. So to do a substitution, I edit the subexpression at this Pos by first instantiating it with these free variables, and then instantiating with the pattern that I want (which is allowed to contain these free variables), and then abstract these free variables again. Here's my code:
/-- The result of `patternAbstract`. -/
structure PatternAbstractResult where
  /-- The position closest to the root such that all bound variables appearing in the
  instantiated pattern are bound at this position. -/
  pos : SubExpr.Pos := .root
  /-- The original expression with instances of the pattern abstracted. -/
  expr : Expr
  /-- The declarations of the free variables introduced by viewing position `pos` in `expr`. -/
  fvarDecls : List LocalDecl := []
/-- replace the  `LocalContext` of each mvar with the current `LocalContext`. -/
def updateMVarLCtxs (mvarIds : Array MVarId) : MetaM Unit := do
  let lctx ← getLCtx
  let mctx ← getMCtx
  let updateDecls decls mvarId := decls.insert mvarId { decls.find! mvarId with lctx }
  let decls := mvarIds.foldl (init := mctx.decls) updateDecls
  setMCtx { mctx with decls }
private abbrev M := ReaderT (List FVarId) StateRefT PatternProgress StateRefT Nat MetaM
/--
Find all occurence of a pattern, abstracting the locations of this pattern,
also allowing for bound variables. The bound variables are replaced by free variables
which are recorded in the field `.fvarDecls`.
These are exactly the variables introduced in the returned outer expression.
-/
partial def PatternAbstract (e : Expr) (p : AbstractMVarsResult) (occs : Occurrences := .all) : MetaM (Option (PatternAbstractResult × Expr)) := do
  let e ← instantiateMVars e
  withNewMCtxDepth do
  withReducible do
  let (mvars, _, p) ← openAbstractMVarsResult p
  let mvarIds := mvars.map Expr.mvarId!
  if p.isFVar && occs == Occurrences.all then
    return some ({ expr := e.abstract #[p] }, p)
  else
    let pHeadIdx := p.toHeadIndex
    let pNumArgs := p.headNumArgs
    let rec visit (pos : SubExpr.Pos) (e : Expr) : M Expr := do
      let introFVar (pos : SubExpr.Pos) (n : Name) (d b : Expr) : M Expr :=
        withLocalDeclD n d fun fvar =>
        withReader (fvar.fvarId! :: ·) do
          if (← get) matches PatternProgress.noMatch then
            updateMVarLCtxs mvarIds
            let e ← visit pos (b.instantiate1 fvar)
            match ← get with
            | .noMatch =>
              return b
            | .someMatch pattern =>
              if pattern.containsFVar fvar.fvarId! then
                let fvarDecls ← liftM $ (← read).mapM FVarId.getDecl
                set (PatternProgress.finished pattern pos fvarDecls)
              return e.abstract #[fvar]
            | .finished .. =>
              return e.abstract #[fvar]
          else
            let e ← visit pos (b.instantiate1 fvar)
            return e.abstract #[fvar]
      let visitChildren : Unit → M Expr := fun _ => do
        match e with
        | .app f a         => return e.updateApp! (← visit pos.pushAppFn f) (← visit pos.pushAppArg a)
        | .mdata _ b       => return e.updateMData! (← visit pos b)
        | .proj _ _ b      => return e.updateProj! (← visit pos.pushProj b)
        | .letE n t v b _  => return e.updateLet! (← visit pos.pushLetVarType t) (← visit pos.pushLetValue v) (← introFVar pos.pushLetBody n t b)
        | .lam n d b _     => return e.updateLambdaE! (← visit pos.pushBindingDomain d) (← introFVar pos.pushBindingBody n d b)
        | .forallE n d b _ => return e.updateForallE! (← visit pos.pushBindingDomain d) (← introFVar pos.pushBindingBody n d b)
        | e                => return e
      let progress ← get
      if progress matches .finished .. then
        return e
      else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then
        visitChildren ()
      else
        let mctx ← getMCtx
        if ← isDefEq e p then
          if progress matches .noMatch then
            set (PatternProgress.someMatch (← instantiateMVars p))
          let i ← getThe Nat
          set (i+1)
          if occs.contains i then
            return .bvar (← read).length
          else
            setMCtx mctx
            visitChildren ()
        else
          visitChildren ()
    let (expr, progress) ← visit SubExpr.Pos.root e |>.run [] |>.run .noMatch |>.run' 0
    match progress with
    | .finished pattern pos fvarDecls =>
      return some ({ expr, pos, fvarDecls }, pattern)
    | .someMatch pattern =>
      return some ({ expr }, pattern)
    | .noMatch => return none
And the code for instantiating:
/- This section follows the definition of `Lean.Meta.replaceSubexpr` -/
variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M] [MonadError M]
/-- Given a constructor index for Expr, runs `g` on the value of that subexpression and replaces it.
Mdata is ignored. An index of 3 is interpreted as the type of the expression. An index of 3 will throw since we can't replace types.
See also `Lean.Meta.transform`, `Lean.Meta.traverseChildren`. -/
private def lensCoordRaw (g : Expr → M Expr) : Nat → Expr → M Expr
  | 0, e@(Expr.app f a)         => return e.updateApp! (← g f) a
  | 1, e@(Expr.app f a)         => return e.updateApp! f (← g a)
  | 0, e@(Expr.lam _ y b _)     => return e.updateLambdaE! (← g y) b
  | 1, e@(Expr.lam _ y b _)     => return e.updateLambdaE! y (← g b)
  | 0, e@(Expr.forallE _ y b _) => return e.updateForallE! (← g y) b
  | 1, e@(Expr.forallE _ y b _) => return e.updateForallE! y (← g b)
  | 0, e@(Expr.letE _ y a b _)  => return e.updateLet! (← g y) a b
  | 1, e@(Expr.letE _ y a b _)  => return e.updateLet! y (← g a) b
  | 2, e@(Expr.letE _ y a b _)  => return e.updateLet! y a (← g b)
  | 0, e@(Expr.proj _ _ b)      => e.updateProj! <$> g b
  | n, e@(Expr.mdata _ a)       => e.updateMData! <$> lensCoordRaw g n a
  | 3, _                        => throwError "Lensing on types is not supported"
  | c, e                        => throwError "Invalid coordinate {c} for {e}"
private def lensRawAux (g : Expr → M Expr) : List Nat → Expr → M Expr
  | []        , e => g e
  | head::tail, e => lensCoordRaw (lensRawAux g tail) head e
/-- Run the given `replace` function to replace the expression at the subexpression position.
If the subexpression is invalid or points to a type then this will throw. -/
def replaceSubexprRaw (p : SubExpr.Pos) (root : Expr) (replace : (subexpr : Expr) → M Expr) : M Expr :=
  lensRawAux replace p.toArray.toList root
/-- instantiate the `PatternAbstractResult` with `e`. -/
def PatternAbstractResult.instantiate (p : PatternAbstractResult) (e : Expr) : M Expr :=
  replaceSubexprRaw p.pos p.expr fun subexpr =>
  let fvars := p.fvarDecls.toArray.reverse.map (.fvar ·.fvarId)
  return ((subexpr.instantiateRev fvars).instantiate1 e).abstract fvars
Last updated: May 02 2025 at 03:31 UTC