diff --git a/language.cabal b/language.cabal index 5653f08..7b21b60 100644 --- a/language.cabal +++ b/language.cabal @@ -36,6 +36,7 @@ executable language -- TypeChecker.TypeCheckerIr -- TypeChecker.Unification TypeChecker.HM + TypeChecker.AlgoW TypeChecker.HMIr Renamer.RenamerM -- Renamer.Renamer diff --git a/src/TypeChecker/AlgoW.hs b/src/TypeChecker/AlgoW.hs new file mode 100644 index 0000000..e630da2 --- /dev/null +++ b/src/TypeChecker/AlgoW.hs @@ -0,0 +1,123 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module TypeChecker.AlgoW where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor (bimap, second) +import Data.Functor.Identity (Identity, runIdentity) +import Data.List (intersect) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromMaybe) + +import Grammar.Abs +import qualified TypeChecker.HMIr as T + +data Poly = Forall [Ident] Type + deriving Show + +a = TPol "a" +b = TPol "b" +int = TMono "int" +arr = TArr + +data Ctx = Ctx { vars :: Map Ident Poly + , sigs :: Map Ident Poly } + +data Env = Env { counter :: Int + , substitutions :: Map Type Type + } + +type Subst = Map Type Type +type Error = String + +newtype Infer a = Infer { runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a } + deriving (Functor, Applicative, Monad, MonadState Env, MonadReader Ctx, MonadError Error) + +initCtx :: Ctx +initCtx = Ctx mempty mempty + +initEnv :: Env +initEnv = Env 0 mempty + +run :: Ctx -> Env -> Infer a -> Either Error a +run c e = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e . runInfer + +w :: Exp -> Infer Type +w = \case + EInt n -> return int + EId i -> (\(Forall _ t) -> t) <$> (lookupVar i >>= inst) + EAbs var e -> do + fr <- fresh + withBinding var (Forall [] (TPol fr)) $ do + t' <- w e + subst (Forall [] $ TArr (TPol fr) t') + EApp e0 e1 -> do + t0 <- substCtx (w e0) + t1 <- w e1 + undefined + +substCtx :: Infer Type -> Infer Type +substCtx m = do + vs <- asks (M.toList . vars) + ks <- traverse (subst . snd) vs + let x = map fst vs + local (\st -> st { vars = M.fromList $ zip x ks }) m + +subst :: Poly -> Infer Poly +subst (Forall xs t) = do + subs <- gets substitutions + case t of + TPol a -> case M.lookup (TPol a) subs of + Nothing -> return $ Forall xs t + Just t' -> return $ Forall (remove a xs) t' + TMono a -> case M.lookup (TMono a) subs of + Nothing -> return $ Forall xs t + Just t' -> return $ Forall (remove a xs) t' + TArr a b -> do + (Forall xs' a') <- subst (Forall xs a) + (Forall xs'' b') <- subst (Forall xs b) + return $ Forall (xs' `intersect` xs'') (TArr a' b') + + +remove :: Ord a => a -> [a] -> [a] +remove a = foldr (\x acc -> if x == a then acc else x : acc) [] + +inst :: Poly -> Infer Poly +inst (Forall xs t) = do + xs' <- mapM (const fresh) xs + let sub = zip xs xs' + let subst' t = case t of + TMono a -> return $ TMono a + TPol a -> case lookup a sub of + Nothing -> return $ TPol a + Just t -> return $ TPol t + TArr a b -> TArr <$> subst' a <*> subst' b + Forall [] <$> subst' t + +-- | Generate a new fresh variable and increment the state +fresh :: Infer Ident +fresh = do + n <- gets counter + modify (\st -> st { counter = n + 1 }) + return . Ident $ "t" ++ show n + +insertSub :: Type -> Type -> Infer () +insertSub t1 t2 = modify (\st -> st { substitutions = M.insert t1 t2 (substitutions st) }) + +withBinding :: Ident -> Poly -> Infer Poly -> Infer Type +withBinding i t m = (\(Forall _ t) -> t) <$> local (\re -> re { vars = M.insert i t (vars re) }) m + +lookupVar :: Ident -> Infer Poly +lookupVar i = do + m <- asks vars + case M.lookup i m of + Just t -> return t + Nothing -> throwError $ "Unbound variable: " ++ show i + +{-# WARNING todo "TODO IN CODE" #-} +todo :: a +todo = error "TODO in code" diff --git a/src/TypeChecker/HM.hs b/src/TypeChecker/HM.hs index 63072d1..7b33cbe 100644 --- a/src/TypeChecker/HM.hs +++ b/src/TypeChecker/HM.hs @@ -47,7 +47,14 @@ inferPrg (Program bs) = do inferBind :: Bind -> Infer T.Bind inferBind (Bind i t _ params rhs) = do (t',e') <- inferExp (makeLambda rhs (reverse params)) - when (t /= t') (throwError $ "Signature of function " ++ show i ++ " with type: " ++ show t ++ " does not match inferred type " ++ show t' ++ " of expression: " ++ show e') + when (t /= t') (throwError . unwords $ [ "Signature of function" + , show i + , "with type:" + , show t + , "does not match inferred type" + , show t' + , "of expression:" + , show e']) return $ T.Bind (t,i) [] e' makeLambda :: Exp -> [Ident] -> Exp @@ -126,7 +133,6 @@ fresh = do -- b = int -- thus when solving constraints it must be the case that -- a = int -> int --- addConstraint :: Type -> Type -> Infer () addConstraint t1 t2 = do modify (\st -> st { constr = M.insert t1 t2 (constr st) })