churf/src/TypeChecker/TypeChecker.hs

636 lines
22 KiB
Haskell
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Auxiliary
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp
run :: Infer a -> Either Error a
run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
checkData :: Data -> Infer ()
checkData d = do
case d of
(Data typ@(TData name ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
if typ == retType t'
then insertConstr (coerce name') (toNew t')
else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ
]
)
constrs
_ ->
throwError $
"incorrectly declared data type '"
<> printTree d
<> "'"
retType :: Type -> Type
retType (TFun _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
-- Type check the program twice to produce all top-level types in the first pass through
_ <- checkDef bs
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind err@(Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda
-- args <- zip args <$> mapM (const fresh) args
-- withBindings (coerce args) $ do
-- e@(_, t) <- inferExp e
-- let args_t = foldl' T.TFun t (reverse (map snd args))
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
-- sub <- bindErr (unify args_t t') err
-- let newT = apply sub args_t
-- insertSig (coerce name) (Just newT)
-- return $ T.Bind (apply sub (coerce name, newT)) [] e
unless
(args_t `typeEq` t')
( throwError $
"Inferred type '"
++ printTree args_t
++ " does not match specified type '"
++ printTree t'
++ "'"
)
return $ T.Bind (coerce name, t') [] e
_ -> do
insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e)
typeEq :: T.Type -> T.Type -> Bool
typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r'
typeEq (T.TLit a) (T.TLit b) = a == b
typeEq (T.TData name a) (T.TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer T.ExpT
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
return $ replace subbed (e', t)
replace :: T.Type -> T.ExpT -> T.ExpT
replace t = second (const t)
class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> T.TFun (toNew t1) (toNew t2)
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug in toNew Data -> T.Data"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
unless
(toNew t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t'
]
)
applySt s1 $ do
s2 <- exprErr (unify (toNew t) t') err
let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EVar i -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Just Nothing -> do
fr <- fresh
insertSig (coerce i) (Just fr)
return (nullSubst, (T.EId $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do
constr <- gets constructors
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t))
Nothing ->
throwError $
"Constructor: '"
<> printTree i
<> "' is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S
err@(EAbs name e) -> do
fr <- fresh
exprErr
( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = T.TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- \| ------------------------------------------
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
err@(EAdd e0 e1) -> do
(s1, (e0', t0)) <- algoW e0
applySt s1 $ do
(s2, (e1', t1)) <- algoW e1
-- applySt s2 $ do
s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
err@(EApp e0 e1) -> do
fr <- fresh
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err
let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
err@(ELet b@(Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args))
bind' <- exprErr (checkBind b) err
env <- asks vars
let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
let t' = apply comp ret_t
return (comp, apply comp (T.ECase (e', t) injs, t'))
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst
unify t0 t1 = do
case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
-------------------------------------------------------------------
(T.TVar (T.MkTVar a), t) -> occurs a t
(t, T.TVar (T.MkTVar b)) -> occurs b t
(T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
if a == b
then return M.empty
else
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree (T.TLit a) <> "'"
, "with"
, "'" <> printTree (T.TLit b) <> "'"
]
(T.TData name t, T.TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else
throwError $
unwords
[ "T.Type constructor:"
, printTree name
, "(" <> printTree t <> ")"
, "does not match with:"
, printTree name'
, "(" <> printTree t' <> ")"
]
(a, b) -> do
throwError . unwords $
[ "'" <> printTree a <> "'"
, "can't be unified with"
, "'" <> printTree b <> "'"
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (T.TVar $ T.MkTVar i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map T.Ident T.Type -> T.Type -> T.Type
generalize env t = go freeVars $ removeForalls t
where
freeVars :: [T.Ident]
freeVars = S.toList $ free t S.\\ free env
go :: [T.Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: T.Type -> Infer T.Type
inst = \case
T.TAll (T.MkTVar bound) t -> do
fr <- fresh
let s = M.singleton bound fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
-- | A class representing free variables functions
class SubstType t where
-- | Apply a substitution to t
apply :: Subst -> t -> t
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
instance FreeVars T.Type where
free :: T.Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (T.TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
instance SubstType T.Type where
apply :: Subst -> T.Type -> T.Type
apply sub t = do
case t of
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where
free :: Map T.Ident T.Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
instance SubstType (Map T.Ident T.Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
apply s = M.map (apply s)
instance SubstType T.ExpT where
apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) ->
( T.ELet
(T.Bind (ident, apply s t1) args (apply s e1))
(apply s e2)
, apply s t2
)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t)
instance SubstType T.Branch where
apply :: Subst -> T.Branch -> T.Branch
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType T.Pattern where
apply :: Subst -> T.Pattern -> T.Pattern
apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance SubstType T.Id where
apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . T.TVar . T.MkTVar . T.Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: T.Ident -> T.Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase expT injs = do
(injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
return (sub2 `compose` sub1, injs, returns_type)
{- | fst = type of init
| snd = type of expr
-}
inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
trace ("BRANCH TYPE: " ++ show branchT) pure ()
newExp@(_, exprT) <- withPattern pat (inferExp expr)
return (branchT, T.Branch newPat newExp, exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
(vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
return (T.PEnum $ coerce p, t)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
litType :: Lit -> T.Type
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"
partitionType ::
Int -> -- Number of parameters to apply
Type ->
([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp =
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)
bindErr :: Infer a -> Bind -> Infer a
bindErr ma exp =
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)