Fixed bug in EApp, cleaned a bit, added todo for disallowing mutual recursion
This commit is contained in:
parent
aa1ff630a5
commit
c4931c3996
1 changed files with 87 additions and 129 deletions
|
|
@ -27,22 +27,7 @@ import Grammar.Abs
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
import TypeChecker.TypeCheckerIr qualified as T
|
||||||
|
|
||||||
-- TODO: Save all substition sets encountered in the program and apply
|
-- TODO: Disallow mutual recursion
|
||||||
-- to all top level functions in the end.
|
|
||||||
|
|
||||||
initCtx = Ctx mempty
|
|
||||||
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
|
|
||||||
|
|
||||||
run :: Infer a -> Either Error a
|
|
||||||
run = run' initEnv initCtx
|
|
||||||
|
|
||||||
run' :: Env -> Ctx -> Infer a -> Either Error a
|
|
||||||
run' e c =
|
|
||||||
runIdentity
|
|
||||||
. runExceptT
|
|
||||||
. flip runReaderT c
|
|
||||||
. flip evalStateT e
|
|
||||||
. runInfer
|
|
||||||
|
|
||||||
-- | Type check a program
|
-- | Type check a program
|
||||||
typecheck :: Program -> Either String (T.Program' Type)
|
typecheck :: Program -> Either String (T.Program' Type)
|
||||||
|
|
@ -73,29 +58,23 @@ preRun [] = return ()
|
||||||
preRun (x : xs) = case x of
|
preRun (x : xs) = case x of
|
||||||
DSig (Sig n t) -> do
|
DSig (Sig n t) -> do
|
||||||
collect (collectTVars t)
|
collect (collectTVars t)
|
||||||
gets (M.member (coerce n) . sigs)
|
duplicateDecl n $ Aux.do
|
||||||
>>= flip
|
"Multiple signatures of function"
|
||||||
when
|
|
||||||
( uncatchableErr $ Aux.do
|
|
||||||
"Duplicate signatures of function"
|
|
||||||
quote $ printTree n
|
quote $ printTree n
|
||||||
)
|
|
||||||
insertSig (coerce n) (Just t) >> preRun xs
|
insertSig (coerce n) (Just t) >> preRun xs
|
||||||
DBind (Bind n _ e) -> do
|
DBind (Bind n _ e) -> do
|
||||||
binds <- gets declaredBinds
|
duplicateDecl n $ Aux.do
|
||||||
when
|
"Multiple declarations of function"
|
||||||
(coerce n `S.member` binds)
|
|
||||||
( uncatchableErr $ Aux.do
|
|
||||||
"Duplicate declarations of function"
|
|
||||||
quote $ printTree n
|
quote $ printTree n
|
||||||
)
|
|
||||||
modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds})
|
|
||||||
collect (collectTVars e)
|
collect (collectTVars e)
|
||||||
s <- gets sigs
|
s <- gets sigs
|
||||||
case M.lookup (coerce n) s of
|
case M.lookup (coerce n) s of
|
||||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||||
Just _ -> preRun xs
|
Just _ -> preRun xs
|
||||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
DData d@(Data t _) -> let collected = collect (collectTVars t) in checkData d collected >> preRun xs
|
||||||
|
where
|
||||||
|
-- Check if function body / signature has been declared already
|
||||||
|
duplicateDecl n msg = gets (M.member (coerce n) . sigs) >>= flip when (uncatchableErr msg)
|
||||||
|
|
||||||
checkDef :: [Def] -> Infer [T.Def' Type]
|
checkDef :: [Def] -> Infer [T.Def' Type]
|
||||||
checkDef [] = return []
|
checkDef [] = return []
|
||||||
|
|
@ -126,10 +105,10 @@ checkBind bind@(Bind name args e) = do
|
||||||
insertSig (coerce name) (Just lambda_t)
|
insertSig (coerce name) (Just lambda_t)
|
||||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||||
|
|
||||||
checkData :: Data -> Infer ()
|
checkData :: (MonadReader Ctx m, Monad m, MonadError Error m) => Data -> m () -> m ()
|
||||||
checkData err@(Data typ injs) = do
|
checkData err@(Data typ injs) ma = do
|
||||||
(name, tvars) <- go typ
|
(name, tvars) <- go typ
|
||||||
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
|
dataErr (mapM_ (\i -> checkInj i name tvars ma) injs) err
|
||||||
where
|
where
|
||||||
go = \case
|
go = \case
|
||||||
TData name typs
|
TData name typs
|
||||||
|
|
@ -140,8 +119,8 @@ checkData err@(Data typ injs) = do
|
||||||
uncatchableErr $
|
uncatchableErr $
|
||||||
unwords ["Bad data type definition: ", printTree typ]
|
unwords ["Bad data type definition: ", printTree typ]
|
||||||
|
|
||||||
checkInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
checkInj :: (MonadError Error m, MonadReader Ctx m, Monad m) => Inj -> UIdent -> [TVar] -> m a -> m a
|
||||||
checkInj (Inj c inj_typ) name tvars
|
checkInj (Inj c inj_typ) name tvars ma
|
||||||
| Right False <- boundTVars tvars inj_typ =
|
| Right False <- boundTVars tvars inj_typ =
|
||||||
catchableErr "Unbound type variables"
|
catchableErr "Unbound type variables"
|
||||||
| TData name' typs <- returnType inj_typ
|
| TData name' typs <- returnType inj_typ
|
||||||
|
|
@ -156,7 +135,7 @@ checkInj (Inj c inj_typ) name tvars
|
||||||
"with type"
|
"with type"
|
||||||
quote $ printTree t
|
quote $ printTree t
|
||||||
"already exist"
|
"already exist"
|
||||||
Nothing -> insertInj (coerce c) inj_typ
|
Nothing -> insertInj (coerce c) inj_typ ma
|
||||||
| otherwise =
|
| otherwise =
|
||||||
uncatchableErr $
|
uncatchableErr $
|
||||||
unwords
|
unwords
|
||||||
|
|
@ -246,11 +225,11 @@ algoW = \case
|
||||||
return (nullSubst, (T.EVar $ coerce i, x))
|
return (nullSubst, (T.EVar $ coerce i, x))
|
||||||
Nothing -> do
|
Nothing -> do
|
||||||
sig <- gets sigs
|
sig <- gets sigs
|
||||||
|
cb <- gets currentBind
|
||||||
case M.lookup (coerce i) sig of
|
case M.lookup (coerce i) sig of
|
||||||
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
|
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||||
Just Nothing -> do
|
Just Nothing -> do
|
||||||
fr <- fresh
|
fr <- fresh
|
||||||
cb <- gets currentBind
|
|
||||||
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
|
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
|
||||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||||
Nothing ->
|
Nothing ->
|
||||||
|
|
@ -258,7 +237,7 @@ algoW = \case
|
||||||
"Unbound variable: "
|
"Unbound variable: "
|
||||||
<> printTree i
|
<> printTree i
|
||||||
EInj i -> do
|
EInj i -> do
|
||||||
constr <- gets injections
|
constr <- asks injections
|
||||||
case M.lookup (coerce i) constr of
|
case M.lookup (coerce i) constr of
|
||||||
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||||
Nothing ->
|
Nothing ->
|
||||||
|
|
@ -304,6 +283,8 @@ algoW = \case
|
||||||
err@(EApp e0 e1) -> do
|
err@(EApp e0 e1) -> do
|
||||||
fr <- fresh
|
fr <- fresh
|
||||||
(s0, (e0', t0)) <- algoW e0
|
(s0, (e0', t0)) <- algoW e0
|
||||||
|
applySt s0 $ do
|
||||||
|
modify (\st -> st{sigs = apply s0 st.sigs})
|
||||||
(s1, (e1', t1)) <- algoW e1
|
(s1, (e1', t1)) <- algoW e1
|
||||||
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
|
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
|
||||||
let t = apply s2 fr
|
let t = apply s2 fr
|
||||||
|
|
@ -368,8 +349,38 @@ inferBranch (Branch pat expr) = do
|
||||||
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
||||||
inferPattern = \case
|
inferPattern = \case
|
||||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||||
|
PCatch -> (T.PCatch,) <$> fresh
|
||||||
|
PVar x -> do
|
||||||
|
fr <- fresh
|
||||||
|
let pvar = T.PVar (coerce x, fr)
|
||||||
|
return (pvar, fr)
|
||||||
|
PEnum p -> do
|
||||||
|
t <- asks (M.lookup (coerce p) . injections)
|
||||||
|
t <-
|
||||||
|
maybeToRightM
|
||||||
|
( Error
|
||||||
|
( Aux.do
|
||||||
|
"Constructor:"
|
||||||
|
quote $ printTree p
|
||||||
|
"does not exist"
|
||||||
|
)
|
||||||
|
True
|
||||||
|
)
|
||||||
|
t
|
||||||
|
unless
|
||||||
|
(typeLength t == 1)
|
||||||
|
( catchableErr $ Aux.do
|
||||||
|
"The constructor"
|
||||||
|
quote $ printTree p
|
||||||
|
" should have "
|
||||||
|
show (typeLength t - 1)
|
||||||
|
" arguments but has been given 0"
|
||||||
|
)
|
||||||
|
let (TData _data _ts) = t -- nasty nasty
|
||||||
|
frs <- mapM (const fresh) _ts
|
||||||
|
return (T.PEnum $ coerce p, TData _data frs)
|
||||||
PInj constr patterns -> do
|
PInj constr patterns -> do
|
||||||
t <- gets (M.lookup (coerce constr) . injections)
|
t <- asks (M.lookup (coerce constr) . injections)
|
||||||
t <-
|
t <-
|
||||||
maybeToRightM
|
maybeToRightM
|
||||||
( Error
|
( Error
|
||||||
|
|
@ -399,36 +410,6 @@ inferPattern = \case
|
||||||
( T.PInj (coerce constr) (apply sub (map fst patterns))
|
( T.PInj (coerce constr) (apply sub (map fst patterns))
|
||||||
, apply sub ret
|
, apply sub ret
|
||||||
)
|
)
|
||||||
PCatch -> (T.PCatch,) <$> fresh
|
|
||||||
PEnum p -> do
|
|
||||||
t <- gets (M.lookup (coerce p) . injections)
|
|
||||||
t <-
|
|
||||||
maybeToRightM
|
|
||||||
( Error
|
|
||||||
( Aux.do
|
|
||||||
"Constructor:"
|
|
||||||
quote $ printTree p
|
|
||||||
"does not exist"
|
|
||||||
)
|
|
||||||
True
|
|
||||||
)
|
|
||||||
t
|
|
||||||
unless
|
|
||||||
(typeLength t == 1)
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
"The constructor"
|
|
||||||
quote $ printTree p
|
|
||||||
" should have "
|
|
||||||
show (typeLength t - 1)
|
|
||||||
" arguments but has been given 0"
|
|
||||||
)
|
|
||||||
let (TData _data _ts) = t -- nasty nasty
|
|
||||||
frs <- mapM (const fresh) _ts
|
|
||||||
return (T.PEnum $ coerce p, TData _data frs)
|
|
||||||
PVar x -> do
|
|
||||||
fr <- fresh
|
|
||||||
let pvar = T.PVar (coerce x, fr)
|
|
||||||
return (pvar, fr)
|
|
||||||
|
|
||||||
-- | Unify two types producing a new substitution
|
-- | Unify two types producing a new substitution
|
||||||
unify :: Type -> Type -> Infer Subst
|
unify :: Type -> Type -> Infer Subst
|
||||||
|
|
@ -437,7 +418,7 @@ unify t0 t1 =
|
||||||
(TFun a b, TFun c d) -> do
|
(TFun a b, TFun c d) -> do
|
||||||
s1 <- unify a c
|
s1 <- unify a c
|
||||||
s2 <- unify (apply s1 b) (apply s1 d)
|
s2 <- unify (apply s1 b) (apply s1 d)
|
||||||
return $ s1 `compose` s2
|
return $ s2 `compose` s1
|
||||||
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
||||||
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
|
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
|
||||||
(TVar (MkTVar a), t) -> occurs (coerce a) t
|
(TVar (MkTVar a), t) -> occurs (coerce a) t
|
||||||
|
|
@ -605,6 +586,9 @@ instance SubstType (Map T.Ident Type) where
|
||||||
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
|
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
|
||||||
apply = M.map . apply
|
apply = M.map . apply
|
||||||
|
|
||||||
|
instance SubstType (Map T.Ident (Maybe Type)) where
|
||||||
|
apply s = M.map (fmap $ apply s)
|
||||||
|
|
||||||
instance SubstType (T.ExpT' Type) where
|
instance SubstType (T.ExpT' Type) where
|
||||||
apply s (e, t) = (apply s e, apply s t)
|
apply s (e, t) = (apply s e, apply s t)
|
||||||
|
|
||||||
|
|
@ -688,15 +672,18 @@ insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||||
|
|
||||||
-- | Insert a constructor into the start with its type
|
-- | Insert a constructor into the start with its type
|
||||||
insertInj :: T.Ident -> Type -> Infer ()
|
insertInj :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
|
||||||
insertInj i t =
|
insertInj i t =
|
||||||
modify (\st -> st{injections = M.insert i t (injections st)})
|
local (\st -> st{injections = M.insert i t (injections st)})
|
||||||
|
|
||||||
|
applySt :: Subst -> Infer a -> Infer a
|
||||||
|
applySt s = local (\st -> st{vars = apply s st.vars})
|
||||||
|
|
||||||
{- | Check if an injection (constructor of data type)
|
{- | Check if an injection (constructor of data type)
|
||||||
with an equivalent name has been declared already
|
with an equivalent name has been declared already
|
||||||
-}
|
-}
|
||||||
existInj :: T.Ident -> Infer (Maybe Type)
|
existInj :: (Monad m, MonadReader Ctx m) => T.Ident -> m (Maybe Type)
|
||||||
existInj n = gets (M.lookup n . injections)
|
existInj n = asks (M.lookup n . injections)
|
||||||
|
|
||||||
setCurrentBind :: T.Ident -> Infer ()
|
setCurrentBind :: T.Ident -> Infer ()
|
||||||
setCurrentBind i = modify (\st -> st{currentBind = i})
|
setCurrentBind i = modify (\st -> st{currentBind = i})
|
||||||
|
|
@ -705,11 +692,12 @@ solveUndecidable :: Infer Subst
|
||||||
solveUndecidable = do
|
solveUndecidable = do
|
||||||
sigs <- gets sigs
|
sigs <- gets sigs
|
||||||
undecided <- gets undecidedSigs
|
undecided <- gets undecidedSigs
|
||||||
let xs = M.toList undecided
|
|
||||||
ys <-
|
ys <-
|
||||||
maybeToRightM
|
maybeToRightM
|
||||||
(Error "SIGNATURE MISSING" False)
|
(Error "SIGNATURE MISSING" False)
|
||||||
(mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs)
|
( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $
|
||||||
|
M.toList undecided
|
||||||
|
)
|
||||||
composeAll <$> mapM (uncurry unify) ys
|
composeAll <$> mapM (uncurry unify) ys
|
||||||
|
|
||||||
tupSequence :: Monad m => (m a, b) -> m (a, b)
|
tupSequence :: Monad m => (m a, b) -> m (a, b)
|
||||||
|
|
@ -738,48 +726,6 @@ litType (LChar _) = char
|
||||||
int = TLit "Int"
|
int = TLit "Int"
|
||||||
char = TLit "Char"
|
char = TLit "Char"
|
||||||
|
|
||||||
typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
|
|
||||||
typeEq (TVar (MkTVar a)) t@(TVar _) = do
|
|
||||||
st <- get
|
|
||||||
case M.lookup (coerce a) st of
|
|
||||||
Nothing -> put $ M.insert (coerce a) t st
|
|
||||||
Just t' ->
|
|
||||||
unless
|
|
||||||
(t == t')
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
quote $ printTree t
|
|
||||||
"does not match with"
|
|
||||||
quote $ printTree t'
|
|
||||||
)
|
|
||||||
typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r'
|
|
||||||
typeEq (TAll _ l) (TAll _ r) = typeEq l r
|
|
||||||
typeEq t@(TLit a) t'@(TLit b) =
|
|
||||||
unless
|
|
||||||
(a == b)
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
quote $ printTree t
|
|
||||||
"does not match with"
|
|
||||||
quote $ printTree t'
|
|
||||||
)
|
|
||||||
typeEq t@(TData nameL tL) t'@(TData nameR tR) = do
|
|
||||||
unless
|
|
||||||
(nameL == nameR)
|
|
||||||
( catchableErr $ Aux.do
|
|
||||||
quote $ printTree t
|
|
||||||
"does not match with"
|
|
||||||
quote $ printTree t'
|
|
||||||
)
|
|
||||||
zipWithM_ typeEq tL tR
|
|
||||||
typeEq t@(TEVar _) t'@(TEVar _) =
|
|
||||||
catchableErr $ Aux.do
|
|
||||||
quote $ printTree t
|
|
||||||
"does not match with"
|
|
||||||
quote $ printTree t'
|
|
||||||
typeEq t t' = catchableErr $ Aux.do
|
|
||||||
quote $ printTree t
|
|
||||||
"does not match with"
|
|
||||||
quote $ printTree t'
|
|
||||||
|
|
||||||
{- | Catch an error if possible and add the given
|
{- | Catch an error if possible and add the given
|
||||||
expression as addition to the error message
|
expression as addition to the error message
|
||||||
-}
|
-}
|
||||||
|
|
@ -824,7 +770,7 @@ bindErr ma bind =
|
||||||
{- | Catch an error if possible and add the given
|
{- | Catch an error if possible and add the given
|
||||||
data as addition to the error message
|
data as addition to the error message
|
||||||
-}
|
-}
|
||||||
dataErr :: Infer a -> Data -> Infer a
|
dataErr :: (MonadError Error m, Monad m) => m a -> Data -> m a
|
||||||
dataErr ma d =
|
dataErr ma d =
|
||||||
catchError
|
catchError
|
||||||
ma
|
ma
|
||||||
|
|
@ -850,19 +796,31 @@ unzip4 =
|
||||||
)
|
)
|
||||||
([], [], [], [])
|
([], [], [], [])
|
||||||
|
|
||||||
newtype Ctx = Ctx {vars :: Map T.Ident Type}
|
initCtx = Ctx mempty mempty
|
||||||
|
initEnv = Env 0 'a' mempty mempty "" mempty mempty
|
||||||
|
|
||||||
|
run :: Infer a -> Either Error a
|
||||||
|
run = run' initEnv initCtx
|
||||||
|
|
||||||
|
run' :: Env -> Ctx -> Infer a -> Either Error a
|
||||||
|
run' e c =
|
||||||
|
runIdentity
|
||||||
|
. runExceptT
|
||||||
|
. flip runReaderT c
|
||||||
|
. flip evalStateT e
|
||||||
|
. runInfer
|
||||||
|
|
||||||
|
data Ctx = Ctx {vars :: Map T.Ident Type, injections :: Map T.Ident Type}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
data Env = Env
|
data Env = Env
|
||||||
{ count :: Int
|
{ count :: Int
|
||||||
, nextChar :: Char
|
, nextChar :: Char
|
||||||
, sigs :: Map T.Ident (Maybe Type)
|
, sigs :: Map T.Ident (Maybe Type)
|
||||||
, injections :: Map T.Ident Type
|
|
||||||
, takenTypeVars :: Set T.Ident
|
, takenTypeVars :: Set T.Ident
|
||||||
, currentBind :: T.Ident
|
, currentBind :: T.Ident
|
||||||
, undecidedSigs :: Map T.Ident Type
|
, undecidedSigs :: Map T.Ident Type
|
||||||
, toDecide :: Set T.Ident
|
, toDecide :: Set T.Ident
|
||||||
, declaredBinds :: Set T.Ident
|
|
||||||
}
|
}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue