Fixed previously incorrect type equality check, commented code, add test
This commit is contained in:
parent
85f31b129b
commit
b1d3e31efd
3 changed files with 335 additions and 311 deletions
|
|
@ -27,38 +27,12 @@ Program below should not type check
|
||||||
main : a -> b ;
|
main : a -> b ;
|
||||||
main x = x;
|
main x = x;
|
||||||
```
|
```
|
||||||
|
## Pattern match on functions
|
||||||
|
|
||||||
|
Program below should not type check
|
||||||
|
|
||||||
## Bugged error message
|
|
||||||
```hs
|
```hs
|
||||||
data Maybe () where {
|
main = case \x. x of {
|
||||||
Nothing : Maybe
|
_ => 0;
|
||||||
Just : Int -> Maybe
|
|
||||||
};
|
|
||||||
|
|
||||||
fmap : (Int -> Int) -> Maybe -> Maybe ;
|
|
||||||
fmap f ma = case ma of {
|
|
||||||
Nothing => Nothing ;
|
|
||||||
Just a => Just (f a) ;
|
|
||||||
};
|
|
||||||
|
|
||||||
pure : Int -> Maybe ;
|
|
||||||
pure x = Just x ;
|
|
||||||
|
|
||||||
ap mf ma = case mf of {
|
|
||||||
Just f => case ma of {
|
|
||||||
Nothing => Nothing;
|
|
||||||
Just a => Just (f a);
|
|
||||||
};
|
|
||||||
Nothing => Nothing;
|
|
||||||
};
|
|
||||||
|
|
||||||
return = pure;
|
|
||||||
|
|
||||||
bind ma f = case ma of {
|
|
||||||
Nothing => Nothing ;
|
|
||||||
Just a => f a ;
|
|
||||||
};
|
};
|
||||||
```
|
```
|
||||||
```
|
|
||||||
TYPECHECKER ERROR
|
|
||||||
Inferred type '("c" -> "Int") -> "Maybe" -> "Maybe" does not match specified type '("Int" -> "Int") -> "Maybe" -> "Maybe"'
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import Control.Monad.Except
|
||||||
import Control.Monad.Identity (Identity, runIdentity)
|
import Control.Monad.Identity (Identity, runIdentity)
|
||||||
import Control.Monad.Reader
|
import Control.Monad.Reader
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Data.Bifunctor (second)
|
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Function (on)
|
import Data.Function (on)
|
||||||
import Data.List (foldl')
|
import Data.List (foldl')
|
||||||
|
|
@ -30,16 +29,17 @@ initCtx = Ctx mempty
|
||||||
initEnv = Env 0 'a' mempty mempty mempty
|
initEnv = Env 0 'a' mempty mempty mempty
|
||||||
|
|
||||||
run :: Infer a -> Either Error a
|
run :: Infer a -> Either Error a
|
||||||
run = runC initEnv initCtx
|
run = run' initEnv initCtx
|
||||||
|
|
||||||
runC :: Env -> Ctx -> Infer a -> Either Error a
|
run' :: Env -> Ctx -> Infer a -> Either Error a
|
||||||
runC e c =
|
run' e c =
|
||||||
runIdentity
|
runIdentity
|
||||||
. runExceptT
|
. runExceptT
|
||||||
. flip runReaderT c
|
. flip runReaderT c
|
||||||
. flip evalStateT e
|
. flip evalStateT e
|
||||||
. runInfer
|
. runInfer
|
||||||
|
|
||||||
|
-- | Type check a program
|
||||||
typecheck :: Program -> Either String (T.Program' Type)
|
typecheck :: Program -> Either String (T.Program' Type)
|
||||||
typecheck = onLeft msg . run . checkPrg
|
typecheck = onLeft msg . run . checkPrg
|
||||||
where
|
where
|
||||||
|
|
@ -47,20 +47,87 @@ typecheck = onLeft msg . run . checkPrg
|
||||||
onLeft f (Left x) = Left $ f x
|
onLeft f (Left x) = Left $ f x
|
||||||
onLeft _ (Right x) = Right x
|
onLeft _ (Right x) = Right x
|
||||||
|
|
||||||
|
checkPrg :: Program -> Infer (T.Program' Type)
|
||||||
|
checkPrg (Program bs) = do
|
||||||
|
preRun bs
|
||||||
|
bs' <- checkDef bs
|
||||||
|
return $ T.Program bs'
|
||||||
|
|
||||||
|
preRun :: [Def] -> Infer ()
|
||||||
|
preRun [] = return ()
|
||||||
|
preRun (x : xs) = case x of
|
||||||
|
DSig (Sig n t) -> do
|
||||||
|
collect (collectTVars t)
|
||||||
|
gets (M.member (coerce n) . sigs)
|
||||||
|
>>= flip
|
||||||
|
when
|
||||||
|
( uncatchableErr $ Aux.do
|
||||||
|
"Duplicate signatures for function"
|
||||||
|
quote $ printTree n
|
||||||
|
)
|
||||||
|
insertSig (coerce n) (Just t) >> preRun xs
|
||||||
|
DBind (Bind n _ e) -> do
|
||||||
|
collect (collectTVars e)
|
||||||
|
s <- gets sigs
|
||||||
|
case M.lookup (coerce n) s of
|
||||||
|
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||||
|
Just _ -> preRun xs
|
||||||
|
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||||
|
|
||||||
|
checkDef :: [Def] -> Infer [T.Def' Type]
|
||||||
|
checkDef [] = return []
|
||||||
|
checkDef (x : xs) = case x of
|
||||||
|
(DBind b) -> do
|
||||||
|
b' <- checkBind b
|
||||||
|
xs' <- checkDef xs
|
||||||
|
return $ T.DBind b' : xs'
|
||||||
|
(DData d) -> do
|
||||||
|
xs' <- checkDef xs
|
||||||
|
return $ T.DData (coerceData d) : xs'
|
||||||
|
(DSig _) -> checkDef xs
|
||||||
|
where
|
||||||
|
coerceData (Data t injs) =
|
||||||
|
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||||
|
|
||||||
|
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||||
|
checkBind (Bind name args e) = do
|
||||||
|
let lambda = makeLambda e (reverse (coerce args))
|
||||||
|
(sub0, (e, lambda_t)) <- inferExp lambda
|
||||||
|
s <- gets sigs
|
||||||
|
case M.lookup (coerce name) s of
|
||||||
|
Just (Just t') -> do
|
||||||
|
let fsig = apply sub0 t'
|
||||||
|
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
|
||||||
|
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
|
||||||
|
unless
|
||||||
|
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
|
||||||
|
( uncatchableErr $ Aux.do
|
||||||
|
"Inferred type"
|
||||||
|
quote $ printTree lambda_t
|
||||||
|
"does not match specified type"
|
||||||
|
quote $ printTree t'
|
||||||
|
)
|
||||||
|
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
|
||||||
|
_ -> do
|
||||||
|
insertSig (coerce name) (Just lambda_t)
|
||||||
|
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||||
|
|
||||||
checkData :: Data -> Infer ()
|
checkData :: Data -> Infer ()
|
||||||
checkData err@(Data typ injs) = do
|
checkData err@(Data typ injs) = do
|
||||||
(name, tvars) <- go typ
|
(name, tvars) <- go typ
|
||||||
dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err
|
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
|
||||||
where
|
where
|
||||||
go = \case
|
go = \case
|
||||||
TData name typs
|
TData name typs
|
||||||
| Right tvars' <- mapM toTVar typs ->
|
| Right tvars' <- mapM toTVar typs ->
|
||||||
pure (name, tvars')
|
pure (name, tvars')
|
||||||
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
|
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
|
||||||
_ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ]
|
_ ->
|
||||||
|
uncatchableErr $
|
||||||
|
unwords ["Bad data type definition: ", printTree typ]
|
||||||
|
|
||||||
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
checkInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
||||||
typecheckInj (Inj c inj_typ) name tvars
|
checkInj (Inj c inj_typ) name tvars
|
||||||
| Right False <- boundTVars tvars inj_typ =
|
| Right False <- boundTVars tvars inj_typ =
|
||||||
catchableErr "Unbound type variables"
|
catchableErr "Unbound type variables"
|
||||||
| TData name' typs <- returnType inj_typ
|
| TData name' typs <- returnType inj_typ
|
||||||
|
|
@ -108,109 +175,11 @@ returnType :: Type -> Type
|
||||||
returnType (TFun _ t2) = returnType t2
|
returnType (TFun _ t2) = returnType t2
|
||||||
returnType a = a
|
returnType a = a
|
||||||
|
|
||||||
checkPrg :: Program -> Infer (T.Program' Type)
|
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||||
checkPrg (Program bs) = do
|
|
||||||
preRun bs
|
|
||||||
bs' <- checkDef bs
|
|
||||||
return $ T.Program bs'
|
|
||||||
|
|
||||||
preRun :: [Def] -> Infer ()
|
|
||||||
preRun [] = return ()
|
|
||||||
preRun (x : xs) = case x of
|
|
||||||
DSig (Sig n t) -> do
|
|
||||||
collect (collectTVars t)
|
|
||||||
gets (M.member (coerce n) . sigs)
|
|
||||||
>>= flip
|
|
||||||
when
|
|
||||||
( uncatchableErr $ Aux.do
|
|
||||||
"Duplicate signatures for function"
|
|
||||||
quote $ printTree n
|
|
||||||
)
|
|
||||||
insertSig (coerce n) (Just t) >> preRun xs
|
|
||||||
DBind (Bind n _ e) -> do
|
|
||||||
collect (collectTVars e)
|
|
||||||
s <- gets sigs
|
|
||||||
case M.lookup (coerce n) s of
|
|
||||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
|
||||||
Just _ -> preRun xs
|
|
||||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
|
||||||
|
|
||||||
checkDef :: [Def] -> Infer [T.Def' Type]
|
|
||||||
checkDef [] = return []
|
|
||||||
checkDef (x : xs) = case x of
|
|
||||||
(DBind b) -> do
|
|
||||||
b' <- checkBind b
|
|
||||||
xs' <- checkDef xs
|
|
||||||
return $ T.DBind b' : xs'
|
|
||||||
(DData d) -> do
|
|
||||||
xs' <- checkDef xs
|
|
||||||
return $ T.DData (coerceData d) : xs'
|
|
||||||
(DSig _) -> checkDef xs
|
|
||||||
where
|
|
||||||
coerceData (Data t injs) =
|
|
||||||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
|
||||||
|
|
||||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
|
||||||
checkBind (Bind name args e) = do
|
|
||||||
let lambda = makeLambda e (reverse (coerce args))
|
|
||||||
(e, lambda_t) <- inferExp lambda
|
|
||||||
s <- gets sigs
|
|
||||||
case M.lookup (coerce name) s of
|
|
||||||
Just (Just t') -> do
|
|
||||||
sub1 <- unify lambda_t t'
|
|
||||||
sub2 <- unify t' lambda_t
|
|
||||||
unless
|
|
||||||
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
|
|
||||||
( uncatchableErr $ Aux.do
|
|
||||||
"Inferred type"
|
|
||||||
quote $ printTree lambda_t
|
|
||||||
"does not match specified type"
|
|
||||||
quote $ printTree t'
|
|
||||||
)
|
|
||||||
return $ T.Bind (coerce name, t') [] (e, lambda_t)
|
|
||||||
_ -> do
|
|
||||||
insertSig (coerce name) (Just lambda_t)
|
|
||||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
|
||||||
|
|
||||||
typeEq :: Type -> Type -> Bool
|
|
||||||
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
|
|
||||||
typeEq (TLit a) (TLit b) = a == b
|
|
||||||
typeEq (TData name a) (TData name' b) =
|
|
||||||
length a == length b
|
|
||||||
&& name == name'
|
|
||||||
&& and (zipWith typeEq a b)
|
|
||||||
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
|
|
||||||
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
|
|
||||||
typeEq (TVar _) (TVar _) = True
|
|
||||||
typeEq _ _ = False
|
|
||||||
|
|
||||||
skolemize :: Type -> Type
|
|
||||||
skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a)
|
|
||||||
skolemize (TAll x t) = TAll x (skolemize t)
|
|
||||||
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
|
|
||||||
skolemize t = t
|
|
||||||
|
|
||||||
isMoreSpecificOrEq :: Type -> Type -> Bool
|
|
||||||
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
|
|
||||||
isMoreSpecificOrEq (TFun a b) (TFun c d) =
|
|
||||||
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
|
|
||||||
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
|
|
||||||
n1 == n2
|
|
||||||
&& length ts1 == length ts2
|
|
||||||
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
|
|
||||||
isMoreSpecificOrEq _ (TVar _) = True
|
|
||||||
isMoreSpecificOrEq a b = a == b
|
|
||||||
|
|
||||||
isPoly :: Type -> Bool
|
|
||||||
isPoly (TAll _ _) = True
|
|
||||||
isPoly (TVar _) = True
|
|
||||||
isPoly _ = False
|
|
||||||
|
|
||||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
|
||||||
inferExp e = do
|
inferExp e = do
|
||||||
(s, (e', t)) <- algoW e
|
(s, (e', t)) <- algoW e
|
||||||
let subbed = apply s t
|
let subbed = apply s t
|
||||||
return $ second (const subbed) (e', t)
|
return (s, (e', subbed))
|
||||||
|
|
||||||
class CollectTVars a where
|
class CollectTVars a where
|
||||||
collectTVars :: a -> Set T.Ident
|
collectTVars :: a -> Set T.Ident
|
||||||
|
|
@ -223,7 +192,8 @@ instance CollectTVars Type where
|
||||||
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
||||||
collectTVars (TAll _ t) = collectTVars t
|
collectTVars (TAll _ t) = collectTVars t
|
||||||
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
|
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
|
||||||
collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
|
collectTVars (TData _ ts) =
|
||||||
|
foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
|
||||||
collectTVars _ = S.empty
|
collectTVars _ = S.empty
|
||||||
|
|
||||||
collect :: Set T.Ident -> Infer ()
|
collect :: Set T.Ident -> Infer ()
|
||||||
|
|
@ -232,7 +202,7 @@ collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
||||||
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||||
algoW = \case
|
algoW = \case
|
||||||
err@(EAnn e t) -> do
|
err@(EAnn e t) -> do
|
||||||
(s1, (e', t')) <- exprErr (algoW e) err
|
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||||
sub1 <- unify t t'
|
sub1 <- unify t t'
|
||||||
sub2 <- unify t' t
|
sub2 <- unify t' t
|
||||||
unless
|
unless
|
||||||
|
|
@ -243,8 +213,7 @@ algoW = \case
|
||||||
"does not match inferred type"
|
"does not match inferred type"
|
||||||
quote $ printTree t'
|
quote $ printTree t'
|
||||||
)
|
)
|
||||||
s2 <- exprErr (unify t t') err
|
let comp = sub2 `compose` sub1 `compose` sub0
|
||||||
let comp = s2 `compose` s1
|
|
||||||
return (comp, apply comp (e', t))
|
return (comp, apply comp (e', t))
|
||||||
|
|
||||||
-- \| ------------------
|
-- \| ------------------
|
||||||
|
|
@ -257,7 +226,9 @@ algoW = \case
|
||||||
EVar i -> do
|
EVar i -> do
|
||||||
var <- asks vars
|
var <- asks vars
|
||||||
case M.lookup (coerce i) var of
|
case M.lookup (coerce i) var of
|
||||||
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
|
Just t ->
|
||||||
|
inst t >>= \x ->
|
||||||
|
return (nullSubst, (T.EVar $ coerce i, x))
|
||||||
Nothing -> do
|
Nothing -> do
|
||||||
sig <- gets sigs
|
sig <- gets sigs
|
||||||
case M.lookup (coerce i) sig of
|
case M.lookup (coerce i) sig of
|
||||||
|
|
@ -266,7 +237,10 @@ algoW = \case
|
||||||
fr <- fresh
|
fr <- fresh
|
||||||
insertSig (coerce i) (Just fr)
|
insertSig (coerce i) (Just fr)
|
||||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||||
Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i
|
Nothing ->
|
||||||
|
uncatchableErr $
|
||||||
|
"Unbound variable: "
|
||||||
|
<> printTree i
|
||||||
EInj i -> do
|
EInj i -> do
|
||||||
constr <- gets injections
|
constr <- gets injections
|
||||||
case M.lookup (coerce i) constr of
|
case M.lookup (coerce i) constr of
|
||||||
|
|
@ -283,14 +257,11 @@ algoW = \case
|
||||||
|
|
||||||
err@(EAbs name e) -> do
|
err@(EAbs name e) -> do
|
||||||
fr <- fresh
|
fr <- fresh
|
||||||
exprErr
|
withBinding (coerce name) fr $ do
|
||||||
( withBinding (coerce name) fr $ do
|
|
||||||
(s1, (e', t')) <- exprErr (algoW e) err
|
(s1, (e', t')) <- exprErr (algoW e) err
|
||||||
let varType = apply s1 fr
|
let varType = apply s1 fr
|
||||||
let newArr = TFun varType t'
|
let newArr = TFun varType t'
|
||||||
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
|
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
|
||||||
)
|
|
||||||
err
|
|
||||||
|
|
||||||
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||||
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||||
|
|
@ -338,29 +309,120 @@ algoW = \case
|
||||||
(s2, (e1', t2)) <- algoW e1
|
(s2, (e1', t2)) <- algoW e1
|
||||||
let comp = s2 `compose` s1
|
let comp = s2 `compose` s1
|
||||||
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
|
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
|
||||||
|
|
||||||
-- \| TODO: Add judgement
|
|
||||||
ECase caseExpr injs -> do
|
ECase caseExpr injs -> do
|
||||||
(sub, (e', t)) <- algoW caseExpr
|
(sub, (e', t)) <- algoW caseExpr
|
||||||
(subst, injs, ret_t) <- checkCase t injs
|
(subst, injs, ret_t) <- checkCase t injs
|
||||||
let comp = subst `compose` sub
|
let comp = subst `compose` sub
|
||||||
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
||||||
|
|
||||||
makeLambda :: Exp -> [T.Ident] -> Exp
|
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
||||||
makeLambda = foldl (flip (EAbs . coerce))
|
checkCase _ [] = catchableErr "Atleast one case required"
|
||||||
|
checkCase expT brnchs = do
|
||||||
|
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||||
|
let sub0 = composeAll subs
|
||||||
|
(sub1, _) <-
|
||||||
|
foldM
|
||||||
|
( \(sub, acc) x ->
|
||||||
|
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||||
|
)
|
||||||
|
(nullSubst, expT)
|
||||||
|
branchTs
|
||||||
|
(sub2, returns_type) <-
|
||||||
|
foldM
|
||||||
|
( \(sub, acc) x ->
|
||||||
|
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||||
|
)
|
||||||
|
(nullSubst, head returns)
|
||||||
|
(tail returns)
|
||||||
|
let comp = sub2 `compose` sub1 `compose` sub0
|
||||||
|
return (comp, apply comp injs, apply comp returns_type)
|
||||||
|
|
||||||
|
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
|
||||||
|
inferBranch (Branch pat expr) = do
|
||||||
|
newPat@(pat, branchT) <- inferPattern pat
|
||||||
|
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
|
||||||
|
return
|
||||||
|
( sub
|
||||||
|
, apply sub branchT
|
||||||
|
, T.Branch (apply sub newPat) (apply sub newExp)
|
||||||
|
, apply sub exprT
|
||||||
|
)
|
||||||
|
|
||||||
|
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
||||||
|
inferPattern = \case
|
||||||
|
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||||
|
PInj constr patterns -> do
|
||||||
|
t <- gets (M.lookup (coerce constr) . injections)
|
||||||
|
t <-
|
||||||
|
maybeToRightM
|
||||||
|
( Error
|
||||||
|
( Aux.do
|
||||||
|
"Constructor:"
|
||||||
|
quote $ printTree constr
|
||||||
|
"does not exist"
|
||||||
|
)
|
||||||
|
True
|
||||||
|
)
|
||||||
|
t
|
||||||
|
let numArgs = typeLength t - 1
|
||||||
|
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||||
|
patterns <- mapM inferPattern patterns
|
||||||
|
unless
|
||||||
|
(length patterns == numArgs)
|
||||||
|
( catchableErr $ Aux.do
|
||||||
|
"The constructor"
|
||||||
|
quote $ printTree constr
|
||||||
|
" should have "
|
||||||
|
show numArgs
|
||||||
|
" arguments but has been given "
|
||||||
|
show (length patterns)
|
||||||
|
)
|
||||||
|
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||||
|
return
|
||||||
|
( T.PInj (coerce constr) (apply sub (map fst patterns))
|
||||||
|
, apply sub ret
|
||||||
|
)
|
||||||
|
PCatch -> (T.PCatch,) <$> fresh
|
||||||
|
PEnum p -> do
|
||||||
|
t <- gets (M.lookup (coerce p) . injections)
|
||||||
|
t <-
|
||||||
|
maybeToRightM
|
||||||
|
( Error
|
||||||
|
( Aux.do
|
||||||
|
"Constructor:"
|
||||||
|
quote $ printTree p
|
||||||
|
"does not exist"
|
||||||
|
)
|
||||||
|
True
|
||||||
|
)
|
||||||
|
t
|
||||||
|
unless
|
||||||
|
(typeLength t == 1)
|
||||||
|
( catchableErr $ Aux.do
|
||||||
|
"The constructor"
|
||||||
|
quote $ printTree p
|
||||||
|
" should have "
|
||||||
|
show (typeLength t - 1)
|
||||||
|
" arguments but has been given 0"
|
||||||
|
)
|
||||||
|
let (TData _data _ts) = t -- nasty nasty
|
||||||
|
frs <- mapM (const fresh) _ts
|
||||||
|
return (T.PEnum $ coerce p, TData _data frs)
|
||||||
|
PVar x -> do
|
||||||
|
fr <- fresh
|
||||||
|
let pvar = T.PVar (coerce x, fr)
|
||||||
|
return (pvar, fr)
|
||||||
|
|
||||||
-- | Unify two types producing a new substitution
|
-- | Unify two types producing a new substitution
|
||||||
unify :: Type -> Type -> Infer Subst
|
unify :: Type -> Type -> Infer Subst
|
||||||
unify t0 t1 = do
|
unify t0 t1 =
|
||||||
case (t0, t1) of
|
case (t0, t1) of
|
||||||
(TFun a b, TFun c d) -> do
|
(TFun a b, TFun c d) -> do
|
||||||
s1 <- unify a c
|
s1 <- unify a c
|
||||||
s2 <- unify (apply s1 b) (apply s1 d)
|
s2 <- unify (apply s1 b) (apply s1 d)
|
||||||
return $ s1 `compose` s2
|
return $ s1 `compose` s2
|
||||||
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
|
|
||||||
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
||||||
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
|
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
|
||||||
-------------------------------------------------------------------
|
|
||||||
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
|
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
|
||||||
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
|
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
|
||||||
(TAll _ t, b) -> unify t b
|
(TAll _ t, b) -> unify t b
|
||||||
|
|
@ -422,7 +484,12 @@ occurs i t =
|
||||||
)
|
)
|
||||||
else return $ M.singleton i t
|
else return $ M.singleton i t
|
||||||
|
|
||||||
-- | Generalize a type over all free variables in the substitution set
|
{- | Generalize a type over all free variables in the substitution set
|
||||||
|
Used for let bindings to allow expression that do not type check in
|
||||||
|
equivalent lambda expressions:
|
||||||
|
Type checks: let f = \x. x in (f True, f 'a')
|
||||||
|
Does not type check: (\f. (f True, f 'a')) (\x. x)
|
||||||
|
-}
|
||||||
generalize :: Map T.Ident Type -> Type -> Type
|
generalize :: Map T.Ident Type -> Type -> Type
|
||||||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||||
where
|
where
|
||||||
|
|
@ -446,15 +513,27 @@ inst = \case
|
||||||
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
|
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
|
||||||
rest -> return rest
|
rest -> return rest
|
||||||
|
|
||||||
-- | Compose two substitution sets
|
-- | Generate a new fresh variable
|
||||||
compose :: Subst -> Subst -> Subst
|
fresh :: Infer Type
|
||||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
fresh = do
|
||||||
|
c <- gets nextChar
|
||||||
composeAll :: [Subst] -> Subst
|
n <- gets count
|
||||||
composeAll = foldl' compose nullSubst
|
taken <- gets takenTypeVars
|
||||||
|
if c == 'z'
|
||||||
-- TODO: Split this class into two separate classes, one for free variables
|
then do
|
||||||
-- and one for applying substitutions
|
modify (\st -> st{count = succ (count st), nextChar = 'a'})
|
||||||
|
else modify (\st -> st{nextChar = next (nextChar st)})
|
||||||
|
if coerce [c] `S.member` taken
|
||||||
|
then do
|
||||||
|
fresh
|
||||||
|
else
|
||||||
|
if n == 0
|
||||||
|
then return . TVar . T.MkTVar $ LIdent [c]
|
||||||
|
else return . TVar . T.MkTVar . LIdent $ c : show n
|
||||||
|
where
|
||||||
|
next :: Char -> Char
|
||||||
|
next 'z' = 'a'
|
||||||
|
next a = succ a
|
||||||
|
|
||||||
-- | A class for substitutions
|
-- | A class for substitutions
|
||||||
class SubstType t where
|
class SubstType t where
|
||||||
|
|
@ -468,7 +547,8 @@ class FreeVars t where
|
||||||
instance FreeVars Type where
|
instance FreeVars Type where
|
||||||
free :: Type -> Set T.Ident
|
free :: Type -> Set T.Ident
|
||||||
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
|
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
|
||||||
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
|
free (TAll (T.MkTVar bound) t) =
|
||||||
|
S.singleton (coerce bound) `S.intersection` free t
|
||||||
free (TLit _) = mempty
|
free (TLit _) = mempty
|
||||||
free (TFun a b) = free a `S.union` free b
|
free (TFun a b) = free a `S.union` free b
|
||||||
free (TData _ a) = free a
|
free (TData _ a) = free a
|
||||||
|
|
@ -540,27 +620,19 @@ instance SubstType (T.Id' Type) where
|
||||||
nullSubst :: Subst
|
nullSubst :: Subst
|
||||||
nullSubst = M.empty
|
nullSubst = M.empty
|
||||||
|
|
||||||
-- | Generate a new fresh variable and increment the state counter
|
-- | Compose two substitution sets
|
||||||
fresh :: Infer Type
|
compose :: Subst -> Subst -> Subst
|
||||||
fresh = do
|
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||||
c <- gets nextChar
|
|
||||||
n <- gets count
|
-- | Compose a list of substitution sets into one
|
||||||
taken <- gets takenTypeVars
|
composeAll :: [Subst] -> Subst
|
||||||
if c == 'z'
|
composeAll = foldl' compose nullSubst
|
||||||
then do
|
|
||||||
modify (\st -> st{count = succ (count st), nextChar = 'a'})
|
{- | Convert a function with arguments to its pointfree version
|
||||||
else modify (\st -> st{nextChar = next (nextChar st)})
|
> makeLambda (add x y = x + y) = add = \x. \y. x + y
|
||||||
if coerce [c] `S.member` taken
|
-}
|
||||||
then do
|
makeLambda :: Exp -> [T.Ident] -> Exp
|
||||||
fresh
|
makeLambda = foldl (flip (EAbs . coerce))
|
||||||
else
|
|
||||||
if n == 0
|
|
||||||
then return . TVar . T.MkTVar $ LIdent [c]
|
|
||||||
else return . TVar . T.MkTVar . LIdent $ c : show n
|
|
||||||
where
|
|
||||||
next :: Char -> Char
|
|
||||||
next 'z' = 'a'
|
|
||||||
next a = succ a
|
|
||||||
|
|
||||||
-- | Run the monadic action with an additional binding
|
-- | Run the monadic action with an additional binding
|
||||||
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
|
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
|
||||||
|
|
@ -571,49 +643,8 @@ withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
|
||||||
withBindings xs =
|
withBindings xs =
|
||||||
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
|
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
|
||||||
|
|
||||||
-- | Insert a function signature into the environment
|
-- | Run the monadic action with a pattern
|
||||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
|
||||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
|
||||||
|
|
||||||
-- | Insert a constructor with its data type
|
|
||||||
insertInj :: T.Ident -> Type -> Infer ()
|
|
||||||
insertInj i t =
|
|
||||||
modify (\st -> st{injections = M.insert i t (injections st)})
|
|
||||||
|
|
||||||
existInj :: T.Ident -> Infer (Maybe Type)
|
|
||||||
existInj n = gets (M.lookup n . injections)
|
|
||||||
|
|
||||||
-------- PATTERN MATCHING ---------
|
|
||||||
|
|
||||||
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
|
||||||
checkCase _ [] = catchableErr "Atleast one case required"
|
|
||||||
checkCase expT brnchs = do
|
|
||||||
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
|
||||||
let sub0 = composeAll subs
|
|
||||||
(sub1, _) <-
|
|
||||||
foldM
|
|
||||||
( \(sub, acc) x ->
|
|
||||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
|
||||||
)
|
|
||||||
(nullSubst, expT)
|
|
||||||
injTs
|
|
||||||
(sub2, returns_type) <-
|
|
||||||
foldM
|
|
||||||
( \(sub, acc) x ->
|
|
||||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
|
||||||
)
|
|
||||||
(nullSubst, head returns)
|
|
||||||
(tail returns)
|
|
||||||
let comp = sub2 `compose` sub1 `compose` sub0
|
|
||||||
return (comp, apply comp injs, apply comp returns_type)
|
|
||||||
|
|
||||||
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
|
|
||||||
inferBranch (Branch pat expr) = do
|
|
||||||
newPat@(pat, branchT) <- inferPattern pat
|
|
||||||
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
|
|
||||||
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
|
|
||||||
|
|
||||||
withPattern :: T.Pattern' Type -> Infer a -> Infer a
|
|
||||||
withPattern p ma = case p of
|
withPattern p ma = case p of
|
||||||
T.PVar (x, t) -> withBinding x t ma
|
T.PVar (x, t) -> withBinding x t ma
|
||||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||||
|
|
@ -621,74 +652,27 @@ withPattern p ma = case p of
|
||||||
T.PCatch -> ma
|
T.PCatch -> ma
|
||||||
T.PEnum _ -> ma
|
T.PEnum _ -> ma
|
||||||
|
|
||||||
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
-- | Insert a function signature into the environment
|
||||||
inferPattern = \case
|
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||||
PInj constr patterns -> do
|
|
||||||
t <- gets (M.lookup (coerce constr) . injections)
|
-- | Insert a constructor into the start with its type
|
||||||
t <-
|
insertInj :: T.Ident -> Type -> Infer ()
|
||||||
maybeToRightM
|
insertInj i t =
|
||||||
( Error
|
modify (\st -> st{injections = M.insert i t (injections st)})
|
||||||
( Aux.do
|
|
||||||
"Constructor:"
|
{- | Check if an injection (constructor of data type)
|
||||||
quote $ printTree constr
|
with an equivalent name has been declared already
|
||||||
"does not exist"
|
-}
|
||||||
)
|
existInj :: T.Ident -> Infer (Maybe Type)
|
||||||
True
|
existInj n = gets (M.lookup n . injections)
|
||||||
)
|
|
||||||
t
|
|
||||||
let numArgs = typeLength t - 1
|
|
||||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
|
||||||
patterns <- mapM inferPattern patterns
|
|
||||||
unless
|
|
||||||
(length patterns == numArgs)
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
"The constructor"
|
|
||||||
quote $ printTree constr
|
|
||||||
" should have "
|
|
||||||
show numArgs
|
|
||||||
" arguments but has been given "
|
|
||||||
show (length patterns)
|
|
||||||
)
|
|
||||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
|
||||||
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
|
|
||||||
PCatch -> (T.PCatch,) <$> fresh
|
|
||||||
PEnum p -> do
|
|
||||||
t <- gets (M.lookup (coerce p) . injections)
|
|
||||||
t <-
|
|
||||||
maybeToRightM
|
|
||||||
( Error
|
|
||||||
( Aux.do
|
|
||||||
"Constructor:"
|
|
||||||
quote $ printTree p
|
|
||||||
"does not exist"
|
|
||||||
)
|
|
||||||
True
|
|
||||||
)
|
|
||||||
t
|
|
||||||
unless
|
|
||||||
(typeLength t == 1)
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
"The constructor"
|
|
||||||
quote $ printTree p
|
|
||||||
" should have "
|
|
||||||
show (typeLength t - 1)
|
|
||||||
" arguments but has been given 0"
|
|
||||||
)
|
|
||||||
let (TData _data _ts) = t -- nasty nasty
|
|
||||||
frs <- mapM (const fresh) _ts
|
|
||||||
return (T.PEnum $ coerce p, TData _data frs)
|
|
||||||
PVar x -> do
|
|
||||||
fr <- fresh
|
|
||||||
let pvar = T.PVar (coerce x, fr)
|
|
||||||
return (pvar, fr)
|
|
||||||
|
|
||||||
flattenType :: Type -> [Type]
|
flattenType :: Type -> [Type]
|
||||||
flattenType (TFun a b) = flattenType a <> flattenType b
|
flattenType (TFun a b) = flattenType a <> flattenType b
|
||||||
flattenType a = [a]
|
flattenType a = [a]
|
||||||
|
|
||||||
typeLength :: Type -> Int
|
typeLength :: Type -> Int
|
||||||
typeLength (TFun a b) = typeLength a + typeLength b
|
typeLength (TFun _ b) = 1 + typeLength b
|
||||||
typeLength _ = 1
|
typeLength _ = 1
|
||||||
|
|
||||||
litType :: Lit -> Type
|
litType :: Lit -> Type
|
||||||
|
|
@ -698,23 +682,63 @@ litType (LChar _) = char
|
||||||
int = TLit "Int"
|
int = TLit "Int"
|
||||||
char = TLit "Char"
|
char = TLit "Char"
|
||||||
|
|
||||||
partitionType ::
|
typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
|
||||||
Int -> -- Number of parameters to apply
|
typeEq (TVar (T.MkTVar a)) t@(TVar _) = do
|
||||||
Type ->
|
st <- get
|
||||||
([Type], Type)
|
case M.lookup (coerce a) st of
|
||||||
partitionType = go []
|
Nothing -> put $ M.insert (coerce a) t st
|
||||||
where
|
Just t' -> unless (t == t') (catchableErr "TYPE MISMATCH")
|
||||||
go acc 0 t = (acc, t)
|
typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r'
|
||||||
go acc i t = case t of
|
typeEq (TAll _ l) (TAll _ r) = typeEq l r
|
||||||
TAll tvar t' -> second (TAll tvar) $ go acc i t'
|
typeEq (TLit a) (TLit b) = unless (a == b) (catchableErr "TYPE MISMATCH")
|
||||||
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
|
typeEq (TData nameL tL) (TData nameR tR) = do
|
||||||
_ -> error "Number of parameters and type doesn't match"
|
unless (nameL == nameR) (catchableErr "TYPE MISMATCH")
|
||||||
|
zipWithM_ typeEq tL tR
|
||||||
|
typeEq (TEVar _) (TEVar _) = catchableErr "TYPE MISMATCH"
|
||||||
|
typeEq _ _ = catchableErr "TYPE MISMATCH"
|
||||||
|
|
||||||
exprErr :: Infer a -> Exp -> Infer a
|
{- | Catch an error if possible and add the given
|
||||||
exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x)
|
expression as addition to the error message
|
||||||
|
-}
|
||||||
|
exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
|
||||||
|
exprErr ma exp =
|
||||||
|
catchError
|
||||||
|
ma
|
||||||
|
( \x ->
|
||||||
|
if x.catchable
|
||||||
|
then
|
||||||
|
throwError
|
||||||
|
( x
|
||||||
|
{ msg =
|
||||||
|
x.msg
|
||||||
|
<> " in expression: \n"
|
||||||
|
<> printTree exp
|
||||||
|
, catchable = False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else throwError x
|
||||||
|
)
|
||||||
|
|
||||||
|
{- | Catch an error if possible and add the given
|
||||||
|
data as addition to the error message
|
||||||
|
-}
|
||||||
dataErr :: Infer a -> Data -> Infer a
|
dataErr :: Infer a -> Data -> Infer a
|
||||||
dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False}))
|
dataErr ma d =
|
||||||
|
catchError
|
||||||
|
ma
|
||||||
|
( \x ->
|
||||||
|
if x.catchable
|
||||||
|
then
|
||||||
|
throwError
|
||||||
|
( x
|
||||||
|
{ msg =
|
||||||
|
x.msg
|
||||||
|
<> " in data: \n"
|
||||||
|
<> printTree d
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else throwError (x{catchable = False})
|
||||||
|
)
|
||||||
|
|
||||||
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
|
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
|
||||||
unzip4 =
|
unzip4 =
|
||||||
|
|
@ -737,6 +761,7 @@ data Env = Env
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
data Error = Error {msg :: String, catchable :: Bool}
|
data Error = Error {msg :: String, catchable :: Bool}
|
||||||
|
deriving (Show)
|
||||||
type Subst = Map T.Ident Type
|
type Subst = Map T.Ident Type
|
||||||
|
|
||||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
||||||
|
|
|
||||||
|
|
@ -187,6 +187,31 @@ bes =
|
||||||
" Nil => 0 ;"
|
" Nil => 0 ;"
|
||||||
" };"
|
" };"
|
||||||
)
|
)
|
||||||
|
, testBe
|
||||||
|
"length function on int list infers correct signature"
|
||||||
|
( D.do
|
||||||
|
"data List () where {"
|
||||||
|
" Nil : List ()"
|
||||||
|
" Cons : Int -> List () -> List ()"
|
||||||
|
"};"
|
||||||
|
|
||||||
|
"length xs = case xs of {"
|
||||||
|
" Nil => 0 ;"
|
||||||
|
" Cons _ xs => 1 + length xs ;"
|
||||||
|
"};"
|
||||||
|
)
|
||||||
|
( D.do
|
||||||
|
"data List () where {"
|
||||||
|
" Nil : List ()"
|
||||||
|
" Cons : Int -> List () -> List ()"
|
||||||
|
"};"
|
||||||
|
|
||||||
|
"length : List () -> Int ;"
|
||||||
|
"length xs = case xs of {"
|
||||||
|
" Nil => 0 ;"
|
||||||
|
" Cons _ xs => 1 + length xs ;"
|
||||||
|
"};"
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
|
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue