improved the idea of error messages, still not very clean

This commit is contained in:
sebastianselander 2023-03-28 10:46:04 +02:00
parent 54f7d54bf9
commit 1558c98d10

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-} {-# LANGUAGE QualifiedDo #-}
@ -39,32 +40,36 @@ runC e c =
. flip evalStateT e . flip evalStateT e
. runInfer . runInfer
typecheck :: Program -> Either Error (T.Program' Type) typecheck :: Program -> Either String (T.Program' Type)
typecheck = run . checkPrg typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData (Data typ injs) = do checkData err@(Data typ injs) = do
(name, tvars) <- go typ (name, tvars) <- go typ
mapM_ (\i -> typecheckInj i name tvars) injs dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err
where where
go = \case go = \case
TData name typs TData name typs
| Right tvars' <- mapM toTVar typs -> | Right tvars' <- mapM toTVar typs ->
pure (name, tvars') pure (name, tvars')
TAll _ _ -> throwError "Explicit foralls not allowed, for now" TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
_ -> throwError $ unwords ["Bad data type definition: ", printTree typ] _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () typecheckInj :: Inj -> UIdent -> [TVar] -> Infer ()
typecheckInj (Inj c inj_typ) name tvars typecheckInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ = | Right False <- boundTVars tvars inj_typ =
throwError "Unbound type variables" catchableErr "Unbound type variables"
| TData name' typs <- returnType inj_typ | TData name' typs <- returnType inj_typ
, Right tvars' <- mapM toTVar typs , Right tvars' <- mapM toTVar typs
, name' == name , name' == name
, tvars' == tvars = do , tvars' == tvars = do
exist <- existInj (coerce c) exist <- existInj (coerce c)
case exist of case exist of
Just t -> throwError $ Aux.do Just t -> uncatchableErr $ Aux.do
"Constructor" "Constructor"
quote $ coerce name quote $ coerce name
"with type" "with type"
@ -72,7 +77,7 @@ typecheckInj (Inj c inj_typ) name tvars
"already exist" "already exist"
Nothing -> insertInj (coerce c) inj_typ Nothing -> insertInj (coerce c) inj_typ
| otherwise = | otherwise =
throwError $ uncatchableErr $
unwords unwords
[ "Bad type constructor: " [ "Bad type constructor: "
, show name , show name
@ -84,7 +89,7 @@ typecheckInj (Inj c inj_typ) name tvars
where where
boundTVars :: [TVar] -> Type -> Either Error Bool boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case boundTVars tvars' = \case
TAll{} -> throwError "Explicit foralls not allowed, for now" TAll{} -> uncatchableErr "Explicit foralls not allowed, for now"
TFun t1 t2 -> do TFun t1 t2 -> do
t1' <- boundTVars tvars t1 t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2 t2' <- boundTVars tvars t2
@ -94,10 +99,10 @@ typecheckInj (Inj c inj_typ) name tvars
TLit _ -> return True TLit _ -> return True
TEVar _ -> error "TEVar in data type declaration" TEVar _ -> error "TEVar in data type declaration"
toTVar :: Type -> Either String TVar toTVar :: Type -> Either Error TVar
toTVar = \case toTVar = \case
TVar tvar -> pure tvar TVar tvar -> pure tvar
_ -> throwError "Not a type variable" _ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
@ -117,7 +122,7 @@ preRun (x : xs) = case x of
gets (M.member (coerce n) . sigs) gets (M.member (coerce n) . sigs)
>>= flip >>= flip
when when
( throwError $ Aux.do ( uncatchableErr $ Aux.do
"Duplicate signatures for function" "Duplicate signatures for function"
quote $ printTree n quote $ printTree n
) )
@ -156,7 +161,7 @@ checkBind (Bind name args e) = do
sub2 <- unify t' lambda_t sub2 <- unify t' lambda_t
unless unless
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t') (apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
( throwError $ Aux.do ( uncatchableErr $ Aux.do
"Inferred type" "Inferred type"
quote $ printTree lambda_t quote $ printTree lambda_t
"does not match specified type" "does not match specified type"
@ -232,13 +237,11 @@ algoW = \case
sub2 <- unify t' t sub2 <- unify t' t
unless unless
(apply sub1 t == t' && apply sub2 t' == t) (apply sub1 t == t' && apply sub2 t' == t)
( throwError $ ( uncatchableErr $ Aux.do
unwords "Annotated type"
[ "Annotated type:" quote $ printTree t
, printTree t "does not match inferred type"
, "does not match inferred type:" quote $ printTree t'
, printTree t'
]
) )
s2 <- exprErr (unify t t') err s2 <- exprErr (unify t t') err
let comp = s2 `compose` s1 let comp = s2 `compose` s1
@ -263,16 +266,16 @@ algoW = \case
fr <- fresh fr <- fresh
insertSig (coerce i) (Just fr) insertSig (coerce i) (Just fr)
return (nullSubst, (T.EVar $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i
EInj i -> do EInj i -> do
constr <- gets injections constr <- gets injections
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing -> Nothing ->
throwError $ uncatchableErr $ Aux.do
"Constructor: '" "Constructor:"
<> printTree i quote $ printTree i
<> "' is not defined" "is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S -- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| --------------------------------- -- \| ---------------------------------
@ -365,7 +368,7 @@ unify t0 t1 = do
(TLit a, TLit b) -> (TLit a, TLit b) ->
if a == b if a == b
then return M.empty then return M.empty
else throwError $ else catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
quote $ printTree (TLit a) quote $ printTree (TLit a)
@ -376,7 +379,7 @@ unify t0 t1 = do
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs return $ foldr compose nullSubst xs
else throwError $ else catchableErr $
Aux.do Aux.do
"Type constructor:" "Type constructor:"
printTree name printTree name
@ -387,14 +390,14 @@ unify t0 t1 = do
(TEVar a, TEVar b) -> (TEVar a, TEVar b) ->
if a == b if a == b
then return M.empty then return M.empty
else throwError $ else catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
quote $ printTree (TEVar a) quote $ printTree (TEVar a)
"with" "with"
quote $ printTree (TEVar b) quote $ printTree (TEVar b)
(a, b) -> do (a, b) -> do
throwError $ catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
quote $ printTree a quote $ printTree a
@ -409,12 +412,14 @@ occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then throwError $ then
Aux.do catchableErr
"Occurs check failed, can't unify" ( Aux.do
quote $ printTree (TVar $ T.MkTVar (coerce i)) "Occurs check failed, can't unify"
"with" quote $ printTree (TVar $ T.MkTVar (coerce i))
quote $ printTree t "with"
quote $ printTree t
)
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
@ -581,7 +586,7 @@ existInj n = gets (M.lookup n . injections)
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = throwError "Atleast one case required" checkCase _ [] = catchableErr "Atleast one case required"
checkCase expT brnchs = do checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs let sub0 = composeAll subs
@ -621,13 +626,23 @@ inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . injections) t <- gets (M.lookup (coerce constr) . injections)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree constr
"does not exist"
)
True
)
t
let numArgs = typeLength t - 1 let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t) let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
unless unless
(length patterns == numArgs) (length patterns == numArgs)
( throwError $ Aux.do ( catchableErr $ Aux.do
"The constructor" "The constructor"
quote $ printTree constr quote $ printTree constr
" should have " " should have "
@ -640,10 +655,20 @@ inferPattern = \case
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . injections) t <- gets (M.lookup (coerce p) . injections)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t t <-
maybeToRightM
( Error
( Aux.do
"Constructor:"
quote $ printTree p
"does not exist"
)
True
)
t
unless unless
(typeLength t == 1) (typeLength t == 1)
( throwError $ Aux.do ( catchableErr $ Aux.do
"The constructor" "The constructor"
quote $ printTree p quote $ printTree p
" should have " " should have "
@ -686,8 +711,10 @@ partitionType = go []
_ -> error "Number of parameters and type doesn't match" _ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp = exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x)
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
dataErr :: Infer a -> Data -> Infer a
dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False}))
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = unzip4 =
@ -709,11 +736,17 @@ data Env = Env
} }
deriving (Show) deriving (Show)
type Error = String data Error = Error {msg :: String, catchable :: Bool}
type Subst = Map T.Ident Type type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
catchableErr :: MonadError Error m => String -> m a
catchableErr msg = throwError $ Error msg True
uncatchableErr :: MonadError Error m => String -> m a
uncatchableErr msg = throwError $ Error msg False
quote :: String -> String quote :: String -> String
quote s = "'" ++ s ++ "'" quote s = "'" ++ s ++ "'"