Fixed previously incorrect type equality check, commented code, add test

This commit is contained in:
sebastianselander 2023-03-28 14:31:20 +02:00
parent 85f31b129b
commit b1d3e31efd
3 changed files with 335 additions and 311 deletions

View file

@ -27,38 +27,12 @@ Program below should not type check
main : a -> b ;
main x = x;
```
## Pattern match on functions
Program below should not type check
## Bugged error message
```hs
data Maybe () where {
Nothing : Maybe
Just : Int -> Maybe
};
fmap : (Int -> Int) -> Maybe -> Maybe ;
fmap f ma = case ma of {
Nothing => Nothing ;
Just a => Just (f a) ;
};
pure : Int -> Maybe ;
pure x = Just x ;
ap mf ma = case mf of {
Just f => case ma of {
Nothing => Nothing;
Just a => Just (f a);
};
Nothing => Nothing;
};
return = pure;
bind ma f = case ma of {
Nothing => Nothing ;
Just a => f a ;
main = case \x. x of {
_ => 0;
};
```
```
TYPECHECKER ERROR
Inferred type '("c" -> "Int") -> "Maybe" -> "Maybe" does not match specified type '("Int" -> "Int") -> "Maybe" -> "Maybe"'

View file

@ -12,7 +12,6 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl')
@ -30,16 +29,17 @@ initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
run :: Infer a -> Either Error a
run = runC initEnv initCtx
run = run' initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c =
run' :: Env -> Ctx -> Infer a -> Either Error a
run' e c =
runIdentity
. runExceptT
. flip runReaderT c
. flip evalStateT e
. runInfer
-- | Type check a program
typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg
where
@ -47,20 +47,87 @@ typecheck = onLeft msg . run . checkPrg
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
return $ T.Program bs'
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
let fsig = apply sub0 t'
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
unless
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
checkData :: Data -> Infer ()
checkData err@(Data typ injs) = do
(name, tvars) <- go typ
dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
where
go = \case
TData name typs
| Right tvars' <- mapM toTVar typs ->
pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
_ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ]
_ ->
uncatchableErr $
unwords ["Bad data type definition: ", printTree typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
typecheckInj (Inj c inj_typ) name tvars
checkInj :: Inj -> UIdent -> [TVar] -> Infer ()
checkInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ =
catchableErr "Unbound type variables"
| TData name' typs <- returnType inj_typ
@ -108,109 +175,11 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
return $ T.Program bs'
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- unify lambda_t t'
sub2 <- unify t' lambda_t
unless
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
typeEq :: Type -> Type -> Bool
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (TLit a) (TLit b) = a == b
typeEq (TData name a) (TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (TVar _) (TVar _) = True
typeEq _ _ = False
skolemize :: Type -> Type
skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a)
skolemize (TAll x t) = TAll x (skolemize t)
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolemize t = t
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
return $ second (const subbed) (e', t)
return (s, (e', subbed))
class CollectTVars a where
collectTVars :: a -> Set T.Ident
@ -223,7 +192,8 @@ instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTVars (TAll _ t) = collectTVars t
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
collectTVars (TData _ ts) =
foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
collectTVars _ = S.empty
collect :: Set T.Ident -> Infer ()
@ -232,7 +202,7 @@ collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW = \case
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
(sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t'
sub2 <- unify t' t
unless
@ -243,8 +213,7 @@ algoW = \case
"does not match inferred type"
quote $ printTree t'
)
s2 <- exprErr (unify t t') err
let comp = s2 `compose` s1
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp (e', t))
-- \| ------------------
@ -257,7 +226,9 @@ algoW = \case
EVar i -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
Just t ->
inst t >>= \x ->
return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup (coerce i) sig of
@ -266,7 +237,10 @@ algoW = \case
fr <- fresh
insertSig (coerce i) (Just fr)
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i
Nothing ->
uncatchableErr $
"Unbound variable: "
<> printTree i
EInj i -> do
constr <- gets injections
case M.lookup (coerce i) constr of
@ -283,14 +257,11 @@ algoW = \case
err@(EAbs name e) -> do
fr <- fresh
exprErr
( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -338,29 +309,120 @@ algoW = \case
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = catchableErr "Atleast one case required"
checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
branchTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return
( sub
, apply sub branchT
, T.Branch (apply sub newPat) (apply sub newExp)
, apply sub exprT
)
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree constr
"does not exist"
)
True
)
t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return
( T.PInj (coerce constr) (apply sub (map fst patterns))
, apply sub ret
)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless
(typeLength t == 1)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
unify t0 t1 =
case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
-------------------------------------------------------------------
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b
@ -422,7 +484,12 @@ occurs i t =
)
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
equivalent lambda expressions:
Type checks: let f = \x. x in (f True, f 'a')
Does not type check: (\f. (f True, f 'a')) (\x. x)
-}
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
@ -446,15 +513,27 @@ inst = \case
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
-- | Generate a new fresh variable
fresh :: Infer Type
fresh = do
c <- gets nextChar
n <- gets count
taken <- gets takenTypeVars
if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ c : show n
where
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | A class for substitutions
class SubstType t where
@ -468,7 +547,8 @@ class FreeVars t where
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (TAll (T.MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
@ -540,27 +620,19 @@ instance SubstType (T.Id' Type) where
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
c <- gets nextChar
n <- gets count
taken <- gets takenTypeVars
if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ c : show n
where
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
{- | Convert a function with arguments to its pointfree version
> makeLambda (add x y = x + y) = add = \x. \y. x + y
-}
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
@ -571,49 +643,8 @@ withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertInj :: T.Ident -> Type -> Infer ()
insertInj i t =
modify (\st -> st{injections = M.insert i t (injections st)})
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
-------- PATTERN MATCHING ---------
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = catchableErr "Atleast one case required"
checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern' Type -> Infer a -> Infer a
-- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
@ -621,74 +652,27 @@ withPattern p ma = case p of
T.PCatch -> ma
T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree constr
"does not exist"
)
True
)
t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless
(typeLength t == 1)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor into the start with its type
insertInj :: T.Ident -> Type -> Infer ()
insertInj i t =
modify (\st -> st{injections = M.insert i t (injections st)})
{- | Check if an injection (constructor of data type)
with an equivalent name has been declared already
-}
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
typeLength :: Type -> Int
typeLength (TFun a b) = typeLength a + typeLength b
typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1
litType :: Lit -> Type
@ -698,23 +682,63 @@ litType (LChar _) = char
int = TLit "Int"
char = TLit "Char"
partitionType ::
Int -> -- Number of parameters to apply
Type ->
([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
typeEq (TVar (T.MkTVar a)) t@(TVar _) = do
st <- get
case M.lookup (coerce a) st of
Nothing -> put $ M.insert (coerce a) t st
Just t' -> unless (t == t') (catchableErr "TYPE MISMATCH")
typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r'
typeEq (TAll _ l) (TAll _ r) = typeEq l r
typeEq (TLit a) (TLit b) = unless (a == b) (catchableErr "TYPE MISMATCH")
typeEq (TData nameL tL) (TData nameR tR) = do
unless (nameL == nameR) (catchableErr "TYPE MISMATCH")
zipWithM_ typeEq tL tR
typeEq (TEVar _) (TEVar _) = catchableErr "TYPE MISMATCH"
typeEq _ _ = catchableErr "TYPE MISMATCH"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x)
{- | Catch an error if possible and add the given
expression as addition to the error message
-}
exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
exprErr ma exp =
catchError
ma
( \x ->
if x.catchable
then
throwError
( x
{ msg =
x.msg
<> " in expression: \n"
<> printTree exp
, catchable = False
}
)
else throwError x
)
{- | Catch an error if possible and add the given
data as addition to the error message
-}
dataErr :: Infer a -> Data -> Infer a
dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False}))
dataErr ma d =
catchError
ma
( \x ->
if x.catchable
then
throwError
( x
{ msg =
x.msg
<> " in data: \n"
<> printTree d
}
)
else throwError (x{catchable = False})
)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 =
@ -737,6 +761,7 @@ data Env = Env
deriving (Show)
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}

View file

@ -187,6 +187,31 @@ bes =
" Nil => 0 ;"
" };"
)
, testBe
"length function on int list infers correct signature"
( D.do
"data List () where {"
" Nil : List ()"
" Cons : Int -> List () -> List ()"
"};"
"length xs = case xs of {"
" Nil => 0 ;"
" Cons _ xs => 1 + length xs ;"
"};"
)
( D.do
"data List () where {"
" Nil : List ()"
" Cons : Int -> List () -> List ()"
"};"
"length : List () -> Int ;"
"length xs = case xs of {"
" Nil => 0 ;"
" Cons _ xs => 1 + length xs ;"
"};"
)
]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction