cleaned up implementations and added check for duplicate constructors
This commit is contained in:
parent
e1633ea147
commit
4b24755b93
1 changed files with 123 additions and 110 deletions
|
|
@ -13,7 +13,6 @@ import Control.Monad.Reader
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Data.Bifunctor (second)
|
import Data.Bifunctor (second)
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Foldable (traverse_)
|
|
||||||
import Data.Function (on)
|
import Data.Function (on)
|
||||||
import Data.List (foldl')
|
import Data.List (foldl')
|
||||||
import Data.List.Extra (unsnoc)
|
import Data.List.Extra (unsnoc)
|
||||||
|
|
@ -22,7 +21,6 @@ import Data.Map qualified as M
|
||||||
import Data.Maybe (fromJust)
|
import Data.Maybe (fromJust)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Set qualified as S
|
import Data.Set qualified as S
|
||||||
import Data.String
|
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
import TypeChecker.TypeCheckerIr qualified as T
|
||||||
|
|
@ -45,33 +43,61 @@ typecheck :: Program -> Either Error (T.Program' Type)
|
||||||
typecheck = run . checkPrg
|
typecheck = run . checkPrg
|
||||||
|
|
||||||
checkData :: Data -> Infer ()
|
checkData :: Data -> Infer ()
|
||||||
checkData d = do
|
checkData (Data typ injs) = do
|
||||||
case d of
|
(name, tvars) <- go typ
|
||||||
(Data typ@(TData _ ts) constrs) -> do
|
mapM_ (\i -> typecheckInj i name tvars) injs
|
||||||
unless
|
where
|
||||||
(all isPoly ts)
|
go = \case
|
||||||
(throwError $ unwords ["Data type incorrectly declared"])
|
TData name typs
|
||||||
traverse_
|
| Right tvars' <- mapM toTVar typs ->
|
||||||
( \(Inj name' t') ->
|
pure (name, tvars')
|
||||||
if typ == returnType t'
|
TAll _ _ -> throwError "Explicit foralls not allowed, for now"
|
||||||
then insertConstr (coerce name') t'
|
_ -> throwError $ unwords ["Bad data type definition: ", printTree typ]
|
||||||
else
|
|
||||||
|
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
||||||
|
typecheckInj (Inj c inj_typ) name tvars
|
||||||
|
| Right False <- boundTVars tvars inj_typ =
|
||||||
|
throwError "Unbound type variables"
|
||||||
|
| TData name' typs <- returnType inj_typ
|
||||||
|
, Right tvars' <- mapM toTVar typs
|
||||||
|
, name' == name
|
||||||
|
, tvars' == tvars = do
|
||||||
|
exist <- existInj (coerce c)
|
||||||
|
case exist of
|
||||||
|
Just t -> throwError $ Aux.do
|
||||||
|
"Constructor"
|
||||||
|
quote $ coerce name
|
||||||
|
"with type"
|
||||||
|
quote $ printTree t
|
||||||
|
"already exist"
|
||||||
|
Nothing -> insertInj (coerce c) inj_typ
|
||||||
|
| otherwise =
|
||||||
throwError $
|
throwError $
|
||||||
unwords
|
unwords
|
||||||
[ "return type of constructor:"
|
[ "Bad type constructor: "
|
||||||
, printTree name'
|
, show name
|
||||||
, "with type:"
|
, "\nExpected: "
|
||||||
, printTree (returnType t')
|
, printTree . TData name $ map TVar tvars
|
||||||
, "does not match data: "
|
, "\nActual: "
|
||||||
, printTree typ
|
, printTree $ returnType inj_typ
|
||||||
]
|
]
|
||||||
)
|
where
|
||||||
constrs
|
boundTVars :: [TVar] -> Type -> Either Error Bool
|
||||||
_ ->
|
boundTVars tvars' = \case
|
||||||
throwError $
|
TAll{} -> throwError "Explicit foralls not allowed, for now"
|
||||||
"incorrectly declared data type '"
|
TFun t1 t2 -> do
|
||||||
<> printTree d
|
t1' <- boundTVars tvars t1
|
||||||
<> "'"
|
t2' <- boundTVars tvars t2
|
||||||
|
return $ t1' && t2'
|
||||||
|
TVar tvar -> return $ tvar `elem` tvars'
|
||||||
|
TData _ typs -> and <$> mapM (boundTVars tvars) typs
|
||||||
|
TLit _ -> return True
|
||||||
|
TEVar _ -> error "TEVar in data type declaration"
|
||||||
|
|
||||||
|
toTVar :: Type -> Either String TVar
|
||||||
|
toTVar = \case
|
||||||
|
TVar tvar -> pure tvar
|
||||||
|
_ -> throwError "Not a type variable"
|
||||||
|
|
||||||
returnType :: Type -> Type
|
returnType :: Type -> Type
|
||||||
returnType (TFun _ t2) = returnType t2
|
returnType (TFun _ t2) = returnType t2
|
||||||
|
|
@ -91,10 +117,9 @@ preRun (x : xs) = case x of
|
||||||
gets (M.member (coerce n) . sigs)
|
gets (M.member (coerce n) . sigs)
|
||||||
>>= flip
|
>>= flip
|
||||||
when
|
when
|
||||||
( throwError $
|
( throwError $ Aux.do
|
||||||
"Duplicate signatures for function '"
|
"Duplicate signatures for function"
|
||||||
<> printTree n
|
quote $ printTree n
|
||||||
<> "'"
|
|
||||||
)
|
)
|
||||||
insertSig (coerce n) (Just t) >> preRun xs
|
insertSig (coerce n) (Just t) >> preRun xs
|
||||||
DBind (Bind n _ e) -> do
|
DBind (Bind n _ e) -> do
|
||||||
|
|
@ -126,12 +151,11 @@ checkBind (Bind name args e) = do
|
||||||
Just (Just t') -> do
|
Just (Just t') -> do
|
||||||
unless
|
unless
|
||||||
(args_t `typeEq` t')
|
(args_t `typeEq` t')
|
||||||
( throwError $
|
( throwError $ Aux.do
|
||||||
"Inferred type '"
|
"Inferred type"
|
||||||
++ printTree args_t
|
quote $ printTree args_t
|
||||||
++ " does not match specified type '"
|
"does not match specified type"
|
||||||
++ printTree t'
|
quote $ printTree t'
|
||||||
++ "'"
|
|
||||||
)
|
)
|
||||||
return $ T.Bind (coerce name, t') [] e
|
return $ T.Bind (coerce name, t') [] e
|
||||||
_ -> do
|
_ -> do
|
||||||
|
|
@ -195,7 +219,7 @@ instance CollectTVars Type where
|
||||||
collect :: Set T.Ident -> Infer ()
|
collect :: Set T.Ident -> Infer ()
|
||||||
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
||||||
|
|
||||||
algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
|
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||||
algoW = \case
|
algoW = \case
|
||||||
err@(EAnn e t) -> do
|
err@(EAnn e t) -> do
|
||||||
(s1, (e', t')) <- exprErr (algoW e) err
|
(s1, (e', t')) <- exprErr (algoW e) err
|
||||||
|
|
@ -209,14 +233,14 @@ algoW = \case
|
||||||
, printTree t'
|
, printTree t'
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
s2 <- exprErr (unify (t) t') err
|
s2 <- exprErr (unify t t') err
|
||||||
let comp = s2 `compose` s1
|
let comp = s2 `compose` s1
|
||||||
return (comp, apply comp (e', t))
|
return (comp, apply comp (e', t))
|
||||||
|
|
||||||
-- \| ------------------
|
-- \| ------------------
|
||||||
-- \| Γ ⊢ i : Int, ∅
|
-- \| Γ ⊢ i : Int, ∅
|
||||||
|
|
||||||
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
|
ELit lit -> return (nullSubst, (T.ELit lit, litType lit))
|
||||||
-- \| x : σ ∈ Γ τ = inst(σ)
|
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||||
-- \| ----------------------
|
-- \| ----------------------
|
||||||
-- \| Γ ⊢ x : τ, ∅
|
-- \| Γ ⊢ x : τ, ∅
|
||||||
|
|
@ -234,7 +258,7 @@ algoW = \case
|
||||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||||
Nothing -> throwError $ "Unbound variable: " <> printTree i
|
Nothing -> throwError $ "Unbound variable: " <> printTree i
|
||||||
EInj i -> do
|
EInj i -> do
|
||||||
constr <- gets constructors
|
constr <- gets injections
|
||||||
case M.lookup (coerce i) constr of
|
case M.lookup (coerce i) constr of
|
||||||
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||||
Nothing ->
|
Nothing ->
|
||||||
|
|
@ -334,14 +358,12 @@ unify t0 t1 = do
|
||||||
(TLit a, TLit b) ->
|
(TLit a, TLit b) ->
|
||||||
if a == b
|
if a == b
|
||||||
then return M.empty
|
then return M.empty
|
||||||
else
|
else throwError $
|
||||||
throwError
|
Aux.do
|
||||||
. unwords
|
"Can not unify"
|
||||||
$ [ "Can not unify"
|
quote $ printTree (TLit a)
|
||||||
, "'" <> printTree (TLit a) <> "'"
|
"with"
|
||||||
, "with"
|
quote $ printTree (TLit b)
|
||||||
, "'" <> printTree (TLit b) <> "'"
|
|
||||||
]
|
|
||||||
(TData name t, TData name' t') ->
|
(TData name t, TData name' t') ->
|
||||||
if name == name' && length t == length t'
|
if name == name' && length t == length t'
|
||||||
then do
|
then do
|
||||||
|
|
@ -351,41 +373,26 @@ unify t0 t1 = do
|
||||||
Aux.do
|
Aux.do
|
||||||
"Type constructor:"
|
"Type constructor:"
|
||||||
printTree name
|
printTree name
|
||||||
"("
|
quote $ printTree t
|
||||||
printTree t
|
|
||||||
")"
|
|
||||||
"does not match with:"
|
"does not match with:"
|
||||||
printTree name'
|
printTree name'
|
||||||
"("
|
quote $ printTree t'
|
||||||
printTree t'
|
|
||||||
")"
|
|
||||||
|
|
||||||
-- [ "Type constructor:"
|
|
||||||
-- , printTree name
|
|
||||||
-- , "(" <> printTree t <> ")"
|
|
||||||
-- , "does not match with:"
|
|
||||||
-- , printTree name'
|
|
||||||
-- , "(" <> printTree t' <> ")"
|
|
||||||
-- ]
|
|
||||||
(TEVar a, TEVar b) ->
|
(TEVar a, TEVar b) ->
|
||||||
if a == b
|
if a == b
|
||||||
then return M.empty
|
then return M.empty
|
||||||
else
|
else throwError $
|
||||||
throwError
|
Aux.do
|
||||||
. unwords
|
"Can not unify"
|
||||||
$ [ "Can not unify"
|
quote $ printTree (TEVar a)
|
||||||
, "'" <> printTree (TEVar a) <> "'"
|
"with"
|
||||||
, "with"
|
quote $ printTree (TEVar b)
|
||||||
, "'" <> printTree (TEVar b) <> "'"
|
|
||||||
]
|
|
||||||
(a, b) -> do
|
(a, b) -> do
|
||||||
throwError
|
throwError $
|
||||||
. unwords
|
Aux.do
|
||||||
$ [ "Can not unify"
|
"Can not unify"
|
||||||
, "'" <> printTree a <> "'"
|
quote $ printTree a
|
||||||
, "with"
|
"with"
|
||||||
, "'" <> printTree b <> "'"
|
quote $ printTree b
|
||||||
]
|
|
||||||
|
|
||||||
{- | Check if a type is contained in another type.
|
{- | Check if a type is contained in another type.
|
||||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||||
|
|
@ -395,14 +402,12 @@ occurs :: T.Ident -> Type -> Infer Subst
|
||||||
occurs i t@(TVar _) = return (M.singleton i t)
|
occurs i t@(TVar _) = return (M.singleton i t)
|
||||||
occurs i t =
|
occurs i t =
|
||||||
if S.member i (free t)
|
if S.member i (free t)
|
||||||
then
|
then throwError $
|
||||||
throwError $
|
Aux.do
|
||||||
unwords
|
"Occurs check failed, can't unify"
|
||||||
[ "Occurs check failed, can't unify"
|
quote $ printTree (TVar $ T.MkTVar (coerce i))
|
||||||
, printTree (TVar $ T.MkTVar (coerce i))
|
"with"
|
||||||
, "with"
|
quote $ printTree t
|
||||||
, printTree t
|
|
||||||
]
|
|
||||||
else return $ M.singleton i t
|
else return $ M.singleton i t
|
||||||
|
|
||||||
-- | Generalize a type over all free variables in the substitution set
|
-- | Generalize a type over all free variables in the substitution set
|
||||||
|
|
@ -455,6 +460,7 @@ instance FreeVars Type where
|
||||||
free (TLit _) = mempty
|
free (TLit _) = mempty
|
||||||
free (TFun a b) = free a `S.union` free b
|
free (TFun a b) = free a `S.union` free b
|
||||||
free (TData _ a) = free a
|
free (TData _ a) = free a
|
||||||
|
free (TEVar _) = S.empty
|
||||||
|
|
||||||
instance FreeVars a => FreeVars [a] where
|
instance FreeVars a => FreeVars [a] where
|
||||||
free = let f acc x = acc `S.union` free x in foldl' f S.empty
|
free = let f acc x = acc `S.union` free x in foldl' f S.empty
|
||||||
|
|
@ -471,7 +477,10 @@ instance SubstType Type where
|
||||||
Nothing -> TAll (T.MkTVar i) (apply sub t)
|
Nothing -> TAll (T.MkTVar i) (apply sub t)
|
||||||
Just _ -> apply sub t
|
Just _ -> apply sub t
|
||||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||||
TData name a -> TData name (map (apply sub) a)
|
TData name a -> TData name (apply sub a)
|
||||||
|
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
|
||||||
|
Nothing -> TEVar (MkTEVar a)
|
||||||
|
Just t -> t
|
||||||
|
|
||||||
instance FreeVars (Map T.Ident Type) where
|
instance FreeVars (Map T.Ident Type) where
|
||||||
free :: Map T.Ident Type -> Set T.Ident
|
free :: Map T.Ident Type -> Set T.Ident
|
||||||
|
|
@ -555,9 +564,12 @@ insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||||
|
|
||||||
-- | Insert a constructor with its data type
|
-- | Insert a constructor with its data type
|
||||||
insertConstr :: T.Ident -> Type -> Infer ()
|
insertInj :: T.Ident -> Type -> Infer ()
|
||||||
insertConstr i t =
|
insertInj i t =
|
||||||
modify (\st -> st{constructors = M.insert i t (constructors st)})
|
modify (\st -> st{injections = M.insert i t (injections st)})
|
||||||
|
|
||||||
|
existInj :: T.Ident -> Infer (Maybe Type)
|
||||||
|
existInj n = gets (M.lookup n . injections)
|
||||||
|
|
||||||
-------- PATTERN MATCHING ---------
|
-------- PATTERN MATCHING ---------
|
||||||
|
|
||||||
|
|
@ -601,37 +613,35 @@ inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
||||||
inferPattern = \case
|
inferPattern = \case
|
||||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||||
PInj constr patterns -> do
|
PInj constr patterns -> do
|
||||||
t <- gets (M.lookup (coerce constr) . constructors)
|
t <- gets (M.lookup (coerce constr) . injections)
|
||||||
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
|
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
|
||||||
let numArgs = typeLength t - 1
|
let numArgs = typeLength t - 1
|
||||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||||
patterns <- mapM inferPattern patterns
|
patterns <- mapM inferPattern patterns
|
||||||
unless
|
unless
|
||||||
(length patterns == numArgs)
|
(length patterns == numArgs)
|
||||||
( throwError $
|
( throwError $ Aux.do
|
||||||
"The constructor '"
|
"The constructor"
|
||||||
++ printTree constr
|
quote $ printTree constr
|
||||||
++ "'"
|
" should have "
|
||||||
++ " should have "
|
show numArgs
|
||||||
++ show numArgs
|
" arguments but has been given "
|
||||||
++ " arguments but has been given "
|
show (length patterns)
|
||||||
++ show (length patterns)
|
|
||||||
)
|
)
|
||||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||||
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
|
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
|
||||||
PCatch -> (T.PCatch,) <$> fresh
|
PCatch -> (T.PCatch,) <$> fresh
|
||||||
PEnum p -> do
|
PEnum p -> do
|
||||||
t <- gets (M.lookup (coerce p) . constructors)
|
t <- gets (M.lookup (coerce p) . injections)
|
||||||
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
|
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
|
||||||
unless
|
unless
|
||||||
(typeLength t == 1)
|
(typeLength t == 1)
|
||||||
( throwError $
|
( throwError $ Aux.do
|
||||||
"The constructor '"
|
"The constructor"
|
||||||
++ printTree p
|
quote $ printTree p
|
||||||
++ "'"
|
" should have "
|
||||||
++ " should have "
|
show (typeLength t - 1)
|
||||||
++ show (typeLength t - 1)
|
" arguments but has been given 0"
|
||||||
++ " arguments but has been given 0"
|
|
||||||
)
|
)
|
||||||
let (TData _data _ts) = t -- nasty nasty
|
let (TData _data _ts) = t -- nasty nasty
|
||||||
frs <- mapM (const fresh) _ts
|
frs <- mapM (const fresh) _ts
|
||||||
|
|
@ -687,7 +697,7 @@ data Env = Env
|
||||||
{ count :: Int
|
{ count :: Int
|
||||||
, nextChar :: Char
|
, nextChar :: Char
|
||||||
, sigs :: Map T.Ident (Maybe Type)
|
, sigs :: Map T.Ident (Maybe Type)
|
||||||
, constructors :: Map T.Ident Type
|
, injections :: Map T.Ident Type
|
||||||
, takenTypeVars :: Set T.Ident
|
, takenTypeVars :: Set T.Ident
|
||||||
}
|
}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
@ -697,3 +707,6 @@ type Subst = Map T.Ident Type
|
||||||
|
|
||||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
||||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||||
|
|
||||||
|
quote :: String -> String
|
||||||
|
quote s = "'" ++ s ++ "'"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue