continued work on pattern matching v2
This commit is contained in:
parent
c3ea343d00
commit
9cd2cdb511
3 changed files with 279 additions and 60 deletions
|
|
@ -8,13 +8,16 @@ import Control.Monad.Except
|
|||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Function (on)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.List (foldl')
|
||||
import Data.List.Extra (allSame)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Data.Tree (flatten)
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
|
|
@ -204,9 +207,9 @@ algoW = \case
|
|||
-- \| ------------------
|
||||
-- \| Γ ⊢ i : Int, ∅
|
||||
|
||||
ELit (LInt n) ->
|
||||
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
|
||||
ELit lit ->
|
||||
let lt = litType lit
|
||||
in return (nullSubst, lt, T.ELit lt lit)
|
||||
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||
-- \| ----------------------
|
||||
-- \| Γ ⊢ x : τ, ∅
|
||||
|
|
@ -289,16 +292,14 @@ algoW = \case
|
|||
(s2, t2, e1') <- algoW e1
|
||||
let composition = s2 `compose` s1
|
||||
return (composition, t2, apply composition $ T.ELet (T.Bind (name, t2) e0') e1')
|
||||
|
||||
-- TODO: give caseExpr a concrete type before proceeding
|
||||
-- probably by returning substitutions in the functions used in this body
|
||||
ECase caseExpr injs -> do
|
||||
(_, t0, e0') <- algoW caseExpr
|
||||
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
|
||||
case ts of
|
||||
[] -> throwError "Case expression missing any matches"
|
||||
ts -> do
|
||||
unified <- zipWithM unify ts (tail ts)
|
||||
let composition = foldl' compose mempty unified
|
||||
let typ = apply composition (head ts)
|
||||
return (composition, typ, apply composition $ T.ECase typ e0' injs')
|
||||
(sub, _, e') <- algoW caseExpr
|
||||
trace ("SUB: " ++ show sub) return ()
|
||||
t <- checkCase caseExpr injs
|
||||
return (sub, t, T.ECase t e' (map (\(Inj i _) -> T.Inj (i, t) e') injs))
|
||||
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: Type -> Type -> Infer Subst
|
||||
|
|
@ -312,7 +313,6 @@ unify t0 t1 = do
|
|||
(a, TPol b) -> occurs b a
|
||||
(TMono a, TMono b) ->
|
||||
if a == b then return M.empty else throwError "Types do not unify"
|
||||
-- \| TODO: Figure out a cleaner way to express the same thing
|
||||
(TConstr (Constr name t), TConstr (Constr name' t')) ->
|
||||
if name == name' && length t == length t'
|
||||
then do
|
||||
|
|
@ -464,52 +464,45 @@ insertConstr i t =
|
|||
|
||||
-------- PATTERN MATCHING ---------
|
||||
|
||||
-- "case expr of", the type of 'expr' is caseType
|
||||
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
|
||||
checkInj caseType (Inj it expr) = do
|
||||
(args, t') <- initType caseType it
|
||||
subst <- unify caseType t'
|
||||
applySt subst $ do
|
||||
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
|
||||
return (T.Inj (it, t') e', t)
|
||||
unifyAll :: [Type] -> Infer [Subst]
|
||||
unifyAll [] = return []
|
||||
unifyAll [_] = return []
|
||||
unifyAll (x : y : xs) = do
|
||||
uni <- unify x y
|
||||
all <- unifyAll (y : xs)
|
||||
return $ uni : all
|
||||
|
||||
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
|
||||
initType expected = \case
|
||||
InitLit lit -> error "Pattern match on literals not implemented yet"
|
||||
InitConstr c args -> do
|
||||
st <- gets constructors
|
||||
case M.lookup c st of
|
||||
Nothing ->
|
||||
throwError $
|
||||
unwords
|
||||
[ "Constructor:"
|
||||
, printTree c
|
||||
, "does not exist"
|
||||
]
|
||||
Just t -> do
|
||||
let flat = flattenType t
|
||||
let returnType = last flat
|
||||
case ( length (init flat) == length args
|
||||
, returnType `isMoreSpecificOrEq` expected
|
||||
) of
|
||||
(True, True) ->
|
||||
return
|
||||
( M.fromList $ zip args (map (Forall []) flat)
|
||||
, expected
|
||||
)
|
||||
(False, _) ->
|
||||
throwError $
|
||||
"Can't partially match on the constructor: "
|
||||
++ printTree c
|
||||
(_, False) ->
|
||||
throwError $
|
||||
unwords
|
||||
[ "Inferred type"
|
||||
, printTree returnType
|
||||
, "does not match expected type:"
|
||||
, printTree expected
|
||||
]
|
||||
InitCatch -> return (mempty, expected)
|
||||
checkCase :: Exp -> [Inj] -> Infer Type
|
||||
checkCase e injs = do
|
||||
expT <- fst <$> inferExp e
|
||||
(injTs, returns) <- mapAndUnzipM checkInj injs
|
||||
unifyAll (expT : injTs)
|
||||
subst <- foldl1 compose <$> zipWithM unify returns (tail returns)
|
||||
let substed = map (apply subst) returns
|
||||
unless (allSame substed || null substed) (throwError "Different return types of case, or no cases")
|
||||
return $ head substed
|
||||
|
||||
{- | fst = type of init
|
||||
| snd = type of expr
|
||||
-}
|
||||
checkInj :: Inj -> Infer (Type, Type)
|
||||
checkInj (Inj it expr) = do
|
||||
initT <- inferInit it
|
||||
(exprT, _) <- inferExp expr
|
||||
return (initT, exprT)
|
||||
|
||||
inferInit :: Init -> Infer Type
|
||||
inferInit = \case
|
||||
InitLit lit -> return $ litType lit
|
||||
InitConstr fn vars -> do
|
||||
gets (M.lookup fn . constructors) >>= \case
|
||||
Nothing -> throwError $ "Constructor: " ++ printTree fn ++ " does not exist"
|
||||
Just a -> do
|
||||
let ft = init $ flattenType a
|
||||
case compare (length vars) (length ft) of
|
||||
EQ -> return . last $ flattenType a
|
||||
_ -> throwError "Partial pattern match not allowed"
|
||||
InitCatch -> fresh
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TArr a b) = flattenType a ++ flattenType b
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue