cleaned up implementations and added check for duplicate constructors
This commit is contained in:
parent
e1633ea147
commit
4b24755b93
1 changed files with 123 additions and 110 deletions
|
|
@ -13,7 +13,6 @@ import Control.Monad.Reader
|
|||
import Control.Monad.State
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
import Data.List.Extra (unsnoc)
|
||||
|
|
@ -22,7 +21,6 @@ import Data.Map qualified as M
|
|||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Data.String
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
|
@ -45,33 +43,61 @@ typecheck :: Program -> Either Error (T.Program' Type)
|
|||
typecheck = run . checkPrg
|
||||
|
||||
checkData :: Data -> Infer ()
|
||||
checkData d = do
|
||||
case d of
|
||||
(Data typ@(TData _ ts) constrs) -> do
|
||||
unless
|
||||
(all isPoly ts)
|
||||
(throwError $ unwords ["Data type incorrectly declared"])
|
||||
traverse_
|
||||
( \(Inj name' t') ->
|
||||
if typ == returnType t'
|
||||
then insertConstr (coerce name') t'
|
||||
else
|
||||
throwError $
|
||||
unwords
|
||||
[ "return type of constructor:"
|
||||
, printTree name'
|
||||
, "with type:"
|
||||
, printTree (returnType t')
|
||||
, "does not match data: "
|
||||
, printTree typ
|
||||
]
|
||||
)
|
||||
constrs
|
||||
_ ->
|
||||
throwError $
|
||||
"incorrectly declared data type '"
|
||||
<> printTree d
|
||||
<> "'"
|
||||
checkData (Data typ injs) = do
|
||||
(name, tvars) <- go typ
|
||||
mapM_ (\i -> typecheckInj i name tvars) injs
|
||||
where
|
||||
go = \case
|
||||
TData name typs
|
||||
| Right tvars' <- mapM toTVar typs ->
|
||||
pure (name, tvars')
|
||||
TAll _ _ -> throwError "Explicit foralls not allowed, for now"
|
||||
_ -> throwError $ unwords ["Bad data type definition: ", printTree typ]
|
||||
|
||||
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
|
||||
typecheckInj (Inj c inj_typ) name tvars
|
||||
| Right False <- boundTVars tvars inj_typ =
|
||||
throwError "Unbound type variables"
|
||||
| TData name' typs <- returnType inj_typ
|
||||
, Right tvars' <- mapM toTVar typs
|
||||
, name' == name
|
||||
, tvars' == tvars = do
|
||||
exist <- existInj (coerce c)
|
||||
case exist of
|
||||
Just t -> throwError $ Aux.do
|
||||
"Constructor"
|
||||
quote $ coerce name
|
||||
"with type"
|
||||
quote $ printTree t
|
||||
"already exist"
|
||||
Nothing -> insertInj (coerce c) inj_typ
|
||||
| otherwise =
|
||||
throwError $
|
||||
unwords
|
||||
[ "Bad type constructor: "
|
||||
, show name
|
||||
, "\nExpected: "
|
||||
, printTree . TData name $ map TVar tvars
|
||||
, "\nActual: "
|
||||
, printTree $ returnType inj_typ
|
||||
]
|
||||
where
|
||||
boundTVars :: [TVar] -> Type -> Either Error Bool
|
||||
boundTVars tvars' = \case
|
||||
TAll{} -> throwError "Explicit foralls not allowed, for now"
|
||||
TFun t1 t2 -> do
|
||||
t1' <- boundTVars tvars t1
|
||||
t2' <- boundTVars tvars t2
|
||||
return $ t1' && t2'
|
||||
TVar tvar -> return $ tvar `elem` tvars'
|
||||
TData _ typs -> and <$> mapM (boundTVars tvars) typs
|
||||
TLit _ -> return True
|
||||
TEVar _ -> error "TEVar in data type declaration"
|
||||
|
||||
toTVar :: Type -> Either String TVar
|
||||
toTVar = \case
|
||||
TVar tvar -> pure tvar
|
||||
_ -> throwError "Not a type variable"
|
||||
|
||||
returnType :: Type -> Type
|
||||
returnType (TFun _ t2) = returnType t2
|
||||
|
|
@ -91,10 +117,9 @@ preRun (x : xs) = case x of
|
|||
gets (M.member (coerce n) . sigs)
|
||||
>>= flip
|
||||
when
|
||||
( throwError $
|
||||
"Duplicate signatures for function '"
|
||||
<> printTree n
|
||||
<> "'"
|
||||
( throwError $ Aux.do
|
||||
"Duplicate signatures for function"
|
||||
quote $ printTree n
|
||||
)
|
||||
insertSig (coerce n) (Just t) >> preRun xs
|
||||
DBind (Bind n _ e) -> do
|
||||
|
|
@ -126,12 +151,11 @@ checkBind (Bind name args e) = do
|
|||
Just (Just t') -> do
|
||||
unless
|
||||
(args_t `typeEq` t')
|
||||
( throwError $
|
||||
"Inferred type '"
|
||||
++ printTree args_t
|
||||
++ " does not match specified type '"
|
||||
++ printTree t'
|
||||
++ "'"
|
||||
( throwError $ Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree args_t
|
||||
"does not match specified type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
return $ T.Bind (coerce name, t') [] e
|
||||
_ -> do
|
||||
|
|
@ -195,7 +219,7 @@ instance CollectTVars Type where
|
|||
collect :: Set T.Ident -> Infer ()
|
||||
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
||||
|
||||
algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
|
||||
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
algoW = \case
|
||||
err@(EAnn e t) -> do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
|
|
@ -209,14 +233,14 @@ algoW = \case
|
|||
, printTree t'
|
||||
]
|
||||
)
|
||||
s2 <- exprErr (unify (t) t') err
|
||||
s2 <- exprErr (unify t t') err
|
||||
let comp = s2 `compose` s1
|
||||
return (comp, apply comp (e', t))
|
||||
|
||||
-- \| ------------------
|
||||
-- \| Γ ⊢ i : Int, ∅
|
||||
|
||||
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
|
||||
ELit lit -> return (nullSubst, (T.ELit lit, litType lit))
|
||||
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||
-- \| ----------------------
|
||||
-- \| Γ ⊢ x : τ, ∅
|
||||
|
|
@ -234,7 +258,7 @@ algoW = \case
|
|||
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||
Nothing -> throwError $ "Unbound variable: " <> printTree i
|
||||
EInj i -> do
|
||||
constr <- gets constructors
|
||||
constr <- gets injections
|
||||
case M.lookup (coerce i) constr of
|
||||
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||
Nothing ->
|
||||
|
|
@ -334,14 +358,12 @@ unify t0 t1 = do
|
|||
(TLit a, TLit b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
else
|
||||
throwError
|
||||
. unwords
|
||||
$ [ "Can not unify"
|
||||
, "'" <> printTree (TLit a) <> "'"
|
||||
, "with"
|
||||
, "'" <> printTree (TLit b) <> "'"
|
||||
]
|
||||
else throwError $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
quote $ printTree (TLit a)
|
||||
"with"
|
||||
quote $ printTree (TLit b)
|
||||
(TData name t, TData name' t') ->
|
||||
if name == name' && length t == length t'
|
||||
then do
|
||||
|
|
@ -351,41 +373,26 @@ unify t0 t1 = do
|
|||
Aux.do
|
||||
"Type constructor:"
|
||||
printTree name
|
||||
"("
|
||||
printTree t
|
||||
")"
|
||||
quote $ printTree t
|
||||
"does not match with:"
|
||||
printTree name'
|
||||
"("
|
||||
printTree t'
|
||||
")"
|
||||
|
||||
-- [ "Type constructor:"
|
||||
-- , printTree name
|
||||
-- , "(" <> printTree t <> ")"
|
||||
-- , "does not match with:"
|
||||
-- , printTree name'
|
||||
-- , "(" <> printTree t' <> ")"
|
||||
-- ]
|
||||
quote $ printTree t'
|
||||
(TEVar a, TEVar b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
else
|
||||
throwError
|
||||
. unwords
|
||||
$ [ "Can not unify"
|
||||
, "'" <> printTree (TEVar a) <> "'"
|
||||
, "with"
|
||||
, "'" <> printTree (TEVar b) <> "'"
|
||||
]
|
||||
else throwError $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
quote $ printTree (TEVar a)
|
||||
"with"
|
||||
quote $ printTree (TEVar b)
|
||||
(a, b) -> do
|
||||
throwError
|
||||
. unwords
|
||||
$ [ "Can not unify"
|
||||
, "'" <> printTree a <> "'"
|
||||
, "with"
|
||||
, "'" <> printTree b <> "'"
|
||||
]
|
||||
throwError $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
quote $ printTree a
|
||||
"with"
|
||||
quote $ printTree b
|
||||
|
||||
{- | Check if a type is contained in another type.
|
||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||
|
|
@ -395,14 +402,12 @@ occurs :: T.Ident -> Type -> Infer Subst
|
|||
occurs i t@(TVar _) = return (M.singleton i t)
|
||||
occurs i t =
|
||||
if S.member i (free t)
|
||||
then
|
||||
throwError $
|
||||
unwords
|
||||
[ "Occurs check failed, can't unify"
|
||||
, printTree (TVar $ T.MkTVar (coerce i))
|
||||
, "with"
|
||||
, printTree t
|
||||
]
|
||||
then throwError $
|
||||
Aux.do
|
||||
"Occurs check failed, can't unify"
|
||||
quote $ printTree (TVar $ T.MkTVar (coerce i))
|
||||
"with"
|
||||
quote $ printTree t
|
||||
else return $ M.singleton i t
|
||||
|
||||
-- | Generalize a type over all free variables in the substitution set
|
||||
|
|
@ -455,6 +460,7 @@ instance FreeVars Type where
|
|||
free (TLit _) = mempty
|
||||
free (TFun a b) = free a `S.union` free b
|
||||
free (TData _ a) = free a
|
||||
free (TEVar _) = S.empty
|
||||
|
||||
instance FreeVars a => FreeVars [a] where
|
||||
free = let f acc x = acc `S.union` free x in foldl' f S.empty
|
||||
|
|
@ -471,7 +477,10 @@ instance SubstType Type where
|
|||
Nothing -> TAll (T.MkTVar i) (apply sub t)
|
||||
Just _ -> apply sub t
|
||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||
TData name a -> TData name (map (apply sub) a)
|
||||
TData name a -> TData name (apply sub a)
|
||||
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TEVar (MkTEVar a)
|
||||
Just t -> t
|
||||
|
||||
instance FreeVars (Map T.Ident Type) where
|
||||
free :: Map T.Ident Type -> Set T.Ident
|
||||
|
|
@ -555,9 +564,12 @@ insertSig :: T.Ident -> Maybe Type -> Infer ()
|
|||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||
|
||||
-- | Insert a constructor with its data type
|
||||
insertConstr :: T.Ident -> Type -> Infer ()
|
||||
insertConstr i t =
|
||||
modify (\st -> st{constructors = M.insert i t (constructors st)})
|
||||
insertInj :: T.Ident -> Type -> Infer ()
|
||||
insertInj i t =
|
||||
modify (\st -> st{injections = M.insert i t (injections st)})
|
||||
|
||||
existInj :: T.Ident -> Infer (Maybe Type)
|
||||
existInj n = gets (M.lookup n . injections)
|
||||
|
||||
-------- PATTERN MATCHING ---------
|
||||
|
||||
|
|
@ -601,37 +613,35 @@ inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
|
|||
inferPattern = \case
|
||||
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
|
||||
PInj constr patterns -> do
|
||||
t <- gets (M.lookup (coerce constr) . constructors)
|
||||
t <- gets (M.lookup (coerce constr) . injections)
|
||||
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
|
||||
let numArgs = typeLength t - 1
|
||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||
patterns <- mapM inferPattern patterns
|
||||
unless
|
||||
(length patterns == numArgs)
|
||||
( throwError $
|
||||
"The constructor '"
|
||||
++ printTree constr
|
||||
++ "'"
|
||||
++ " should have "
|
||||
++ show numArgs
|
||||
++ " arguments but has been given "
|
||||
++ show (length patterns)
|
||||
( throwError $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree constr
|
||||
" should have "
|
||||
show numArgs
|
||||
" arguments but has been given "
|
||||
show (length patterns)
|
||||
)
|
||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||
return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret)
|
||||
PCatch -> (T.PCatch,) <$> fresh
|
||||
PEnum p -> do
|
||||
t <- gets (M.lookup (coerce p) . constructors)
|
||||
t <- gets (M.lookup (coerce p) . injections)
|
||||
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
|
||||
unless
|
||||
(typeLength t == 1)
|
||||
( throwError $
|
||||
"The constructor '"
|
||||
++ printTree p
|
||||
++ "'"
|
||||
++ " should have "
|
||||
++ show (typeLength t - 1)
|
||||
++ " arguments but has been given 0"
|
||||
( throwError $ Aux.do
|
||||
"The constructor"
|
||||
quote $ printTree p
|
||||
" should have "
|
||||
show (typeLength t - 1)
|
||||
" arguments but has been given 0"
|
||||
)
|
||||
let (TData _data _ts) = t -- nasty nasty
|
||||
frs <- mapM (const fresh) _ts
|
||||
|
|
@ -687,7 +697,7 @@ data Env = Env
|
|||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
, constructors :: Map T.Ident Type
|
||||
, injections :: Map T.Ident Type
|
||||
, takenTypeVars :: Set T.Ident
|
||||
}
|
||||
deriving (Show)
|
||||
|
|
@ -697,3 +707,6 @@ type Subst = Map T.Ident Type
|
|||
|
||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||
|
||||
quote :: String -> String
|
||||
quote s = "'" ++ s ++ "'"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue