churf/src/TypeChecker/TypeCheckerHm.hs

967 lines
34 KiB
Haskell
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, maybeToRightM, typeof, unzip4)
import qualified Auxiliary as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
{-
TODO
Prettifying the types of generated variables does only need to be done when
presenting the types to the user, i.e, when the user has made a mistake.
For succesfully typed programs the types only need to match.
-}
-- | Type check a program
typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
-- sgs <- gets sigs
bs <- map snd . sortOn fst <$> bindCount bs
bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs
-- | Send the map of user declared signatures to not rename stuff the user defined
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
prettify s (T.Program defs) = T.Program $ map (go s) defs
where
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
go _ (T.DData d) = T.DData d
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
| Just (Just _) <- M.lookup name m = b
| otherwise =
let fvs = nub $ freeOrdered t
m = M.fromList $ zip fvs letters
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)]
bindCount [] = return []
bindCount (x : xs) = do
(o, d) <- go x
b <- bindCount xs
return $ (o, d) : b
where
go :: Def -> Infer (Int, Def)
go b@(DBind (Bind _ _ e)) = do
db <- gets declaredBinds
let n = runIdentity $ evalStateT (countBinds db e) mempty
return (n, b)
go (DSig sig) = pure (0, DSig sig)
go (DData data_) = pure (-1, DData data_)
countBinds :: Set T.Ident -> Exp -> StateT (Set T.Ident) Identity Int
countBinds declared = \case
EVar i -> do
found <- get
if coerce i `S.member` declared && not (coerce i `S.member` found)
then put (S.insert (coerce i) found) >> return 1
else return 0
ELet _ e -> countBinds declared e
EApp e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAdd e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAbs _ e -> countBinds declared e
ECase e1 brnchs -> do
let f (Branch _ e2) = countBinds declared e2
(+) . sum <$> mapM f brnchs <*> countBinds declared e1
_ -> return 0
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTVars t)
s <- gets (M.keys . sigs)
duplicateDecl n s $ Aux.do
"Multiple signatures of function"
quote $ printTree n
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
s <- gets (S.toList . declaredBinds)
duplicateDecl n s $ Aux.do
"Multiple declarations of function"
quote $ printTree n
collect (collectTVars e)
insertBind $ coerce n
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
-- Much cleaner implementation, unfortunately one minor bug
-- checkBind :: Bind -> Infer (T.Bind' Type)
-- checkBind (Bind name args expr) = do
-- fr <- fresh
-- let lambda = makeLambda expr (reverse (coerce args))
-- withBinding (coerce name) fr $ do
-- (sub, (e, infSig)) <- algoW lambda
-- env <- asks vars
-- let genInfSig = generalize (apply sub env) infSig
-- maybeSig <- gets (join . M.lookup (coerce name) . sigs)
-- case maybeSig of
-- Just typSig -> do
-- unless
-- (genInfSig <<= typSig)
-- ( throwError $
-- Error
-- ( Aux.do
-- "Inferred type"
-- quote $ printTree infSig
-- "doesn't match given type"
-- quote $ printTree typSig
-- )
-- False
-- )
-- return $ T.Bind (coerce name, typSig) [] (apply sub e, typSig)
-- _ -> do
-- insertSig (coerce name) (Just genInfSig)
-- return $ T.Bind (coerce name, genInfSig) [] (apply sub e, genInfSig)
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(e, infSig) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just typSig) -> do
env <- asks vars
let genInfSig = generalize mempty infSig
trace "\n\n" pure ()
trace ("genInfSig: " ++ printTree genInfSig) pure ()
trace ("typSig: " ++ printTree typSig ++ "\n\n") pure ()
sub <- genInfSig `unify` typSig
--b <- (genInfSig <<= typSig)
unless True
( throwError $
Error
( Aux.do
"Inferred type"
quote $ printTree genInfSig
"doesn't match given type"
quote $ printTree typSig
)
False
)
-- Applying sub to typSig will worsen error messages.
-- Unfortunately I do not know a better solution at the moment.
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
_ -> do
insertSig (coerce name) (Just infSig)
return (T.Bind (coerce name, infSig) [] (e, infSig))
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
(name, tvars) <- go [] typ
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
where
go tvars = \case
TAll tvar t -> go (tvar : tvars) t
TData name typs
| Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars' ->
pure (name, tvars')
_ ->
uncatchableErr $
unwords ["Bad data type definition: ", printTree typ]
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
checkInj (Inj c inj_typ) name tvars
| TData name' typs <- returnType inj_typ
, Right tvars' <- mapM toTVar typs
, name' == name
, tvars' == tvars = do
exist <- existInj (coerce c)
case exist of
Just t -> uncatchableErr $ Aux.do
"Constructor"
quote $ coerce name
"with type"
quote $ printTree t
"already exist"
Nothing -> insertInj (coerce c) inj_typ
| otherwise =
uncatchableErr $
unwords
[ "Bad type constructor: "
, show name
, "\nExpected: "
, printTree . TData name $ map TVar tvars
, "\nActual: "
, printTree $ returnType inj_typ
]
toTVar :: Type -> Either Error TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
return (e', subbed)
class CollectTVars a where
collectTVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty
instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTVars (TAll _ t) = collectTVars t
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
collectTVars (TData _ ts) =
foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
collectTVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case
err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t'
sub2 <- unify t' t
b <- (apply sub1 t <<= apply sub2 t')
unless b
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, (apply comp e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit lit, typeof lit))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EVar (LIdent i) -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t ->
inst t >>= \x ->
return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup (coerce i) sig of
Just (Just t) -> do
t <- freshen t
return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing ->
uncatchableErr $
"Unbound variable: "
<> printTree i
EInj i -> do
constr <- gets injections
case M.lookup (coerce i) constr of
Just t -> do
t <- freshen t
return (nullSubst, (T.EInj $ coerce i, t))
Nothing ->
uncatchableErr $ Aux.do
"Constructor:"
quote $ printTree i
"is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S
err@(EAbs name e) -> do
fr <- fresh
withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- \| ------------------------------------------
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
err@(EAdd e0 e1) -> do
(s1, (e0', t0)) <- algoW e0
(s2, (e1', t1)) <- algoW e1
s3 <- exprErr (unify t0 int) err
s4 <- exprErr (unify t1 int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do
fr <- fresh
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, (e1', t1)) <- algoW e1
s2 <- unify (apply s1 t0) (TFun t1 fr)
let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet (Bind name args e) e1 -> do
fr <- fresh
withBinding (coerce name) fr $ do
(s1, e@(_, t0)) <- algoW (makeLambda e (coerce args))
env <- asks vars
let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return
( comp
, apply
comp
(T.ELet (T.Bind (coerce name, t0) [] e) (e1', t2), t2)
)
ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = do
fr <- fresh
return (nullSubst, [], fr)
checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
branchTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch err@(Branch pat expr) = do
pat@(_, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- catchError (withPattern pat (algoW expr)) (\x -> throwError Error{msg = x.msg <> " in pattern '" <> printTree err <> "'", catchable = False})
return
( sub
, apply sub branchT
, T.Branch (apply sub pat) (apply sub newExp)
, apply sub exprT
)
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = typeof lit in return (T.PLit lit, lt)
PCatch -> (T.PCatch,) <$> fresh
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x)
return (pvar, fr)
PEnum p -> do
t <- gets (M.lookup (coerce p) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless
(typeLength t == 1)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, TData _data frs)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree constr
"does not exist"
)
True
)
t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return
( T.PInj (coerce constr) (apply sub patterns)
, apply sub ret
)
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 =
let fvs = S.toList $ free t0 `S.union` free t1
m = M.fromList $ zip fvs letters
in case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else catchableErr $
Aux.do
"Type constructor:"
printTree name
quote $ printTree $ map (replace m) t
"does not match with:"
printTree name'
quote $ printTree $ map (replace m) t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree $ replace m a
"with"
quote $ printTree $ replace m b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TEVar _) = return (M.singleton i t)
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
let fvs = S.toList $ free t
m = M.fromList $ zip fvs letters
in if S.member i (free t)
then
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree $ replace m (TVar $ MkTVar (coerce i))
"with"
quote $ printTree $ replace m t
)
else return $ M.singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
equivalent lambda expressions:
Type checks: let f = \x. x in (f True, f 'a')
Does not type check: (\f. (f True, f 'a')) (\x. x)
-}
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
fr <- fresh
let s = M.singleton (coerce bound) fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
-- Only one of 'freshen' and 'inst' should be needed but something doesn't work
-- when I remove either.
freshen :: Type -> Infer Type
freshen t = do
let frees = S.toList (free t)
xs <- mapM (const fresh) frees
let sub = M.fromList $ zip frees xs
return $ apply sub t
-- | Generate a new fresh variable
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st{count = succ (count st)})
return $ TVar $ MkTVar $ LIdent $ show n
-- Is the left a subtype of the right
(<<=) :: Type -> Type -> Infer Bool
(<<=) a b = case (a, b) of
(TVar a, TVar b) -> return $ a == b
(TVar a, _) -> return True
(TFun a b, TFun c d) -> do
bfirst <- a <<= c
bsecond <- b <<= d
return (bfirst && bsecond)
(TData n1 ts1, TData n2 ts2) -> do
b <- and <$> zipWithM (<<=) ts1 ts2
return (b && n1 == n2 && length ts1 == length ts2)
(t1@(TAll _ _ ), t2) -> let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1 t2
(t1, t2@(TAll _ _)) -> let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2'
(t1, t2) -> return $ t1 == t2
where
go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do
freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let t1' = apply sub t1
let t2' = apply sub t2
trace ("t1': " ++ printTree t1') pure ()
trace ("t2': " ++ printTree t2') pure ()
t1' <<= t2'
{-
Renaming: a -> b -> a and c -> d -> c
gives 0 -> 1 -> 0 and -> 2 -> 3 -> 2
They have to be given the same name. Alpha-renaming in the subtype check is done incorrectly
-}
-- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) =
let (tvars', t') = gatherTVars (tvar : tvars) t
in (tvars', t')
gatherTVars tvars t = (tvars, t)
-- | A class for substitutions
class SubstType t where
-- | Apply a substitution to t
apply :: Subst -> t -> t
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
instance FreeVars (T.Bind' Type) where
free (T.Bind (_, t) _ _) = free t
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (MkTVar a)) = S.singleton (coerce a)
free (TAll (MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
free (TEVar _) = S.empty
instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty
instance SubstType Type where
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free = free . M.elems
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s)
instance SubstType (T' T.Exp' Type) where
apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where
apply s = \case
T.EVar i -> T.EVar i
T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet
(T.Bind (ident, apply s t1) args (apply s e1))
(apply s e2)
T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2)
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj i -> T.EInj i
instance SubstType (T.Def' Type) where
apply s = \case
T.DBind (T.Bind name args e) ->
T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
d -> d
instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar iden -> T.PVar iden
T.PLit lit -> T.PLit lit
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance SubstType (T T.Ident Type) where
apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
{- | Convert a function with arguments to its pointfree version
> makeLambda (add x y = x + y) = add = \x. \y. x + y
-}
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
withPattern (p, t) ma = case p of
T.PVar x -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
insertBind :: T.Ident -> Infer ()
insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds})
-- | Insert a constructor into the start with its type
insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m ()
insertInj i t =
modify (\st -> st{injections = M.insert i t (injections st)})
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s st.vars})
{- | Check if an injection (constructor of data type)
with an equivalent name has been declared already
-}
existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type)
existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1
{- | Catch an error if possible and add the given
expression as addition to the error message
-}
exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
exprErr ma exp =
catchError
ma
( \err ->
if err.catchable
then
throwError
( err
{ msg =
err.msg
<> " in expression: \n"
<> printTree exp
, catchable = False
}
)
else throwError err
)
bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a
bindErr ma bind =
catchError
ma
( \err ->
if err.catchable
then
throwError
( err
{ msg =
err.msg
<> " in function: \n"
<> printTree bind
, catchable = False
}
)
else throwError err
)
{- | Catch an error if possible and add the given
data as addition to the error message
-}
dataErr :: (MonadError Error m, Monad m) => m a -> Data -> m a
dataErr ma d =
catchError
ma
( \err ->
if err.catchable
then
throwError
( err
{ msg =
err.msg
<> " in data: \n"
<> printTree d
}
)
else throwError (err{catchable = False})
)
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty mempty
run :: Infer a -> Either Error (a, [Warning])
run = run' initEnv initCtx
run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning])
run' e c =
runIdentity
. runExceptT
. runWriterT
. flip runReaderT c
. flip evalStateT e
. runInfer
newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident
}
deriving (Show)
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Warning = NonExhaustive String
deriving (Show)
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
catchableErr :: MonadError Error m => String -> m a
catchableErr msg = throwError $ Error msg True
uncatchableErr :: MonadError Error m => String -> m a
uncatchableErr msg = throwError $ Error msg False
quote :: String -> String
quote s = "'" ++ s ++ "'"
letters :: [T.Ident]
letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z']
{-
first = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "b"))))))
second = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "a"))))))
-}