import Std.Data.HashMap import Lean.Expr import Lean.Exception import Specimen.Idents import Specimen.TSyntaxCombinators open Lean Idents Elab -- Adapted from "Generating Good Generators for Inductive Relations", POPL '18 -- and the QuickChick source code -- https://github.com/QuickChick/QuickChick/blob/internal-rewrite/plugin/newUnifyQC.ml.cppo /-- An `Unknown` is a Lean `Name` correpsonding to a variable -/ abbrev Unknown := Name deriving instance Repr, BEq for Unknown /-- `Ord` instance for `Unknown`s (which just uses the `Ord` instance for their underlying string representations) -/ instance : Ord Unknown where compare u1 u2 := compare u1.toString u2.toString /-- *Ranges* represent sets of potential values that a variable can take on during generation -/ inductive Range /-- Undefined value, parameterized by a type `ty` (represented as a Lean `Expr`) - The `unknowns` we want to generate (e.g. a `Tree` in the BST example) start out with `Undef` ranges -/ | Undef (ty : Expr) /-- A `Fixed` range means the corresponding unknown's value serves as an input at runtime (e.g. `lo` and `hi` in the BST example) -/ | Fixed /-- The `Range` of an `Unknown` can be another `Unknown` (to facilitate sharing)-/ | Unknown (u : Unknown) /-- A `Range` can be a constructor `C` fully applied to a list of `Range`s -/ | Ctor (ctor : Name) (rs : List Range) | Lit (l : Literal) deriving Repr, Inhabited, BEq deriving instance Ord for Literal /-- A `Pattern` is either an unknown or a fully-applied constructor -/ inductive Pattern | UnknownPattern : Unknown -> Pattern | CtorPattern : Name -> List Pattern -> Pattern | LitPattern : Literal → Pattern deriving Repr, Inhabited, BEq, Ord /-- A *constructor expression* (`ConstructorExpr`) is either a variable (represented by its `Name`), or a constructor (identified by its `Name`) applied to some list of arguments, each of which are themselves `ConstructorExpr`s - Note: this type is isomorphic to `Pattern`, but we define a separate type to avoid confusing `ConstructorExpr` with `Pattern` since they are used in different parts of the algorithm -/ inductive ConstructorExpr | Unknown : Name → ConstructorExpr | Ctor : Name → List ConstructorExpr → ConstructorExpr | FuncApp : Name → List ConstructorExpr → ConstructorExpr | TyCtor : Name → List ConstructorExpr → ConstructorExpr /- A TyCtor is an inductive family applied to arguments. Used for instance in Prod.mk which requires two types as arguments. -/ | Lit : Literal → ConstructorExpr | CSort : Level → ConstructorExpr /-- A placeholder emitted as an anonymous hole (`_`) and re-inferred during elaboration. Used for implicit/instance-implicit arguments that cannot otherwise be represented (e.g. the `GetElem?` validity lambda in `getElem?`), which are always inferable from the explicit arguments. Unlike `Unknown`, a `Hole` is treated as a constant by the variable-collection and unification machinery — never a fresh unknown to generate. -/ | Hole : ConstructorExpr deriving Repr, BEq, Inhabited /-- Converts a `ConstructorExpr` to a Lean `Expr` -/ partial def constructorExprToExpr (ctorExpr : ConstructorExpr) : Expr := match ctorExpr with | .Unknown name => mkConst name | .FuncApp name [] => mkConst name | .Hole => -- A *dangling* metavariable (not registered in any `MetavarContext`). Fine -- here since this conversion is only used for delaboration/pretty-printing; -- generator emission renders holes as `_` (see `constructorExprToTSyntaxTerm`). .mvar ⟨`_specimenHole⟩ | .Ctor ctorName ctorArgs | .TyCtor ctorName ctorArgs | .FuncApp ctorName ctorArgs => mkAppN (mkConst ctorName) (constructorExprToExpr <$> ctorArgs.toArray) | .Lit l => .lit l | .CSort lvl => .sort lvl /-- `ToExpr` instance for `ConstructorExpr` -/ instance : ToExpr ConstructorExpr where toExpr := constructorExprToExpr toTypeExpr := mkConst ``Expr /-- Converts a `ConstructorExpr` to a Lean `TSyntax term` -/ partial def constructorExprToTSyntaxTerm (ctorExpr : ConstructorExpr) : MetaM (TSyntax `term) := match ctorExpr with | .Unknown name => `($(mkIdent name)) | .Hole => `(_) | .Ctor ctorName ctorArgs | .TyCtor ctorName ctorArgs | .FuncApp ctorName ctorArgs => do let argTerms ← ctorArgs.toArray.mapM constructorExprToTSyntaxTerm `($(mkIdent ctorName) $argTerms:term*) | .Lit l => mkLiteral l | .CSort _lvl => `(Sort _) /- Questionable choice to ignore the level of the sort, and let lean infer it. -/ /-- Converts a `Pattern` to an equivalent `ConstructorExpr` -/ partial def constructorExprOfPattern (pattern : Pattern) : ConstructorExpr := match pattern with | .UnknownPattern u => .Unknown u | .CtorPattern ctorName args => .Ctor ctorName (constructorExprOfPattern <$> args) | .LitPattern l => .Lit l /-- Converts a `ConstructorExpr` to an equivalent `Pattern` -/ partial def patternOfConstructorExpr (ctorExpr : ConstructorExpr) : Option Pattern := match ctorExpr with | .Unknown u => some $ .UnknownPattern u | .Ctor ctorName args | .TyCtor ctorName args => .CtorPattern ctorName <$> (List.mapM patternOfConstructorExpr args) | .FuncApp _ _ => none | .Lit l => some $ .LitPattern l | .CSort _lvl => none | .Hole => none /-- An `UnknownMap` maps `Unknown`s to `Range`s -/ abbrev UnknownMap := Std.HashMap Unknown Range /-- `UnifyState` stores the current state of the unification algorithm -/ structure UnifyState where /-- `constraints` is an `UnknownMap` that maps unknowns to ranges -/ constraints : UnknownMap /-- `equalities` keeps track of equalities between unknowns that need to be checked -/ equalities : Std.HashSet (Unknown × Unknown) /-- `patterns` contains a list of necessary pattern matches that need to be performed -/ patterns : List (Unknown × Pattern) /-- A set of all existing `Unknown`s -/ unknowns : Std.HashSet Unknown /-- The names of the output variables (variables to be generated) -/ outputNames : List Name /-- The types of the output variables (variables to be generated) -/ outputTypes : List Expr /-- All inputs (top-level arguments) to the generator -/ inputNames : List Name /-- The list of hypotheses in the constructor's type (excluding the constructor's conclusion) Each hypothesis is represented as a pair `(constructor name, constructor args)` i.e. a constructor name applied to a list of arguments, each of which are `ConstructorExpr`s -/ hypotheses : List (Name × List ConstructorExpr) deriving Repr /-- Initial (empty) unification state - Note that the dummy `outputNames` and `outputTypes` are never used (it will be updated when callers invoke the unification algorithm through this monad) -/ def emptyUnifyState : UnifyState := { constraints := ∅, equalities := ∅, patterns := [], unknowns := ∅, outputNames := [], outputTypes := [] inputNames := [] hypotheses := [] } --------------------------------------------------------------- -- `ToMessageData` instances for pretty-printing --------------------------------------------------------------- /-- Pretty-prints a `Range` as a `MessageData` -/ partial def toMessageDataRange (range : Range) : MessageData := match range with | .Undef tyExpr => m!"Undef {tyExpr}" | .Ctor c rs => let rendredCtorArgs := toMessageDataRange <$> rs m!"Ctor ({c} {rendredCtorArgs})" | .Unknown u => m!"Unknown {u}" | .Fixed => m!"Fixed" | .Lit l => m!"Lit {repr l}" instance : ToMessageData Range where toMessageData := toMessageDataRange /-- Pretty-prints a `Pattern` as a `MessageData` -/ partial def toMessageDataPattern (p : Pattern) : MessageData := match p with | .UnknownPattern u => m!"UnknownPattern {u}" | .CtorPattern c ps => let renderedCtorArgs := toMessageDataPattern <$> ps m!"CtorPattern ({c} {renderedCtorArgs})" | .LitPattern l => m!"LitPattern {repr l}" instance : ToMessageData Pattern where toMessageData pattern := toMessageDataPattern pattern /-- Pretty-prints a `ConstructorExpr` as a `MessageData` -/ partial def toMessageDataConstructorExpr (ctorExpr : ConstructorExpr) : MessageData := match ctorExpr with | .Unknown x => m!"Unknown {x}" | .Ctor c args => let renderedArgs := toMessageDataConstructorExpr <$> args m!"Ctor ({c} {renderedArgs})" | .FuncApp f args => let renderedArgs := toMessageDataConstructorExpr <$> args m!"FuncApp ({f} {renderedArgs})" | .Lit l => m!"Lit {repr l}" | .TyCtor c args => let renderedArgs := toMessageDataConstructorExpr <$> args m!"TyCtor ({c} {renderedArgs})" | .CSort lvl => m!"CSort {lvl}" | .Hole => m!"Hole" instance : ToMessageData ConstructorExpr where toMessageData := toMessageDataConstructorExpr instance : ToMessageData UnifyState where toMessageData unifyState := let constraints := unifyState.constraints.toList -- Sort the constraints map based on the ordering of its keys let sortedConstraints := List.mergeSort constraints (fun (u1, _) (u2, _) => Ordering.isLE (compare u1 u2)) let formattedConstraints := sortedConstraints.map $ fun (u, r) => m!"{u} ↦ {r}" let equalities := unifyState.equalities.toList.map $ fun (u1, u2) => m!"{u1} = {u2}" let patterns := unifyState.patterns.map $ fun (u, pat) => m!"{u} ≡ {pat}" let unknowns := unifyState.unknowns.toList.map $ fun u => m!"{u}" let hyps := unifyState.hypotheses.map $ fun hyp => m!"{hyp}" m!"⟨\n constraints := \n{formattedConstraints},\n equalities := {equalities},\n patterns := {patterns},\n unknowns := {unknowns}\n, hypotheses := {hyps}\n⟩" --------------------------------------------------------------- -- Unification monad (fig. 2 in Generating Good Generators) --------------------------------------------------------------- /-- `UnifyM` is a monad for unification + code generation. Note that the definition of `UnifyM` (after unfolding) is: ``` UnifyM α := UnifyState → MetaM (Option (α × UnifyState)) ``` -/ abbrev UnifyM (α : Type) := StateT UnifyState (OptionT MetaM) α namespace UnifyM /-- `update u r` sets the range of the unknown `u` to be `r` -/ def update (u : Unknown) (r : Range) : UnifyM Unit := do modify $ fun s => let k := s.constraints { s with constraints := k.insert u r } /-- Applies a function `f` to the `constraints` map stored in `UnifyState` - This function allows us to fetch the `constraints` map without needing a seperate `get` call inside the State monad -/ def withConstraints (f : UnknownMap → UnifyM α) : UnifyM α := do let unifyState ← get f unifyState.constraints /-- Applies a function `f` to the list of `hypotheses` stored in `UnifyState` - This function allows us to fetch the `hypotheses` field without needing a seperate `get` call inside the State monad -/ def withHypotheses (f : List (Name × List ConstructorExpr) → UnifyM α) : UnifyM α := do let unifyState ← get f unifyState.hypotheses /-- Applies a function `f` to the set of `unknowns` stored in `UnifyState` - This function allows us to fetch the `hypotheses` field without needing a seperate `get` call inside the State monad -/ def withUnknowns (f : Std.HashSet Unknown → UnifyM α) : UnifyM α := do let unifyState ← get f unifyState.unknowns /-- Directly applies a function `f` to the `constraint` map in the state -/ def modifyConstraints (f : UnknownMap → UnknownMap) : UnifyM Unit := modify $ fun s => { s with constraints := f s.constraints } /-- Batches multiple constraint updates together for performance -/ def updateMany (updates : List (Unknown × Range)) : UnifyM Unit := modifyConstraints $ fun constraints => updates.foldl (fun acc (u, r) => acc.insert u r) constraints /-- Registers a new equality check between unknowns `u1` and `u2` -/ def registerEquality (u1 : Unknown) (u2 : Unknown) : UnifyM Unit := modify $ fun s => let eqs := s.equalities { s with equalities := eqs.union {(u1, u2)} } /-- Adds a new pattern match -/ def addPattern (u : Unknown) (p : Pattern) : UnifyM Unit := modify $ fun s => let ps := s.patterns { s with patterns := (u, p) :: ps } /-- Produces a fresh unknown that is guaranteed not to be in the existing set of `unknowns` -/ def freshUnknown (unknowns : Std.HashSet Unknown) : Unknown := let existingNames := unknowns.toArray genFreshName existingNames `u /-- Produces and registers a new unknown in the `UnifyState` -/ def registerFreshUnknown : UnifyM Unknown := do modifyGet $ fun s => let us := s.unknowns let u := freshUnknown us (u, { s with unknowns := us.union { u } }) /-- Inserts an unknown `u` into the set of existing `unknowns` in `UnifyState` -/ def insertUnknown (u : Unknown) : UnifyM Unit := do modify $ fun s => { s with unknowns := s.unknowns.insert u} /-- Runs a `UnifyM` computation and discards the resulting state, with the result returned in the `MetaM` monad as an `Option` -/ def runInMetaM (action : UnifyM α) (st : UnifyState) : MetaM (Option α) := do OptionT.run (StateT.run' action st) /-- Runs a `UnifyM α` to something in `MetaM`, keeping the unification state. -/ def runUnifyM (action : UnifyM α) (st : UnifyState) : MetaM (Option (α × UnifyState)) := do OptionT.run (StateT.run action st) /-- Finds the `Range` corresponding to an `Unknown` `u` in the `UnknownMap` `k`, returning an informative error message if `u ∉ k.keys` -/ def findCorrespondingRange (k : UnknownMap) (u : Unknown) : UnifyM Range := match k[u]? with | some r => return r | none => throwError m!"unable to find unknown {u} in UnknownMap {k.toList}" /-- Determines if an unknown `u` has a `Range` of `Undef τ` for some type `τ` in the constraints map -/ def hasUndefRange (u : Unknown) : UnifyM Bool := do UnifyM.withConstraints (fun k => do let r ← findCorrespondingRange k u match r with | .Undef _ => return true | _ => return false) /-- Determines whether a `Range` is fixed with respect to the constraint map `k` -/ partial def hasFixedRange (k : UnknownMap) (r : Range) : UnifyM Bool := do match r with | .Undef _ => return false | .Fixed => return true | .Unknown u => do let range ← findCorrespondingRange k u hasFixedRange k range | .Ctor _ rs => rs.allM (hasFixedRange k) | .Lit _ => return true /-- `findUnknownsWithUndefRanges unknowns` returns all `u ∈ unknowns` such that `u ↦ Undef τ` for some type `τ` in the `constraints` map stored within `UnifyState` -/ def findUnknownsWithUndefRanges (unknowns : List Unknown) : UnifyM (List Unknown) := do let state ← get let k := state.constraints List.foldlM (fun acc u => do let r ← findCorrespondingRange k u match r with | .Undef _ => return (u :: acc) | _ => return acc) [] unknowns /-- Updates the `constraint` map so that for each `u ∈ unknowns`, we have the binding `u ↦ Fixed` in the `UnknownMap` `constraints` - Note: this doesn't handle chains of `Unknown`s in `constraints` -/ def fixRanges (unknowns : List Unknown) : UnifyM Unit := do updateMany (unknowns.map (fun u => (u, .Fixed))) /-- `fixRangeHandleUnknownChains u` updates the `constraint` map so that we have the binding `u ↦ Fixed` in the `UnknownMap` `constraints` - This handles chains of `Unknown`s in the `UnknownMap`, i.e. cases where `u ↦ Unknown u'` where `u'` is another key in the `UnknownMap` (in this case, we recursively fix the range corresponding to `u'`) - This function corresponds to `fixVariable` / `fixRange` in the QuickChick codebase -/ partial def fixRangeHandleUnknownChains (u : Unknown) : UnifyM Unit := do let s ← get let k := s.constraints let r ← findCorrespondingRange k u fixRange u r where /-- `fixRange u r` updates the binding for `u` `UnknownMap` so that corresponding range is `Fixed` if `r` is not already `Fixed` -/ fixRange (u : Unknown) (r : Range) : UnifyM Unit := match r with | .Fixed => return () | .Undef _ => update u .Fixed | .Unknown u' => do let s ← get let r' ← findCorrespondingRange s.constraints u' fixRange u' r' | .Ctor _ rs => for range in rs do fixRange `unusedParameter range | .Lit _ => return () /--`findCanonicalUnknown k u` finds the *canonical* representation of the unknown `u` based on the `ConstraintMap` `k`. Specifically: - If `u ↦ Unknown u'` in `k`, then we recursively look up the canonical representation of `u'` by traversing the unification graph formed by the `constraints` map in `UnifyState` - If `u ↦ r` (where `r` is any `Range` that is not some `Unknown`), then `u` is its own canonical representation - If `u ∉ k`, then we just return `u` as is. - This function is used to handle cases in `constraints` where an unknown maps to another unknown. - Note: this function corresponds to `correct_var` in the QuickChick code. -/ partial def findCanonicalUnknown (k : UnknownMap) (u : Unknown) : UnifyM Unknown := try (do let r ← UnifyM.findCorrespondingRange k u match r with | .Unknown u' => findCanonicalUnknown k u' | _ => return u) catch _ => return u /-- `updateConstructorArg k ctorArg` uses the `UnknownMap` `k` to rewrite any unknowns that appear in the `ConstructorExpr` `ctorArg`, substituting each `Unknown` with its canonical representation (determined by calling `updateUnknown`) - See `updateHypothesesWithUnificationResult` for an example of how this function is used. - Note: this function corresponds to `correct_rocq_constr` in the QuickChick code. -/ partial def updateConstructorArg (k : UnknownMap) (ctorArg : ConstructorExpr) : UnifyM ConstructorExpr := do match ctorArg with | .Unknown arg => let canonicalUnknown ← findCanonicalUnknown k arg if arg != canonicalUnknown then return (.Unknown canonicalUnknown) else return (.Unknown arg) | .Ctor ctorName args => return .Ctor ctorName (← args.mapM $ updateConstructorArg k) | .TyCtor ctorName args => return .TyCtor ctorName (← args.mapM $ updateConstructorArg k) | .FuncApp ctorName args => return .FuncApp ctorName (← args.mapM $ updateConstructorArg k) | .Lit l => return .Lit l | .CSort lvl => return .CSort lvl | .Hole => return .Hole /-- `updatePattern k p` uses the `UnknownMap` `k` to rewrite any unknowns that appear in the `Pattern` `p`, substituting each `Unknown` with its canonical representation (determined by calling `updateUnknown`) - Note: this function corresponds to `correct_pat` in the QuickChick code -/ partial def updatePattern (k : UnknownMap) (p : Pattern) : UnifyM Pattern := do match p with | .UnknownPattern u => return .UnknownPattern (← findCanonicalUnknown k u) | .CtorPattern c args => do let updatedArgs ← args.mapM (updatePattern k) return .CtorPattern c updatedArgs | .LitPattern l => return .LitPattern l /-- Extends the current state with the contents of the fields in `newState` (Note that the `outputNames` and `outputTypes` of `newState` are used) -/ def extendState (newState : UnifyState) : UnifyM Unit := do modify $ fun s => { s with constraints := s.constraints.union newState.constraints equalities := s.equalities.union newState.equalities patterns := s.patterns ++ newState.patterns unknowns := s.unknowns.union newState.unknowns outputNames := newState.outputNames outputTypes := newState.outputTypes inputNames := s.inputNames ++ newState.inputNames hypotheses := s.hypotheses ++ newState.hypotheses } /-- Accumulates all the `Unknown`s in a `ConstructorExpr` -/ def collectUnknownsInConstructorExpr (ctorExpr : ConstructorExpr) : List Unknown := match ctorExpr with | .Unknown u => [u] | .Ctor c args | .TyCtor c args | .FuncApp c args => c :: List.flatMap collectUnknownsInConstructorExpr args | .Lit _ => [] | .CSort _ => [] | .Hole => [] mutual /-- Evaluates a `Range`, returning a `ConstructorExpr`. Note that if the `Range` is `Fixed` or `Undef`, we return `none` (via `failure`). -/ partial def evaluateRange (r : Range) : UnifyM ConstructorExpr := do match r with | .Ctor c args => do let args' ← List.mapM evaluateRange args return (ConstructorExpr.Ctor c args') | .Unknown u => evaluateUnknown u | .Lit l => return .Lit l | .Fixed | .Undef _ => throwError m!"unable to evaluate range {r}" /-- Evaluates an `Unknown` based on the bindings in the `UnknownMap`, returning a `ConstructorExpr`. Precondition: there must not be any cycles of `Unknown`s in the `UnknownMap`. -/ partial def evaluateUnknown (v : Unknown) : UnifyM ConstructorExpr := do withConstraints $ fun k => do let r ← findCorrespondingRange k v match r with | .Undef _ | .Fixed => return ConstructorExpr.Unknown v | .Unknown u => (evaluateUnknown u) <|> (return ConstructorExpr.Unknown v) | .Ctor c args => do let args' ← args.mapM evaluateRange return (ConstructorExpr.Ctor c args') | .Lit l => return .Lit l end /-- Determines whether a `Range` is `Fixed`. If the `Range` is in the form `Unknown u`, we check if the range corresponding to `u` in the `UnknownMap` is `Fixed` (this handles chains of `Unknowns` in the `UnknownMap`) -/ partial def isRangeFixed (r : Range) : UnifyM Bool := match r with | .Fixed => return true | .Undef _ => return false | .Unknown u => do withConstraints $ fun k => do match k[u]? with | none => return false | some r' => isRangeFixed r' | .Ctor _ args => List.allM isRangeFixed args | .Lit _ => return true /-- Determines if an `Unknown` has a `Fixed` `Range` in the `UnknownMap` (this handles chains of `Unknowns` in the `UnknownMap`) -/ def isUnknownFixed (u : Unknown) : UnifyM Bool := isRangeFixed (.Unknown u) end UnifyM ------------------------------------------------------------------ -- Unification algorithm (fig. 3 of Generating Good Generators) ------------------------------------------------------------------ mutual /-- Top-level unification function which unifies the ranges mapped to by two unknowns -/ partial def unify (range1 : Range) (range2 : Range) : UnifyM Unit := do match range1, range2 with | .Unknown u1, .Unknown u2 => if u1 == u2 then return () else UnifyM.withConstraints $ fun k => do let r1 ← UnifyM.findCorrespondingRange k u1 let r2 ← UnifyM.findCorrespondingRange k u2 unifyR (u1, r1) (u2, r2) | c1@(.Ctor _ _), c2@(.Ctor _ _) => unifyC c1 c2 | .Unknown u1, c2@(.Ctor _ _) => UnifyM.withConstraints $ fun k => do let r1 ← UnifyM.findCorrespondingRange k u1 unifyRC (u1, r1) c2 | c1@(.Ctor _ _), .Unknown u2 => UnifyM.withConstraints $ fun k => do let r2 ← UnifyM.findCorrespondingRange k u2 unifyRC (u2, r2) c1 | .Lit l, .Lit l' => unifyL l l' | .Unknown u1, .Lit l => UnifyM.withConstraints $ fun k => do let r1 ← UnifyM.findCorrespondingRange k u1 unifyRC (u1, r1) (.Lit l) | .Lit l, .Unknown u2 => UnifyM.withConstraints $ fun k => do let r2 ← UnifyM.findCorrespondingRange k u2 unifyRC (u2, r2) (.Lit l) | r1, r2 => throwError "Unable to unify {r1} with {r2}" /-- Takes two `(Unknown, Range)` pairs & unifies them based on their `Range`s -/ partial def unifyR : Unknown × Range → Unknown × Range → UnifyM Unit -- If the range of an unknown (e.g. `u1`) is undefined, -- we update `u1` to point to the range of `u2` | (u1, .Undef _), (u2, _) => UnifyM.update u1 (.Unknown u2) | (u1, _), (u2, .Undef _) => UnifyM.update u2 (.Unknown u1) | (_, u1'@(.Unknown _)), (u2, _) => unify u1' (.Unknown u2) | (u1, _), (_, u2'@(.Unknown _)) => unify (.Unknown u1) u2' | (_, c1@(.Ctor _ _)), (_, c2@(.Ctor _ _)) => unifyC c1 c2 | (u1, .Fixed), (u2, .Fixed) => do -- Assert that whatever the values of `u1` and `u2` are, they are equal -- Record this equality check using `equality`, then update `u1`'s range to the other UnifyM.registerEquality u1 u2 UnifyM.update u1 (.Unknown u2) | (u1, .Fixed), (_, c2@(.Ctor _ _)) => handleMatch u1 c2 | (_, c1@(.Ctor _ _)), (u2, .Fixed) => handleMatch u2 c1 | (_u1, .Lit l), (_u2, .Lit l') => unifyL l l' | (u, .Fixed), (_u', .Lit l) => handleMatch u (.Lit l) | (_, .Lit l), (u, .Fixed) => handleMatch u (.Lit l) | (_u1, .Lit l), (_u2, c@(.Ctor _ _)) => throwError m!"unifyC: unable to unify literal {repr l} with constructor {c}" | (_u1, c@(.Ctor _ _)), (_u2, .Lit l) => throwError m!"unifyC: unable to unify constructor {c} with literal {repr l}" partial def unifyL (l l' : Literal) : UnifyM Unit := if l == l' then return () else do let st ← get let constraints := st.constraints throwError m!"unifyC: unable to unify {repr $ Range.Lit l} with {repr $ Range.Lit l'}, constraints = {constraints.toList}, literals not equal." /-- Unifies two `Range`s that are both constructors -/ partial def unifyC (r1 : Range) (r2 : Range) : UnifyM Unit := match r1, r2 with | .Ctor c1 rs1, .Ctor c2 rs2 => -- Recursively unify each of the constructor arguments -- Invariant: all ranges that appear as constructor args contain only constructors & unknowns if c1 == c2 && rs1.length == rs2.length then for (r1, r2) in (List.zip rs1 rs2) do unify r1 r2 else do let st ← get let constraints := st.constraints throwError m!"unifyC: unable to unify {r1} with {r2}, constraints = {constraints.toList}" | _, _ => throwError m!"unifyC: unable to unify r1 = {r1}, r2 = {r2} which are not both constructors" /-- Unifies an `(Unknown, Range)` pair with a `Range` -/ partial def unifyRC : Unknown × Range → Range → UnifyM Unit | (u1, .Undef _), c2@(.Ctor _ _) | (u1, .Undef _), c2@(.Lit _) => UnifyM.update u1 c2 | (_, .Unknown u'), c2@(.Ctor _ _) | (_, .Unknown u'), c2@(.Lit _) => UnifyM.withConstraints $ fun k => do let r ← UnifyM.findCorrespondingRange k u' unifyRC (u', r) c2 | (u, .Fixed), c2@(.Ctor _ _) | (u, .Fixed), c2@(.Lit _) => handleMatch u c2 | (_, c1@(.Ctor _ _)), c2@(.Ctor _ _) => unifyC c1 c2 | (_, .Lit l), .Lit l' => unifyL l l' | (u, r), r' => throwError m!"unifyRC called, unable to unify (unknown {u}, range {r}) with range {r'}" /-- Corresponds to `match` in the pseudocode (we call this `handleMatch` since `match` is a reserved keyword in Lean) -/ partial def handleMatch (unknown : Unknown) (range : Range) : UnifyM Unit := do match unknown, range with | u, .Ctor c rs => do let ps ← rs.mapM matchAux UnifyM.addPattern u (Pattern.CtorPattern c ps) | u, .Lit l => UnifyM.addPattern u (Pattern.LitPattern l) | u, r => throwError m!"handleMatch called, unable to match unknown {u} with range {r} which is not in constructor form" /-- `matchAux` traverses a `Range` and converts it into a pattern which can be used in a `match` expression -/ partial def matchAux (range : Range) : UnifyM Pattern := do match range with | .Ctor c rs => do -- Recursively handle ranges let ps ← rs.mapM matchAux return .CtorPattern c ps | .Unknown u => UnifyM.withConstraints $ fun k => do let r ← UnifyM.findCorrespondingRange k u match r with | .Undef _ => do -- Unknown becomes a pattern variable (bound by the pattern match) -- (i.e. the unknown serves as an input at runtime) -- We update `u`'s range to be `fixed` -- (since we're extracting information out of the scrutinee) UnifyM.update u .Fixed return .UnknownPattern u | .Fixed => do -- Handles non-linear patterns: -- produce a fresh unknown `u'`, use `u'` as the pattern variable -- & enforce an equality check between `u` and `u'` let u' ← UnifyM.registerFreshUnknown UnifyM.registerEquality u' u UnifyM.update u' r return .UnknownPattern u' | u'@(.Unknown _) => matchAux u' | .Ctor c rs => do let ps ← rs.mapM matchAux return .CtorPattern c ps | .Lit l => return .LitPattern l | .Lit l => return .LitPattern l | _ => throwError m!"matchAux called with unexpected range {range}" end