cleaned up implementations and added check for duplicate constructors

This commit is contained in:
sebastian 2023-03-27 22:38:39 +02:00
parent e1633ea147
commit 4b24755b93

View file

@ -13,7 +13,6 @@ import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second) import Data.Bifunctor (second)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
@ -22,7 +21,6 @@ import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Data.String
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
@ -45,33 +43,61 @@ typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg typecheck = run . checkPrg
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = do checkData (Data typ injs) = do
case d of (name, tvars) <- go typ
(Data typ@(TData _ ts) constrs) -> do mapM_ (\i -> typecheckInj i name tvars) injs
unless where
(all isPoly ts) go = \case
(throwError $ unwords ["Data type incorrectly declared"]) TData name typs
traverse_ | Right tvars' <- mapM toTVar typs ->
( \(Inj name' t') -> pure (name, tvars')
if typ == returnType t' TAll _ _ -> throwError "Explicit foralls not allowed, for now"
then insertConstr (coerce name') t' _ -> throwError $ unwords ["Bad data type definition: ", printTree typ]
else
throwError $ typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
unwords typecheckInj (Inj c inj_typ) name tvars
[ "return type of constructor:" | Right False <- boundTVars tvars inj_typ =
, printTree name' throwError "Unbound type variables"
, "with type:" | TData name' typs <- returnType inj_typ
, printTree (returnType t') , Right tvars' <- mapM toTVar typs
, "does not match data: " , name' == name
, printTree typ , tvars' == tvars = do
] exist <- existInj (coerce c)
) case exist of
constrs Just t -> throwError $ Aux.do
_ -> "Constructor"
throwError $ quote $ coerce name
"incorrectly declared data type '" "with type"
<> printTree d quote $ printTree t
<> "'" "already exist"
Nothing -> insertInj (coerce c) inj_typ
| otherwise =
throwError $
unwords
[ "Bad type constructor: "
, show name
, "\nExpected: "
, printTree . TData name $ map TVar tvars
, "\nActual: "
, printTree $ returnType inj_typ
]
where
boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case
TAll{} -> throwError "Explicit foralls not allowed, for now"
TFun t1 t2 -> do
t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2
return $ t1' && t2'
TVar tvar -> return $ tvar `elem` tvars'
TData _ typs -> and <$> mapM (boundTVars tvars) typs
TLit _ -> return True
TEVar _ -> error "TEVar in data type declaration"
toTVar :: Type -> Either String TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> throwError "Not a type variable"
returnType :: Type -> Type returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
@ -91,10 +117,9 @@ preRun (x : xs) = case x of
gets (M.member (coerce n) . sigs) gets (M.member (coerce n) . sigs)
>>= flip >>= flip
when when
( throwError $ ( throwError $ Aux.do
"Duplicate signatures for function '" "Duplicate signatures for function"
<> printTree n quote $ printTree n
<> "'"
) )
insertSig (coerce n) (Just t) >> preRun xs insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
@ -126,12 +151,11 @@ checkBind (Bind name args e) = do
Just (Just t') -> do Just (Just t') -> do
unless unless
(args_t `typeEq` t') (args_t `typeEq` t')
( throwError $ ( throwError $ Aux.do
"Inferred type '" "Inferred type"
++ printTree args_t quote $ printTree args_t
++ " does not match specified type '" "does not match specified type"
++ printTree t' quote $ printTree t'
++ "'"
) )
return $ T.Bind (coerce name, t') [] e return $ T.Bind (coerce name, t') [] e
_ -> do _ -> do
@ -195,7 +219,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, (T.ExpT' Type)) algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
@ -209,14 +233,14 @@ algoW = \case
, printTree t' , printTree t'
] ]
) )
s2 <- exprErr (unify (t) t') err s2 <- exprErr (unify t t') err
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (e', t)) return (comp, apply comp (e', t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit)) ELit lit -> return (nullSubst, (T.ELit lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
@ -234,7 +258,7 @@ algoW = \case
return (nullSubst, (T.EVar $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do EInj i -> do
constr <- gets constructors constr <- gets injections
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing -> Nothing ->
@ -334,14 +358,12 @@ unify t0 t1 = do
(TLit a, TLit b) -> (TLit a, TLit b) ->
if a == b if a == b
then return M.empty then return M.empty
else else throwError $
throwError Aux.do
. unwords "Can not unify"
$ [ "Can not unify" quote $ printTree (TLit a)
, "'" <> printTree (TLit a) <> "'" "with"
, "with" quote $ printTree (TLit b)
, "'" <> printTree (TLit b) <> "'"
]
(TData name t, TData name' t') -> (TData name t, TData name' t') ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do
@ -351,41 +373,26 @@ unify t0 t1 = do
Aux.do Aux.do
"Type constructor:" "Type constructor:"
printTree name printTree name
"(" quote $ printTree t
printTree t
")"
"does not match with:" "does not match with:"
printTree name' printTree name'
"(" quote $ printTree t'
printTree t'
")"
-- [ "Type constructor:"
-- , printTree name
-- , "(" <> printTree t <> ")"
-- , "does not match with:"
-- , printTree name'
-- , "(" <> printTree t' <> ")"
-- ]
(TEVar a, TEVar b) -> (TEVar a, TEVar b) ->
if a == b if a == b
then return M.empty then return M.empty
else else throwError $
throwError Aux.do
. unwords "Can not unify"
$ [ "Can not unify" quote $ printTree (TEVar a)
, "'" <> printTree (TEVar a) <> "'" "with"
, "with" quote $ printTree (TEVar b)
, "'" <> printTree (TEVar b) <> "'"
]
(a, b) -> do (a, b) -> do
throwError throwError $
. unwords Aux.do
$ [ "Can not unify" "Can not unify"
, "'" <> printTree a <> "'" quote $ printTree a
, "with" "with"
, "'" <> printTree b <> "'" quote $ printTree b
]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -395,14 +402,12 @@ occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then then throwError $
throwError $ Aux.do
unwords "Occurs check failed, can't unify"
[ "Occurs check failed, can't unify" quote $ printTree (TVar $ T.MkTVar (coerce i))
, printTree (TVar $ T.MkTVar (coerce i)) "with"
, "with" quote $ printTree t
, printTree t
]
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
@ -455,6 +460,7 @@ instance FreeVars Type where
free (TLit _) = mempty free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a free (TData _ a) = free a
free (TEVar _) = S.empty
instance FreeVars a => FreeVars [a] where instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty free = let f acc x = acc `S.union` free x in foldl' f S.empty
@ -471,7 +477,10 @@ instance SubstType Type where
Nothing -> TAll (T.MkTVar i) (apply sub t) Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (map (apply sub) a) TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar a)
Just t -> t
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
@ -555,9 +564,12 @@ insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: T.Ident -> Type -> Infer () insertInj :: T.Ident -> Type -> Infer ()
insertConstr i t = insertInj i t =
modify (\st -> st{constructors = M.insert i t (constructors st)}) modify (\st -> st{injections = M.insert i t (injections st)})
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
@ -601,37 +613,35 @@ inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors) t <- gets (M.lookup (coerce constr) . injections)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
let numArgs = typeLength t - 1 let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t) let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
unless unless
(length patterns == numArgs) (length patterns == numArgs)
( throwError $ ( throwError $ Aux.do
"The constructor '" "The constructor"
++ printTree constr quote $ printTree constr
++ "'" " should have "
++ " should have " show numArgs
++ show numArgs " arguments but has been given "
++ " arguments but has been given " show (length patterns)
++ show (length patterns)
) )
sub <- composeAll <$> zipWithM unify vs (map snd patterns) sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret) return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors) t <- gets (M.lookup (coerce p) . injections)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless unless
(typeLength t == 1) (typeLength t == 1)
( throwError $ ( throwError $ Aux.do
"The constructor '" "The constructor"
++ printTree p quote $ printTree p
++ "'" " should have "
++ " should have " show (typeLength t - 1)
++ show (typeLength t - 1) " arguments but has been given 0"
++ " arguments but has been given 0"
) )
let (TData _data _ts) = t -- nasty nasty let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts frs <- mapM (const fresh) _ts
@ -687,7 +697,7 @@ data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type , injections :: Map T.Ident Type
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
} }
deriving (Show) deriving (Show)
@ -697,3 +707,6 @@ type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
quote :: String -> String
quote s = "'" ++ s ++ "'"