diff --git a/fourmolu.yaml b/fourmolu.yaml index f15300e..cf7ab2f 100644 --- a/fourmolu.yaml +++ b/fourmolu.yaml @@ -3,12 +3,12 @@ function-arrows: trailing comma-style: leading import-export-style: diff-friendly indent-wheres: false -record-brace-space: false +record-brace-space: true newlines-between-decls: 1 haddock-style: multi-line haddock-style-module: let-style: auto in-style: right-align -respectful: true +respectful: false fixities: [] unicode: never diff --git a/language.cabal b/language.cabal index e3d40b9..36b63c7 100644 --- a/language.cabal +++ b/language.cabal @@ -17,7 +17,7 @@ extra-source-files: common warnings - ghc-options: -Wdefault + ghc-options: -W executable language import: warnings diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 1584b4f..9b94f55 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -1,26 +1,30 @@ -{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Use camelCase" #-} module TypeChecker.TypeChecker where -import Control.Monad (when, void) -import Control.Monad.Except (ExceptT, throwError, runExceptT) +import Control.Monad (void) +import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.State (StateT) -import qualified Control.Monad.State as St +import Control.Monad.State qualified as St import Data.Functor.Identity (Identity, runIdentity) import Data.Map (Map) -import qualified Data.Map as M -import Grammar.ErrM (Err) -import Grammar.Print -import Data.List (findIndex) - -import Debug.Trace (trace) +import Data.Map qualified as M import TypeChecker.TypeCheckerIr -data Ctx = Ctx { vars :: Map Integer Type - , sigs :: Map Ident Type - , nextFresh :: Ident - } - deriving Show +data Ctx = Ctx + { vars :: Map Integer Type + , sigs :: Map Ident Type + , nextFresh :: Int + } + deriving (Show) + +-- Perhaps swap over to reader monad instead for vars and sigs. +type Infer = StateT Ctx (ExceptT Error Identity) {- @@ -28,18 +32,20 @@ The type checker will assume we first rename all variables to unique name, as to have to care about scoping. It significantly improves the quality of life of the programmer. +TODOs: + Add skolemization variables. i.e + { \x. 3 : forall a. a -> a } + should not type check + + Generalize. Not really sure what that means though + -} -type Infer = StateT Ctx (ExceptT Error Identity) - -initEnv :: Ctx -initEnv = Ctx mempty mempty "a" - run :: Infer a -> Either Error a -run = runIdentity . runExceptT . flip St.evalStateT initEnv +run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0) typecheck :: RProgram -> Either Error TProgram -typecheck = run . inferPrg +typecheck = run . inferPrg inferPrg :: RProgram -> Infer TProgram inferPrg (RProgram xs) = do @@ -54,122 +60,83 @@ inferBind (RBind name e) = do inferExp :: RExp -> Infer (Type, TExp) inferExp = \case - RAnn expr typ -> do - (t,expr') <- inferExp expr - when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn") - return (typ,expr') - - -- Name is only here for proper error messages - RBound num name -> - M.lookup num <$> St.gets vars >>= \case - Nothing -> throwError $ UnboundVar "RBound" - Just t -> return (t, TBound num name t) - + (t, expr') <- inferExp expr + void $ t =:= typ + return (typ, expr') + RBound num name -> do + t <- lookupVars num + return (t, TBound num name t) RFree name -> do - M.lookup name <$> St.gets sigs >>= \case - Nothing -> throwError $ UnboundVar "RFree" - Just t -> return (t, TFree name t) - - RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int")) - - RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str")) - - -- Should do proper unification using union-find. Some nice libs exist - RApp expr1 expr2 -> do - (typ1, expr1') <- inferExp expr1 - (typ2, expr2') <- inferExp expr2 - fvar <- fresh - case typ1 of - (TPoly (Ident x)) -> do - let newType = (TArrow (TPoly (Ident x)) (TPoly fvar)) - specifyType expr1 newType - typ1' <- apply newType typ1 - return $ (typ1', TApp expr1' expr2' typ1') - _ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2 - + t <- lookupSigs name + return (t, TFree name t) + RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int")) + RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str")) RAdd expr1 expr2 -> do - (typ1, expr1') <- inferExp expr1 - (typ2, expr2') <- inferExp expr2 - when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd") - specifyType expr1 (TMono "Int") - specifyType expr2 (TMono "Int") - return (TMono "Int", TAdd expr1' expr2' (TMono "Int")) - + (typ1, expr1') <- check expr1 (TMono "Int") + (_, expr2') <- check expr2 (TMono "Int") + return (typ1, TAdd expr1' expr2' typ1) + RApp expr1 expr2 -> do + (fn_t, expr1') <- inferExp expr1 + (arg_t, expr2') <- inferExp expr2 + res <- fresh + -- TODO: Double check if this is correct behavior. + -- It might be the case that we should return res, rather than new_t + new_t <- fn_t =:= TArrow arg_t res + return (new_t, TApp expr1' expr2' new_t) RAbs num name expr -> do - insertVars num (TPoly "a") + arg <- fresh + insertVars num arg (typ, expr') <- inferExp expr - newTyp <- lookupVars num - return $ (TArrow newTyp typ, TAbs num name expr' typ) + return (TArrow arg typ, TAbs num name expr' typ) --- Aux -isInt :: Type -> Bool -isInt (TMono "Int") = True -isInt _ = False +check :: RExp -> Type -> Infer (Type, TExp) +check e t = do + (t', e') <- inferExp e + t'' <- t' =:= t + return (t'', e') -isArrow :: Type -> Bool -isArrow (TArrow _ _) = True -isArrow _ = False - -isPoly :: Type -> Bool -isPoly (TPoly _) = True -isPoly _ = False - -fresh :: Infer Ident +fresh :: Infer Type fresh = do - (Ident var) <- St.gets nextFresh - when (length var == 0) (throwError $ Default "fresh") - index <- case findIndex (== (head var)) alphabet of - Nothing -> throwError $ Default "fresh" - Just i -> return i - let nextIndex = (index + 1) `mod` 26 - let newVar = Ident $ [alphabet !! nextIndex] - St.modify (\st -> st { nextFresh = newVar }) - return newVar - where - alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char] + var <- St.gets nextFresh + St.modify (\st -> st {nextFresh = succ var}) + return (TPoly $ Ident (show var)) -unify :: Type -> Type -> Infer Type -unify = todo - --- | Specify the type of a bound variable --- Because in lambdas we have to assume a general type and update it -specifyType :: RExp -> Type -> Infer () -specifyType (RBound num name) typ = do - insertVars num typ - return () -specifyType _ _ = return () +-- | Unify two types. +(=:=) :: Type -> Type -> Infer Type +(=:=) (TPoly _) b = return b +(=:=) a (TPoly _) = return a +(=:=) (TMono a) (TMono b) | a == b = return (TMono a) +(=:=) (TArrow a b) (TArrow c d) = do + t1 <- a =:= c + t2 <- b =:= d + return $ TArrow t1 t2 +(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b]) +-- Unused currently lookupVars :: Integer -> Infer Type lookupVars i = do st <- St.gets vars case M.lookup i st of - Just t -> return t - Nothing -> throwError $ UnboundVar "lookupVars" + Just t -> return t + Nothing -> throwError $ UnboundVar "lookupVars" insertVars :: Integer -> Type -> Infer () insertVars i t = do st <- St.get - St.put ( st { vars = M.insert i t st.vars } ) + St.put (st {vars = M.insert i t st.vars}) lookupSigs :: Ident -> Infer Type lookupSigs i = do st <- St.gets sigs case M.lookup i st of - Just t -> return t - Nothing -> throwError $ UnboundVar "lookupSigs" + Just t -> return t + Nothing -> throwError $ UnboundVar "lookupSigs" insertSigs :: Ident -> Type -> Infer () insertSigs i t = do st <- St.get - St.put ( st { sigs = M.insert i t st.sigs } ) - --- Have to figure out the equivalence classes for types. --- Currently this does not support more than exact matches. -apply :: Type -> Type -> Infer Type -apply (TArrow t1 t2) t3 - | t1 == t3 = return t2 -apply t1 t2 = throwError $ TypeMismatch "apply" + St.put (st {sigs = M.insert i t st.sigs}) {-# WARNING todo "TODO IN CODE" #-} todo :: a @@ -183,21 +150,30 @@ data Error | UnboundVar String | AnnotatedMismatch String | Default String - deriving Show + deriving (Show) -- Tests -lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x")) -lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String"))) +-- (\x. x + 1) 1 +app_lambda :: RExp +app_lambda = app lambda one -fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x")))) +lambda :: RExp +lambda = RAbs 0 "x" $ add bound one +add :: RExp -> RExp -> RExp +add = RAdd ---add x = \y. x+y; -add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y"))) --- main = (\z. z+z) ((add 4) 6); -main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd -four = RConst (CInt 4) -six = RConst (CInt 6) -applyAdd = (RApp (RApp add four) six) -partialAdd = RApp add four +bound = RBound 0 "x" + +app :: RExp -> RExp -> RExp +app = RApp + +one :: RExp +one = RConst (CInt 1) + +fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int")) + +arr_t = TArrow (TMono "Int") (TPoly "1") + +f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x"))) diff --git a/src/TypeChecker/Unification.hs b/src/TypeChecker/Unification.hs index 1842707..6c86a70 100644 --- a/src/TypeChecker/Unification.hs +++ b/src/TypeChecker/Unification.hs @@ -1,28 +1,31 @@ -{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} module TypeChecker.Unification where -import Renamer.Renamer -import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..)) -import qualified Renamer.RenamerIr as R - -import Control.Monad.Reader -import Control.Monad.State -import Control.Monad.Except -import Data.Functor.Identity -import Control.Arrow ((>>>)) -import Control.Unification hiding ((=:=), applyBindings) -import qualified Control.Unification as U -import Control.Unification.IntVar -import Data.Functor.Fixedpoint -import GHC.Generics (Generic1) -import Data.Foldable (fold) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (fromMaybe, fromJust) -import Data.Set (Set, (\\)) -import qualified Data.Set as S -import Debug.Trace (trace) +import Control.Arrow ((>>>)) +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Unification hiding (applyBindings, (=:=)) +import Control.Unification qualified as U +import Control.Unification.IntVar +import Data.Foldable (fold) +import Data.Functor.Fixedpoint +import Data.Functor.Identity +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe (fromJust, fromMaybe) +import Data.Set (Set, (\\)) +import Data.Set qualified as S +import Debug.Trace (trace) +import GHC.Generics (Generic1) +import Renamer.Renamer +import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..)) +import Renamer.RenamerIr qualified as R type Ctx = Map Ident UPolytype @@ -39,12 +42,13 @@ instance Show a => Show (TypeT a) where type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity))) type Type = Fix TypeT + type UType = UTerm TypeT IntVar data Poly t = Forall [Ident] t - deriving (Eq, Show, Functor) + deriving (Eq, Show, Functor) -type Polytype = Poly Type +type Polytype = Poly Type type UPolytype = Poly UType @@ -67,23 +71,23 @@ pattern UTPoly :: Ident -> UType pattern UTPoly v = UTerm (TPolyT v) data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType - deriving Show + deriving (Show) -data Program = Program [Bind] - deriving Show +newtype Program = Program [Bind] + deriving (Show) data Bind = Bind Ident Exp Polytype - deriving Show + deriving (Show) data Exp = EAnn Exp Polytype | EBound Ident Polytype - | EFree Ident Polytype + | EFree Ident Polytype | EConst Const Polytype | EApp Exp Exp Polytype | EAdd Exp Exp Polytype | EAbs Ident Exp Polytype - deriving Show + deriving (Show) data TExp = TAnn TExp UType @@ -93,11 +97,11 @@ data TExp | TApp TExp TExp UType | TAdd TExp TExp UType | TAbs Ident TExp UType - deriving Show + deriving (Show) ---------------------------------------------------------- typecheck :: RProgram -> Either TypeError Program -typecheck = run . inferProgram +typecheck = run . inferProgram inferProgram :: RProgram -> Infer Program inferProgram (RProgram binds) = do @@ -106,7 +110,7 @@ inferProgram (RProgram binds) = do inferBind :: RBind -> Infer Bind inferBind (RBind i e) = do - (t,e') <- infer e + (t, e') <- infer e e'' <- convert fromUType e' t' <- fromUType t insertSigs i (Forall [] t) @@ -114,20 +118,19 @@ inferBind (RBind i e) = do fromUType :: UType -> Infer Polytype fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype)) - + convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp convert f = \case (TAnn e t) -> do e' <- convert f e - t' <- (f t) - return $ EAnn e' t' + EAnn e' <$> f t (TFree i t) -> do t' <- f t return $ EFree i t' (TBound i t) -> do t' <- f t return $ EBound i t' - (TConst c t) -> do + (TConst c t) -> do t' <- f t return $ EConst c t' (TApp e1 e2 t) -> do @@ -135,42 +138,43 @@ convert f = \case e2' <- convert f e2 t' <- f t return $ EApp e1' e2' t' - (TAdd e1 e2 t) -> do + (TAdd e1 e2 t) -> do e1' <- convert f e1 e2' <- convert f e2 t' <- f t return $ EAdd e1' e2' t' - (TAbs i e t) -> do + (TAbs i e t) -> do e' <- convert f e t' <- f t return $ EAbs i e' t' run :: Infer a -> Either TypeError a -run = flip evalStateT mempty - >>> flip runReaderT mempty - >>> runExceptT - >>> evalIntBindingT - >>> runIdentity +run = + flip evalStateT mempty + >>> flip runReaderT mempty + >>> runExceptT + >>> evalIntBindingT + >>> runIdentity infer :: RExp -> Infer (UType, TExp) infer = \case - (RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int")) - (RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String")) + (RConst (CInt i)) -> return (UTMono "Int", TConst (CInt i) (UTMono "Int")) + (RConst (CStr str)) -> return (UTMono "String", TConst (CStr str) (UTMono "String")) (RAdd e1 e2) -> do - (t1,e1') <- infer e2 - (t2,e2') <- infer e1 - t1 =:= (UTMono "Int") - t2 =:= (UTMono "Int") - return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int")) + (t1, e1') <- infer e2 + (t2, e2') <- infer e1 + t1 =:= UTMono "Int" + t2 =:= UTMono "Int" + return (UTMono "Int", TAdd e1' e2' (UTMono "Int")) (RAnn e t) -> do - (t',e') <- infer e + (t', e') <- infer e check e t' return (t', TAnn e' t') (RApp e1 e2) -> do - (f,e1') <- infer e1 - (arg,e2') <- infer e2 + (f, e1') <- infer e1 + (arg, e2') <- infer e2 res <- fresh - f =:= UTArrow f arg + f =:= UTArrow arg res return (res, TApp e1' e2' res) (RAbs _ i e) -> do arg <- fresh @@ -199,8 +203,8 @@ lookupSigsT :: Ident -> Infer UType lookupSigsT x@(Ident i) = do ctx <- ask case M.lookup x ctx of - Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i) - Just ut -> return $ fromPolytype ut + Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i) + Just ut -> return $ fromPolytype ut insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m () insertSigs x ty = modify (M.insert x ty) @@ -218,21 +222,23 @@ withBinding x ty = local (M.insert x ty) deriving instance Ord IntVar class FreeVars a where - freeVars :: a -> Infer (Set (Either Ident IntVar)) + freeVars :: a -> Infer (Set (Either Ident IntVar)) instance FreeVars UType where - freeVars ut = do - fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut - let ftvs = ucata (const S.empty) - (\case {TMonoT x -> S.singleton (Left x); f -> fold f}) - ut - return $ fuvs `S.union` ftvs + freeVars ut = do + fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut + let ftvs = + ucata + (const S.empty) + (\case TMonoT x -> S.singleton (Left x); f -> fold f) + ut + return $ fuvs `S.union` ftvs instance FreeVars UPolytype where - freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut + freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut instance FreeVars Ctx where - freeVars = fmap S.unions . mapM freeVars . M.elems + freeVars = fmap S.unions . mapM freeVars . M.elems fresh :: Infer UType fresh = UVar <$> lift (lift (lift freeVar)) @@ -249,21 +255,22 @@ applyBindings = lift . lift . U.applyBindings instantiate :: UPolytype -> Infer UType instantiate (Forall xs uty) = do - xs' <- mapM (const fresh) xs - return $ substU (M.fromList (zip (map Left xs) xs')) uty + xs' <- mapM (const fresh) xs + return $ substU (M.fromList (zip (map Left xs) xs')) uty substU :: Map (Either Ident IntVar) UType -> UType -> UType -substU m = ucata - (\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) - (\case - TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m) - f -> UTerm f - ) +substU m = + ucata + (\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) + ( \case + TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m) + f -> UTerm f + ) skolemize :: UPolytype -> Infer UType skolemize (Forall xs uty) = do - xs' <- mapM (const fresh) xs - return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty + xs' <- mapM (const fresh) xs + return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty where toSkolem (UVar v) = UTPoly (mkVarName "s" v) @@ -272,13 +279,13 @@ mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1) generalize :: UType -> Infer UPolytype generalize uty = do - uty' <- applyBindings uty - ctx <- ask - tmfvs <- freeVars uty' - ctxfvs <- freeVars ctx - let fvs = S.toList $ tmfvs \\ ctxfvs - xs = map (either id (mkVarName "a")) fvs - return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty') + uty' <- applyBindings uty + ctx <- ask + tmfvs <- freeVars uty' + ctxfvs <- freeVars ctx + let fvs = S.toList $ tmfvs \\ ctxfvs + xs = map (either id (mkVarName "a")) fvs + return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty') fromUPolytype :: UPolytype -> Polytype fromUPolytype = fmap (fromJust . freeze)