added tc as well

This commit is contained in:
sebastianselander 2023-03-24 16:10:46 +01:00
parent 38680a4dcb
commit 481667f2d8

View file

@ -4,6 +4,7 @@
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
import Auxiliary
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
@ -113,7 +114,7 @@ checkPrg (Program bs) = do
(DBind b) -> do (DBind b) -> do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs) (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
@ -136,7 +137,7 @@ checkBind (Bind name args e) = do
insertSig (coerce name) (Just lambdaT) insertSig (coerce name) (Just lambdaT)
return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e) return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e)
-- where -- where
-- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)] -- getFunctionTypes :: Map T.Ident (Maybe T.Type) -> T.ExpT -> [(T.Ident, T.Type)]
-- getFunctionTypes s = \case -- getFunctionTypes s = \case
-- (T.EId b, t) -> case M.lookup b s of -- (T.EId b, t) -> case M.lookup b s of
-- Just Nothing -> return (b, t) -- Just Nothing -> return (b, t)
@ -184,12 +185,25 @@ instance NewType Type T.Type where
TData i ts -> T.TData (coerce i) (map toNew ts) TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker" TEVar _ -> error "Should not exist after typechecker"
-- instance NewType Indexed T.TData where instance NewType Lit T.Lit where
-- toNew (Indexed name vars) = T.TData (coerce name) (map toNew vars) toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug in toNew Data -> T.Data"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT) algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this -- \| TODO: More testing need to be done. Unsure of the correctness of this
@ -213,7 +227,7 @@ algoW = \case
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit lit, litType lit)) ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
@ -228,7 +242,7 @@ algoW = \case
Just Nothing -> Just Nothing ->
(\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh
Nothing -> throwError $ "Unbound variable: " <> printTree i Nothing -> throwError $ "Unbound variable: " <> printTree i
ECons i -> do EInj i -> do
constr <- gets constructors constr <- gets constructors
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t)) Just t -> return (nullSubst, (T.EId $ coerce i, t))
@ -311,7 +325,7 @@ algoW = \case
let t' = apply comp ret_t let t' = apply comp ret_t
return (comp, (T.ECase (e', t) injs, t')) return (comp, (T.ECase (e', t) injs, t'))
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
@ -364,7 +378,7 @@ unify t0 t1 = do
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal where these are equal
-} -}
occurs :: Ident -> T.Type -> Infer Subst occurs :: T.Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t@(T.TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
@ -379,12 +393,12 @@ occurs i t =
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident T.Type -> T.Type -> T.Type generalize :: Map T.Ident T.Type -> T.Type -> T.Type
generalize env t = go freeVars $ removeForalls t generalize env t = go freeVars $ removeForalls t
where where
freeVars :: [Ident] freeVars :: [T.Ident]
freeVars = S.toList $ free t S.\\ free env freeVars = S.toList $ free t S.\\ free env
go :: [Ident] -> T.Type -> T.Type go :: [T.Ident] -> T.Type -> T.Type
go [] t = t go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type removeForalls :: T.Type -> T.Type
@ -414,13 +428,13 @@ compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | A class representing free variables functions -- | A class representing free variables functions
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
free :: t -> Set Ident free :: t -> Set T.Ident
-- | Apply a substitution to t -- | Apply a substitution to t
apply :: Subst -> t -> t apply :: Subst -> t -> t
instance FreeVars T.Type where instance FreeVars T.Type where
free :: T.Type -> Set Ident free :: T.Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty free (T.TLit _) = mempty
@ -442,14 +456,14 @@ instance FreeVars T.Type where
T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a) T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map Ident T.Type) where instance FreeVars (Map T.Ident T.Type) where
free :: Map Ident T.Type -> Set Ident free :: Map T.Ident T.Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident T.Type -> Map Ident T.Type apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
apply s = M.map (apply s) apply s = M.map (apply s)
instance FreeVars T.ExpT where instance FreeVars T.ExpT where
free :: T.ExpT -> Set Ident free :: T.ExpT -> Set T.Ident
free = error "free not implemented for T.Exp" free = error "free not implemented for T.Exp"
apply :: Subst -> T.ExpT -> T.ExpT apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case apply s = \case
@ -466,14 +480,14 @@ instance FreeVars T.ExpT where
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1) (T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t) (T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t)
instance FreeVars T.Inj where instance FreeVars T.Branch where
free :: T.Inj -> Set Ident free :: T.Branch -> Set T.Ident
free = undefined free = undefined
apply :: Subst -> T.Inj -> T.Inj apply :: Subst -> T.Branch -> T.Branch
apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e) apply s (T.Branch (i, t) e) = T.Branch (i, apply s t) (apply s e)
instance FreeVars [T.Inj] where instance FreeVars [T.Branch] where
free :: [T.Inj] -> Set Ident free :: [T.Branch] -> Set T.Ident
free = foldl' (\acc x -> free x `S.union` acc) mempty free = foldl' (\acc x -> free x `S.union` acc) mempty
apply s = map (apply s) apply s = map (apply s)
@ -490,31 +504,31 @@ fresh :: Infer T.Type
fresh = do fresh = do
n <- gets count n <- gets count
modify (\st -> st{count = n + 1}) modify (\st -> st{count = n + 1})
return . T.TVar . T.MkTVar . Ident $ show n return . T.TVar . T.MkTVar . T.Ident $ show n
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings -- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a
withBindings xs = withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> Maybe T.Type -> Infer () insertSig :: T.Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> T.Type -> Infer () insertConstr :: T.Ident -> T.Type -> Infer ()
insertConstr i t = insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)}) modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type) checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase expT injs = do checkCase expT injs = do
(injTs, injs, returns) <- unzip3 <$> mapM checkInj injs (injTs, injs, returns) <- unzip3 <$> mapM checkBranch injs
(sub1, _) <- (sub1, _) <-
foldM foldM
( \(sub, acc) x -> ( \(sub, acc) x ->
@ -534,29 +548,23 @@ checkCase expT injs = do
{- | fst = type of init {- | fst = type of init
| snd = type of expr | snd = type of expr
-} -}
checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type) inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
checkInj (Inj it expr) = do inferBranch (Branch it expr) = do
(initT, vars) <- inferInit it (initT, vars) <- inferPattern it
(e, exprT) <- withBindings vars (inferExp expr) (e, exprT) <- withBindings vars (inferExp expr)
return (initT, T.Inj (it, initT) (e, exprT), exprT) return (initT, T.Branch (it, initT) (e, exprT), exprT)
inferInit :: Init -> Infer (T.Type, [T.Id]) inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
inferInit = \case inferPattern = \case
InitLit lit -> return (litType lit, mempty) PLit lit -> return (T.PLit $ toNew lit, litType lit)
InitConstructor fn vars -> do PInj constr patterns -> do
gets (M.lookup (coerce fn) . constructors) >>= \case t <- gets (M.lookup (coerce constr) . constructors)
Nothing -> t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
throwError $ (vs, ret) <- maybeToRightM (throwError "Partial pattern match not allowed") (unsnoc $ flattenType t)
"Constructor: " <> printTree fn <> " does not exist" patterns <- mapM inferPattern patterns
Just a -> do undefined
case unsnoc $ flattenType a of PCatch -> (T.PCatch,) <$> fresh
Nothing -> throwError "Partial pattern match not allowed" PVar x -> undefined
Just (vs, ret) ->
case length vars `compare` length vs of
EQ -> do
return (ret, zip (coerce vars) vs)
_ -> throwError "Partial pattern match not allowed"
InitCatch -> (,mempty) <$> fresh
flattenType :: T.Type -> [T.Type] flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType (T.TFun a b) = flattenType a <> flattenType b