diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index a371977..f17e589 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs - -- sgs <- gets sigs bs <- checkDef bs - -- return . prettify sgs . T.Program $ bs return . T.Program $ bs -- | Send the map of user declared signatures to not rename stuff the user defined -prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type -prettify s (T.Program defs) = T.Program $ map (go s) defs - where - go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type - go _ (T.DData d) = T.DData d - go m b@(T.DBind (T.Bind (name, t) args (e, et))) - | Just (Just _) <- M.lookup name m = b - | otherwise = - let fvs = nub $ freeOrdered t - m = M.fromList $ zip fvs letters - in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et) - -replace :: Map T.Ident T.Ident -> Type -> Type -replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of - Just t -> TVar . MkTVar . LIdent $ coerce t - Nothing -> def -replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 -replace m (TData name ts) = TData name (map (replace m) ts) -replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of - Just found -> TAll (MkTVar $ coerce found) (replace m t) - Nothing -> def -replace _ t = t - preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of @@ -484,80 +459,74 @@ inferPattern = \case -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst -unify t0 t1 = - let fvs = S.toList $ free t0 `S.union` free t1 - m = M.fromList $ zip fvs letters - in case (t0, t1) of - (TFun a b, TFun c d) -> do - s1 <- unify a c - s2 <- unify (apply s1 b) (apply s1 d) - return $ s2 `compose` s1 - (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t - (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t - (TVar (MkTVar a), t) -> occurs (coerce a) t - (t, TVar (MkTVar b)) -> occurs (coerce b) t - -- Forall unification should change - (TAll _ t, b) -> unify t b - (a, TAll _ t) -> unify a t - (TLit a, TLit b) -> - if a == b - then return M.empty - else catchableErr $ - Aux.do - "Can not unify" - quote $ printTree (TLit a) - "with" - quote $ printTree (TLit b) - (TData name t, TData name' t') -> - if name == name' && length t == length t' - then do - xs <- zipWithM unify t t' - return $ foldr compose nullSubst xs - else catchableErr $ - Aux.do - "Type constructor:" - printTree name - quote $ printTree $ map (replace m) t - "does not match with:" - printTree name' - quote $ printTree $ map (replace m) t' - (TEVar a, TEVar b) -> - if a == b - then return M.empty - else catchableErr $ - Aux.do - "Can not unify" - quote $ printTree (TEVar a) - "with" - quote $ printTree (TEVar b) - (a, b) -> do - catchableErr $ - Aux.do - "Can not unify" - quote $ printTree $ replace m a - "with" - quote $ printTree $ replace m b +unify t0 t1 = case (t0, t1) of + (TFun a b, TFun c d) -> do + s1 <- unify a c + s2 <- unify (apply s1 b) (apply s1 d) + return $ s2 `compose` s1 + (TVar a, t@(TData _ _)) -> return $ singleton a t + (t@(TData _ _), TVar b) -> return $ singleton b t + (TVar a, t) -> occurs a t + (t, TVar b) -> occurs b t + -- Forall unification should change + (TAll _ t, b) -> unify t b + (a, TAll _ t) -> unify a t + (TLit a, TLit b) -> + if a == b + then return nullSubst + else catchableErr $ + Aux.do + "Can not unify" + quote $ printTree (TLit a) + "with" + quote $ printTree (TLit b) + (TData name t, TData name' t') -> + if name == name' && length t == length t' + then do + xs <- zipWithM unify t t' + return $ foldr compose nullSubst xs + else catchableErr $ + Aux.do + "Type constructor:" + printTree name + quote $ printTree t + "does not match with:" + printTree name' + quote $ printTree t' + (TEVar a, TEVar b) -> + if a == b + then return nullSubst + else catchableErr $ + Aux.do + "Can not unify" + quote $ printTree (TEVar a) + "with" + quote $ printTree (TEVar b) + (a, b) -> do + catchableErr $ + Aux.do + "Can not unify" + quote $ printTree a + "with" + quote $ printTree b {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} -occurs :: T.Ident -> Type -> Infer Subst -occurs i t@(TEVar _) = return (M.singleton i t) -occurs i t@(TVar _) = return (M.singleton i t) -occurs i t = - let fvs = S.toList $ free t - m = M.fromList $ zip fvs letters - in if S.member i (free t) - then - catchableErr - ( Aux.do - "Occurs check failed, can't unify" - quote $ printTree $ replace m (TVar $ MkTVar (coerce i)) - "with" - quote $ printTree $ replace m t - ) - else return $ M.singleton i t +occurs :: TVar -> Type -> Infer Subst +occurs i t@(TEVar _) = return (singleton i t) +occurs i t@(TVar _) = return (singleton i t) +occurs i t + | S.member i (free t) = + catchableErr + ( Aux.do + "Occurs check failed, can't unify" + quote $ printTree (TVar i) + "with" + quote $ printTree t + ) + | otherwise = return $ singleton i t {- | Generalize a type over all free variables in the substitution set Used for let bindings to allow expression that do not type check in @@ -568,9 +537,9 @@ occurs i t = generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where - go :: [T.Ident] -> Type -> Type + go :: [TVar] -> Type -> Type go [] t = t - go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) + go (x : xs) t = TAll x (go xs t) removeForalls :: Type -> Type removeForalls (TAll _ t) = removeForalls t removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) @@ -581,9 +550,9 @@ with fresh ones. -} inst :: Type -> Infer Type inst = \case - TAll (MkTVar bound) t -> do + TAll bound t -> do fr <- fresh - let s = M.singleton (coerce bound) fr + let s = singleton bound fr apply s <$> inst t TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest @@ -598,7 +567,7 @@ freshen :: Type -> Infer Type freshen t = do let frees = S.toList (free t) xs <- mapM (const fresh) frees - let sub = M.fromList $ zip frees xs + let sub = Subst . M.fromList $ zip frees xs return $ apply sub t -- | Generate a new fresh variable @@ -633,10 +602,10 @@ fresh = do go tvars t1 t2 = do -- probably not necessary freshies <- mapM (const fresh) tvars - let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies + let sub = Subst . M.fromList $ zip tvars freshies let t1' = apply sub t1 let t2' = apply sub t2 - let alph = execState (alpha t1' t2') mempty + let alph = Subst $ execState (alpha t1' t2') mempty return $ apply alph t1' == t2' -- Pre-condition: All TAlls are outermost @@ -646,11 +615,11 @@ fresh = do -- Alpha rename the first type's type variable to match second. -- Pre-condition: No TAll are checked - alpha :: Type -> Type -> State (Map T.Ident Type) () + alpha :: Type -> Type -> State (Map TVar Type) () alpha t1 t2 = case (t1, t2) of - (TVar (MkTVar (LIdent i)), t2) -> do + (TVar i, t2) -> do m <- get - put (M.insert (coerce i) t2 m) + put (M.insert i t2 m) (TFun t1 t2, TFun t3 t4) -> do alpha t1 t3 alpha t2 t4 @@ -664,16 +633,16 @@ class SubstType t where class FreeVars t where -- | Get all free variables from t - free :: t -> Set T.Ident + free :: t -> Set TVar instance FreeVars (T.Bind' Type) where free (T.Bind (_, t) _ _) = free t instance FreeVars Type where - free :: Type -> Set T.Ident - free (TVar (MkTVar a)) = S.singleton (coerce a) - free (TAll (MkTVar bound) t) = - S.singleton (coerce bound) `S.intersection` free t + free :: Type -> Set TVar + free (TVar a) = S.singleton a + free (TAll bound t) = + S.singleton bound `S.intersection` free t free (TLit _) = mempty free (TFun a b) = free a `S.union` free b free (TData _ a) = free a @@ -687,22 +656,26 @@ instance SubstType Type where apply sub t = do case t of TLit _ -> t - TVar (MkTVar a) -> case M.lookup (coerce a) sub of - Nothing -> TVar (MkTVar $ coerce a) + TVar a -> case find a sub of + Nothing -> TVar a Just t -> t - TAll (MkTVar i) t -> case M.lookup (coerce i) sub of - Nothing -> TAll (MkTVar i) (apply sub t) + TAll i t -> case find i sub of + Nothing -> TAll i (apply sub t) Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) - TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of + TEVar (MkTEVar a) -> case find (MkTVar a) sub of Nothing -> TEVar (MkTEVar $ coerce a) Just t -> t instance FreeVars (Map T.Ident Type) where - free :: Map T.Ident Type -> Set T.Ident + free :: Map T.Ident Type -> Set TVar free = free . M.elems +instance SubstType (Map TVar Type) where + apply :: Subst -> Map TVar Type -> Map TVar Type + apply = M.map . apply + instance SubstType (Map T.Ident Type) where apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply = M.map . apply @@ -759,7 +732,7 @@ nullSubst = mempty -- | Compose two substitution sets compose :: Subst -> Subst -> Subst -compose m1 m2 = M.map (apply m1) m2 `M.union` m1 +compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1' -- | Compose a list of substitution sets into one composeAll :: [Subst] -> Subst @@ -910,7 +883,21 @@ data Env = Env data Error = Error {msg :: String, catchable :: Bool} deriving (Show) -type Subst = Map T.Ident Type + +newtype Subst = Subst {unSubst :: Map TVar Type} + deriving (Eq, Ord, Show) + +singleton :: TVar -> Type -> Subst +singleton a b = Subst (M.singleton a b) + +find :: TVar -> Subst -> Maybe Type +find tvar (Subst s) = M.lookup tvar s + +instance Semigroup Subst where + (<>) = compose + +instance Monoid Subst where + mempty = Subst mempty newtype Warning = NonExhaustive String deriving (Show) diff --git a/tests/TestRenamer.hs b/tests/TestRenamer.hs index dc71d38..d56781e 100644 --- a/tests/TestRenamer.hs +++ b/tests/TestRenamer.hs @@ -76,14 +76,18 @@ rn_sig = rn_bind1 = specify "Rename simple bind" $ shouldSatisfyOk - "f x = (\\y. let y2 = y + 1 in y2) (x + 1)" + ( unlines + [ ".+ x y = x" + , "f x = (\\y. let y2 = y + 1 in y2) (x + 1)" + ] + ) rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $ D.do "data forall a. List a where" " Nil : List a " " Cons : a -> List a -> List a" - + ".+ x y = x" "length : forall a. List a -> Int" "length list = case list of" " Nil => 0" diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs index 00d6472..44e4745 100644 --- a/tests/TestTypeCheckerBidir.hs +++ b/tests/TestTypeCheckerBidir.hs @@ -1,28 +1,28 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PatternSynonyms #-} {-# HLINT ignore "Use camelCase" #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} module TestTypeCheckerBidir (test, testTypeCheckerBidir) where -import Test.Hspec +import Test.Hspec -import AnnForall (annotateForall) -import Control.Monad ((<=<)) -import Desugar.Desugar (desugar) -import Grammar.Abs (Program) -import Grammar.ErrM (Err, pattern Bad, pattern Ok) -import Grammar.Layout (resolveLayout) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) -import Renamer.Renamer (rename) -import ReportForall (reportForall) -import TypeChecker.RemoveForall (removeForall) -import TypeChecker.ReportTEVar (reportTEVar) -import TypeChecker.TypeChecker (TypeChecker (Bi)) -import TypeChecker.TypeCheckerBidir (typecheck) -import qualified TypeChecker.TypeCheckerIr as T +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Desugar.Desugar (desugar) +import Grammar.Abs (Program) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import TypeChecker.RemoveForall (removeForall) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeChecker (TypeChecker (Bi)) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr qualified as T test = hspec testTypeCheckerBidir @@ -54,13 +54,19 @@ tc_id = tc_double = specify "Addition inference" $ run - ["double x = x + x"] + [ ".+ : Int -> Int -> Int" + , ".+ x y = x" + , "double x = x + x" + ] `shouldSatisfy` ok tc_add_lam = specify "Addition lambda inference" $ run - ["four = (\\x. x + x) 2"] + [ ".+ : Int -> Int -> Int" + , ".+ x y = x" + , "four = (\\x. x + x) 2" + ] `shouldSatisfy` ok tc_const = @@ -88,6 +94,8 @@ tc_rank2 = run [ "const : a -> b -> a" , "const x y = x" + , ".+ : Int -> Int -> Int" + , ".+ x y = x" , "rank2 : a -> (forall c. c -> Int) -> b -> Int" , "rank2 x f y = f x + f y" , "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'" @@ -195,18 +203,21 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do -- run (fs ++ correct1) `shouldSatisfy` ok specify "Second correct case expression accepted" $ run (fs ++ correct2) `shouldSatisfy` ok + where -- specify "Third correct case expression accepted" $ -- run (fs ++ correct3) `shouldSatisfy` ok -- specify "Forth correct case expression accepted" $ -- run (fs ++ correct4) `shouldSatisfy` ok - where + fs = [ "data List a where" , " Nil : List a" , " Cons : a -> List a -> List a" ] wrong1 = - [ "length : List c -> Int" + [ ".+ : Int -> Int -> Int" + , ".+ x y = x" + , "length : List c -> Int" , "length = \\list. case list of" , " Nil => 0" , " Cons 6 xs => 1 + length xs" @@ -254,10 +265,10 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do correct4 = [ "elems : List (List c) -> Int" , "elems = \\list. case list of" - --, " Nil => 0" - --, " Cons Nil Nil => 0" - --, " Cons Nil xs => elems xs" - , " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" + , -- , " Nil => 0" + -- , " Cons Nil Nil => 0" + -- , " Cons Nil xs => elems xs" + " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" ] tc_if = specify "Test if else case expression" $ do @@ -298,12 +309,19 @@ tc_infer_case = describe "Infer case expression" $ do tc_rec1 = specify "Infer simple recursive definition" $ - run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok + run + [ ".+ : Int -> Int -> Int" + , ".+ x y = x" + , "test x = 1 + test (x + 1)" + ] + `shouldSatisfy` ok tc_rec2 = specify "Infer recursive definition with pattern matching" $ run - [ "data Bool where" + [ ".+ : Int -> Int -> Int" + , ".+ x y = x" + , "data Bool where" , " False : Bool" , " True : Bool" , "test = \\x. case x of" @@ -329,5 +347,5 @@ runPrint = ["double x = x + x"] ok = \case - Ok _ -> True + Ok _ -> True Bad _ -> False diff --git a/tests/TestTypeCheckerHm.hs b/tests/TestTypeCheckerHm.hs index 9a14e76..5d59ca6 100644 --- a/tests/TestTypeCheckerHm.hs +++ b/tests/TestTypeCheckerHm.hs @@ -57,6 +57,8 @@ goods = , testSatisfy "A basic arithmetic function should be able to be inferred" ( D.do + ".+ : Int -> Int -> Int" + ".+ x y = x" "plusOne x = x + 1 ;" "main x = plusOne x ;" ) @@ -74,6 +76,8 @@ goods = , testSatisfy "length function on int list infers correct signature" ( D.do + ".+ : Int -> Int -> Int" + ".+ x y = x" "data List where " " Nil : List" " Cons : Int -> List -> List" @@ -114,6 +118,8 @@ bads = , testSatisfy "Using a concrete function (primitive type) on a skolem variable should not succeed" ( D.do + ".+ : Int -> Int -> Int" + ".+ x y = x" "plusOne : Int -> Int ;" "plusOne x = x + 1 ;" "f : a -> Int ;" @@ -131,6 +137,8 @@ bads = "Pattern matching on literal and _List should not succeed" ( D.do _List + ".+ : Int -> Int -> Int" + ".+ x y = x" "length : List c -> Int;" "length _List = case _List of {" " 0 => 0;"