From 4b24755b9324fd6bed5b22ce300c66bc72ab15e1 Mon Sep 17 00:00:00 2001 From: sebastian Date: Mon, 27 Mar 2023 22:38:39 +0200 Subject: [PATCH] cleaned up implementations and added check for duplicate constructors --- src/TypeChecker/TypeCheckerHm.hs | 233 ++++++++++++++++--------------- 1 file changed, 123 insertions(+), 110 deletions(-) diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index e7dff50..026810f 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -13,7 +13,6 @@ import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (second) import Data.Coerce (coerce) -import Data.Foldable (traverse_) import Data.Function (on) import Data.List (foldl') import Data.List.Extra (unsnoc) @@ -22,7 +21,6 @@ import Data.Map qualified as M import Data.Maybe (fromJust) import Data.Set (Set) import Data.Set qualified as S -import Data.String import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T @@ -45,33 +43,61 @@ typecheck :: Program -> Either Error (T.Program' Type) typecheck = run . checkPrg checkData :: Data -> Infer () -checkData d = do - case d of - (Data typ@(TData _ ts) constrs) -> do - unless - (all isPoly ts) - (throwError $ unwords ["Data type incorrectly declared"]) - traverse_ - ( \(Inj name' t') -> - if typ == returnType t' - then insertConstr (coerce name') t' - else - throwError $ - unwords - [ "return type of constructor:" - , printTree name' - , "with type:" - , printTree (returnType t') - , "does not match data: " - , printTree typ - ] - ) - constrs - _ -> - throwError $ - "incorrectly declared data type '" - <> printTree d - <> "'" +checkData (Data typ injs) = do + (name, tvars) <- go typ + mapM_ (\i -> typecheckInj i name tvars) injs + where + go = \case + TData name typs + | Right tvars' <- mapM toTVar typs -> + pure (name, tvars') + TAll _ _ -> throwError "Explicit foralls not allowed, for now" + _ -> throwError $ unwords ["Bad data type definition: ", printTree typ] + +typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () +typecheckInj (Inj c inj_typ) name tvars + | Right False <- boundTVars tvars inj_typ = + throwError "Unbound type variables" + | TData name' typs <- returnType inj_typ + , Right tvars' <- mapM toTVar typs + , name' == name + , tvars' == tvars = do + exist <- existInj (coerce c) + case exist of + Just t -> throwError $ Aux.do + "Constructor" + quote $ coerce name + "with type" + quote $ printTree t + "already exist" + Nothing -> insertInj (coerce c) inj_typ + | otherwise = + throwError $ + unwords + [ "Bad type constructor: " + , show name + , "\nExpected: " + , printTree . TData name $ map TVar tvars + , "\nActual: " + , printTree $ returnType inj_typ + ] + where + boundTVars :: [TVar] -> Type -> Either Error Bool + boundTVars tvars' = \case + TAll{} -> throwError "Explicit foralls not allowed, for now" + TFun t1 t2 -> do + t1' <- boundTVars tvars t1 + t2' <- boundTVars tvars t2 + return $ t1' && t2' + TVar tvar -> return $ tvar `elem` tvars' + TData _ typs -> and <$> mapM (boundTVars tvars) typs + TLit _ -> return True + TEVar _ -> error "TEVar in data type declaration" + +toTVar :: Type -> Either String TVar +toTVar = \case + TVar tvar -> pure tvar + _ -> throwError "Not a type variable" returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 @@ -91,10 +117,9 @@ preRun (x : xs) = case x of gets (M.member (coerce n) . sigs) >>= flip when - ( throwError $ - "Duplicate signatures for function '" - <> printTree n - <> "'" + ( throwError $ Aux.do + "Duplicate signatures for function" + quote $ printTree n ) insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do @@ -126,12 +151,11 @@ checkBind (Bind name args e) = do Just (Just t') -> do unless (args_t `typeEq` t') - ( throwError $ - "Inferred type '" - ++ printTree args_t - ++ " does not match specified type '" - ++ printTree t' - ++ "'" + ( throwError $ Aux.do + "Inferred type" + quote $ printTree args_t + "does not match specified type" + quote $ printTree t' ) return $ T.Bind (coerce name, t') [] e _ -> do @@ -195,7 +219,7 @@ instance CollectTVars Type where collect :: Set T.Ident -> Infer () collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) -algoW :: Exp -> Infer (Subst, (T.ExpT' Type)) +algoW :: Exp -> Infer (Subst, T.ExpT' Type) algoW = \case err@(EAnn e t) -> do (s1, (e', t')) <- exprErr (algoW e) err @@ -209,14 +233,14 @@ algoW = \case , printTree t' ] ) - s2 <- exprErr (unify (t) t') err + s2 <- exprErr (unify t t') err let comp = s2 `compose` s1 return (comp, apply comp (e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ - ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit)) + ELit lit -> return (nullSubst, (T.ELit lit, litType lit)) -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ @@ -234,7 +258,7 @@ algoW = \case return (nullSubst, (T.EVar $ coerce i, fr)) Nothing -> throwError $ "Unbound variable: " <> printTree i EInj i -> do - constr <- gets constructors + constr <- gets injections case M.lookup (coerce i) constr of Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Nothing -> @@ -334,14 +358,12 @@ unify t0 t1 = do (TLit a, TLit b) -> if a == b then return M.empty - else - throwError - . unwords - $ [ "Can not unify" - , "'" <> printTree (TLit a) <> "'" - , "with" - , "'" <> printTree (TLit b) <> "'" - ] + else throwError $ + Aux.do + "Can not unify" + quote $ printTree (TLit a) + "with" + quote $ printTree (TLit b) (TData name t, TData name' t') -> if name == name' && length t == length t' then do @@ -351,41 +373,26 @@ unify t0 t1 = do Aux.do "Type constructor:" printTree name - "(" - printTree t - ")" + quote $ printTree t "does not match with:" printTree name' - "(" - printTree t' - ")" - - -- [ "Type constructor:" - -- , printTree name - -- , "(" <> printTree t <> ")" - -- , "does not match with:" - -- , printTree name' - -- , "(" <> printTree t' <> ")" - -- ] + quote $ printTree t' (TEVar a, TEVar b) -> if a == b then return M.empty - else - throwError - . unwords - $ [ "Can not unify" - , "'" <> printTree (TEVar a) <> "'" - , "with" - , "'" <> printTree (TEVar b) <> "'" - ] + else throwError $ + Aux.do + "Can not unify" + quote $ printTree (TEVar a) + "with" + quote $ printTree (TEVar b) (a, b) -> do - throwError - . unwords - $ [ "Can not unify" - , "'" <> printTree a <> "'" - , "with" - , "'" <> printTree b <> "'" - ] + throwError $ + Aux.do + "Can not unify" + quote $ printTree a + "with" + quote $ printTree b {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution @@ -395,14 +402,12 @@ occurs :: T.Ident -> Type -> Infer Subst occurs i t@(TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) - then - throwError $ - unwords - [ "Occurs check failed, can't unify" - , printTree (TVar $ T.MkTVar (coerce i)) - , "with" - , printTree t - ] + then throwError $ + Aux.do + "Occurs check failed, can't unify" + quote $ printTree (TVar $ T.MkTVar (coerce i)) + "with" + quote $ printTree t else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set @@ -455,6 +460,7 @@ instance FreeVars Type where free (TLit _) = mempty free (TFun a b) = free a `S.union` free b free (TData _ a) = free a + free (TEVar _) = S.empty instance FreeVars a => FreeVars [a] where free = let f acc x = acc `S.union` free x in foldl' f S.empty @@ -471,7 +477,10 @@ instance SubstType Type where Nothing -> TAll (T.MkTVar i) (apply sub t) Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) - TData name a -> TData name (map (apply sub) a) + TData name a -> TData name (apply sub a) + TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of + Nothing -> TEVar (MkTEVar a) + Just t -> t instance FreeVars (Map T.Ident Type) where free :: Map T.Ident Type -> Set T.Ident @@ -555,9 +564,12 @@ insertSig :: T.Ident -> Maybe Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type -insertConstr :: T.Ident -> Type -> Infer () -insertConstr i t = - modify (\st -> st{constructors = M.insert i t (constructors st)}) +insertInj :: T.Ident -> Type -> Infer () +insertInj i t = + modify (\st -> st{injections = M.insert i t (injections st)}) + +existInj :: T.Ident -> Infer (Maybe Type) +existInj n = gets (M.lookup n . injections) -------- PATTERN MATCHING --------- @@ -601,37 +613,35 @@ inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) inferPattern = \case PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) PInj constr patterns -> do - t <- gets (M.lookup (coerce constr) . constructors) + t <- gets (M.lookup (coerce constr) . injections) t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t let numArgs = typeLength t - 1 let (vs, ret) = fromJust (unsnoc $ flattenType t) patterns <- mapM inferPattern patterns unless (length patterns == numArgs) - ( throwError $ - "The constructor '" - ++ printTree constr - ++ "'" - ++ " should have " - ++ show numArgs - ++ " arguments but has been given " - ++ show (length patterns) + ( throwError $ Aux.do + "The constructor" + quote $ printTree constr + " should have " + show numArgs + " arguments but has been given " + show (length patterns) ) sub <- composeAll <$> zipWithM unify vs (map snd patterns) return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret) PCatch -> (T.PCatch,) <$> fresh PEnum p -> do - t <- gets (M.lookup (coerce p) . constructors) + t <- gets (M.lookup (coerce p) . injections) t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t unless (typeLength t == 1) - ( throwError $ - "The constructor '" - ++ printTree p - ++ "'" - ++ " should have " - ++ show (typeLength t - 1) - ++ " arguments but has been given 0" + ( throwError $ Aux.do + "The constructor" + quote $ printTree p + " should have " + show (typeLength t - 1) + " arguments but has been given 0" ) let (TData _data _ts) = t -- nasty nasty frs <- mapM (const fresh) _ts @@ -687,7 +697,7 @@ data Env = Env { count :: Int , nextChar :: Char , sigs :: Map T.Ident (Maybe Type) - , constructors :: Map T.Ident Type + , injections :: Map T.Ident Type , takenTypeVars :: Set T.Ident } deriving (Show) @@ -697,3 +707,6 @@ type Subst = Map T.Ident Type newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) + +quote :: String -> String +quote s = "'" ++ s ++ "'"