Fixed previously incorrect type equality check, commented code, add test

This commit is contained in:
sebastianselander 2023-03-28 14:31:20 +02:00
parent 85f31b129b
commit b1d3e31efd
3 changed files with 335 additions and 311 deletions

View file

@ -27,38 +27,12 @@ Program below should not type check
main : a -> b ; main : a -> b ;
main x = x; main x = x;
``` ```
## Pattern match on functions
Program below should not type check
## Bugged error message
```hs ```hs
data Maybe () where { main = case \x. x of {
Nothing : Maybe _ => 0;
Just : Int -> Maybe
};
fmap : (Int -> Int) -> Maybe -> Maybe ;
fmap f ma = case ma of {
Nothing => Nothing ;
Just a => Just (f a) ;
};
pure : Int -> Maybe ;
pure x = Just x ;
ap mf ma = case mf of {
Just f => case ma of {
Nothing => Nothing;
Just a => Just (f a);
};
Nothing => Nothing;
};
return = pure;
bind ma f = case ma of {
Nothing => Nothing ;
Just a => f a ;
}; };
``` ```
```
TYPECHECKER ERROR
Inferred type '("c" -> "Int") -> "Maybe" -> "Maybe" does not match specified type '("Int" -> "Int") -> "Maybe" -> "Maybe"'

View file

@ -12,7 +12,6 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
@ -30,16 +29,17 @@ initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runC initEnv initCtx run = run' initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a run' :: Env -> Ctx -> Infer a -> Either Error a
runC e c = run' e c =
runIdentity runIdentity
. runExceptT . runExceptT
. flip runReaderT c . flip runReaderT c
. flip evalStateT e . flip evalStateT e
. runInfer . runInfer
-- | Type check a program
typecheck :: Program -> Either String (T.Program' Type) typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
where where
@ -47,20 +47,87 @@ typecheck = onLeft msg . run . checkPrg
onLeft f (Left x) = Left $ f x onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
return $ T.Program bs'
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
let fsig = apply sub0 t'
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
unless
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData err@(Data typ injs) = do checkData err@(Data typ injs) = do
(name, tvars) <- go typ (name, tvars) <- go typ
dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
where where
go = \case go = \case
TData name typs TData name typs
| Right tvars' <- mapM toTVar typs -> | Right tvars' <- mapM toTVar typs ->
pure (name, tvars') pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
_ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] _ ->
uncatchableErr $
unwords ["Bad data type definition: ", printTree typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () checkInj :: Inj -> UIdent -> [TVar] -> Infer ()
typecheckInj (Inj c inj_typ) name tvars checkInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ = | Right False <- boundTVars tvars inj_typ =
catchableErr "Unbound type variables" catchableErr "Unbound type variables"
| TData name' typs <- returnType inj_typ | TData name' typs <- returnType inj_typ
@ -108,109 +175,11 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
checkPrg :: Program -> Infer (T.Program' Type) inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
return $ T.Program bs'
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- unify lambda_t t'
sub2 <- unify t' lambda_t
unless
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
typeEq :: Type -> Type -> Bool
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (TLit a) (TLit b) = a == b
typeEq (TData name a) (TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (TVar _) (TVar _) = True
typeEq _ _ = False
skolemize :: Type -> Type
skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a)
skolemize (TAll x t) = TAll x (skolemize t)
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolemize t = t
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
return $ second (const subbed) (e', t) return (s, (e', subbed))
class CollectTVars a where class CollectTVars a where
collectTVars :: a -> Set T.Ident collectTVars :: a -> Set T.Ident
@ -223,7 +192,8 @@ instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTVars (TAll _ t) = collectTVars t collectTVars (TAll _ t) = collectTVars t
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2 collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts collectTVars (TData _ ts) =
foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
collectTVars _ = S.empty collectTVars _ = S.empty
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
@ -232,7 +202,7 @@ collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type) algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t' sub1 <- unify t t'
sub2 <- unify t' t sub2 <- unify t' t
unless unless
@ -243,8 +213,7 @@ algoW = \case
"does not match inferred type" "does not match inferred type"
quote $ printTree t' quote $ printTree t'
) )
s2 <- exprErr (unify t t') err let comp = sub2 `compose` sub1 `compose` sub0
let comp = s2 `compose` s1
return (comp, apply comp (e', t)) return (comp, apply comp (e', t))
-- \| ------------------ -- \| ------------------
@ -257,7 +226,9 @@ algoW = \case
EVar i -> do EVar i -> do
var <- asks vars var <- asks vars
case M.lookup (coerce i) var of case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x)) Just t ->
inst t >>= \x ->
return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
@ -266,7 +237,10 @@ algoW = \case
fr <- fresh fr <- fresh
insertSig (coerce i) (Just fr) insertSig (coerce i) (Just fr)
return (nullSubst, (T.EVar $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i Nothing ->
uncatchableErr $
"Unbound variable: "
<> printTree i
EInj i -> do EInj i -> do
constr <- gets injections constr <- gets injections
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
@ -283,14 +257,11 @@ algoW = \case
err@(EAbs name e) -> do err@(EAbs name e) -> do
fr <- fresh fr <- fresh
exprErr withBinding (coerce name) fr $ do
( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr let varType = apply s1 fr
let newArr = TFun varType t' let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -338,29 +309,120 @@ algoW = \case
(s2, (e1', t2)) <- algoW e1 (s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (T.ELet bind' (e1', t2), t2)) return (comp, apply comp (T.ELet bind' (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs (subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub let comp = subst `compose` sub
return (comp, apply comp (T.ECase (e', t) injs, ret_t)) return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
makeLambda = foldl (flip (EAbs . coerce)) checkCase _ [] = catchableErr "Atleast one case required"
checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
branchTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return
( sub
, apply sub branchT
, T.Branch (apply sub newPat) (apply sub newExp)
, apply sub exprT
)
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree constr
"does not exist"
)
True
)
t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return
( T.PInj (coerce constr) (apply sub (map fst patterns))
, apply sub ret
)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless
(typeLength t == 1)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = do unify t0 t1 =
case (t0, t1) of case (t0, t1) of
(TFun a b, TFun c d) -> do (TFun a b, TFun c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
-------------------------------------------------------------------
(TVar (T.MkTVar a), t) -> occurs (coerce a) t (TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t (t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
@ -422,7 +484,12 @@ occurs i t =
) )
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set {- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
equivalent lambda expressions:
Type checks: let f = \x. x in (f True, f 'a')
Does not type check: (\f. (f True, f 'a')) (\x. x)
-}
generalize :: Map T.Ident Type -> Type -> Type generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
@ -446,15 +513,27 @@ inst = \case
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
-- | Compose two substitution sets -- | Generate a new fresh variable
compose :: Subst -> Subst -> Subst fresh :: Infer Type
compose m1 m2 = M.map (apply m1) m2 `M.union` m1 fresh = do
c <- gets nextChar
composeAll :: [Subst] -> Subst n <- gets count
composeAll = foldl' compose nullSubst taken <- gets takenTypeVars
if c == 'z'
-- TODO: Split this class into two separate classes, one for free variables then do
-- and one for applying substitutions modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ c : show n
where
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
@ -468,7 +547,8 @@ class FreeVars t where
instance FreeVars Type where instance FreeVars Type where
free :: Type -> Set T.Ident free :: Type -> Set T.Ident
free (TVar (T.MkTVar a)) = S.singleton (coerce a) free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t free (TAll (T.MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a free (TData _ a) = free a
@ -540,27 +620,19 @@ instance SubstType (T.Id' Type) where
nullSubst :: Subst nullSubst :: Subst
nullSubst = M.empty nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Compose two substitution sets
fresh :: Infer Type compose :: Subst -> Subst -> Subst
fresh = do compose m1 m2 = M.map (apply m1) m2 `M.union` m1
c <- gets nextChar
n <- gets count -- | Compose a list of substitution sets into one
taken <- gets takenTypeVars composeAll :: [Subst] -> Subst
if c == 'z' composeAll = foldl' compose nullSubst
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'}) {- | Convert a function with arguments to its pointfree version
else modify (\st -> st{nextChar = next (nextChar st)}) > makeLambda (add x y = x + y) = add = \x. \y. x + y
if coerce [c] `S.member` taken -}
then do makeLambda :: Exp -> [T.Ident] -> Exp
fresh makeLambda = foldl (flip (EAbs . coerce))
else
if n == 0
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ c : show n
where
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
@ -571,49 +643,8 @@ withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs = withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Run the monadic action with a pattern
insertSig :: T.Ident -> Maybe Type -> Infer () withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertInj :: T.Ident -> Type -> Infer ()
insertInj i t =
modify (\st -> st{injections = M.insert i t (injections st)})
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
-------- PATTERN MATCHING ---------
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = catchableErr "Atleast one case required"
checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern' Type -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
@ -621,74 +652,27 @@ withPattern p ma = case p of
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) -- | Insert a function signature into the environment
inferPattern = \case insertSig :: T.Ident -> Maybe Type -> Infer ()
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections) -- | Insert a constructor into the start with its type
t <- insertInj :: T.Ident -> Type -> Infer ()
maybeToRightM insertInj i t =
( Error modify (\st -> st{injections = M.insert i t (injections st)})
( Aux.do
"Constructor:" {- | Check if an injection (constructor of data type)
quote $ printTree constr with an equivalent name has been declared already
"does not exist" -}
) existInj :: T.Ident -> Infer (Maybe Type)
True existInj n = gets (M.lookup n . injections)
)
t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . injections)
t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless
(typeLength t == 1)
( catchableErr $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: Type -> Int typeLength :: Type -> Int
typeLength (TFun a b) = typeLength a + typeLength b typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1 typeLength _ = 1
litType :: Lit -> Type litType :: Lit -> Type
@ -698,23 +682,63 @@ litType (LChar _) = char
int = TLit "Int" int = TLit "Int"
char = TLit "Char" char = TLit "Char"
partitionType :: typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
Int -> -- Number of parameters to apply typeEq (TVar (T.MkTVar a)) t@(TVar _) = do
Type -> st <- get
([Type], Type) case M.lookup (coerce a) st of
partitionType = go [] Nothing -> put $ M.insert (coerce a) t st
where Just t' -> unless (t == t') (catchableErr "TYPE MISMATCH")
go acc 0 t = (acc, t) typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r'
go acc i t = case t of typeEq (TAll _ l) (TAll _ r) = typeEq l r
TAll tvar t' -> second (TAll tvar) $ go acc i t' typeEq (TLit a) (TLit b) = unless (a == b) (catchableErr "TYPE MISMATCH")
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 typeEq (TData nameL tL) (TData nameR tR) = do
_ -> error "Number of parameters and type doesn't match" unless (nameL == nameR) (catchableErr "TYPE MISMATCH")
zipWithM_ typeEq tL tR
typeEq (TEVar _) (TEVar _) = catchableErr "TYPE MISMATCH"
typeEq _ _ = catchableErr "TYPE MISMATCH"
exprErr :: Infer a -> Exp -> Infer a {- | Catch an error if possible and add the given
exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x) expression as addition to the error message
-}
exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
exprErr ma exp =
catchError
ma
( \x ->
if x.catchable
then
throwError
( x
{ msg =
x.msg
<> " in expression: \n"
<> printTree exp
, catchable = False
}
)
else throwError x
)
{- | Catch an error if possible and add the given
data as addition to the error message
-}
dataErr :: Infer a -> Data -> Infer a dataErr :: Infer a -> Data -> Infer a
dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False})) dataErr ma d =
catchError
ma
( \x ->
if x.catchable
then
throwError
( x
{ msg =
x.msg
<> " in data: \n"
<> printTree d
}
)
else throwError (x{catchable = False})
)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = unzip4 =
@ -737,6 +761,7 @@ data Env = Env
deriving (Show) deriving (Show)
data Error = Error {msg :: String, catchable :: Bool} data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}

View file

@ -187,6 +187,31 @@ bes =
" Nil => 0 ;" " Nil => 0 ;"
" };" " };"
) )
, testBe
"length function on int list infers correct signature"
( D.do
"data List () where {"
" Nil : List ()"
" Cons : Int -> List () -> List ()"
"};"
"length xs = case xs of {"
" Nil => 0 ;"
" Cons _ xs => 1 + length xs ;"
"};"
)
( D.do
"data List () where {"
" Nil : List ()"
" Cons : Int -> List () -> List ()"
"};"
"length : List () -> Int ;"
"length xs = case xs of {"
" Nil => 0 ;"
" Cons _ xs => 1 + length xs ;"
"};"
)
] ]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction