From ad615cc9d89ce1e4ed66430c3a4084bcf4945df7 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 13:26:41 +0100 Subject: [PATCH] Document and fix code style --- src/TypeChecker.hs | 197 ++++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 99 deletions(-) diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index b30a360..380b009 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree, import Prelude hiding (exp, id) import qualified TypeCheckerIr as T - -- NOTE: this type checker is poorly tested -- TODO @@ -22,9 +21,9 @@ import qualified TypeCheckerIr as T -- Type inference data Cxt = Cxt - { env :: Map Ident Type - , sig :: Map Ident Type - } + { env :: Map Ident Type -- ^ Local scope signature + , sig :: Map Ident Type -- ^ Top-level signatures + } initCxt :: [Bind] -> Cxt initCxt sc = Cxt { env = mempty @@ -34,133 +33,133 @@ initCxt sc = Cxt { env = mempty typecheck :: Program -> Err T.Program typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc - +-- | Check if infered rhs type matches type signature. checkBind :: Cxt -> Bind -> Err T.Bind checkBind cxt b = - case expandLambdas b of - Bind name t _ parms rhs -> do - (rhs', t_rhs) <- infer cxt rhs - - unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs - - pure $ T.Bind (name, t) (zip parms ts_parms) rhs' - - where - ts_parms = fst $ partitionType (length parms) t + case expandLambdas b of + Bind name t _ parms rhs -> do + (rhs', t_rhs) <- infer cxt rhs + unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs + pure $ T.Bind (name, t) (zip parms ts_parms) rhs' + where + ts_parms = fst $ partitionType (length parms) t +-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ expandLambdas :: Bind -> Bind expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' where rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms ts_parms = fst $ partitionType (length parms) t - +-- | Infer type of expression. infer :: Cxt -> Exp -> Err (T.Exp, Type) infer cxt = \case + EId x -> + case lookupEnv x cxt of + Nothing -> + case lookupSig x cxt of + Nothing -> throwError ("Unbound variable:" ++ printTree x) + Just t -> pure (T.EId (x, t), t) + Just t -> pure (T.EId (x, t), t) - EId x -> - case lookupEnv x cxt of - Nothing -> - case lookupSig x cxt of - Nothing -> throwError ("Unbound variable:" ++ printTree x) - Just t -> pure (T.EId (x, t), t) - Just t -> pure (T.EId (x, t), t) + EInt i -> pure (T.EInt i, T.TInt) - EInt i -> pure (T.EInt i, T.TInt) + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure (T.EApp t2 e' e1', t2) + _ -> do + throwError ("Not a function: " ++ show e) - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure (T.EApp t2 e' e1', t2) - _ -> do - throwError ("Not a function: " ++ show e) + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure (T.EAdd T.TInt e' e1', T.TInt) - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure (T.EAdd T.TInt e' e1', T.TInt) + EAbs x t e -> do + (e', t1) <- infer (insertEnv x t cxt) e + let t_abs = TFun t t1 + pure (T.EAbs t_abs (x, t) e', t_abs) - EAbs x t e -> do - (e', t1) <- infer (insertEnv x t cxt) e - let t_abs = TFun t t1 - pure (T.EAbs t_abs (x, t) e', t_abs) + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + (e', t) <- infer cxt' e + pure (T.ELet b' e', t) - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - (e', t) <- infer cxt' e - pure (T.ELet b' e', t) - - EAnn e t -> do - (e', t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (e', t1) + EAnn e t -> do + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) +-- | Check infered type matches the supplied type. check :: Cxt -> Exp -> Type -> Err T.Exp check cxt exp typ = case exp of - EId x -> do - t <- case lookupEnv x cxt of - Nothing -> maybeToRightM - ("Unbound variable:" ++ printTree x) - (lookupSig x cxt) - Just t -> pure t + EId x -> do + t <- case lookupEnv x cxt of + Nothing -> maybeToRightM + ("Unbound variable:" ++ printTree x) + (lookupSig x cxt) + Just t -> pure t + unless (typeEq t typ) . throwError $ typeErr x typ t + pure $ T.EId (x, t) - unless (typeEq t typ) . throwError $ typeErr x typ t + EInt i -> do + unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ + pure $ T.EInt i - pure $ T.EId (x, t) + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure $ T.EApp t2 e' e1' + _ -> throwError ("Not a function 2: " ++ printTree e) - EInt i -> do - unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ - pure $ T.EInt i + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure $ T.EAdd T.TInt e' e1' - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure $ T.EApp t2 e' e1' - _ -> throwError ("Not a function 2: " ++ printTree e) + EAbs x t e -> do + (e', t_e) <- infer (insertEnv x t cxt) e + let t1 = TFun t t_e + unless (typeEq t1 typ) $ throwError "Wrong lamda type!" + pure $ T.EAbs t1 (x, t) e' - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure $ T.EAdd T.TInt e' e1' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + e' <- check cxt' e typ + pure $ T.ELet b' e' - EAbs x t e -> do - (e', t_e) <- infer (insertEnv x t cxt) e - let t1 = TFun t t_e - unless (typeEq t1 typ) $ throwError "Wrong lamda type!" - pure $ T.EAbs t1 (x, t) e' - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - e' <- check cxt' e typ - pure $ T.ELet b' e' - - EAnn e t -> do - unless (typeEq t typ) $ - throwError "Inferred type and type annotation doesn't match" - check cxt e t - -insertBind :: Bind -> Cxt -> Cxt -insertBind (Bind n t _ _ _) = insertEnv n t + EAnn e t -> do + unless (typeEq t typ) $ + throwError "Inferred type and type annotation doesn't match" + check cxt e t +-- | Check if types are equivalent. Doesn't handle coercion or polymorphism. typeEq :: Type -> Type -> Bool typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq t t1 = t == t1 -partitionType :: Int -> Type -> ([Type], Type) +-- | Partion type into types of parameters and return type. +partitionType :: Int -- Number of parameters to apply + -> Type + -> ([Type], Type) partitionType = go [] where go acc 0 t = (acc, t) go acc i t = case t of - TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" + TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t lookupEnv :: Ident -> Cxt -> Maybe Type lookupEnv x = Map.lookup x . env @@ -173,7 +172,7 @@ lookupSig x = Map.lookup x . sig typeErr :: Print a => a -> Type -> Type -> String typeErr p expected actual = render $ concatD - [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" - , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" - , doc $ showString "Actual: " , prt 0 actual - ] + [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" + , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" + , doc $ showString "Actual: " , prt 0 actual + ]