diff --git a/cabal.project.local b/cabal.project.local deleted file mode 100644 index 0432756..0000000 --- a/cabal.project.local +++ /dev/null @@ -1,2 +0,0 @@ -ignore-project: False -tests: True diff --git a/cabal.project.local~ b/cabal.project.local~ deleted file mode 100644 index 40fdf41..0000000 --- a/cabal.project.local~ +++ /dev/null @@ -1,2 +0,0 @@ -ignore-project: False -tests: False diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 2b53760..9cc37ee 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -19,6 +19,7 @@ import Data.Map qualified as M import Data.Maybe (fromJust) import Data.Set (Set) import Data.Set qualified as S +import Debug.Trace (trace) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr ( @@ -31,8 +32,7 @@ import TypeChecker.TypeCheckerIr ( import TypeChecker.TypeCheckerIr qualified as T initCtx = Ctx mempty - -initEnv = Env 0 mempty mempty +initEnv = Env 0 'a' mempty mempty mempty runPretty :: Exp -> Either Error String runPretty = fmap (printTree . fst) . run . inferExp @@ -82,39 +82,39 @@ retType a = a checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do preRun bs - -- Type check the program twice to produce all top-level types in the first pass through - _ <- checkDef bs - bs'' <- checkDef bs - return $ T.Program bs'' - where - preRun :: [Def] -> Infer () - preRun [] = return () - preRun (x : xs) = case x of - DSig (Sig n t) -> do - gets (M.member (coerce n) . sigs) - >>= flip - when - ( throwError $ - "Duplicate signatures for function '" - <> printTree n - <> "'" - ) - insertSig (coerce n) (Just $ toNew t) >> preRun xs - DBind (Bind n _ _) -> do - s <- gets sigs - case M.lookup (coerce n) s of - Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs - DData d@(Data _ _) -> checkData d >> preRun xs + bs' <- checkDef bs + return $ T.Program bs' - checkDef :: [Def] -> Infer [T.Def] - checkDef [] = return [] - checkDef (x : xs) = case x of - (DBind b) -> do - b' <- checkBind b - fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) - (DSig _) -> checkDef xs +preRun :: [Def] -> Infer () +preRun [] = return () +preRun (x : xs) = case x of + DSig (Sig n t) -> do + collect (collectTypeVars t) + gets (M.member (coerce n) . sigs) + >>= flip + when + ( throwError $ + "Duplicate signatures for function '" + <> printTree n + <> "'" + ) + insertSig (coerce n) (Just $ toNew t) >> preRun xs + DBind (Bind n _ e) -> do + collect (collectTypeVars e) + s <- gets sigs + case M.lookup (coerce n) s of + Nothing -> insertSig (coerce n) Nothing >> preRun xs + Just _ -> preRun xs + DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs + +checkDef :: [Def] -> Infer [T.Def] +checkDef [] = return [] +checkDef (x : xs) = case x of + (DBind b) -> do + b' <- checkBind b + fmap (T.DBind b' :) (checkDef xs) + (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) + (DSig _) -> checkDef xs checkBind :: Bind -> Infer T.Bind checkBind (Bind name args e) = do @@ -171,6 +171,23 @@ inferExp e = do let subbed = apply s t return $ second (const subbed) (e', t) +class CollectTVars a where + collectTypeVars :: a -> Set T.Ident + +instance CollectTVars Exp where + collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e + collectTypeVars _ = S.empty + +instance CollectTVars Type where + collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) + collectTypeVars (TAll _ t) = collectTypeVars t + collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2 + collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts + collectTypeVars _ = S.empty + +collect :: Set T.Ident -> Infer () +collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) + class NewType a b where toNew :: a -> b @@ -321,8 +338,9 @@ algoW = \case (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub - let t' = apply comp ret_t - return (comp, apply comp (T.ECase (e', t) injs, t')) + trace ("EXPR: " ++ show (apply comp t)) pure () + trace ("CASES: " ++ show (apply comp ret_t)) pure () + return (comp, apply comp (T.ECase (e', t) injs, ret_t)) makeLambda :: Exp -> [T.Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) @@ -335,7 +353,7 @@ unify t0 t1 = do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 - ----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- + ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t ------------------------------------------------------------------- @@ -517,9 +535,24 @@ nullSubst = M.empty -- | Generate a new fresh variable and increment the state counter fresh :: Infer T.Type fresh = do + c <- gets nextChar n <- gets count - modify (\st -> st{count = n + 1}) - return . T.TVar . T.MkTVar . T.Ident $ show n + taken <- gets takenTypeVars + if c == 'z' + then do + modify (\st -> st{count = succ (count st), nextChar = 'a'}) + else modify (\st -> st{nextChar = next (nextChar st)}) + if coerce [c] `S.member` taken + then do + fresh + else + if n == 0 + then return . T.TVar . T.MkTVar . T.Ident $ [c] + else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n + +next :: Char -> Char +next 'z' = 'a' +next a = succ a -- | Run the monadic action with an additional binding withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a @@ -567,7 +600,7 @@ inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type) inferBranch (Branch pat expr) = do newPat@(pat, branchT) <- inferPattern pat (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) - return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) + return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) withPattern :: T.Pattern -> Infer a -> Infer a withPattern p ma = case p of @@ -586,15 +619,36 @@ inferPattern = \case let numArgs = typeLength t - 1 let (vs, ret) = fromJust (unsnoc $ flattenType t) patterns <- mapM inferPattern patterns - unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns)) + unless + (length patterns == numArgs) + ( throwError $ + "The constructor '" + ++ printTree constr + ++ "'" + ++ " should have " + ++ show numArgs + ++ " arguments but has been given " + ++ show (length patterns) + ) sub <- composeAll <$> zipWithM unify vs (map snd patterns) return (T.PInj (coerce constr) (map fst patterns), apply sub ret) PCatch -> (T.PCatch,) <$> fresh PEnum p -> do t <- gets (M.lookup (coerce p) . constructors) t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t - unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0") - return (T.PEnum $ coerce p, t) + unless + (typeLength t == 1) + ( throwError $ + "The constructor '" + ++ printTree p + ++ "'" + ++ " should have " + ++ show (typeLength t - 1) + ++ " arguments but has been given 0" + ) + let (T.TData _data _ts) = t -- nasty nasty + frs <- mapM (const fresh) _ts + return (T.PEnum $ coerce p, T.TData _data frs) PVar x -> do fr <- fresh let pvar = T.PVar (coerce x, fr) @@ -632,4 +686,9 @@ exprErr ma exp = catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) -unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], []) +unzip4 = + foldl' + ( \(as, bs, cs, ds) (a, b, c, d) -> + (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d]) + ) + ([], [], [], []) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 692fec8..f2419d5 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -10,6 +10,7 @@ import Control.Monad.State import Data.Char (isDigit) import Data.Functor.Identity (Identity) import Data.Map (Map) +import Data.Set (Set) import Data.String qualified import Grammar.Print import Prelude @@ -20,8 +21,10 @@ newtype Ctx = Ctx {vars :: Map Ident Type} data Env = Env { count :: Int + , nextChar :: Char , sigs :: Map Ident (Maybe Type) , constructors :: Map Ident Type + , takenTypeVars :: Set Ident } deriving (Show) diff --git a/test_program b/test_program index b43a99a..ac209ea 100644 --- a/test_program +++ b/test_program @@ -1,9 +1,28 @@ -data Bool () where { - True : Bool () - False : Bool () - }; +data Maybe (a) where { + Nothing : Maybe (a) + Just : a -> Maybe (a) + }; -main = case True of { - True => 1; - False => 0; - }; +fmap : (a -> b) -> Maybe (a) -> Maybe (b) ; +fmap f ma = case ma of { + Nothing => Nothing ; + Just a => Just (f a) ; +}; + +pure : a -> Maybe (a) ; +pure x = Just x ; + +ap mf ma = case mf of { + Just f => case ma of { + Nothing => Nothing; + Just a => Just (f a); + }; + Nothing => Nothing; +}; + +return = pure; + +bind ma f = case ma of { + Nothing => Nothing ; + Just a => f a ; +}; diff --git a/tests/Tests.hs b/tests/Tests.hs index d1b87a6..eb28db8 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -16,149 +16,153 @@ main :: IO () main = do mapM_ hspec goods mapM_ hspec bads + mapM_ hspec bes goods = - [ specify "Basic polymorphism with multiple type variables" $ - run - ( D.do - _const - "main = const 'a' 65 ;" - ) - `shouldSatisfy` ok - , specify "Head with a correct signature is accepted" $ - run - ( D.do - _list - _headSig - _head - ) - `shouldSatisfy` ok - , specify "A basic arithmetic function should be able to be inferred" $ - run - ( D.do - "plusOne x = x + 1 ;" - "main x = plusOne x ;" - ) - `shouldBe` run - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - "main : Int -> Int ;" - "main x = plusOne x ;" - ) - , specify "A basic arithmetic function should be able to be inferred" $ - run - ( D.do - "plusOne x = x + 1 ;" - ) - `shouldBe` run - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - ) - , specify "Most simple inference possible" $ - run - ( D.do - _id - ) - `shouldSatisfy` ok - , specify "Pattern matching on a nested list" $ - run - ( D.do - _list - "main : List (List (a)) -> Int ;" - "main xs = case xs of {" - " Cons Nil _ => 1 ;" - " _ => 0 ;" - "};" - ) - `shouldSatisfy` ok - , specify "List of function Int -> Int functions should be inferred corretly" $ - run - ( D.do - _list - "main xs = case xs of {" - " Cons f _ => f 1 ;" - " Nil => 0 ;" - " };" - ) - `shouldBe` run - ( D.do - _list - "main : List (Int -> Int) -> Int ;" - "main xs = case xs of {" - " Cons f _ => f 1 ;" - " Nil => 0 ;" - " };" - ) + [ testSatisfy + "Basic polymorphism with multiple type variables" + ( D.do + _const + "main = const 'a' 65 ;" + ) + ok + , testSatisfy + "Head with a correct signature is accepted" + ( D.do + _List + _headSig + _head + ) + ok + , testSatisfy + "Most simple inference possible" + ( D.do + _id + ) + ok + , testSatisfy + "Pattern matching on a nested list" + ( D.do + _List + "main : List (List (a)) -> Int ;" + "main xs = case xs of {" + " Cons Nil _ => 1 ;" + " _ => 0 ;" + "};" + ) + ok ] bads = - [ specify "Infinite type unification should not succeed" $ - run - ( D.do - "main = \\x. x x ;" - ) - `shouldSatisfy` bad - , specify "Pattern matching using different types should not succeed" $ - run - ( D.do - _list - "bad xs = case xs of {" - " 1 => 0 ;" - " Nil => 0 ;" - "};" - ) - `shouldSatisfy` bad - , specify "Using a concrete function (data type) on a skolem variable should not succeed" $ - run - ( D.do - _bool - _not - "f : a -> Bool () ;" - "f x = not x ;" - ) - `shouldSatisfy` bad - , specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $ - run - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - "f : a -> Int ;" - " f x = plusOne x ;" - ) - `shouldSatisfy` bad - , specify "A function without signature used in an incompatible context should not succeed" $ - run - ( D.do - "main = _id 1 2 ;" - "_id x = x ;" - ) - `shouldSatisfy` bad - , specify "Pattern matching on literal and _list should not succeed" $ - run - ( D.do - _list - "length : List (c) -> Int;" - "length _list = case _list of {" - " 0 => 0;" - " Cons x xs => 1 + length xs;" - "};" - ) - `shouldSatisfy` bad - , specify "List of function Int -> Int functions should not be usable on Char" $ - run - ( D.do - _list - "main : List (Int -> Int) -> Int ;" - "main xs = case xs of {" - " Cons f _ => f 'a' ;" - " Nil => 0 ;" - " };" - ) - `shouldSatisfy` bad + [ testSatisfy + "Infinite type unification should not succeed" + ( D.do + "main = \\x. x x ;" + ) + bad + , testSatisfy + "Pattern matching using different types should not succeed" + ( D.do + _List + "bad xs = case xs of {" + " 1 => 0 ;" + " Nil => 0 ;" + "};" + ) + bad + , testSatisfy + "Using a concrete function (data type) on a skolem variable should not succeed" + ( D.do + _Bool + _not + "f : a -> Bool () ;" + "f x = not x ;" + ) + bad + , testSatisfy + "Using a concrete function (primitive type) on a skolem variable should not succeed" + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "f : a -> Int ;" + "f x = plusOne x ;" + ) + bad + , testSatisfy + "A function without signature used in an incompatible context should not succeed" + ( D.do + "main = _id 1 2 ;" + "_id x = x ;" + ) + bad + , testSatisfy + "Pattern matching on literal and _List should not succeed" + ( D.do + _List + "length : List (c) -> Int;" + "length _List = case _List of {" + " 0 => 0;" + " Cons x xs => 1 + length xs;" + "};" + ) + bad + , testSatisfy + "List of function Int -> Int functions should not be usable on Char" + ( D.do + _List + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 'a' ;" + " Nil => 0 ;" + " };" + ) + bad ] +bes = + [ testBe + "A basic arithmetic function should be able to be inferred" + ( D.do + "plusOne x = x + 1 ;" + "main x = plusOne x ;" + ) + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "main : Int -> Int ;" + "main x = plusOne x ;" + ) + , testBe + "A basic arithmetic function should be able to be inferred" + ( D.do + "plusOne x = x + 1 ;" + ) + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + ) + , testBe + "List of function Int -> Int functions should be inferred corretly" + ( D.do + _List + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) + ( D.do + _List + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) + ] + +testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction +testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe + run = typecheck <=< pProgram . myLexer ok (Right _) = True @@ -171,7 +175,7 @@ bad = not . ok _const = D.do "const : a -> b -> a ;" "const x y = x ;" -_list = D.do +_List = D.do "data List (a) where" " {" " Nil : List (a)" @@ -187,7 +191,7 @@ _head = D.do " Cons x xs => x ;" " };" -_bool = D.do +_Bool = D.do "data Bool () where {" " True : Bool ()" " False : Bool ()" @@ -200,3 +204,15 @@ _not = D.do " False => True ;" "};" _id = "id x = x ;" + +_Maybe = D.do + "data Maybe (a) where {" + " Nothing : Maybe (a)" + " Just : a -> Maybe (a)" + " };" + +_fmap = D.do + "fmap f ma = case ma of {" + " Nothing => Nothing ;" + " Just a => Just (f a) ;" + "};"