Fixed previously incorrect type equality check, commented code, add test
This commit is contained in:
parent
85f31b129b
commit
b1d3e31efd
3 changed files with 335 additions and 311 deletions
|
|
@ -12,7 +12,6 @@ import Control.Monad.Except
|
|||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
|
|
@ -30,16 +29,17 @@ initCtx = Ctx mempty
|
|||
initEnv = Env 0 'a' mempty mempty mempty
|
||||
|
||||
run :: Infer a -> Either Error a
|
||||
run = runC initEnv initCtx
|
||||
run = run' initEnv initCtx
|
||||
|
||||
runC :: Env -> Ctx -> Infer a -> Either Error a
|
||||
runC e c =
|
||||
run' :: Env -> Ctx -> Infer a -> Either Error a
|
||||
run' e c =
|
||||
runIdentity
|
||||
. runExceptT
|
||||
. flip runReaderT c
|
||||
. flip evalStateT e
|
||||
. runInfer
|
||||
|
||||
-- | Type check a program
|
||||
typecheck :: Program -> Either String (T.Program' Type)
|
||||
typecheck = onLeft msg . run . checkPrg
|
||||
where
|
||||
|
|
@ -47,20 +47,87 @@ typecheck = onLeft msg . run . checkPrg
|
|||
onLeft f (Left x) = Left $ f x
|
||||
onLeft _ (Right x) = Right x
|
||||
|
||||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
bs' <- checkDef bs
|
||||
return $ T.Program bs'
|
||||
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x : xs) = case x of
|
||||
DSig (Sig n t) -> do
|
||||
collect (collectTVars t)
|
||||
gets (M.member (coerce n) . sigs)
|
||||
>>= flip
|
||||
when
|
||||
( uncatchableErr $ Aux.do
|
||||
"Duplicate signatures for function"
|
||||
quote $ printTree n
|
||||
)
|
||||
insertSig (coerce n) (Just t) >> preRun xs
|
||||
DBind (Bind n _ e) -> do
|
||||
collect (collectTVars e)
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def' Type]
|
||||
checkDef [] = return []
|
||||
checkDef (x : xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
xs' <- checkDef xs
|
||||
return $ T.DBind b' : xs'
|
||||
(DData d) -> do
|
||||
xs' <- checkDef xs
|
||||
return $ T.DData (coerceData d) : xs'
|
||||
(DSig _) -> checkDef xs
|
||||
where
|
||||
coerceData (Data t injs) =
|
||||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind (Bind name args e) = do
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(sub0, (e, lambda_t)) <- inferExp lambda
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
let fsig = apply sub0 t'
|
||||
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
|
||||
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
|
||||
unless
|
||||
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
|
||||
( uncatchableErr $ Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree lambda_t
|
||||
"does not match specified type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just lambda_t)
|
||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||
|
||||
checkData :: Data -> Infer ()
|
||||
checkData err@(Data typ injs) = do
|
||||
(name, tvars) <- go typ
|
||||
dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err
|
||||
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
|
||||
where
|
||||
go = \case
|
||||
TData name typs
|
||||
| Right tvars' <- mapM toTVar typs ->
|
||||
pure (name, tvars')
|
||||
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
|
||||
_ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ]
|
||||
_ ->
|
||||
uncatchableErr $
|
||||
unwords ["Bad data type definition: ", printTree typ]
|
||||
|
||||
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
||||
typecheckInj (Inj c inj_typ) name tvars
|
||||
checkInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
||||
checkInj (Inj c inj_typ) name tvars
|
||||
| Right False <- boundTVars tvars inj_typ =
|
||||
catchableErr "Unbound type variables"
|
||||
| TData name' typs <- returnType inj_typ
|
||||
|
|
@ -108,109 +175,11 @@ returnType :: Type -> Type
|
|||
returnType (TFun _ t2) = returnType t2
|
||||
returnType a = a
|
||||
|
||||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
bs' <- checkDef bs
|
||||
return $ T.Program bs'
|
||||
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x : xs) = case x of
|
||||
DSig (Sig n t) -> do
|
||||
collect (collectTVars t)
|
||||
gets (M.member (coerce n) . sigs)
|
||||
>>= flip
|
||||
when
|
||||
( uncatchableErr $ Aux.do
|
||||
"Duplicate signatures for function"
|
||||
quote $ printTree n
|
||||
)
|
||||
insertSig (coerce n) (Just t) >> preRun xs
|
||||
DBind (Bind n _ e) -> do
|
||||
collect (collectTVars e)
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def' Type]
|
||||
checkDef [] = return []
|
||||
checkDef (x : xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
xs' <- checkDef xs
|
||||
return $ T.DBind b' : xs'
|
||||
(DData d) -> do
|
||||
xs' <- checkDef xs
|
||||
return $ T.DData (coerceData d) : xs'
|
||||
(DSig _) -> checkDef xs
|
||||
where
|
||||
coerceData (Data t injs) =
|
||||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind (Bind name args e) = do
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(e, lambda_t) <- inferExp lambda
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
sub1 <- unify lambda_t t'
|
||||
sub2 <- unify t' lambda_t
|
||||
unless
|
||||
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
|
||||
( uncatchableErr $ Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree lambda_t
|
||||
"does not match specified type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
return $ T.Bind (coerce name, t') [] (e, lambda_t)
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just lambda_t)
|
||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||
|
||||
typeEq :: Type -> Type -> Bool
|
||||
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TLit a) (TLit b) = a == b
|
||||
typeEq (TData name a) (TData name' b) =
|
||||
length a == length b
|
||||
&& name == name'
|
||||
&& and (zipWith typeEq a b)
|
||||
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
|
||||
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
|
||||
typeEq (TVar _) (TVar _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
skolemize :: Type -> Type
|
||||
skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a)
|
||||
skolemize (TAll x t) = TAll x (skolemize t)
|
||||
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
|
||||
skolemize t = t
|
||||
|
||||
isMoreSpecificOrEq :: Type -> Type -> Bool
|
||||
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
|
||||
isMoreSpecificOrEq (TFun a b) (TFun c d) =
|
||||
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
|
||||
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
|
||||
n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
|
||||
isMoreSpecificOrEq _ (TVar _) = True
|
||||
isMoreSpecificOrEq a b = a == b
|
||||
|
||||
isPoly :: Type -> Bool
|
||||
isPoly (TAll _ _) = True
|
||||
isPoly (TVar _) = True
|
||||
isPoly _ = False
|
||||
|
||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
||||
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
inferExp e = do
|
||||
(s, (e', t)) <- algoW e
|
||||
let subbed = apply s t
|
||||
return $ second (const subbed) (e', t)
|
||||
return (s, (e', subbed))
|
||||
|
||||
class CollectTVars a where
|
||||
collectTVars :: a -> Set T.Ident
|
||||
|
|
@ -223,7 +192,8 @@ instance CollectTVars Type where
|
|||
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
||||
collectTVars (TAll _ t) = collectTVars t
|
||||
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
|
||||
collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
|
||||
collectTVars (TData _ ts) =
|
||||
foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
|
||||
collectTVars _ = S.empty
|
||||
|
||||
collect :: Set T.Ident -> Infer ()
|
||||
|
|
@ -232,7 +202,7 @@ collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
|||
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
algoW = \case
|
||||
err@(EAnn e t) -> do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||
sub1 <- unify t t'
|
||||
sub2 <- unify t' t
|
||||
unless
|
||||
|
|
@ -243,8 +213,7 @@ algoW = \case
|
|||
"does not match inferred type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
s2 <- exprErr (unify t t') err
|
||||
let comp = s2 `compose` s1
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, apply comp (e', t))
|
||||
|
||||
-- \| ------------------
|
||||
|
|
@ -257,7 +226,9 @@ algoW = \case
|
|||
EVar i -> do
|
||||
var <- asks vars
|
||||
case M.lookup (coerce i) var of
|
||||
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
|
||||
Just t ->
|
||||
inst t >>= \x ->
|
||||
return (nullSubst, (T.EVar $ coerce i, x))
|
||||
Nothing -> do
|
||||
sig <- gets sigs
|
||||
case M.lookup (coerce i) sig of
|
||||
|
|
@ -266,7 +237,10 @@ algoW = \case
|
|||
fr <- fresh
|
||||
insertSig (coerce i) (Just fr)
|
||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||
Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i
|
||||
Nothing ->
|
||||
uncatchableErr $
|
||||
"Unbound variable: "
|
||||
<> printTree i
|
||||
EInj i -> do
|
||||
constr <- gets injections
|
||||
case M.lookup (coerce i) constr of
|
||||
|
|
@ -283,14 +257,11 @@ algoW = \case
|
|||
|
||||
err@(EAbs name e) -> do
|
||||
fr <- fresh
|
||||
exprErr
|
||||
( withBinding (coerce name) fr $ do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
let varType = apply s1 fr
|
||||
let newArr = TFun varType t'
|
||||
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
|
||||
)
|
||||
err
|
||||
withBinding (coerce name) fr $ do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
let varType = apply s1 fr
|
||||
let newArr = TFun varType t'
|
||||
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
|
||||
|
||||
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||
|
|
@ -338,29 +309,120 @@ algoW = \case
|
|||
(s2, (e1', t2)) <- algoW e1
|
||||
let comp = s2 `compose` s1
|
||||
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
|
||||
|
||||
-- \| TODO: Add judgement
|
||||
ECase caseExpr injs -> do
|
||||
(sub, (e', t)) <- algoW caseExpr
|
||||
(subst, injs, ret_t) <- checkCase t injs
|
||||
let comp = subst `compose` sub
|
||||
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
||||
|
||||
makeLambda :: Exp -> [T.Ident] -> Exp
|
||||
makeLambda = foldl (flip (EAbs . coerce))
|
||||
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
||||
checkCase _ [] = catchableErr "Atleast one case required"
|
||||
checkCase expT brnchs = do
|
||||
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||
let sub0 = composeAll subs
|
||||
(sub1, _) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||
)
|
||||
(nullSubst, expT)
|
||||
branchTs
|
||||
(sub2, returns_type) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||
)
|
||||
(nullSubst, head returns)
|
||||
(tail returns)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, apply comp injs, apply comp returns_type)
|
||||
|
||||
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
|
||||
inferBranch (Branch pat expr) = do
|
||||
newPat@(pat, branchT) <- inferPattern pat
|
||||
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
|
||||
return
|
||||
( sub
|
||||
, apply sub branchT
|
||||
, T.Branch (apply sub newPat) (apply sub newExp)
|
||||
, apply sub exprT
|
||||
)
|
||||
|
||||
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
||||
inferPattern = \case
|
||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||
PInj constr patterns -> do
|
||||
t <- gets (M.lookup (coerce constr) . injections)
|
||||
t <-
|
||||
maybeToRightM
|
||||
( Error
|
||||
( Aux.do
|
||||
"Constructor:"
|
||||
quote $ printTree constr
|
||||
"does not exist"
|
||||
)
|
||||
True
|
||||
)
|
||||
t
|
||||
let numArgs = typeLength t - 1
|
||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||
patterns <- mapM inferPattern patterns
|
||||
unless
|
||||
(length patterns == numArgs)
|
||||
( catchableErr $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree constr
|
||||
" should have "
|
||||
show numArgs
|
||||
" arguments but has been given "
|
||||
show (length patterns)
|
||||
)
|
||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||
return
|
||||
( T.PInj (coerce constr) (apply sub (map fst patterns))
|
||||
, apply sub ret
|
||||
)
|
||||
PCatch -> (T.PCatch,) <$> fresh
|
||||
PEnum p -> do
|
||||
t <- gets (M.lookup (coerce p) . injections)
|
||||
t <-
|
||||
maybeToRightM
|
||||
( Error
|
||||
( Aux.do
|
||||
"Constructor:"
|
||||
quote $ printTree p
|
||||
"does not exist"
|
||||
)
|
||||
True
|
||||
)
|
||||
t
|
||||
unless
|
||||
(typeLength t == 1)
|
||||
( catchableErr $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree p
|
||||
" should have "
|
||||
show (typeLength t - 1)
|
||||
" arguments but has been given 0"
|
||||
)
|
||||
let (TData _data _ts) = t -- nasty nasty
|
||||
frs <- mapM (const fresh) _ts
|
||||
return (T.PEnum $ coerce p, TData _data frs)
|
||||
PVar x -> do
|
||||
fr <- fresh
|
||||
let pvar = T.PVar (coerce x, fr)
|
||||
return (pvar, fr)
|
||||
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: Type -> Type -> Infer Subst
|
||||
unify t0 t1 = do
|
||||
unify t0 t1 =
|
||||
case (t0, t1) of
|
||||
(TFun a b, TFun c d) -> do
|
||||
s1 <- unify a c
|
||||
s2 <- unify (apply s1 b) (apply s1 d)
|
||||
return $ s1 `compose` s2
|
||||
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
|
||||
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
||||
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
|
||||
-------------------------------------------------------------------
|
||||
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
|
||||
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
|
||||
(TAll _ t, b) -> unify t b
|
||||
|
|
@ -422,7 +484,12 @@ occurs i t =
|
|||
)
|
||||
else return $ M.singleton i t
|
||||
|
||||
-- | Generalize a type over all free variables in the substitution set
|
||||
{- | Generalize a type over all free variables in the substitution set
|
||||
Used for let bindings to allow expression that do not type check in
|
||||
equivalent lambda expressions:
|
||||
Type checks: let f = \x. x in (f True, f 'a')
|
||||
Does not type check: (\f. (f True, f 'a')) (\x. x)
|
||||
-}
|
||||
generalize :: Map T.Ident Type -> Type -> Type
|
||||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||
where
|
||||
|
|
@ -446,15 +513,27 @@ inst = \case
|
|||
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
|
||||
rest -> return rest
|
||||
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
|
||||
composeAll :: [Subst] -> Subst
|
||||
composeAll = foldl' compose nullSubst
|
||||
|
||||
-- TODO: Split this class into two separate classes, one for free variables
|
||||
-- and one for applying substitutions
|
||||
-- | Generate a new fresh variable
|
||||
fresh :: Infer Type
|
||||
fresh = do
|
||||
c <- gets nextChar
|
||||
n <- gets count
|
||||
taken <- gets takenTypeVars
|
||||
if c == 'z'
|
||||
then do
|
||||
modify (\st -> st{count = succ (count st), nextChar = 'a'})
|
||||
else modify (\st -> st{nextChar = next (nextChar st)})
|
||||
if coerce [c] `S.member` taken
|
||||
then do
|
||||
fresh
|
||||
else
|
||||
if n == 0
|
||||
then return . TVar . T.MkTVar $ LIdent [c]
|
||||
else return . TVar . T.MkTVar . LIdent $ c : show n
|
||||
where
|
||||
next :: Char -> Char
|
||||
next 'z' = 'a'
|
||||
next a = succ a
|
||||
|
||||
-- | A class for substitutions
|
||||
class SubstType t where
|
||||
|
|
@ -468,7 +547,8 @@ class FreeVars t where
|
|||
instance FreeVars Type where
|
||||
free :: Type -> Set T.Ident
|
||||
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
|
||||
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
|
||||
free (TAll (T.MkTVar bound) t) =
|
||||
S.singleton (coerce bound) `S.intersection` free t
|
||||
free (TLit _) = mempty
|
||||
free (TFun a b) = free a `S.union` free b
|
||||
free (TData _ a) = free a
|
||||
|
|
@ -540,27 +620,19 @@ instance SubstType (T.Id' Type) where
|
|||
nullSubst :: Subst
|
||||
nullSubst = M.empty
|
||||
|
||||
-- | Generate a new fresh variable and increment the state counter
|
||||
fresh :: Infer Type
|
||||
fresh = do
|
||||
c <- gets nextChar
|
||||
n <- gets count
|
||||
taken <- gets takenTypeVars
|
||||
if c == 'z'
|
||||
then do
|
||||
modify (\st -> st{count = succ (count st), nextChar = 'a'})
|
||||
else modify (\st -> st{nextChar = next (nextChar st)})
|
||||
if coerce [c] `S.member` taken
|
||||
then do
|
||||
fresh
|
||||
else
|
||||
if n == 0
|
||||
then return . TVar . T.MkTVar $ LIdent [c]
|
||||
else return . TVar . T.MkTVar . LIdent $ c : show n
|
||||
where
|
||||
next :: Char -> Char
|
||||
next 'z' = 'a'
|
||||
next a = succ a
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
|
||||
-- | Compose a list of substitution sets into one
|
||||
composeAll :: [Subst] -> Subst
|
||||
composeAll = foldl' compose nullSubst
|
||||
|
||||
{- | Convert a function with arguments to its pointfree version
|
||||
> makeLambda (add x y = x + y) = add = \x. \y. x + y
|
||||
-}
|
||||
makeLambda :: Exp -> [T.Ident] -> Exp
|
||||
makeLambda = foldl (flip (EAbs . coerce))
|
||||
|
||||
-- | Run the monadic action with an additional binding
|
||||
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
|
||||
|
|
@ -571,49 +643,8 @@ withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
|
|||
withBindings xs =
|
||||
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
|
||||
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||
|
||||
-- | Insert a constructor with its data type
|
||||
insertInj :: T.Ident -> Type -> Infer ()
|
||||
insertInj i t =
|
||||
modify (\st -> st{injections = M.insert i t (injections st)})
|
||||
|
||||
existInj :: T.Ident -> Infer (Maybe Type)
|
||||
existInj n = gets (M.lookup n . injections)
|
||||
|
||||
-------- PATTERN MATCHING ---------
|
||||
|
||||
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
||||
checkCase _ [] = catchableErr "Atleast one case required"
|
||||
checkCase expT brnchs = do
|
||||
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||
let sub0 = composeAll subs
|
||||
(sub1, _) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||
)
|
||||
(nullSubst, expT)
|
||||
injTs
|
||||
(sub2, returns_type) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
|
||||
)
|
||||
(nullSubst, head returns)
|
||||
(tail returns)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, apply comp injs, apply comp returns_type)
|
||||
|
||||
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
|
||||
inferBranch (Branch pat expr) = do
|
||||
newPat@(pat, branchT) <- inferPattern pat
|
||||
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
|
||||
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
|
||||
|
||||
withPattern :: T.Pattern' Type -> Infer a -> Infer a
|
||||
-- | Run the monadic action with a pattern
|
||||
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
|
||||
withPattern p ma = case p of
|
||||
T.PVar (x, t) -> withBinding x t ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
|
|
@ -621,74 +652,27 @@ withPattern p ma = case p of
|
|||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
|
||||
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
||||
inferPattern = \case
|
||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||
PInj constr patterns -> do
|
||||
t <- gets (M.lookup (coerce constr) . injections)
|
||||
t <-
|
||||
maybeToRightM
|
||||
( Error
|
||||
( Aux.do
|
||||
"Constructor:"
|
||||
quote $ printTree constr
|
||||
"does not exist"
|
||||
)
|
||||
True
|
||||
)
|
||||
t
|
||||
let numArgs = typeLength t - 1
|
||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||
patterns <- mapM inferPattern patterns
|
||||
unless
|
||||
(length patterns == numArgs)
|
||||
( catchableErr $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree constr
|
||||
" should have "
|
||||
show numArgs
|
||||
" arguments but has been given "
|
||||
show (length patterns)
|
||||
)
|
||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
|
||||
PCatch -> (T.PCatch,) <$> fresh
|
||||
PEnum p -> do
|
||||
t <- gets (M.lookup (coerce p) . injections)
|
||||
t <-
|
||||
maybeToRightM
|
||||
( Error
|
||||
( Aux.do
|
||||
"Constructor:"
|
||||
quote $ printTree p
|
||||
"does not exist"
|
||||
)
|
||||
True
|
||||
)
|
||||
t
|
||||
unless
|
||||
(typeLength t == 1)
|
||||
( catchableErr $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree p
|
||||
" should have "
|
||||
show (typeLength t - 1)
|
||||
" arguments but has been given 0"
|
||||
)
|
||||
let (TData _data _ts) = t -- nasty nasty
|
||||
frs <- mapM (const fresh) _ts
|
||||
return (T.PEnum $ coerce p, TData _data frs)
|
||||
PVar x -> do
|
||||
fr <- fresh
|
||||
let pvar = T.PVar (coerce x, fr)
|
||||
return (pvar, fr)
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||
|
||||
-- | Insert a constructor into the start with its type
|
||||
insertInj :: T.Ident -> Type -> Infer ()
|
||||
insertInj i t =
|
||||
modify (\st -> st{injections = M.insert i t (injections st)})
|
||||
|
||||
{- | Check if an injection (constructor of data type)
|
||||
with an equivalent name has been declared already
|
||||
-}
|
||||
existInj :: T.Ident -> Infer (Maybe Type)
|
||||
existInj n = gets (M.lookup n . injections)
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
|
||||
typeLength :: Type -> Int
|
||||
typeLength (TFun a b) = typeLength a + typeLength b
|
||||
typeLength (TFun _ b) = 1 + typeLength b
|
||||
typeLength _ = 1
|
||||
|
||||
litType :: Lit -> Type
|
||||
|
|
@ -698,23 +682,63 @@ litType (LChar _) = char
|
|||
int = TLit "Int"
|
||||
char = TLit "Char"
|
||||
|
||||
partitionType ::
|
||||
Int -> -- Number of parameters to apply
|
||||
Type ->
|
||||
([Type], Type)
|
||||
partitionType = go []
|
||||
where
|
||||
go acc 0 t = (acc, t)
|
||||
go acc i t = case t of
|
||||
TAll tvar t' -> second (TAll tvar) $ go acc i t'
|
||||
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
|
||||
_ -> error "Number of parameters and type doesn't match"
|
||||
typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
|
||||
typeEq (TVar (T.MkTVar a)) t@(TVar _) = do
|
||||
st <- get
|
||||
case M.lookup (coerce a) st of
|
||||
Nothing -> put $ M.insert (coerce a) t st
|
||||
Just t' -> unless (t == t') (catchableErr "TYPE MISMATCH")
|
||||
typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r'
|
||||
typeEq (TAll _ l) (TAll _ r) = typeEq l r
|
||||
typeEq (TLit a) (TLit b) = unless (a == b) (catchableErr "TYPE MISMATCH")
|
||||
typeEq (TData nameL tL) (TData nameR tR) = do
|
||||
unless (nameL == nameR) (catchableErr "TYPE MISMATCH")
|
||||
zipWithM_ typeEq tL tR
|
||||
typeEq (TEVar _) (TEVar _) = catchableErr "TYPE MISMATCH"
|
||||
typeEq _ _ = catchableErr "TYPE MISMATCH"
|
||||
|
||||
exprErr :: Infer a -> Exp -> Infer a
|
||||
exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x)
|
||||
{- | Catch an error if possible and add the given
|
||||
expression as addition to the error message
|
||||
-}
|
||||
exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
|
||||
exprErr ma exp =
|
||||
catchError
|
||||
ma
|
||||
( \x ->
|
||||
if x.catchable
|
||||
then
|
||||
throwError
|
||||
( x
|
||||
{ msg =
|
||||
x.msg
|
||||
<> " in expression: \n"
|
||||
<> printTree exp
|
||||
, catchable = False
|
||||
}
|
||||
)
|
||||
else throwError x
|
||||
)
|
||||
|
||||
{- | Catch an error if possible and add the given
|
||||
data as addition to the error message
|
||||
-}
|
||||
dataErr :: Infer a -> Data -> Infer a
|
||||
dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False}))
|
||||
dataErr ma d =
|
||||
catchError
|
||||
ma
|
||||
( \x ->
|
||||
if x.catchable
|
||||
then
|
||||
throwError
|
||||
( x
|
||||
{ msg =
|
||||
x.msg
|
||||
<> " in data: \n"
|
||||
<> printTree d
|
||||
}
|
||||
)
|
||||
else throwError (x{catchable = False})
|
||||
)
|
||||
|
||||
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
|
||||
unzip4 =
|
||||
|
|
@ -737,6 +761,7 @@ data Env = Env
|
|||
deriving (Show)
|
||||
|
||||
data Error = Error {msg :: String, catchable :: Bool}
|
||||
deriving (Show)
|
||||
type Subst = Map T.Ident Type
|
||||
|
||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue