Document and fix code style

This commit is contained in:
Martin Fredin 2023-02-18 13:26:41 +01:00
parent a3e57dde7b
commit ad615cc9d8

View file

@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree,
import Prelude hiding (exp, id) import Prelude hiding (exp, id)
import qualified TypeCheckerIr as T import qualified TypeCheckerIr as T
-- NOTE: this type checker is poorly tested -- NOTE: this type checker is poorly tested
-- TODO -- TODO
@ -22,9 +21,9 @@ import qualified TypeCheckerIr as T
-- Type inference -- Type inference
data Cxt = Cxt data Cxt = Cxt
{ env :: Map Ident Type { env :: Map Ident Type -- ^ Local scope signature
, sig :: Map Ident Type , sig :: Map Ident Type -- ^ Top-level signatures
} }
initCxt :: [Bind] -> Cxt initCxt :: [Bind] -> Cxt
initCxt sc = Cxt { env = mempty initCxt sc = Cxt { env = mempty
@ -34,133 +33,133 @@ initCxt sc = Cxt { env = mempty
typecheck :: Program -> Err T.Program typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
-- | Check if infered rhs type matches type signature.
checkBind :: Cxt -> Bind -> Err T.Bind checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b = checkBind cxt b =
case expandLambdas b of case expandLambdas b of
Bind name t _ parms rhs -> do Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs (rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
pure $ T.Bind (name, t) (zip parms ts_parms) rhs' ts_parms = fst $ partitionType (length parms) t
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
expandLambdas :: Bind -> Bind expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
where where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t ts_parms = fst $ partitionType (length parms) t
-- | Infer type of expression.
infer :: Cxt -> Exp -> Err (T.Exp, Type) infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EId x -> EInt i -> pure (T.EInt i, T.TInt)
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EInt i -> pure (T.EInt i, T.TInt) EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EApp e e1 -> do EAdd e e1 -> do
(e', t) <- infer cxt e e' <- check cxt e T.TInt
case t of e1' <- check cxt e1 T.TInt
TFun t1 t2 -> do pure (T.EAdd T.TInt e' e1', T.TInt)
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do EAbs x t e -> do
e' <- check cxt e T.TInt (e', t1) <- infer (insertEnv x t cxt) e
e1' <- check cxt e1 T.TInt let t_abs = TFun t t1
pure (T.EAdd T.TInt e' e1', T.TInt) pure (T.EAbs t_abs (x, t) e', t_abs)
EAbs x t e -> do ELet b e -> do
(e', t1) <- infer (insertEnv x t cxt) e let cxt' = insertBind b cxt
let t_abs = TFun t t1 b' <- checkBind cxt' b
pure (T.EAbs t_abs (x, t) e', t_abs) (e', t) <- infer cxt' e
pure (T.ELet b' e', t)
ELet b e -> do EAnn e t -> do
let cxt' = insertBind b cxt (e', t1) <- infer cxt e
b' <- checkBind cxt' b unless (typeEq t t1) $
(e', t) <- infer cxt' e throwError "Inferred type and type annotation doesn't match"
pure (T.ELet b' e', t) pure (e', t1)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of check cxt exp typ = case exp of
EId x -> do EId x -> do
t <- case lookupEnv x cxt of t <- case lookupEnv x cxt of
Nothing -> maybeToRightM Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x) ("Unbound variable:" ++ printTree x)
(lookupSig x cxt) (lookupSig x cxt)
Just t -> pure t Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
unless (typeEq t typ) . throwError $ typeErr x typ t EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
pure $ T.EId (x, t) EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EInt i -> do EAdd e e1 -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ e' <- check cxt e T.TInt
pure $ T.EInt i e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
EApp e e1 -> do EAbs x t e -> do
(e', t) <- infer cxt e (e', t_e) <- infer (insertEnv x t cxt) e
case t of let t1 = TFun t t_e
TFun t1 t2 -> do unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
e1' <- check cxt e1 t1 pure $ T.EAbs t1 (x, t) e'
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do ELet b e -> do
e' <- check cxt e T.TInt let cxt' = insertBind b cxt
e1' <- check cxt e1 T.TInt b' <- checkBind cxt' b
pure $ T.EAdd T.TInt e' e1' e' <- check cxt' e typ
pure $ T.ELet b' e'
EAbs x t e -> do EAnn e t -> do
(e', t_e) <- infer (insertEnv x t cxt) e unless (typeEq t typ) $
let t1 = TFun t t_e throwError "Inferred type and type annotation doesn't match"
unless (typeEq t1 typ) $ throwError "Wrong lamda type!" check cxt e t
pure $ T.EAbs t1 (x, t) e'
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1 typeEq t t1 = t == t1
partitionType :: Int -> Type -> ([Type], Type) -- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go [] partitionType = go []
where where
go acc 0 t = (acc, t) go acc 0 t = (acc, t)
go acc i t = case t of go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match" _ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
lookupEnv :: Ident -> Cxt -> Maybe Type lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env lookupEnv x = Map.lookup x . env
@ -173,7 +172,7 @@ lookupSig x = Map.lookup x . sig
typeErr :: Print a => a -> Type -> Type -> String typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual , doc $ showString "Actual: " , prt 0 actual
] ]