From eafe0fea0b40fc5adc6b6cee3e5f243c5f2a6490 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Thu, 16 Feb 2023 16:37:36 +0100 Subject: [PATCH] Rewrote using unification-fd. Heavily inspired (aka copied) from: https://byorgey.wordpress.com/2021/09/08/implementing-hindley-milner-with-the-unification-fd-library/ --- language.cabal | 3 +- src/Main.hs | 5 +- src/TypeChecker/TypeChecker.hs | 39 +++-- src/TypeChecker/Unification.hs | 284 +++++++++++++++++++++++++++++++++ test_program | 4 +- 5 files changed, 314 insertions(+), 21 deletions(-) create mode 100644 src/TypeChecker/Unification.hs diff --git a/language.cabal b/language.cabal index 5668b83..e3d40b9 100644 --- a/language.cabal +++ b/language.cabal @@ -33,6 +33,7 @@ executable language Grammar.ErrM TypeChecker.TypeChecker TypeChecker.TypeCheckerIr + TypeChecker.Unification Renamer.Renamer Renamer.RenamerIr @@ -45,6 +46,6 @@ executable language , either , extra , array - , equivalence + , unification-fd default-language: GHC2021 diff --git a/src/Main.hs b/src/Main.hs index 3679582..0845f8c 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -5,7 +5,8 @@ import Grammar.Par (myLexer, pProgram) import Grammar.Print (printTree) import System.Environment (getArgs) import System.Exit (exitFailure, exitSuccess) -import TypeChecker.TypeChecker (typecheck) +-- import TypeChecker.TypeChecker (typecheck) +import TypeChecker.Unification (typecheck) import Renamer.Renamer (rename) import Grammar.Print (prt) @@ -43,4 +44,4 @@ main = getArgs >>= \case putStrLn "" putStrLn " ----- TYPECHECKER ----- " putStrLn "" - putStrLn . printTree $ prg + putStrLn . show $ prg diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index f663ec4..1584b4f 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -11,13 +11,14 @@ import Data.Map (Map) import qualified Data.Map as M import Grammar.ErrM (Err) import Grammar.Print +import Data.List (findIndex) import Debug.Trace (trace) import TypeChecker.TypeCheckerIr data Ctx = Ctx { vars :: Map Integer Type , sigs :: Map Ident Type - , count :: Int + , nextFresh :: Ident } deriving Show @@ -32,7 +33,7 @@ programmer. type Infer = StateT Ctx (ExceptT Error Identity) initEnv :: Ctx -initEnv = Ctx mempty mempty 0 +initEnv = Ctx mempty mempty "a" run :: Infer a -> Either Error a run = runIdentity . runExceptT . flip St.evalStateT initEnv @@ -51,7 +52,6 @@ inferBind (RBind name e) = do insertSigs name t return $ TBind name t e' - inferExp :: RExp -> Infer (Type, TExp) inferExp = \case @@ -79,14 +79,14 @@ inferExp = \case RApp expr1 expr2 -> do (typ1, expr1') <- inferExp expr1 (typ2, expr2') <- inferExp expr2 - cnt <- incCount + fvar <- fresh case typ1 of (TPoly (Ident x)) -> do - let newType = (TArrow (TPoly (Ident x)) (TPoly . Ident $ x ++ (show cnt))) + let newType = (TArrow (TPoly (Ident x)) (TPoly fvar)) specifyType expr1 newType typ1' <- apply newType typ1 return $ (typ1', TApp expr1' expr2' typ1') - _ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ2 typ1 + _ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2 RAdd expr1 expr2 -> do (typ1, expr1') <- inferExp expr1 @@ -115,11 +115,22 @@ isPoly :: Type -> Bool isPoly (TPoly _) = True isPoly _ = False -incCount :: Infer Int -incCount = do - st <- St.get - St.put ( st { count = succ st.count } ) - return st.count +fresh :: Infer Ident +fresh = do + (Ident var) <- St.gets nextFresh + when (length var == 0) (throwError $ Default "fresh") + index <- case findIndex (== (head var)) alphabet of + Nothing -> throwError $ Default "fresh" + Just i -> return i + let nextIndex = (index + 1) `mod` 26 + let newVar = Ident $ [alphabet !! nextIndex] + St.modify (\st -> st { nextFresh = newVar }) + return newVar + where + alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char] + +unify :: Type -> Type -> Infer Type +unify = todo -- | Specify the type of a bound variable -- Because in lambdas we have to assume a general type and update it @@ -153,12 +164,6 @@ insertSigs i t = do st <- St.get St.put ( st { sigs = M.insert i t st.sigs } ) -union :: Type -> Type -> Infer () -union = todo - -find :: Type -> Type -find = todo - -- Have to figure out the equivalence classes for types. -- Currently this does not support more than exact matches. apply :: Type -> Type -> Infer Type diff --git a/src/TypeChecker/Unification.hs b/src/TypeChecker/Unification.hs new file mode 100644 index 0000000..1842707 --- /dev/null +++ b/src/TypeChecker/Unification.hs @@ -0,0 +1,284 @@ +{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-} + +module TypeChecker.Unification where + +import Renamer.Renamer +import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..)) +import qualified Renamer.RenamerIr as R + +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Except +import Data.Functor.Identity +import Control.Arrow ((>>>)) +import Control.Unification hiding ((=:=), applyBindings) +import qualified Control.Unification as U +import Control.Unification.IntVar +import Data.Functor.Fixedpoint +import GHC.Generics (Generic1) +import Data.Foldable (fold) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromMaybe, fromJust) +import Data.Set (Set, (\\)) +import qualified Data.Set as S +import Debug.Trace (trace) + +type Ctx = Map Ident UPolytype + +type TypeError = String + +data TypeT a = TPolyT Ident | TMonoT Ident | TArrowT a a + deriving (Functor, Foldable, Traversable, Generic1, Unifiable) + +instance Show a => Show (TypeT a) where + show (TPolyT (Ident i)) = i + show (TMonoT (Ident i)) = i + show (TArrowT a b) = show a ++ " -> " ++ show b + +type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity))) + +type Type = Fix TypeT +type UType = UTerm TypeT IntVar + +data Poly t = Forall [Ident] t + deriving (Eq, Show, Functor) + +type Polytype = Poly Type + +type UPolytype = Poly UType + +pattern TPoly :: Ident -> Type +pattern TPoly v = Fix (TPolyT v) + +pattern TMono :: Ident -> Type +pattern TMono v = Fix (TMonoT v) + +pattern TArrow :: Type -> Type -> Type +pattern TArrow t1 t2 = Fix (TArrowT t1 t2) + +pattern UTMono :: Ident -> UType +pattern UTMono v = UTerm (TMonoT v) + +pattern UTArrow :: UType -> UType -> UType +pattern UTArrow t1 t2 = UTerm (TArrowT t1 t2) + +pattern UTPoly :: Ident -> UType +pattern UTPoly v = UTerm (TPolyT v) + +data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType + deriving Show + +data Program = Program [Bind] + deriving Show + +data Bind = Bind Ident Exp Polytype + deriving Show + +data Exp + = EAnn Exp Polytype + | EBound Ident Polytype + | EFree Ident Polytype + | EConst Const Polytype + | EApp Exp Exp Polytype + | EAdd Exp Exp Polytype + | EAbs Ident Exp Polytype + deriving Show + +data TExp + = TAnn TExp UType + | TFree Ident UType + | TBound Ident UType + | TConst Const UType + | TApp TExp TExp UType + | TAdd TExp TExp UType + | TAbs Ident TExp UType + deriving Show + +---------------------------------------------------------- +typecheck :: RProgram -> Either TypeError Program +typecheck = run . inferProgram + +inferProgram :: RProgram -> Infer Program +inferProgram (RProgram binds) = do + binds' <- mapM inferBind binds + return $ Program binds' + +inferBind :: RBind -> Infer Bind +inferBind (RBind i e) = do + (t,e') <- infer e + e'' <- convert fromUType e' + t' <- fromUType t + insertSigs i (Forall [] t) + return $ Bind i e'' t' + +fromUType :: UType -> Infer Polytype +fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype)) + +convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp +convert f = \case + (TAnn e t) -> do + e' <- convert f e + t' <- (f t) + return $ EAnn e' t' + (TFree i t) -> do + t' <- f t + return $ EFree i t' + (TBound i t) -> do + t' <- f t + return $ EBound i t' + (TConst c t) -> do + t' <- f t + return $ EConst c t' + (TApp e1 e2 t) -> do + e1' <- convert f e1 + e2' <- convert f e2 + t' <- f t + return $ EApp e1' e2' t' + (TAdd e1 e2 t) -> do + e1' <- convert f e1 + e2' <- convert f e2 + t' <- f t + return $ EAdd e1' e2' t' + (TAbs i e t) -> do + e' <- convert f e + t' <- f t + return $ EAbs i e' t' + +run :: Infer a -> Either TypeError a +run = flip evalStateT mempty + >>> flip runReaderT mempty + >>> runExceptT + >>> evalIntBindingT + >>> runIdentity + +infer :: RExp -> Infer (UType, TExp) +infer = \case + (RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int")) + (RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String")) + (RAdd e1 e2) -> do + (t1,e1') <- infer e2 + (t2,e2') <- infer e1 + t1 =:= (UTMono "Int") + t2 =:= (UTMono "Int") + return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int")) + (RAnn e t) -> do + (t',e') <- infer e + check e t' + return (t', TAnn e' t') + (RApp e1 e2) -> do + (f,e1') <- infer e1 + (arg,e2') <- infer e2 + res <- fresh + f =:= UTArrow f arg + return (res, TApp e1' e2' res) + (RAbs _ i e) -> do + arg <- fresh + withBinding i (Forall [] arg) $ do + (res, e') <- infer e + return $ (UTArrow arg res, TAbs i e' (UTArrow arg res)) + (RFree i) -> do + t <- lookupSigsT i + return (t, TFree i t) + (RBound _ i) -> do + t <- lookupVarT i + return (t, TBound i t) + +check :: RExp -> UType -> Infer () +check expr t = do + (t', _) <- infer expr + t =:= t' + return () + +lookupVarT :: Ident -> Infer UType +lookupVarT x@(Ident i) = do + ctx <- ask + maybe (throwError $ "Var - Unbound variable: " <> i) instantiate (M.lookup x ctx) + +lookupSigsT :: Ident -> Infer UType +lookupSigsT x@(Ident i) = do + ctx <- ask + case M.lookup x ctx of + Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i) + Just ut -> return $ fromPolytype ut + +insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m () +insertSigs x ty = modify (M.insert x ty) + +fromPolytype :: UPolytype -> UType +fromPolytype (Forall ids ut) = ut + +ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a +ucata f _ (UVar v) = f v +ucata f g (UTerm t) = g (fmap (ucata f g) t) + +withBinding :: MonadReader Ctx m => Ident -> UPolytype -> m a -> m a +withBinding x ty = local (M.insert x ty) + +deriving instance Ord IntVar + +class FreeVars a where + freeVars :: a -> Infer (Set (Either Ident IntVar)) + +instance FreeVars UType where + freeVars ut = do + fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut + let ftvs = ucata (const S.empty) + (\case {TMonoT x -> S.singleton (Left x); f -> fold f}) + ut + return $ fuvs `S.union` ftvs + +instance FreeVars UPolytype where + freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut + +instance FreeVars Ctx where + freeVars = fmap S.unions . mapM freeVars . M.elems + +fresh :: Infer UType +fresh = UVar <$> lift (lift (lift freeVar)) + +instance Fallible TypeT IntVar TypeError where + occursFailure iv ut = "Infinite" + mismatchFailure iv ut = "Mismatch" + +(=:=) :: UType -> UType -> Infer UType +(=:=) s t = lift . lift $ s U.=:= t + +applyBindings :: UType -> Infer UType +applyBindings = lift . lift . U.applyBindings + +instantiate :: UPolytype -> Infer UType +instantiate (Forall xs uty) = do + xs' <- mapM (const fresh) xs + return $ substU (M.fromList (zip (map Left xs) xs')) uty + +substU :: Map (Either Ident IntVar) UType -> UType -> UType +substU m = ucata + (\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) + (\case + TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m) + f -> UTerm f + ) + +skolemize :: UPolytype -> Infer UType +skolemize (Forall xs uty) = do + xs' <- mapM (const fresh) xs + return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty + where + toSkolem (UVar v) = UTPoly (mkVarName "s" v) + +mkVarName :: String -> IntVar -> Ident +mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1) + +generalize :: UType -> Infer UPolytype +generalize uty = do + uty' <- applyBindings uty + ctx <- ask + tmfvs <- freeVars uty' + ctxfvs <- freeVars ctx + let fvs = S.toList $ tmfvs \\ ctxfvs + xs = map (either id (mkVarName "a")) fvs + return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty') + +fromUPolytype :: UPolytype -> Polytype +fromUPolytype = fmap (fromJust . freeze) diff --git a/test_program b/test_program index db9a44e..fdb3de4 100644 --- a/test_program +++ b/test_program @@ -1 +1,3 @@ -test f x = f x +apply w x = \y. \z. w + x + y + z ; + +main = apply 1 2 3 4 ;