From 481667f2d8a04d72e1cb955da358972d5d60e9a6 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Fri, 24 Mar 2023 16:10:46 +0100 Subject: [PATCH] added tc as well --- src/TypeChecker/TypeChecker.hs | 110 ++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 2a19b6e..5b22999 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -4,6 +4,7 @@ -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeChecker where +import Auxiliary import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State @@ -113,7 +114,7 @@ checkPrg (Program bs) = do (DBind b) -> do b' <- checkBind b fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap (T.DData d :) (checkDef xs) + (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) (DSig _) -> checkDef xs checkBind :: Bind -> Infer T.Bind @@ -136,7 +137,7 @@ checkBind (Bind name args e) = do insertSig (coerce name) (Just lambdaT) return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e) -- where - -- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)] + -- getFunctionTypes :: Map T.Ident (Maybe T.Type) -> T.ExpT -> [(T.Ident, T.Type)] -- getFunctionTypes s = \case -- (T.EId b, t) -> case M.lookup b s of -- Just Nothing -> return (b, t) @@ -184,12 +185,25 @@ instance NewType Type T.Type where TData i ts -> T.TData (coerce i) (map toNew ts) TEVar _ -> error "Should not exist after typechecker" --- instance NewType Indexed T.TData where --- toNew (Indexed name vars) = T.TData (coerce name) (map toNew vars) +instance NewType Lit T.Lit where + toNew (LInt i) = T.LInt i + toNew (LChar i) = T.LChar i + +instance NewType Data T.Data where + toNew (Data t xs) = T.Data (name $ retType t) (toNew xs) + where + name (TData n _) = coerce n + name _ = error "Bug in toNew Data -> T.Data" + +instance NewType Constructor T.Constructor where + toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs) instance NewType TVar T.TVar where toNew (MkTVar i) = T.MkTVar $ coerce i +instance NewType a b => NewType [a] [b] where + toNew = map toNew + algoW :: Exp -> Infer (Subst, T.ExpT) algoW = \case -- \| TODO: More testing need to be done. Unsure of the correctness of this @@ -213,7 +227,7 @@ algoW = \case -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ - ELit lit -> return (nullSubst, (T.ELit lit, litType lit)) + ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit)) -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ @@ -228,7 +242,7 @@ algoW = \case Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh Nothing -> throwError $ "Unbound variable: " <> printTree i - ECons i -> do + EInj i -> do constr <- gets constructors case M.lookup (coerce i) constr of Just t -> return (nullSubst, (T.EId $ coerce i, t)) @@ -311,7 +325,7 @@ algoW = \case let t' = apply comp ret_t return (comp, (T.ECase (e', t) injs, t')) -makeLambda :: Exp -> [Ident] -> Exp +makeLambda :: Exp -> [T.Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) -- | Unify two types producing a new substitution @@ -364,7 +378,7 @@ unify t0 t1 = do I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} -occurs :: Ident -> T.Type -> Infer Subst +occurs :: T.Ident -> T.Type -> Infer Subst occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) @@ -379,12 +393,12 @@ occurs i t = else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set -generalize :: Map Ident T.Type -> T.Type -> T.Type +generalize :: Map T.Ident T.Type -> T.Type -> T.Type generalize env t = go freeVars $ removeForalls t where - freeVars :: [Ident] + freeVars :: [T.Ident] freeVars = S.toList $ free t S.\\ free env - go :: [Ident] -> T.Type -> T.Type + go :: [T.Ident] -> T.Type -> T.Type go [] t = t go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) removeForalls :: T.Type -> T.Type @@ -414,13 +428,13 @@ compose m1 m2 = M.map (apply m1) m2 `M.union` m1 -- | A class representing free variables functions class FreeVars t where -- | Get all free variables from t - free :: t -> Set Ident + free :: t -> Set T.Ident -- | Apply a substitution to t apply :: Subst -> t -> t instance FreeVars T.Type where - free :: T.Type -> Set Ident + free :: T.Type -> Set T.Ident free (T.TVar (T.MkTVar a)) = S.singleton a free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t free (T.TLit _) = mempty @@ -442,14 +456,14 @@ instance FreeVars T.Type where T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TData name a -> T.TData name (map (apply sub) a) -instance FreeVars (Map Ident T.Type) where - free :: Map Ident T.Type -> Set Ident +instance FreeVars (Map T.Ident T.Type) where + free :: Map T.Ident T.Type -> Set T.Ident free m = foldl' S.union S.empty (map free $ M.elems m) - apply :: Subst -> Map Ident T.Type -> Map Ident T.Type + apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply s = M.map (apply s) instance FreeVars T.ExpT where - free :: T.ExpT -> Set Ident + free :: T.ExpT -> Set T.Ident free = error "free not implemented for T.Exp" apply :: Subst -> T.ExpT -> T.ExpT apply s = \case @@ -466,14 +480,14 @@ instance FreeVars T.ExpT where (T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1) (T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t) -instance FreeVars T.Inj where - free :: T.Inj -> Set Ident +instance FreeVars T.Branch where + free :: T.Branch -> Set T.Ident free = undefined - apply :: Subst -> T.Inj -> T.Inj - apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e) + apply :: Subst -> T.Branch -> T.Branch + apply s (T.Branch (i, t) e) = T.Branch (i, apply s t) (apply s e) -instance FreeVars [T.Inj] where - free :: [T.Inj] -> Set Ident +instance FreeVars [T.Branch] where + free :: [T.Branch] -> Set T.Ident free = foldl' (\acc x -> free x `S.union` acc) mempty apply s = map (apply s) @@ -490,31 +504,31 @@ fresh :: Infer T.Type fresh = do n <- gets count modify (\st -> st{count = n + 1}) - return . T.TVar . T.MkTVar . Ident $ show n + return . T.TVar . T.MkTVar . T.Ident $ show n -- | Run the monadic action with an additional binding -withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a +withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) -- | Run the monadic action with several additional bindings -withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a +withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Insert a function signature into the environment -insertSig :: Ident -> Maybe T.Type -> Infer () +insertSig :: T.Ident -> Maybe T.Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type -insertConstr :: Ident -> T.Type -> Infer () +insertConstr :: T.Ident -> T.Type -> Infer () insertConstr i t = modify (\st -> st{constructors = M.insert i t (constructors st)}) -------- PATTERN MATCHING --------- -checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type) +checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) checkCase expT injs = do - (injTs, injs, returns) <- unzip3 <$> mapM checkInj injs + (injTs, injs, returns) <- unzip3 <$> mapM checkBranch injs (sub1, _) <- foldM ( \(sub, acc) x -> @@ -534,29 +548,23 @@ checkCase expT injs = do {- | fst = type of init | snd = type of expr -} -checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type) -checkInj (Inj it expr) = do - (initT, vars) <- inferInit it +inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type) +inferBranch (Branch it expr) = do + (initT, vars) <- inferPattern it (e, exprT) <- withBindings vars (inferExp expr) - return (initT, T.Inj (it, initT) (e, exprT), exprT) + return (initT, T.Branch (it, initT) (e, exprT), exprT) -inferInit :: Init -> Infer (T.Type, [T.Id]) -inferInit = \case - InitLit lit -> return (litType lit, mempty) - InitConstructor fn vars -> do - gets (M.lookup (coerce fn) . constructors) >>= \case - Nothing -> - throwError $ - "Constructor: " <> printTree fn <> " does not exist" - Just a -> do - case unsnoc $ flattenType a of - Nothing -> throwError "Partial pattern match not allowed" - Just (vs, ret) -> - case length vars `compare` length vs of - EQ -> do - return (ret, zip (coerce vars) vs) - _ -> throwError "Partial pattern match not allowed" - InitCatch -> (,mempty) <$> fresh +inferPattern :: Pattern -> Infer (T.Pattern, T.Type) +inferPattern = \case + PLit lit -> return (T.PLit $ toNew lit, litType lit) + PInj constr patterns -> do + t <- gets (M.lookup (coerce constr) . constructors) + t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t + (vs, ret) <- maybeToRightM (throwError "Partial pattern match not allowed") (unsnoc $ flattenType t) + patterns <- mapM inferPattern patterns + undefined + PCatch -> (T.PCatch,) <$> fresh + PVar x -> undefined flattenType :: T.Type -> [T.Type] flattenType (T.TFun a b) = flattenType a <> flattenType b