Fixed errors in tc hm

This commit is contained in:
sebastianselander 2023-03-27 16:48:23 +02:00
parent 847ec37117
commit 6e54378327
4 changed files with 346 additions and 345 deletions

View file

@ -1,31 +1,29 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Subst)
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Data.String
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
@ -39,7 +37,7 @@ run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg
checkData :: Data -> Infer ()
@ -50,9 +48,9 @@ checkData d = do
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
( \(Inj name' t') ->
if typ == retType t'
then insertConstr (coerce name') (toNew t')
then insertConstr (coerce name') (t')
else
throwError $
unwords
@ -73,9 +71,9 @@ checkData d = do
retType :: Type -> Type
retType (TFun _ t2) = retType t2
retType a = a
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
@ -94,25 +92,27 @@ preRun (x : xs) = case x of
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
insertSig (coerce n) (Just $ t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs)
(DSig _) -> checkDef xs
where
coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer T.Bind
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda
@ -133,41 +133,41 @@ checkBind (Bind name args e) = do
insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e)
typeEq :: T.Type -> T.Type -> Bool
typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r'
typeEq (T.TLit a) (T.TLit b) = a == b
typeEq (T.TData name a) (T.TData name' b) =
typeEq :: Type -> Type -> Bool
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (TLit a) (TLit b) = a == b
typeEq (TData name a) (TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (TVar _) (TVar _) = True
typeEq _ _ = False
skolem :: T.Type -> T.Type
skolem (T.TVar (T.MkTVar a)) = T.TLit a
skolem (T.TAll x t) = T.TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
skolem t = t
skolem :: Type -> Type
skolem (TVar (T.MkTVar a)) = TLit (coerce a)
skolem (TAll x t) = TAll x (skolem t)
skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2
skolem t = t
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True
isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer T.ExpT
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
@ -178,7 +178,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -190,43 +190,12 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
algoW = \case
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
unless
(toNew t `isMoreSpecificOrEq` t')
(t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
@ -236,34 +205,34 @@ algoW = \case
]
)
applySt s1 $ do
s2 <- exprErr (unify (toNew t) t') err
s2 <- exprErr (unify (t) t') err
let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
return (comp, apply comp (e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit))
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EVar i -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
insertSig (coerce i) (Just fr)
return (nullSubst, (T.EId $ coerce i, fr))
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do
constr <- gets constructors
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t))
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing ->
throwError $
"Constructor: '"
@ -280,7 +249,7 @@ algoW = \case
( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = T.TFun varType t'
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
@ -314,7 +283,7 @@ algoW = \case
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
@ -346,33 +315,33 @@ makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
-------------------------------------------------------------------
(T.TVar (T.MkTVar a), t) -> occurs a t
(t, T.TVar (T.MkTVar b)) -> occurs b t
(T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
else
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree (T.TLit a) <> "'"
, "'" <> printTree (TLit a) <> "'"
, "with"
, "'" <> printTree (T.TLit b) <> "'"
, "'" <> printTree (TLit b) <> "'"
]
(T.TData name t, T.TData name' t') ->
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
@ -380,7 +349,7 @@ unify t0 t1 = do
else
throwError $
unwords
[ "T.Type constructor:"
[ "Type constructor:"
, printTree name
, "(" <> printTree t <> ")"
, "does not match with:"
@ -398,42 +367,42 @@ unify t0 t1 = do
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t)
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (T.TVar $ T.MkTVar i)
, printTree (TVar $ T.MkTVar (coerce i))
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map T.Ident T.Type -> T.Type -> T.Type
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
go :: [T.Ident] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: T.Type -> Infer T.Type
inst :: Type -> Infer Type
inst = \case
T.TAll (T.MkTVar bound) t -> do
TAll (T.MkTVar bound) t -> do
fr <- fresh
let s = M.singleton bound fr
let s = M.singleton (coerce bound) fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
@ -455,41 +424,40 @@ class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
instance FreeVars T.Type where
free :: T.Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (T.TData _ a) =
free (TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
instance SubstType T.Type where
apply :: Subst -> T.Type -> T.Type
instance SubstType Type where
apply :: Subst -> Type -> Type
apply sub t = do
case t of
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where
free :: Map T.Ident T.Type -> Set T.Ident
TLit a -> TLit a
TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (T.MkTVar $ coerce a)
Just t -> t
TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (map (apply sub) a)
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
instance SubstType (Map T.Ident T.Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply s = M.map (apply s)
instance SubstType T.Exp where
apply :: Subst -> T.Exp -> T.Exp
instance SubstType (T.Exp' Type) where
apply s = \case
T.EId i -> T.EId i
T.EVar i -> T.EVar i
T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet
@ -499,19 +467,18 @@ instance SubstType T.Exp where
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj{} -> error "implement"
instance SubstType T.Branch where
apply :: Subst -> T.Branch -> T.Branch
instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType T.Pattern where
apply :: Subst -> T.Pattern -> T.Pattern
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType a => SubstType [a] where
apply s = map (apply s)
@ -519,7 +486,7 @@ instance SubstType a => SubstType [a] where
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where
instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment.
@ -531,7 +498,7 @@ nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type
fresh :: Infer Type
fresh = do
c <- gets nextChar
n <- gets count
@ -545,34 +512,34 @@ fresh = do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
next a = succ a
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe T.Type -> Infer ()
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: T.Ident -> T.Type -> Infer ()
insertConstr :: T.Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = throwError "Atleast one case required"
checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
@ -594,23 +561,23 @@ checkCase expT brnchs = do
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern :: T.Pattern' Type -> Infer a -> Infer a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt)
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
@ -644,28 +611,28 @@ inferPattern = \case
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
typeLength :: T.Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b
typeLength _ = 1
typeLength :: Type -> Int
typeLength (TFun a b) = typeLength a + typeLength b
typeLength _ = 1
litType :: Lit -> T.Type
litType (LInt _) = int
litType :: Lit -> Type
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"
int = TLit "Int"
char = TLit "Char"
partitionType ::
Int -> -- Number of parameters to apply
@ -676,8 +643,8 @@ partitionType = go []
go acc 0 t = (acc, t)
go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp =
@ -691,3 +658,19 @@ unzip4 =
)
([], [], [], [])
newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
}
deriving (Show)
type Error = String
type Subst = Map T.Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))