small fixed and added qualifiedDo

This commit is contained in:
sebastian 2023-03-27 21:16:48 +02:00
parent a38e96a83b
commit e1633ea147
2 changed files with 108 additions and 80 deletions

View file

@ -1,9 +1,15 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Auxiliary (module Auxiliary) where module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError) import Control.Monad.Error.Class (liftEither)
import Data.Either.Combinators (maybeToRight) import Control.Monad.Except (MonadError)
import TypeChecker.TypeCheckerIr (Type (TFun)) import Data.Either.Combinators (maybeToRight)
import TypeChecker.TypeCheckerIr (Type (TFun))
import Prelude hiding ((>>), (>>=))
(>>) a b = a ++ " " ++ b
(>>=) a f = f a
snoc :: a -> [a] -> [a] snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x] snoc x xs = xs ++ [x]
@ -15,9 +21,8 @@ mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b])
mapAccumM f = go mapAccumM f = go
where where
go acc = \case go acc = \case
[] -> pure (acc, []) [] -> pure (acc, [])
x:xs -> do x : xs -> do
(acc', x') <- f acc x (acc', x') <- f acc x
(acc'', xs') <- go acc' xs (acc'', xs') <- go acc' xs
pure (acc'', x':xs') pure (acc'', x' : xs')

View file

@ -1,10 +1,12 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary import Auxiliary (maybeToRightM)
import Auxiliary qualified as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
@ -28,14 +30,16 @@ import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runC initEnv initCtx run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e runC e c =
runIdentity
. runExceptT
. flip runReaderT c
. flip evalStateT e
. runInfer
typecheck :: Program -> Either Error (T.Program' Type) typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg typecheck = run . checkPrg
@ -49,15 +53,15 @@ checkData d = do
(throwError $ unwords ["Data type incorrectly declared"]) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ traverse_
( \(Inj name' t') -> ( \(Inj name' t') ->
if typ == retType t' if typ == returnType t'
then insertConstr (coerce name') (t') then insertConstr (coerce name') t'
else else
throwError $ throwError $
unwords unwords
[ "return type of constructor:" [ "return type of constructor:"
, printTree name' , printTree name'
, "with type:" , "with type:"
, printTree (retType t') , printTree (returnType t')
, "does not match data: " , "does not match data: "
, printTree typ , printTree typ
] ]
@ -69,9 +73,9 @@ checkData d = do
<> printTree d <> printTree d
<> "'" <> "'"
retType :: Type -> Type returnType :: Type -> Type
retType (TFun _ t2) = retType t2 returnType (TFun _ t2) = returnType t2
retType a = a returnType a = a
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
@ -92,7 +96,7 @@ preRun (x : xs) = case x of
<> printTree n <> printTree n
<> "'" <> "'"
) )
insertSig (coerce n) (Just $ t) >> preRun xs insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
collect (collectTypeVars e) collect (collectTypeVars e)
s <- gets sigs s <- gets sigs
@ -107,10 +111,11 @@ checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs) (DData d) -> fmap (T.DData (coerceData d) :) (checkDef xs)
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
where where
coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type) checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
@ -145,11 +150,11 @@ typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (TVar _) (TVar _) = True typeEq (TVar _) (TVar _) = True
typeEq _ _ = False typeEq _ _ = False
skolem :: Type -> Type skolemize :: Type -> Type
skolem (TVar (T.MkTVar a)) = TLit (coerce a) skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a)
skolem (TAll x t) = TAll x (skolem t) skolemize (TAll x t) = TAll x (skolemize t)
skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2 skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolem t = t skolemize t = t
isMoreSpecificOrEq :: Type -> Type -> Bool isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
@ -204,10 +209,9 @@ algoW = \case
, printTree t' , printTree t'
] ]
) )
applySt s1 $ do s2 <- exprErr (unify (t) t') err
s2 <- exprErr (unify (t) t') err let comp = s2 `compose` s1
let comp = s2 `compose` s1 return (comp, apply comp (e', t))
return (comp, apply comp (e', t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
@ -262,16 +266,14 @@ algoW = \case
err@(EAdd e0 e1) -> do err@(EAdd e0 e1) -> do
(s1, (e0', t0)) <- algoW e0 (s1, (e0', t0)) <- algoW e0
applySt s1 $ do (s2, (e1', t1)) <- algoW e1
(s2, (e1', t1)) <- algoW e1 s3 <- exprErr (unify (apply s2 t0) int) err
-- applySt s2 $ do s4 <- exprErr (unify (apply s3 t1) int) err
s3 <- exprErr (unify (apply s2 t0) int) err let comp = s4 `compose` s3 `compose` s2 `compose` s1
s4 <- exprErr (unify (apply s3 t1) int) err return
let comp = s4 `compose` s3 `compose` s2 `compose` s1 ( comp
return , apply comp (T.EAdd (e0', t0) (e1', t1), int)
( comp )
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -281,12 +283,11 @@ algoW = \case
err@(EApp e0 e1) -> do err@(EApp e0 e1) -> do
fr <- fresh fr <- fresh
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do (s1, (e1', t1)) <- algoW e1
(s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err let t = apply s2 fr
let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0
let comp = s2 `compose` s1 `compose` s0 return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ---------------------------------------------- -- \| ----------------------------------------------
@ -346,22 +347,45 @@ unify t0 t1 = do
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs return $ foldr compose nullSubst xs
else throwError $
Aux.do
"Type constructor:"
printTree name
"("
printTree t
")"
"does not match with:"
printTree name'
"("
printTree t'
")"
-- [ "Type constructor:"
-- , printTree name
-- , "(" <> printTree t <> ")"
-- , "does not match with:"
-- , printTree name'
-- , "(" <> printTree t' <> ")"
-- ]
(TEVar a, TEVar b) ->
if a == b
then return M.empty
else else
throwError $ throwError
unwords . unwords
[ "Type constructor:" $ [ "Can not unify"
, printTree name , "'" <> printTree (TEVar a) <> "'"
, "(" <> printTree t <> ")" , "with"
, "does not match with:" , "'" <> printTree (TEVar b) <> "'"
, printTree name' ]
, "(" <> printTree t' <> ")"
]
(a, b) -> do (a, b) -> do
throwError . unwords $ throwError
[ "'" <> printTree a <> "'" . unwords
, "can't be unified with" $ [ "Can not unify"
, "'" <> printTree b <> "'" , "'" <> printTree a <> "'"
] , "with"
, "'" <> printTree b <> "'"
]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -415,7 +439,7 @@ composeAll = foldl' compose nullSubst
-- TODO: Split this class into two separate classes, one for free variables -- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions -- and one for applying substitutions
-- | A class representing free variables functions -- | A class for substitutions
class SubstType t where class SubstType t where
-- | Apply a substitution to t -- | Apply a substitution to t
apply :: Subst -> t -> t apply :: Subst -> t -> t
@ -430,9 +454,10 @@ instance FreeVars Type where
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct free (TData _ a) = free a
free (TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty
instance SubstType Type where instance SubstType Type where
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
@ -447,13 +472,14 @@ instance SubstType Type where
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (map (apply sub) a) TData name a -> TData name (map (apply sub) a)
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free = free . M.elems
instance SubstType (Map T.Ident Type) where instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply s = M.map (apply s) apply = M.map . apply
instance SubstType (T.Exp' Type) where instance SubstType (T.Exp' Type) where
apply s = \case apply s = \case
@ -467,7 +493,7 @@ instance SubstType (T.Exp' Type) where
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e) T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj{} -> error "implement" T.EInj i -> T.EInj i
instance SubstType (T.Branch' Type) where instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
@ -489,10 +515,6 @@ instance (SubstType a, SubstType b) => SubstType (a, b) where
instance SubstType (T.Id' Type) where instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set -- | Represents the empty substition set
nullSubst :: Subst nullSubst :: Subst
nullSubst = M.empty nullSubst = M.empty
@ -513,11 +535,11 @@ fresh = do
else else
if n == 0 if n == 0
then return . TVar . T.MkTVar $ LIdent [c] then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ [c] ++ show n else return . TVar . T.MkTVar . LIdent $ c : show n
where
next :: Char -> Char next :: Char -> Char
next 'z' = 'a' next 'z' = 'a'
next a = succ a next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
@ -673,4 +695,5 @@ data Env = Env
type Error = String type Error = String
type Subst = Map T.Ident Type type Subst = Map T.Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)