From 3335ab7a57c4420a65bce8f8de1f3692b86f417c Mon Sep 17 00:00:00 2001 From: sebastian Date: Wed, 22 Mar 2023 21:26:14 +0100 Subject: [PATCH] compatible, EId rule for parsing is not working, testing not done yet --- Grammar.cf | 4 +- src/Main.hs | 8 +- src/Renamer/Renamer.hs | 35 ++--- src/TypeChecker/TypeChecker.hs | 212 ++++++++++++++++--------------- src/TypeChecker/TypeCheckerIr.hs | 28 ++-- 5 files changed, 146 insertions(+), 141 deletions(-) diff --git a/Grammar.cf b/Grammar.cf index 27dfd05..28696c6 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -51,14 +51,14 @@ ELit. Exp4 ::= Lit ; EApp. Exp3 ::= Exp3 Exp4 ; EAdd. Exp1 ::= Exp1 "+" Exp2 ; ELet. Exp ::= "let" LIdent "=" Exp "in" Exp ; -EAbs. Exp ::= "\\" LIdent "." Exp ; +EAbs. Exp ::= "\\" Ident "." Exp ; ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; ------------------------------------------------------------------------------- -- * LITERALS ------------------------------------------------------------------------------- -LInt. Lit ::= Integer ; +LInt. Lit ::= Integer ; LChar. Lit ::= Char ; ------------------------------------------------------------------------------- diff --git a/src/Main.hs b/src/Main.hs index 0a00cd6..5a96404 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -12,7 +12,7 @@ import Renamer.Renamer (rename) import System.Environment (getArgs) import System.Exit (exitFailure, exitSuccess) --- import TypeChecker.TypeChecker (typecheck) +import TypeChecker.TypeChecker (typecheck) main :: IO () main = @@ -32,9 +32,9 @@ main' s = do renamed <- fromRenamerErr . rename $ parsed putStrLn $ printTree renamed - -- putStrLn "\n-- TypeChecker --" - -- typechecked <- fromTypeCheckerErr $ typecheck renamed - -- putStrLn $ show typechecked + putStrLn "\n-- TypeChecker --" + typechecked <- fromTypeCheckerErr $ typecheck renamed + putStrLn $ printTree typechecked -- putStrLn "\n-- Lambda Lifter --" -- let lifted = lambdaLift typechecked diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index aac8b16..9f69185 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -21,6 +21,7 @@ import Data.Map (Map) import Data.Map qualified as Map import Data.Maybe (fromMaybe) import Data.Tuple.Extra (dupe) +import Data.Coerce (coerce) import Grammar.Abs -- | Rename all variables and local binds @@ -30,15 +31,15 @@ rename (Program defs) = Program <$> renameDefs defs renameDefs :: [Def] -> Either String [Def] renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt where - initNames = Map.fromList [dupe name | DBind (Bind name _ _) <- defs] + initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs] renameDef :: Def -> Rn Def renameDef = \case DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DBind (Bind name vars rhs) -> do - (new_names, vars') <- newNames initNames vars + (new_names, vars') <- newNames initNames (coerce vars) rhs' <- snd <$> renameExp new_names rhs - pure . DBind $ Bind name vars' rhs' + pure . DBind $ Bind name (coerce vars') rhs' DData (Data (Indexed cname types) constrs) -> do tvars' <- mapM nextNameTVar tvars let tvars_lt = zip tvars tvars' @@ -90,11 +91,11 @@ newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a} deriving (Functor, Applicative, Monad, MonadState Cxt) -- | Maps old to new name -type Names = Map LIdent LIdent +type Names = Map Ident Ident renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp old_names = \case - EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) + EId n -> pure (coerce old_names, EId . fromMaybe n $ Map.lookup n (coerce old_names)) ELit lit -> pure (old_names, ELit lit) EApp e1 e2 -> do (env1, e1') <- renameExp old_names e1 @@ -107,14 +108,14 @@ renameExp old_names = \case -- TODO fix shadowing ELet name rhs e -> do - (new_names, name') <- newName old_names name + (new_names, name') <- newName old_names (coerce name) (new_names', rhs') <- renameExp new_names rhs (new_names'', e') <- renameExp new_names' e - pure (new_names'', ELet name' rhs' e') + pure (new_names'', ELet (coerce name') rhs' e') EAbs par e -> do - (new_names, par') <- newName old_names par + (new_names, par') <- newName old_names (coerce par) (new_names', e') <- renameExp new_names e - pure (new_names', EAbs par' e') + pure (new_names', EAbs (coerce par') e') EAnn e t -> do (new_names, e') <- renameExp old_names e t' <- renameTVars t @@ -138,8 +139,8 @@ renameInj ns (Inj init e) = do renameInit :: Names -> Init -> Rn (Names, Init) renameInit ns i = case i of InitConstructor cs vars -> do - (ns_new, vars') <- newNames ns vars - return (ns_new, InitConstructor cs vars') + (ns_new, vars') <- newNames ns (coerce vars) + return (ns_new, InitConstructor cs (coerce vars')) rest -> return (ns, rest) renameTVars :: Type -> Rn Type @@ -169,26 +170,26 @@ substitute tvar1 tvar2 typ = case typ of substitute' = substitute tvar1 tvar2 -- | Create a new name and add it to name environment. -newName :: Names -> LIdent -> Rn (Names, LIdent) +newName :: Names -> Ident -> Rn (Names, Ident) newName env old_name = do new_name <- makeName old_name pure (Map.insert old_name new_name env, new_name) -- | Create multiple names and add them to the name environment -newNames :: Names -> [LIdent] -> Rn (Names, [LIdent]) +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) newNames = mapAccumM newName -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -makeName :: LIdent -> Rn LIdent -makeName (LIdent prefix) = do +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = do i <- gets var_counter - let name = LIdent $ prefix ++ "_" ++ show i + let name = Ident $ prefix ++ "_" ++ show i modify $ \cxt -> cxt{var_counter = succ cxt.var_counter} pure name nextNameTVar :: TVar -> Rn TVar nextNameTVar (MkTVar (LIdent s)) = do i <- gets tvar_counter - let tvar = MkTVar . LIdent $ s ++ "_" ++ show i + let tvar = MkTVar $ coerce $ s ++ "_" ++ show i modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter} pure tvar diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 4b9269d..a3929b5 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -57,7 +57,7 @@ checkData d = do traverse_ ( \(Constructor name' t') -> if TIndexed typ == retType t' - then insertConstr name' t' + then insertConstr (coerce name') (toNew t') else throwError $ unwords @@ -85,7 +85,7 @@ checkPrg (Program bs) = do preRun [] = return () preRun (x : xs) = case x of -- TODO: Check for no overlapping signature definitions - DSig (Sig n t) -> insertSig n t >> preRun xs + DSig (Sig n t) -> insertSig (coerce n) (toNew t) >> preRun xs DBind (Bind{}) -> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs @@ -100,13 +100,13 @@ checkPrg (Program bs) = do checkBind :: Bind -> Infer T.Bind checkBind (Bind name args e) = do - let lambda = makeLambda e (reverse args) + let lambda = makeLambda e (reverse $ coerce args) e@(_, t') <- inferExp lambda -- TODO: Check for match against existing signatures return $ T.Bind (coerce name, t') [] e -- (apply s e) where - makeLambda :: Exp -> [LIdent] -> Exp - makeLambda = foldl (flip EAbs) + makeLambda :: Exp -> [Ident] -> Exp + makeLambda = foldl (flip (EAbs . coerce)) {- | Check if two types are considered equal For the purpose of the algorithm two polymorphic types are always considered @@ -138,7 +138,7 @@ isPoly _ = False inferExp :: Exp -> Infer T.ExpT inferExp e = do - (s, t, e') <- algoW e + (s, (e', t)) <- algoW e let subbed = apply s t return $ replace subbed (e', t) @@ -151,15 +151,18 @@ class NewType a b where instance NewType Type T.Type where toNew = \case TLit i -> T.TLit $ coerce i - TVar v -> T.TVar v + TVar v -> T.TVar $ toNew v TFun t1 t2 -> T.TFun (toNew t1) (toNew t2) - TAll b t -> T.TAll b (toNew t) + TAll b t -> T.TAll (toNew b) (toNew t) TIndexed i -> T.TIndexed (toNew i) TEVar _ -> error "Should not exist after typechecker" instance NewType Indexed T.Indexed where toNew (Indexed name vars) = T.Indexed (coerce name) (map toNew vars) +instance NewType TVar T.TVar where + toNew (MkTVar i) = T.MkTVar $ coerce i + algoW :: Exp -> Infer (Subst, T.ExpT) algoW = \case -- \| TODO: More testing need to be done. Unsure of the correctness of this @@ -178,14 +181,14 @@ algoW = \case applySt s1 $ do s2 <- unify (toNew t) t' let comp = s2 `compose` s1 - return (comp, (apply comp e', toNew t)) + return (comp, apply comp (e', toNew t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ ELit lit -> - let lt = toNew $ litType lit - in return (nullSubst, (T.ELit lt lit, lt)) + let lt = litType lit + in return (nullSubst, (T.ELit lit, lt)) -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ @@ -193,15 +196,15 @@ algoW = \case EId i -> do var <- asks vars case M.lookup i var of - Just t -> inst (toNew t) >>= \x -> return (nullSubst, x, T.EId (i, x)) + Just t -> inst t >>= \(x) -> return (nullSubst, (T.EId (i, x), x)) Nothing -> do sig <- gets sigs case M.lookup i sig of - Just t -> return (nullSubst, toNew t, T.EId (i, toNew t)) + Just t -> return (nullSubst, (T.EId (i, t), t)) Nothing -> do constr <- gets constructors case M.lookup i constr of - Just t -> return (nullSubst, toNew t, T.EId (i, toNew t)) + Just t -> return (nullSubst, (T.EId (i, t), t)) Nothing -> throwError $ "Unbound variable: " ++ show i @@ -212,11 +215,11 @@ algoW = \case EAbs name e -> do fr <- fresh - withBinding (coerce name) (Forall [] (toNew fr)) $ do - (s1, t', e') <- algoW e - let varType = toNew $ apply s1 fr - let newArr = T.TFun varType (toNew t') - return (s1, newArr, apply s1 $ T.EAbs newArr (coerce name, varType) (e', newArr)) + withBinding (coerce name) fr $ do + (s1, (e', t')) <- algoW e + let varType = apply s1 fr + let newArr = T.TFun varType t' + return (s1, apply s1 $ (T.EAbs (coerce name, varType) (e', newArr), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -225,17 +228,16 @@ algoW = \case -- This might be wrong EAdd e0 e1 -> do - (s1, t0, e0') <- algoW e0 + (s1, (e0', t0)) <- algoW e0 applySt s1 $ do - (s2, t1, e1') <- algoW e1 + (s2, (e1', t1)) <- algoW e1 -- applySt s2 $ do - s3 <- unify (apply s2 t0) (T.TLit "Int") - s4 <- unify (apply s3 t1) (T.TLit "Int") + s3 <- unify (apply s2 t0) int + s4 <- unify (apply s3 t1) int let comp = s4 `compose` s3 `compose` s2 `compose` s1 return ( comp - , T.TLit "Int" - , apply comp $ T.EAdd (T.TLit "Int") (e0', t0) (e1', t1) + , apply comp $ (T.EAdd (e0', t0) (e1', t1), int) ) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 @@ -244,15 +246,15 @@ algoW = \case -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ EApp e0 e1 -> do - fr <- toNew <$> fresh - (s0, t0, e0') <- algoW e0 + fr <- fresh + (s0, (e0', t0)) <- algoW e0 applySt s0 $ do - (s1, t1, e1') <- algoW e1 + (s1, (e1', t1)) <- algoW e1 -- applySt s1 $ do - s2 <- unify (apply s1 t0) (T.TFun (toNew t1) fr) + s2 <- unify (apply s1 t0) (T.TFun t1 fr) let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0 - return (comp, t, apply comp $ T.EApp t (e0', t0) (e1', t1)) + return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -261,19 +263,21 @@ algoW = \case -- The bar over S₀ and Γ means "generalize" ELet name e0 e1 -> do - (s1, t1, e0') <- algoW e0 + (s1, (e0', t1)) <- algoW e0 env <- asks vars let t' = generalize (apply s1 env) t1 - withBinding name t' $ do - (s2, t2, e1') <- algoW e1 + withBinding (coerce name) t' $ do + (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 - return (comp, t2, apply comp $ T.ELet (T.Bind (name, t2) e0') e1') + return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2)) + + -- \| TODO: Add judgement ECase caseExpr injs -> do - (sub, t, e') <- algoW caseExpr + (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub let t' = apply comp ret_t - return (comp, t', T.ECase t' e' injs) + return (comp, (T.ECase (e', t) injs, t')) -- | Unify two types producing a new substitution unify :: T.Type -> T.Type -> Infer Subst @@ -283,8 +287,8 @@ unify t0 t1 = do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 - (T.TVar t, b) -> occurs b t - (a, T.TVar t) -> occurs a t + (T.TVar (T.MkTVar a), t) -> occurs a t + (t, T.TVar (T.MkTVar b)) -> occurs b t (T.TAll _ t, b) -> unify t b (a, T.TAll _ t) -> unify a t (T.TLit a, T.TLit b) -> @@ -298,20 +302,20 @@ unify t0 t1 = do throwError $ unwords [ "T.Type constructor:" - , printT . Tree name - , "(" ++ printT . Tree t ++ ")" + , printTree name + , "(" ++ printTree t ++ ")" , "does not match with:" - , printT . Tree name' - , "(" ++ printT . Tree t' ++ ")" + , printTree name' + , "(" ++ printTree t' ++ ")" ] (a, b) -> do ctx <- ask env <- get throwError . unwords $ [ "T.Type:" - , printT . Tree a + , printTree a , "can't be unified with:" - , printT . Tree b + , printTree b , "\nCtx:" , show ctx , "\nEnv:" @@ -322,7 +326,7 @@ unify t0 t1 = do I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} -occurs :: LIdent -> T.Type -> Infer Subst +occurs :: Ident -> T.Type -> Infer Subst occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) @@ -330,26 +334,37 @@ occurs i t = throwError $ unwords [ "Occurs check failed, can't unify" - , printTree (TVar $ MkTVar i) + , printTree (T.TVar $ T.MkTVar i) , "with" , printTree t ] else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set -generalize :: Map Ident Poly -> Type -> Poly -generalize env t = Forall (S.toList $ free t S.\\ free env) t +generalize :: Map Ident T.Type -> T.Type -> T.Type +generalize env t = go freeVars $ removeForalls t + where + freeVars :: [Ident] + freeVars = S.toList $ free t S.\\ free env + go :: [Ident] -> T.Type -> T.Type + go [] t = t + go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) + removeForalls :: T.Type -> T.Type + removeForalls (T.TAll _ t) = removeForalls t + removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. -} inst :: T.Type -> Infer T.Type inst = \case - T.TAll bound t -> do + T.TAll (T.MkTVar bound) t -> do fr <- fresh - let s = M.singleton fr bound + let s = M.singleton bound fr apply s <$> inst t - _ -> undefined + T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2 + rest -> return rest -- | Compose two substitution sets compose :: Subst -> Subst -> Subst @@ -361,15 +376,15 @@ compose m1 m2 = M.map (apply m1) m2 `M.union` m1 -- | A class representing free variables functions class FreeVars t where -- | Get all free variables from t - free :: t -> Set LIdent + free :: t -> Set Ident -- | Apply a substitution to t apply :: Subst -> t -> t instance FreeVars T.Type where - free :: T.Type -> Set LIdent - free (T.TVar (MkTVar a)) = S.singleton a - free (T.TAll (MkTVar bound) t) = (S.singleton bound) `S.intersection` free t + free :: T.Type -> Set Ident + free (T.TVar (T.MkTVar a)) = S.singleton a + free (T.TAll (T.MkTVar bound) t) = (S.singleton bound) `S.intersection` free t free (T.TLit _) = mempty free (T.TFun a b) = free a `S.union` free b -- \| Not guaranteed to be correct @@ -380,53 +395,40 @@ instance FreeVars T.Type where apply sub t = do case t of T.TLit a -> T.TLit a - T.TVar (MkTVar a) -> case M.lookup a sub of - Nothing -> T.TVar (MkTVar a) + T.TVar (T.MkTVar a) -> case M.lookup a sub of + Nothing -> T.TVar (T.MkTVar $ coerce a) Just t -> t T.TAll bound t -> undefined T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TIndexed (T.Indexed name a) -> T.TIndexed (T.Indexed name (map (apply sub) a)) -instance FreeVars Poly where - free :: Poly -> Set LIdent - free (Forall xs t) = free t S.\\ S.fromList xs - apply :: Subst -> Poly -> Poly - apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) - -instance FreeVars (Map LIdent Poly) where - free :: Map LIdent Poly -> Set LIdent +instance FreeVars (Map Ident T.Type) where + free :: Map Ident T.Type -> Set Ident free m = foldl' S.union S.empty (map free $ M.elems m) - apply :: Subst -> Map LIdent Poly -> Map LIdent Poly + apply :: Subst -> Map Ident T.Type -> Map Ident T.Type apply s = M.map (apply s) -instance FreeVars T.Exp where - free :: T.Exp -> Set LIdent +instance FreeVars T.ExpT where + free :: T.ExpT -> Set Ident free = error "free not implemented for T.Exp" - apply :: Subst -> T.Exp -> T.Exp + apply :: Subst -> T.ExpT -> T.ExpT apply s = \case - T.EId (ident, t) -> - T.EId (ident, apply s t) - T.ELit t lit -> - T.ELit (apply s t) lit - T.ELet (T.Bind (ident, t) args e1) e2 -> - T.ELet (T.Bind (ident, apply s t) args (apply s e1)) (apply s e2) - T.EApp t e1 e2 -> - T.EApp (apply s t) (apply s e1) (apply s e2) - T.EAdd t e1 e2 -> - T.EAdd (apply s t) (apply s e1) (apply s e2) - T.EAbs t1 (ident, t2) e -> - T.EAbs (apply s t1) (ident, apply s t2) (apply s e) - T.ECase t e injs -> - T.ECase (apply s t) (apply s e) (apply s injs) + (T.EId (i, innerT), outerT) -> (T.EId (i, apply s innerT), apply s outerT) + (T.ELit lit, t) -> (T.ELit lit, apply s t) + (T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2) + (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), (apply s t)) + (T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), (apply s t)) + (T.EAbs (ident, t2) e, t1) -> (T.EAbs (ident, apply s t2) (apply s e), (apply s t1)) + (T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), (apply s t)) instance FreeVars T.Inj where - free :: T.Inj -> Set LIdent + free :: T.Inj -> Set Ident free = undefined apply :: Subst -> T.Inj -> T.Inj apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e) instance FreeVars [T.Inj] where - free :: [T.Inj] -> Set LIdent + free :: [T.Inj] -> Set Ident free = foldl' (\acc x -> free x `S.union` acc) mempty apply s = map (apply s) @@ -439,33 +441,33 @@ nullSubst :: Subst nullSubst = M.empty -- | Generate a new fresh variable and increment the state counter -fresh :: Infer Type +fresh :: Infer T.Type fresh = do n <- gets count modify (\st -> st{count = n + 1}) - return . TVar . MkTVar . LIdent $ show n + return . T.TVar . T.MkTVar . Ident $ show n -- | Run the monadic action with an additional binding -withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a +withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) -- | Run the monadic action with several additional bindings -withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, Poly)] -> m a -> m a +withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Insert a function signature into the environment -insertSig :: LIdent -> Type -> Infer () +insertSig :: Ident -> T.Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type -insertConstr :: UIdent -> Type -> Infer () +insertConstr :: Ident -> T.Type -> Infer () insertConstr i t = modify (\st -> st{constructors = M.insert i t (constructors st)}) -------- PATTERN MATCHING --------- -checkCase :: Type -> [Inj] -> Infer (Subst, [T.Inj], Type) +checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type) checkCase expT injs = do (injTs, injs, returns) <- unzip3 <$> mapM checkInj injs (sub1, _) <- @@ -487,18 +489,17 @@ checkCase expT injs = do {- | fst = type of init | snd = type of expr -} -checkInj :: Inj -> Infer (Type, T.Inj, Type) +checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type) checkInj (Inj it expr) = do (initT, vars) <- inferInit it - let converted = map (second (Forall [])) vars - (exprT, e) <- withBindings converted (inferExp expr) - return (initT, T.Inj (it, initT) e, exprT) + (e, exprT) <- withBindings vars (inferExp expr) + return (initT, T.Inj (it, initT) (e, exprT), exprT) -inferInit :: Init -> Infer (Type, [T.Id]) +inferInit :: Init -> Infer (T.Type, [T.Id]) inferInit = \case InitLit lit -> return (litType lit, mempty) InitConstructor fn vars -> do - gets (M.lookup fn . constructors) >>= \case + gets (M.lookup (coerce fn) . constructors) >>= \case Nothing -> throwError $ "Constructor: " ++ printTree fn ++ " does not exist" @@ -508,14 +509,17 @@ inferInit = \case Just (vs, ret) -> case length vars `compare` length vs of EQ -> do - return (ret, zip vars vs) + return (ret, zip (coerce vars) vs) _ -> throwError "Partial pattern match not allowed" InitCatch -> (,mempty) <$> fresh -flattenType :: Type -> [Type] -flattenType (TFun a b) = flattenType a ++ flattenType b +flattenType :: T.Type -> [T.Type] +flattenType (T.TFun a b) = flattenType a ++ flattenType b flattenType a = [a] -litType :: Lit -> Type -litType (LInt _) = TLit "Int" -litType (LChar _) = TLit "Char" +litType :: Lit -> T.Type +litType (LInt _) = int +litType (LChar _) = char + +int = T.TLit "Int" +char = T.TLit "Char" diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index bfc8a6a..9cf2059 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -12,9 +12,7 @@ import Grammar.Abs ( Ident (..), Init (..), Lit (..), - TVar (..), ) -import Grammar.Abs qualified as GA (Type (..)) import Grammar.Print import Prelude import Prelude qualified as C (Eq, Ord, Read, Show) @@ -23,13 +21,13 @@ import Prelude qualified as C (Eq, Ord, Read, Show) data Poly = Forall [Ident] Type deriving (Show) -newtype Ctx = Ctx {vars :: Map Ident Poly} +newtype Ctx = Ctx {vars :: Map Ident Type} deriving (Show) data Env = Env { count :: Int - , sigs :: Map Ident GA.Type - , constructors :: Map Ident GA.Type + , sigs :: Map Ident Type + , constructors :: Map Ident Type } deriving (Show) @@ -41,6 +39,9 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) newtype Program = Program [Def] deriving (C.Eq, C.Ord, C.Show, C.Read) +data TVar = MkTVar Ident + deriving (Show, Eq, Ord, Read) + data Type = TLit Ident | TVar TVar @@ -130,7 +131,7 @@ prtIdP i (name, t) = instance Print Exp where prt i = \case EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"] - ELit _ lit -> prPrec i 3 $ concatD [prt 0 lit, doc $ showString "\n"] + ELit lit -> prPrec i 3 $ concatD [prt 0 lit, doc $ showString "\n"] ELet bs e -> prPrec i 3 $ concatD @@ -140,34 +141,31 @@ instance Print Exp where , prt 0 e , doc $ showString "\n" ] - EApp _ e1 e2 -> + EApp e1 e2 -> prPrec i 2 $ concatD [ prt 2 e1 , prt 3 e2 ] - EAdd t e1 e2 -> + EAdd e1 e2 -> prPrec i 1 $ concatD [ doc $ showString "@" - , prt 0 t , prt 1 e1 , doc $ showString "+" , prt 2 e2 , doc $ showString "\n" ] - EAbs t n e -> + EAbs n e -> prPrec i 0 $ concatD [ doc $ showString "@" - , prt 0 t - , doc $ showString "\\" , prtId 0 n , doc $ showString "." , prt 0 e , doc $ showString "\n" ] - ECase t exp injs -> + ECase exp injs -> prPrec i 0 @@ -179,7 +177,6 @@ instance Print Exp where , prt 0 injs , doc (showString "}") , doc (showString ":") - , prt 0 t , doc $ showString "\n" ] ) @@ -196,6 +193,9 @@ instance Print [Inj] where prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] +instance Print TVar where + prt i (MkTVar id) = prt i id + instance Print Type where prt i = \case TLit uident -> prPrec i 2 (concatD [prt 0 uident])