diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index d27ac24..fb0b8cb 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -1,9 +1,15 @@ {-# LANGUAGE LambdaCase #-} + module Auxiliary (module Auxiliary) where -import Control.Monad.Error.Class (liftEither) -import Control.Monad.Except (MonadError) -import Data.Either.Combinators (maybeToRight) -import TypeChecker.TypeCheckerIr (Type (TFun)) + +import Control.Monad.Error.Class (liftEither) +import Control.Monad.Except (MonadError) +import Data.Either.Combinators (maybeToRight) +import TypeChecker.TypeCheckerIr (Type (TFun)) +import Prelude hiding ((>>), (>>=)) + +(>>) a b = a ++ " " ++ b +(>>=) a f = f a snoc :: a -> [a] -> [a] snoc x xs = xs ++ [x] @@ -15,9 +21,8 @@ mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b]) mapAccumM f = go where go acc = \case - [] -> pure (acc, []) - x:xs -> do - (acc', x') <- f acc x - (acc'', xs') <- go acc' xs - pure (acc'', x':xs') - + [] -> pure (acc, []) + x : xs -> do + (acc', x') <- f acc x + (acc'', xs') <- go acc' xs + pure (acc'', x' : xs') diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 1254a87..e7dff50 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,10 +1,12 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary +import Auxiliary (maybeToRightM) +import Auxiliary qualified as Aux import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader @@ -28,14 +30,16 @@ import TypeChecker.TypeCheckerIr qualified as T initCtx = Ctx mempty initEnv = Env 0 'a' mempty mempty mempty -runPretty :: Exp -> Either Error String -runPretty = fmap (printTree . fst) . run . inferExp - run :: Infer a -> Either Error a run = runC initEnv initCtx runC :: Env -> Ctx -> Infer a -> Either Error a -runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e +runC e c = + runIdentity + . runExceptT + . flip runReaderT c + . flip evalStateT e + . runInfer typecheck :: Program -> Either Error (T.Program' Type) typecheck = run . checkPrg @@ -49,15 +53,15 @@ checkData d = do (throwError $ unwords ["Data type incorrectly declared"]) traverse_ ( \(Inj name' t') -> - if typ == retType t' - then insertConstr (coerce name') (t') + if typ == returnType t' + then insertConstr (coerce name') t' else throwError $ unwords [ "return type of constructor:" , printTree name' , "with type:" - , printTree (retType t') + , printTree (returnType t') , "does not match data: " , printTree typ ] @@ -69,9 +73,9 @@ checkData d = do <> printTree d <> "'" -retType :: Type -> Type -retType (TFun _ t2) = retType t2 -retType a = a +returnType :: Type -> Type +returnType (TFun _ t2) = returnType t2 +returnType a = a checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do @@ -92,7 +96,7 @@ preRun (x : xs) = case x of <> printTree n <> "'" ) - insertSig (coerce n) (Just $ t) >> preRun xs + insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do collect (collectTypeVars e) s <- gets sigs @@ -107,10 +111,11 @@ checkDef (x : xs) = case x of (DBind b) -> do b' <- checkBind b fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs) + (DData d) -> fmap (T.DData (coerceData d) :) (checkDef xs) (DSig _) -> checkDef xs where - coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs + coerceData (Data t injs) = + T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do @@ -145,11 +150,11 @@ typeEq t1 (TAll _ t2) = t1 `typeEq` t2 typeEq (TVar _) (TVar _) = True typeEq _ _ = False -skolem :: Type -> Type -skolem (TVar (T.MkTVar a)) = TLit (coerce a) -skolem (TAll x t) = TAll x (skolem t) -skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2 -skolem t = t +skolemize :: Type -> Type +skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a) +skolemize (TAll x t) = TAll x (skolemize t) +skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 +skolemize t = t isMoreSpecificOrEq :: Type -> Type -> Bool isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 @@ -204,10 +209,9 @@ algoW = \case , printTree t' ] ) - applySt s1 $ do - s2 <- exprErr (unify (t) t') err - let comp = s2 `compose` s1 - return (comp, apply comp (e', t)) + s2 <- exprErr (unify (t) t') err + let comp = s2 `compose` s1 + return (comp, apply comp (e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -262,16 +266,14 @@ algoW = \case err@(EAdd e0 e1) -> do (s1, (e0', t0)) <- algoW e0 - applySt s1 $ do - (s2, (e1', t1)) <- algoW e1 - -- applySt s2 $ do - s3 <- exprErr (unify (apply s2 t0) int) err - s4 <- exprErr (unify (apply s3 t1) int) err - let comp = s4 `compose` s3 `compose` s2 `compose` s1 - return - ( comp - , apply comp (T.EAdd (e0', t0) (e1', t1), int) - ) + (s2, (e1', t1)) <- algoW e1 + s3 <- exprErr (unify (apply s2 t0) int) err + s4 <- exprErr (unify (apply s3 t1) int) err + let comp = s4 `compose` s3 `compose` s2 `compose` s1 + return + ( comp + , apply comp (T.EAdd (e0', t0) (e1', t1), int) + ) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') @@ -281,12 +283,11 @@ algoW = \case err@(EApp e0 e1) -> do fr <- fresh (s0, (e0', t0)) <- algoW e0 - applySt s0 $ do - (s1, (e1', t1)) <- algoW e1 - s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err - let t = apply s2 fr - let comp = s2 `compose` s1 `compose` s0 - return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) + (s1, (e1', t1)) <- algoW e1 + s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err + let t = apply s2 fr + let comp = s2 `compose` s1 `compose` s0 + return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -346,22 +347,45 @@ unify t0 t1 = do then do xs <- zipWithM unify t t' return $ foldr compose nullSubst xs + else throwError $ + Aux.do + "Type constructor:" + printTree name + "(" + printTree t + ")" + "does not match with:" + printTree name' + "(" + printTree t' + ")" + + -- [ "Type constructor:" + -- , printTree name + -- , "(" <> printTree t <> ")" + -- , "does not match with:" + -- , printTree name' + -- , "(" <> printTree t' <> ")" + -- ] + (TEVar a, TEVar b) -> + if a == b + then return M.empty else - throwError $ - unwords - [ "Type constructor:" - , printTree name - , "(" <> printTree t <> ")" - , "does not match with:" - , printTree name' - , "(" <> printTree t' <> ")" - ] + throwError + . unwords + $ [ "Can not unify" + , "'" <> printTree (TEVar a) <> "'" + , "with" + , "'" <> printTree (TEVar b) <> "'" + ] (a, b) -> do - throwError . unwords $ - [ "'" <> printTree a <> "'" - , "can't be unified with" - , "'" <> printTree b <> "'" - ] + throwError + . unwords + $ [ "Can not unify" + , "'" <> printTree a <> "'" + , "with" + , "'" <> printTree b <> "'" + ] {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution @@ -415,7 +439,7 @@ composeAll = foldl' compose nullSubst -- TODO: Split this class into two separate classes, one for free variables -- and one for applying substitutions --- | A class representing free variables functions +-- | A class for substitutions class SubstType t where -- | Apply a substitution to t apply :: Subst -> t -> t @@ -430,9 +454,10 @@ instance FreeVars Type where free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t free (TLit _) = mempty free (TFun a b) = free a `S.union` free b - -- \| Not guaranteed to be correct - free (TData _ a) = - foldl' (\acc x -> free x `S.union` acc) S.empty a + free (TData _ a) = free a + +instance FreeVars a => FreeVars [a] where + free = let f acc x = acc `S.union` free x in foldl' f S.empty instance SubstType Type where apply :: Subst -> Type -> Type @@ -447,13 +472,14 @@ instance SubstType Type where Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (map (apply sub) a) + instance FreeVars (Map T.Ident Type) where free :: Map T.Ident Type -> Set T.Ident - free m = foldl' S.union S.empty (map free $ M.elems m) + free = free . M.elems instance SubstType (Map T.Ident Type) where apply :: Subst -> Map T.Ident Type -> Map T.Ident Type - apply s = M.map (apply s) + apply = M.map . apply instance SubstType (T.Exp' Type) where apply s = \case @@ -467,7 +493,7 @@ instance SubstType (T.Exp' Type) where T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) T.EAbs ident e -> T.EAbs ident (apply s e) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) - T.EInj{} -> error "implement" + T.EInj i -> T.EInj i instance SubstType (T.Branch' Type) where apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) @@ -489,10 +515,6 @@ instance (SubstType a, SubstType b) => SubstType (a, b) where instance SubstType (T.Id' Type) where apply s (name, t) = (name, apply s t) --- | Apply substitutions to the environment. -applySt :: Subst -> Infer a -> Infer a -applySt s = local (\st -> st{vars = apply s (vars st)}) - -- | Represents the empty substition set nullSubst :: Subst nullSubst = M.empty @@ -513,11 +535,11 @@ fresh = do else if n == 0 then return . TVar . T.MkTVar $ LIdent [c] - else return . TVar . T.MkTVar . LIdent $ [c] ++ show n - -next :: Char -> Char -next 'z' = 'a' -next a = succ a + else return . TVar . T.MkTVar . LIdent $ c : show n + where + next :: Char -> Char + next 'z' = 'a' + next a = succ a -- | Run the monadic action with an additional binding withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a @@ -673,4 +695,5 @@ data Env = Env type Error = String type Subst = Map T.Ident Type -type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) +newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} + deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)