From 6e54378327cbf6aa8c5b3d6cd53d1ba0c8b555a1 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Mon, 27 Mar 2023 16:48:23 +0200 Subject: [PATCH] Fixed errors in tc hm --- src/Codegen/LlvmIr.hs | 60 ++--- src/Monomorphizer/MonomorphizerIr.hs | 15 +- src/TypeChecker/TypeCheckerHm.hs | 361 +++++++++++++-------------- src/TypeChecker/TypeCheckerIr.hs | 255 ++++++++++--------- 4 files changed, 346 insertions(+), 345 deletions(-) diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index 0baf35a..59850b6 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -8,19 +8,18 @@ module Codegen.LlvmIr ( LLVMComp (..), Visibility (..), CallingConvention (..), - ToIr(..) + ToIr (..), ) where -import Data.List (intercalate) -import Grammar.Abs (Character) -import TypeChecker.TypeCheckerIr (Ident (..)) +import Data.List (intercalate) +import TypeChecker.TypeCheckerIr (Ident (..)) -data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show +data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show) instance ToIr CallingConvention where toIr :: CallingConvention -> String toIr TailCC = "tailcc" toIr FastCC = "fastcc" - toIr CCC = "ccc" + toIr CCC = "ccc" toIr ColdCC = "coldcc" -- | A datatype which represents some basic LLVM types @@ -34,7 +33,7 @@ data LLVMType | Function LLVMType [LLVMType] | Array Integer LLVMType | CustomType Ident - deriving Show + deriving (Show) class ToIr a where toIr :: a -> String @@ -63,12 +62,12 @@ data LLVMComp | LLSge | LLSlt | LLSle - deriving Show + deriving (Show) instance ToIr LLVMComp where toIr :: LLVMComp -> String toIr = \case - LLEq -> "eq" - LLNe -> "ne" + LLEq -> "eq" + LLNe -> "ne" LLUgt -> "ugt" LLUge -> "uge" LLUlt -> "ult" @@ -78,30 +77,31 @@ instance ToIr LLVMComp where LLSlt -> "slt" LLSle -> "sle" -data Visibility = Local | Global deriving Show +data Visibility = Local | Global deriving (Show) instance ToIr Visibility where toIr :: Visibility -> String - toIr Local = "%" + toIr Local = "%" toIr Global = "@" --- | Represents a LLVM "value", as in an integer, a register variable, --- or a string contstant +{- | Represents a LLVM "value", as in an integer, a register variable, +or a string contstant +-} data LLVMValue = VInteger Integer - | VChar Character + | VChar Char | VIdent Ident LLVMType | VConstant String | VFunction Ident Visibility LLVMType - deriving Show + deriving (Show) instance ToIr LLVMValue where toIr :: LLVMValue -> String toIr v = case v of - VInteger i -> show i - VChar i -> show i - VIdent (Ident n) _ -> "%" <> n + VInteger i -> show i + VChar i -> show i + VIdent (Ident n) _ -> "%" <> n VFunction (Ident n) vis _ -> toIr vis <> n - VConstant s -> "c" <> show s + VConstant s -> "c" <> show s type Params = [(Ident, LLVMType)] type Args = [(LLVMType, LLVMValue)] @@ -114,8 +114,8 @@ data LLVMIr | Declare LLVMType Ident Params | SetVariable Ident LLVMIr | Variable Ident - -- extractvalue , {, }* - | ExtractValue LLVMType LLVMValue Integer + | -- extractvalue , {, }* + ExtractValue LLVMType LLVMValue Integer | GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | Add LLVMType LLVMValue LLVMValue @@ -136,7 +136,7 @@ data LLVMIr | Comment String | UnsafeRaw String -- This should generally be avoided, and proper -- instructions should be used in its place - deriving Show + deriving (Show) -- | Converts a list of LLVMIr instructions to a string llvmIrToString :: [LLVMIr] -> String @@ -146,14 +146,15 @@ llvmIrToString = go 0 go _ [] = mempty go i (x : xs) = do let (i', n) = case x of - Define{} -> (i + 1, 0) + Define{} -> (i + 1, 0) DefineEnd -> (i - 1, 0) - _ -> (i, i) + _ -> (i, i) insToString n x <> go i' xs - {- | Converts a LLVM inststruction to a String, allowing for printing etc. - The integer represents the indentation - -} - {- FOURMOLU_DISABLE -} + +-- \| Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +-- +{- FOURMOLU_DISABLE -} insToString :: Int -> LLVMIr -> String insToString i l = replicate i '\t' <> case l of @@ -261,4 +262,3 @@ llvmIrToString = go 0 lblPfx :: String lblPfx = "lbl_" - diff --git a/src/Monomorphizer/MonomorphizerIr.hs b/src/Monomorphizer/MonomorphizerIr.hs index 383e9fc..66888c0 100644 --- a/src/Monomorphizer/MonomorphizerIr.hs +++ b/src/Monomorphizer/MonomorphizerIr.hs @@ -1,7 +1,6 @@ -module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module GA) where +module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where -import qualified Grammar.Abs as GA (Ident (..)) -import qualified TypeChecker.TypeCheckerIr as TIR (Ident (..)) +import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) type Id = (TIR.Ident, Type) @@ -26,8 +25,12 @@ data Exp | ECase ExpT [Branch] deriving (Show, Ord, Eq) -data Pattern = PVar Id | PLit (Lit, Type) | PInj TIR.Ident [Pattern] - | PCatch | PEnum TIR.Ident +data Pattern + = PVar Id + | PLit (Lit, Type) + | PInj TIR.Ident [Pattern] + | PCatch + | PEnum TIR.Ident deriving (Eq, Ord, Show) data Branch = Branch (Pattern, Type) ExpT @@ -48,4 +51,4 @@ data Type = TLit TIR.Ident | TFun Type Type flattenType :: Type -> [Type] flattenType (TFun t1 t2) = t1 : flattenType t2 -flattenType x = [x] +flattenType x = [x] diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index a24a0b7..1254a87 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,31 +1,29 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary -import Control.Monad.Except -import Control.Monad.Identity (runIdentity) -import Control.Monad.Reader -import Control.Monad.State -import Data.Bifunctor (second) -import Data.Coerce (coerce) -import Data.Foldable (traverse_) -import Data.Function (on) -import Data.List (foldl') -import Data.List.Extra (unsnoc) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (fromJust) -import Data.Set (Set) -import qualified Data.Set as S -import Debug.Trace (trace) -import Grammar.Abs -import Grammar.Print (printTree) -import qualified TypeChecker.TypeCheckerIr as T -import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, - Subst) +import Auxiliary +import Control.Monad.Except +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor (second) +import Data.Coerce (coerce) +import Data.Foldable (traverse_) +import Data.Function (on) +import Data.List (foldl') +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import Data.Set qualified as S +import Data.String +import Grammar.Abs +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr qualified as T initCtx = Ctx mempty initEnv = Env 0 'a' mempty mempty mempty @@ -39,7 +37,7 @@ run = runC initEnv initCtx runC :: Env -> Ctx -> Infer a -> Either Error a runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e -typecheck :: Program -> Either Error T.Program +typecheck :: Program -> Either Error (T.Program' Type) typecheck = run . checkPrg checkData :: Data -> Infer () @@ -50,9 +48,9 @@ checkData d = do (all isPoly ts) (throwError $ unwords ["Data type incorrectly declared"]) traverse_ - ( \(Constructor name' t') -> + ( \(Inj name' t') -> if typ == retType t' - then insertConstr (coerce name') (toNew t') + then insertConstr (coerce name') (t') else throwError $ unwords @@ -73,9 +71,9 @@ checkData d = do retType :: Type -> Type retType (TFun _ t2) = retType t2 -retType a = a +retType a = a -checkPrg :: Program -> Infer T.Program +checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs bs' <- checkDef bs @@ -94,25 +92,27 @@ preRun (x : xs) = case x of <> printTree n <> "'" ) - insertSig (coerce n) (Just $ toNew t) >> preRun xs + insertSig (coerce n) (Just $ t) >> preRun xs DBind (Bind n _ e) -> do collect (collectTypeVars e) s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs + Just _ -> preRun xs DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs -checkDef :: [Def] -> Infer [T.Def] +checkDef :: [Def] -> Infer [T.Def' Type] checkDef [] = return [] checkDef (x : xs) = case x of (DBind b) -> do b' <- checkBind b fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) + (DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs) (DSig _) -> checkDef xs + where + coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs -checkBind :: Bind -> Infer T.Bind +checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) e@(_, args_t) <- inferExp lambda @@ -133,41 +133,41 @@ checkBind (Bind name args e) = do insertSig (coerce name) (Just args_t) return (T.Bind (coerce name, args_t) [] e) -typeEq :: T.Type -> T.Type -> Bool -typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r' -typeEq (T.TLit a) (T.TLit b) = a == b -typeEq (T.TData name a) (T.TData name' b) = +typeEq :: Type -> Type -> Bool +typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' +typeEq (TLit a) (TLit b) = a == b +typeEq (TData name a) (TData name' b) = length a == length b && name == name' && and (zipWith typeEq a b) -typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2 -typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2 -typeEq (T.TVar _) (T.TVar _) = True +typeEq (TAll _ t1) t2 = t1 `typeEq` t2 +typeEq t1 (TAll _ t2) = t1 `typeEq` t2 +typeEq (TVar _) (TVar _) = True typeEq _ _ = False -skolem :: T.Type -> T.Type -skolem (T.TVar (T.MkTVar a)) = T.TLit a -skolem (T.TAll x t) = T.TAll x (skolem t) -skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 -skolem t = t +skolem :: Type -> Type +skolem (TVar (T.MkTVar a)) = TLit (coerce a) +skolem (TAll x t) = TAll x (skolem t) +skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2 +skolem t = t -isMoreSpecificOrEq :: T.Type -> T.Type -> Bool -isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 -isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = +isMoreSpecificOrEq :: Type -> Type -> Bool +isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 +isMoreSpecificOrEq (TFun a b) (TFun c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d -isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = +isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) = n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2) -isMoreSpecificOrEq _ (T.TVar _) = True +isMoreSpecificOrEq _ (TVar _) = True isMoreSpecificOrEq a b = a == b isPoly :: Type -> Bool isPoly (TAll _ _) = True -isPoly (TVar _) = True -isPoly _ = False +isPoly (TVar _) = True +isPoly _ = False -inferExp :: Exp -> Infer T.ExpT +inferExp :: Exp -> Infer (T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t @@ -178,7 +178,7 @@ class CollectTVars a where instance CollectTVars Exp where collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e - collectTypeVars _ = S.empty + collectTypeVars _ = S.empty instance CollectTVars Type where collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) @@ -190,43 +190,12 @@ instance CollectTVars Type where collect :: Set T.Ident -> Infer () collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) -class NewType a b where - toNew :: a -> b - -instance NewType Type T.Type where - toNew = \case - TLit i -> T.TLit $ coerce i - TVar v -> T.TVar $ toNew v - TFun t1 t2 -> (T.TFun `on` toNew) t1 t2 - TAll b t -> T.TAll (toNew b) (toNew t) - TData i ts -> T.TData (coerce i) (map toNew ts) - TEVar _ -> error "Should not exist after typechecker" - -instance NewType Lit T.Lit where - toNew (LInt i) = T.LInt i - toNew (LChar i) = T.LChar i - -instance NewType Data T.Data where - toNew (Data t xs) = T.Data (name $ retType t) (toNew xs) - where - name (TData n _) = coerce n - name _ = error "Bug: Data types should not be able to be typed over non type variables" - -instance NewType Constructor T.Constructor where - toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs) - -instance NewType TVar T.TVar where - toNew (MkTVar i) = T.MkTVar $ coerce i - -instance NewType a b => NewType [a] [b] where - toNew = map toNew - -algoW :: Exp -> Infer (Subst, T.ExpT) +algoW :: Exp -> Infer (Subst, (T.ExpT' Type)) algoW = \case err@(EAnn e t) -> do (s1, (e', t')) <- exprErr (algoW e) err unless - (toNew t `isMoreSpecificOrEq` t') + (t `isMoreSpecificOrEq` t') ( throwError $ unwords [ "Annotated type:" @@ -236,34 +205,34 @@ algoW = \case ] ) applySt s1 $ do - s2 <- exprErr (unify (toNew t) t') err + s2 <- exprErr (unify (t) t') err let comp = s2 `compose` s1 - return (comp, apply comp (e', toNew t)) + return (comp, apply comp (e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ - ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit)) + ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit)) -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ EVar i -> do var <- asks vars case M.lookup (coerce i) var of - Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x)) + Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x)) Nothing -> do sig <- gets sigs case M.lookup (coerce i) sig of - Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) + Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) Just Nothing -> do fr <- fresh insertSig (coerce i) (Just fr) - return (nullSubst, (T.EId $ coerce i, fr)) + return (nullSubst, (T.EVar $ coerce i, fr)) Nothing -> throwError $ "Unbound variable: " <> printTree i EInj i -> do constr <- gets constructors case M.lookup (coerce i) constr of - Just t -> return (nullSubst, (T.EId $ coerce i, t)) + Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Nothing -> throwError $ "Constructor: '" @@ -280,7 +249,7 @@ algoW = \case ( withBinding (coerce name) fr $ do (s1, (e', t')) <- exprErr (algoW e) err let varType = apply s1 fr - let newArr = T.TFun varType t' + let newArr = TFun varType t' return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) ) err @@ -314,7 +283,7 @@ algoW = \case (s0, (e0', t0)) <- algoW e0 applySt s0 $ do (s1, (e1', t1)) <- algoW e1 - s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err + s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0 return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) @@ -346,33 +315,33 @@ makeLambda :: Exp -> [T.Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) -- | Unify two types producing a new substitution -unify :: T.Type -> T.Type -> Infer Subst +unify :: Type -> Type -> Infer Subst unify t0 t1 = do case (t0, t1) of - (T.TFun a b, T.TFun c d) -> do + (TFun a b, TFun c d) -> do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- - (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t - (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t + (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t + (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t ------------------------------------------------------------------- - (T.TVar (T.MkTVar a), t) -> occurs a t - (t, T.TVar (T.MkTVar b)) -> occurs b t - (T.TAll _ t, b) -> unify t b - (a, T.TAll _ t) -> unify a t - (T.TLit a, T.TLit b) -> + (TVar (T.MkTVar a), t) -> occurs (coerce a) t + (t, TVar (T.MkTVar b)) -> occurs (coerce b) t + (TAll _ t, b) -> unify t b + (a, TAll _ t) -> unify a t + (TLit a, TLit b) -> if a == b then return M.empty else throwError . unwords $ [ "Can not unify" - , "'" <> printTree (T.TLit a) <> "'" + , "'" <> printTree (TLit a) <> "'" , "with" - , "'" <> printTree (T.TLit b) <> "'" + , "'" <> printTree (TLit b) <> "'" ] - (T.TData name t, T.TData name' t') -> + (TData name t, TData name' t') -> if name == name' && length t == length t' then do xs <- zipWithM unify t t' @@ -380,7 +349,7 @@ unify t0 t1 = do else throwError $ unwords - [ "T.Type constructor:" + [ "Type constructor:" , printTree name , "(" <> printTree t <> ")" , "does not match with:" @@ -398,42 +367,42 @@ unify t0 t1 = do I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} -occurs :: T.Ident -> T.Type -> Infer Subst -occurs i t@(T.TVar _) = return (M.singleton i t) +occurs :: T.Ident -> Type -> Infer Subst +occurs i t@(TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) then throwError $ unwords [ "Occurs check failed, can't unify" - , printTree (T.TVar $ T.MkTVar i) + , printTree (TVar $ T.MkTVar (coerce i)) , "with" , printTree t ] else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set -generalize :: Map T.Ident T.Type -> T.Type -> T.Type +generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where - go :: [T.Ident] -> T.Type -> T.Type - go [] t = t - go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) - removeForalls :: T.Type -> T.Type - removeForalls (T.TAll _ t) = removeForalls t - removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) - removeForalls t = t + go :: [T.Ident] -> Type -> Type + go [] t = t + go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t) + removeForalls :: Type -> Type + removeForalls (TAll _ t) = removeForalls t + removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. -} -inst :: T.Type -> Infer T.Type +inst :: Type -> Infer Type inst = \case - T.TAll (T.MkTVar bound) t -> do + TAll (T.MkTVar bound) t -> do fr <- fresh - let s = M.singleton bound fr + let s = M.singleton (coerce bound) fr apply s <$> inst t - T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2 + TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest -- | Compose two substitution sets @@ -455,41 +424,40 @@ class FreeVars t where -- | Get all free variables from t free :: t -> Set T.Ident -instance FreeVars T.Type where - free :: T.Type -> Set T.Ident - free (T.TVar (T.MkTVar a)) = S.singleton a - free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t - free (T.TLit _) = mempty - free (T.TFun a b) = free a `S.union` free b +instance FreeVars Type where + free :: Type -> Set T.Ident + free (TVar (T.MkTVar a)) = S.singleton (coerce a) + free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t + free (TLit _) = mempty + free (TFun a b) = free a `S.union` free b -- \| Not guaranteed to be correct - free (T.TData _ a) = + free (TData _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a -instance SubstType T.Type where - apply :: Subst -> T.Type -> T.Type +instance SubstType Type where + apply :: Subst -> Type -> Type apply sub t = do case t of - T.TLit a -> T.TLit a - T.TVar (T.MkTVar a) -> case M.lookup a sub of - Nothing -> T.TVar (T.MkTVar $ coerce a) - Just t -> t - T.TAll (T.MkTVar i) t -> case M.lookup i sub of - Nothing -> T.TAll (T.MkTVar i) (apply sub t) - Just _ -> apply sub t - T.TFun a b -> T.TFun (apply sub a) (apply sub b) - T.TData name a -> T.TData name (map (apply sub) a) -instance FreeVars (Map T.Ident T.Type) where - free :: Map T.Ident T.Type -> Set T.Ident + TLit a -> TLit a + TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of + Nothing -> TVar (T.MkTVar $ coerce a) + Just t -> t + TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of + Nothing -> TAll (T.MkTVar i) (apply sub t) + Just _ -> apply sub t + TFun a b -> TFun (apply sub a) (apply sub b) + TData name a -> TData name (map (apply sub) a) +instance FreeVars (Map T.Ident Type) where + free :: Map T.Ident Type -> Set T.Ident free m = foldl' S.union S.empty (map free $ M.elems m) -instance SubstType (Map T.Ident T.Type) where - apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type +instance SubstType (Map T.Ident Type) where + apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply s = M.map (apply s) -instance SubstType T.Exp where - apply :: Subst -> T.Exp -> T.Exp +instance SubstType (T.Exp' Type) where apply s = \case - T.EId i -> T.EId i + T.EVar i -> T.EVar i T.ELit lit -> T.ELit lit T.ELet (T.Bind (ident, t1) args e1) e2 -> T.ELet @@ -499,19 +467,18 @@ instance SubstType T.Exp where T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) T.EAbs ident e -> T.EAbs ident (apply s e) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) + T.EInj{} -> error "implement" -instance SubstType T.Branch where - apply :: Subst -> T.Branch -> T.Branch +instance SubstType (T.Branch' Type) where apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) -instance SubstType T.Pattern where - apply :: Subst -> T.Pattern -> T.Pattern +instance SubstType (T.Pattern' Type) where apply s = \case T.PVar (iden, t) -> T.PVar (iden, apply s t) - T.PLit (lit, t) -> T.PLit (lit, apply s t) - T.PInj i ps -> T.PInj i $ apply s ps - T.PCatch -> T.PCatch - T.PEnum i -> T.PEnum i + T.PLit (lit, t) -> T.PLit (lit, apply s t) + T.PInj i ps -> T.PInj i $ apply s ps + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i instance SubstType a => SubstType [a] where apply s = map (apply s) @@ -519,7 +486,7 @@ instance SubstType a => SubstType [a] where instance (SubstType a, SubstType b) => SubstType (a, b) where apply s (a, b) = (apply s a, apply s b) -instance SubstType T.Id where +instance SubstType (T.Id' Type) where apply s (name, t) = (name, apply s t) -- | Apply substitutions to the environment. @@ -531,7 +498,7 @@ nullSubst :: Subst nullSubst = M.empty -- | Generate a new fresh variable and increment the state counter -fresh :: Infer T.Type +fresh :: Infer Type fresh = do c <- gets nextChar n <- gets count @@ -545,34 +512,34 @@ fresh = do fresh else if n == 0 - then return . T.TVar . T.MkTVar . T.Ident $ [c] - else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n + then return . TVar . T.MkTVar $ LIdent [c] + else return . TVar . T.MkTVar . LIdent $ [c] ++ show n next :: Char -> Char next 'z' = 'a' -next a = succ a +next a = succ a -- | Run the monadic action with an additional binding -withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a +withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) -- | Run the monadic action with several additional bindings -withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a +withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Insert a function signature into the environment -insertSig :: T.Ident -> Maybe T.Type -> Infer () +insertSig :: T.Ident -> Maybe Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type -insertConstr :: T.Ident -> T.Type -> Infer () +insertConstr :: T.Ident -> Type -> Infer () insertConstr i t = modify (\st -> st{constructors = M.insert i t (constructors st)}) -------- PATTERN MATCHING --------- -checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) +checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) checkCase _ [] = throwError "Atleast one case required" checkCase expT brnchs = do (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs @@ -594,23 +561,23 @@ checkCase expT brnchs = do let comp = sub2 `compose` sub1 `compose` sub0 return (comp, apply comp injs, apply comp returns_type) -inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type) +inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type) inferBranch (Branch pat expr) = do newPat@(pat, branchT) <- inferPattern pat (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) -withPattern :: T.Pattern -> Infer a -> Infer a +withPattern :: T.Pattern' Type -> Infer a -> Infer a withPattern p ma = case p of T.PVar (x, t) -> withBinding x t ma - T.PInj _ ps -> foldl' (flip withPattern) ma ps - T.PLit _ -> ma - T.PCatch -> ma - T.PEnum _ -> ma + T.PInj _ ps -> foldl' (flip withPattern) ma ps + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma -inferPattern :: Pattern -> Infer (T.Pattern, T.Type) +inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) inferPattern = \case - PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt) + PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) PInj constr patterns -> do t <- gets (M.lookup (coerce constr) . constructors) t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t @@ -644,28 +611,28 @@ inferPattern = \case ++ show (typeLength t - 1) ++ " arguments but has been given 0" ) - let (T.TData _data _ts) = t -- nasty nasty + let (TData _data _ts) = t -- nasty nasty frs <- mapM (const fresh) _ts - return (T.PEnum $ coerce p, T.TData _data frs) + return (T.PEnum $ coerce p, TData _data frs) PVar x -> do fr <- fresh let pvar = T.PVar (coerce x, fr) return (pvar, fr) -flattenType :: T.Type -> [T.Type] -flattenType (T.TFun a b) = flattenType a <> flattenType b -flattenType a = [a] +flattenType :: Type -> [Type] +flattenType (TFun a b) = flattenType a <> flattenType b +flattenType a = [a] -typeLength :: T.Type -> Int -typeLength (T.TFun a b) = typeLength a + typeLength b -typeLength _ = 1 +typeLength :: Type -> Int +typeLength (TFun a b) = typeLength a + typeLength b +typeLength _ = 1 -litType :: Lit -> T.Type -litType (LInt _) = int +litType :: Lit -> Type +litType (LInt _) = int litType (LChar _) = char -int = T.TLit "Int" -char = T.TLit "Char" +int = TLit "Int" +char = TLit "Char" partitionType :: Int -> -- Number of parameters to apply @@ -676,8 +643,8 @@ partitionType = go [] go acc 0 t = (acc, t) go acc i t = case t of TAll tvar t' -> second (TAll tvar) $ go acc i t' - TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" + TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" exprErr :: Infer a -> Exp -> Infer a exprErr ma exp = @@ -691,3 +658,19 @@ unzip4 = ) ([], [], [], []) +newtype Ctx = Ctx {vars :: Map T.Ident Type} + deriving (Show) + +data Env = Env + { count :: Int + , nextChar :: Char + , sigs :: Map T.Ident (Maybe Type) + , constructors :: Map T.Ident Type + , takenTypeVars :: Set T.Ident + } + deriving (Show) + +type Error = String +type Subst = Map T.Ident Type + +type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 05949c9..46d1127 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,24 +1,24 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} +module TypeChecker.TypeCheckerIr ( + module Grammar.Abs, + module TypeChecker.TypeCheckerIr, +) where -module TypeChecker.TypeCheckerIr - ( module Grammar.Abs - , module TypeChecker.TypeCheckerIr - ) where - -import Data.String (IsString) -import Grammar.Abs (Lit (..), TVar (..)) -import Grammar.Print -import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) +import Data.String (IsString) +import Grammar.Abs (Lit (..), TVar (..)) +import Grammar.Print +import Prelude +import Prelude qualified as C (Eq, Ord, Read, Show) newtype Program' t = Program [Def' t] - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) -data Def' t = DBind (Bind' t) - | DData (Data' t) - deriving (C.Eq, C.Ord, C.Show, C.Read) +data Def' t + = DBind (Bind' t) + | DData (Data' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) data Type = TLit Ident @@ -26,24 +26,24 @@ data Type | TData Ident [Type] | TFun Type Type | TAll TVar Type - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) data Data' t = Data t [Inj' t] - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) data Inj' t = Inj Ident t - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) newtype Ident = Ident String - deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) + deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) data Pattern' t - = PVar (Id' t) -- TODO should be Ident - | PLit (Lit, t) -- TODO should be Lit + = PVar (Id' t) -- TODO should be Ident + | PLit (Lit, t) -- TODO should be Lit | PCatch | PEnum Ident | PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t) - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) data Exp' t = EVar Ident @@ -52,18 +52,18 @@ data Exp' t | ELet (Bind' t) (ExpT' t) | EApp (ExpT' t) (ExpT' t) | EAdd (ExpT' t) (ExpT' t) - | EAbs Ident (ExpT' t) + | EAbs Ident (ExpT' t) | ECase (ExpT' t) [Branch' t] - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) -type Id' t = (Ident, t) +type Id' t = (Ident, t) type ExpT' t = (Exp' t, t) data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) deriving (C.Eq, C.Ord, C.Show, C.Read) data Branch' t = Branch (Pattern' t, t) (ExpT' t) - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) instance Print Ident where prt i (Ident s) = prt i s @@ -72,127 +72,143 @@ instance Print t => Print (Program' t) where prt i (Program sc) = prPrec i 0 $ prt 0 sc instance Print t => Print (Bind' t) where - prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD - [ prtSig sig - , prt 0 name - , prtIdPs 0 parms - , doc $ showString "=" - , prt 0 rhs - ] + prt i (Bind sig@(name, _) parms rhs) = + prPrec i 0 $ + concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] prtSig :: Print t => Id' t -> Doc -prtSig (name, t) = concatD [ prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ";" - ] +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] instance Print t => Print (ExpT' t) where - prt i (e, t) = concatD [ doc $ showString "(" - , prt i e - , doc $ showString "," - , prt i t - , doc $ showString ")" - ] + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] instance Print t => Print [Bind' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prtIdPs :: Print t => Int -> [Id' t] -> Doc prtIdPs i = prPrec i 0 . concatD . map (prt i) instance Print t => Print (Id' t) where - prt i (name, t) = concatD [ doc $ showString "(" - , prt i name - , doc $ showString "," - , prt i t - , doc $ showString ")" - ] + prt i (name, t) = + concatD + [ doc $ showString "(" + , prt i name + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] instance Print t => Print (Exp' t) where - prt i = \case - EVar name -> prPrec i 3 $ prt 0 name - EInj name -> prPrec i 3 $ prt 0 name - ELit lit -> prPrec i 3 $ prt 0 lit - ELet b e -> prPrec i 3 $ concatD - [ doc $ showString "let" - , prt 0 b - , doc $ showString "in" - , prt 0 e - ] - EApp e1 e2 -> prPrec i 2 $ concatD - [ prt 2 e1 - , prt 3 e2 - ] - EAdd e1 e2 -> prPrec i 1 $ concatD - [ prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - EAbs v e -> prPrec i 0 $ concatD - [ doc $ showString "\\" - , prt 0 v - , doc $ showString "." - , prt 0 e - ] - - ECase e branches -> prPrec i 0 $ concatD - [ doc $ showString "case" - , prt 0 e - , doc $ showString "of" - , doc $ showString "{" - , prt 0 branches - , doc $ showString "}" - ] - + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + EInj name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> + prPrec i 1 $ + concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs v e -> + prPrec i 0 $ + concatD + [ doc $ showString "\\" + , prt 0 v + , doc $ showString "." + , prt 0 e + ] + ECase e branches -> + prPrec i 0 $ + concatD + [ doc $ showString "case" + , prt 0 e + , doc $ showString "of" + , doc $ showString "{" + , prt 0 branches + , doc $ showString "}" + ] instance Print t => Print (Branch' t) where - prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) instance Print t => Print [Branch' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print t => Print (Def' t) where - prt i = \case - DBind bind -> prPrec i 0 (concatD [prt 0 bind]) - DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) instance Print t => Print (Data' t) where - prt i = \case - Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) instance Print t => Print (Inj' t) where - prt i = \case - Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) instance Print t => Print (Pattern' t) where - prt i = \case - PVar name -> prPrec i 1 (concatD [prt 0 name]) - PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) - PCatch -> prPrec i 1 (concatD [doc (showString "_")]) - PEnum name -> prPrec i 1 (concatD [prt 0 name]) - PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) instance Print t => Print [Def' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print [Type] where - prt _ [] = concatD [] - prt _ (x:xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] instance Print Type where - prt i = \case - TLit uident -> prPrec i 1 (concatD [prt 0 uident]) - TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) - TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) - TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) - TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) + TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) type Program = Program' Type type Def = Def' Type @@ -201,9 +217,8 @@ type Bind = Bind' Type type Branch = Branch' Type type Pattern = Pattern' Type type Inj = Inj' Type -type Exp = Exp' Type +type Exp = Exp' Type type ExpT = ExpT' Type -type Id = Id' Type +type Id = Id' Type pattern DBind' id vars expt = DBind (Bind id vars expt) -pattern DData' typ injs = DData (Data typ injs) - +pattern DData' typ injs = DData (Data typ injs)