cleaned up implementations and added check for duplicate constructors

This commit is contained in:
sebastian 2023-03-27 22:38:39 +02:00
parent e1633ea147
commit 4b24755b93

View file

@ -13,7 +13,6 @@ import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
@ -22,7 +21,6 @@ import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Data.String
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
@ -45,33 +43,61 @@ typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg
checkData :: Data -> Infer ()
checkData d = do
case d of
(Data typ@(TData _ ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Inj name' t') ->
if typ == returnType t'
then insertConstr (coerce name') t'
else
throwError $
unwords
[ "return type of constructor:"
, printTree name'
, "with type:"
, printTree (returnType t')
, "does not match data: "
, printTree typ
]
)
constrs
_ ->
throwError $
"incorrectly declared data type '"
<> printTree d
<> "'"
checkData (Data typ injs) = do
(name, tvars) <- go typ
mapM_ (\i -> typecheckInj i name tvars) injs
where
go = \case
TData name typs
| Right tvars' <- mapM toTVar typs ->
pure (name, tvars')
TAll _ _ -> throwError "Explicit foralls not allowed, for now"
_ -> throwError $ unwords ["Bad data type definition: ", printTree typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
typecheckInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ =
throwError "Unbound type variables"
| TData name' typs <- returnType inj_typ
, Right tvars' <- mapM toTVar typs
, name' == name
, tvars' == tvars = do
exist <- existInj (coerce c)
case exist of
Just t -> throwError $ Aux.do
"Constructor"
quote $ coerce name
"with type"
quote $ printTree t
"already exist"
Nothing -> insertInj (coerce c) inj_typ
| otherwise =
throwError $
unwords
[ "Bad type constructor: "
, show name
, "\nExpected: "
, printTree . TData name $ map TVar tvars
, "\nActual: "
, printTree $ returnType inj_typ
]
where
boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case
TAll{} -> throwError "Explicit foralls not allowed, for now"
TFun t1 t2 -> do
t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2
return $ t1' && t2'
TVar tvar -> return $ tvar `elem` tvars'
TData _ typs -> and <$> mapM (boundTVars tvars) typs
TLit _ -> return True
TEVar _ -> error "TEVar in data type declaration"
toTVar :: Type -> Either String TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> throwError "Not a type variable"
returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
@ -91,10 +117,9 @@ preRun (x : xs) = case x of
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
( throwError $ Aux.do
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
@ -126,12 +151,11 @@ checkBind (Bind name args e) = do
Just (Just t') -> do
unless
(args_t `typeEq` t')
( throwError $
"Inferred type '"
++ printTree args_t
++ " does not match specified type '"
++ printTree t'
++ "'"
( throwError $ Aux.do
"Inferred type"
quote $ printTree args_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, t') [] e
_ -> do
@ -195,7 +219,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW = \case
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
@ -209,14 +233,14 @@ algoW = \case
, printTree t'
]
)
s2 <- exprErr (unify (t) t') err
s2 <- exprErr (unify t t') err
let comp = s2 `compose` s1
return (comp, apply comp (e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
ELit lit -> return (nullSubst, (T.ELit lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
@ -234,7 +258,7 @@ algoW = \case
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do
constr <- gets constructors
constr <- gets injections
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing ->
@ -334,14 +358,12 @@ unify t0 t1 = do
(TLit a, TLit b) ->
if a == b
then return M.empty
else
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree (TLit a) <> "'"
, "with"
, "'" <> printTree (TLit b) <> "'"
]
else throwError $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
@ -351,41 +373,26 @@ unify t0 t1 = do
Aux.do
"Type constructor:"
printTree name
"("
printTree t
")"
quote $ printTree t
"does not match with:"
printTree name'
"("
printTree t'
")"
-- [ "Type constructor:"
-- , printTree name
-- , "(" <> printTree t <> ")"
-- , "does not match with:"
-- , printTree name'
-- , "(" <> printTree t' <> ")"
-- ]
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
else
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree (TEVar a) <> "'"
, "with"
, "'" <> printTree (TEVar b) <> "'"
]
else throwError $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree a <> "'"
, "with"
, "'" <> printTree b <> "'"
]
throwError $
Aux.do
"Can not unify"
quote $ printTree a
"with"
quote $ printTree b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -395,14 +402,12 @@ occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TVar $ T.MkTVar (coerce i))
, "with"
, printTree t
]
then throwError $
Aux.do
"Occurs check failed, can't unify"
quote $ printTree (TVar $ T.MkTVar (coerce i))
"with"
quote $ printTree t
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
@ -455,6 +460,7 @@ instance FreeVars Type where
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
free (TEVar _) = S.empty
instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty
@ -471,7 +477,10 @@ instance SubstType Type where
Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (map (apply sub) a)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
@ -555,9 +564,12 @@ insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: T.Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
insertInj :: T.Ident -> Type -> Infer ()
insertInj i t =
modify (\st -> st{injections = M.insert i t (injections st)})
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
-------- PATTERN MATCHING ---------
@ -601,37 +613,35 @@ inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors)
t <- gets (M.lookup (coerce constr) . injections)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless
(length patterns == numArgs)
( throwError $
"The constructor '"
++ printTree constr
++ "'"
++ " should have "
++ show numArgs
++ " arguments but has been given "
++ show (length patterns)
( throwError $ Aux.do
"The constructor"
quote $ printTree constr
" should have "
show numArgs
" arguments but has been given "
show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors)
t <- gets (M.lookup (coerce p) . injections)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless
(typeLength t == 1)
( throwError $
"The constructor '"
++ printTree p
++ "'"
++ " should have "
++ show (typeLength t - 1)
++ " arguments but has been given 0"
( throwError $ Aux.do
"The constructor"
quote $ printTree p
" should have "
show (typeLength t - 1)
" arguments but has been given 0"
)
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
@ -687,7 +697,7 @@ data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type
, injections :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
}
deriving (Show)
@ -697,3 +707,6 @@ type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
quote :: String -> String
quote s = "'" ++ s ++ "'"