From 210e55bb15d3596afe70e26cebadd1de921607a2 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Wed, 15 Feb 2023 23:55:16 +0100 Subject: [PATCH] Adjust old type checker to new syntax, and refactor lambda lifter to use typed AST --- Grammar.cf | 14 ++- language.cabal | 5 +- sample-programs/basic-1 | 3 +- sample-programs/basic-10 | 3 + sample-programs/basic-2 | 7 +- sample-programs/basic-3 | 3 +- sample-programs/basic-4 | 7 +- sample-programs/basic-5 | 9 +- sample-programs/basic-6 | 3 +- sample-programs/basic-7 | 3 + sample-programs/basic-8 | 7 +- sample-programs/basic-9 | 3 +- src/Auxiliary.hs | 12 ++- src/LambdaLifter.hs | 225 +++++++++++++++++---------------------- src/Main.hs | 24 ++++- src/Renamer.hs | 83 +++++++++++++++ src/TypeChecker.hs | 180 +++++++++++++++++++++++++++++++ src/TypeCheckerIr.hs | 108 +++++++++++++++++++ 18 files changed, 554 insertions(+), 145 deletions(-) create mode 100644 sample-programs/basic-10 create mode 100644 src/Renamer.hs create mode 100644 src/TypeCheckerIr.hs diff --git a/Grammar.cf b/Grammar.cf index 410d11d..bbddf2f 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -7,14 +7,22 @@ EInt. Exp3 ::= Integer; ELet. Exp3 ::= "let" [Bind] "in" Exp; EApp. Exp2 ::= Exp2 Exp3; EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident "." Exp; +EAbs. Exp ::= "\\" Ident ":" Type "." Exp; +EAnn. Exp3 ::= "(" Exp ":" Type ")"; + +Bind. Bind ::= Ident ":" Type ";" + Ident [Ident] "=" Exp ; -Bind. Bind ::= Ident [Ident] "=" Exp; separator Bind ";"; -separator Ident " "; +separator Ident ""; coercions Exp 3; +TInt. Type1 ::= "Int" ; +TPol. Type1 ::= Ident ; +TFun. Type ::= Type1 "->" Type ; +coercions Type 1 ; + comment "--"; comment "{-" "-}"; diff --git a/language.cabal b/language.cabal index 0577abe..8d0e109 100644 --- a/language.cabal +++ b/language.cabal @@ -33,7 +33,10 @@ executable language Grammar.ErrM LambdaLifter Auxiliary - Interpreter + -- Interpreter + Renamer + TypeChecker + TypeCheckerIr hs-source-dirs: src diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index f109950..84a2499 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,2 +1,3 @@ -f = \x. x+1; +f : Int -> Int; +f = \x:Int. x+1; diff --git a/sample-programs/basic-10 b/sample-programs/basic-10 new file mode 100644 index 0000000..e12cb39 --- /dev/null +++ b/sample-programs/basic-10 @@ -0,0 +1,3 @@ + +main : Int -> Int -> Int; +main x y = (x : Int) + y; diff --git a/sample-programs/basic-2 b/sample-programs/basic-2 index 4b8ead0..7ece283 100644 --- a/sample-programs/basic-2 +++ b/sample-programs/basic-2 @@ -1,4 +1,7 @@ -add x = \y. x+y; -main = (\z. z+z) ((add 4) 6); +add : Int -> Int -> Int; +add x = \y:Int. x+y; + +main : Int; +main = (\z:Int. z+z) ((add 4) 6); diff --git a/sample-programs/basic-3 b/sample-programs/basic-3 index 9443439..2110141 100644 --- a/sample-programs/basic-3 +++ b/sample-programs/basic-3 @@ -1,2 +1,3 @@ -main = (\x. x+x+3) ((\x. x) 2) +main : Int; +main = (\x:Int. x+x+3) ((\x:Int. x) 2); diff --git a/sample-programs/basic-4 b/sample-programs/basic-4 index 1de7a8c..71e257f 100644 --- a/sample-programs/basic-4 +++ b/sample-programs/basic-4 @@ -1,2 +1,7 @@ -f x = let g = (\y. y+1) in g (g x) +f : Int -> Int; +f x = let + g : Int -> Int; + g = (\y:Int. y+1); + in + g (g x); diff --git a/sample-programs/basic-5 b/sample-programs/basic-5 index 9984ddd..f5e8154 100644 --- a/sample-programs/basic-5 +++ b/sample-programs/basic-5 @@ -1,9 +1,14 @@ +id : Int -> Int; id x = x; +add : Int -> Int -> Int; add x y = x + y; +double : Int -> Int; double n = n + n; -apply f x = \y. f x y; +apply : (Int -> Int -> Int) -> Int -> Int -> Int; +apply f x = \y:Int. f x y; -main = apply (id add) ((\x. x + 1) 1) (double 3); +main : Int; +main = apply add ((\x:Int. x + 1) 1) (double (id 3)); diff --git a/sample-programs/basic-6 b/sample-programs/basic-6 index 511ae10..73ee1b5 100644 --- a/sample-programs/basic-6 +++ b/sample-programs/basic-6 @@ -1,3 +1,4 @@ -f = \x.\y. x+y +f : Int -> Int -> Int; +f = \x:Int.\y:Int. x+y; diff --git a/sample-programs/basic-7 b/sample-programs/basic-7 index b3769b9..763d271 100644 --- a/sample-programs/basic-7 +++ b/sample-programs/basic-7 @@ -1,5 +1,8 @@ +add : Int -> Int -> Int; add x y = x + y; +apply : (Int -> Int) -> Int -> Int; apply f x = f x; +main : Int; main = apply (add 4) 6; diff --git a/sample-programs/basic-8 b/sample-programs/basic-8 index 59abdac..8e8162f 100644 --- a/sample-programs/basic-8 +++ b/sample-programs/basic-8 @@ -1,2 +1,7 @@ -f x = let double = \y. y+y in (\x. x+y) 4; +f : Int -> Int; +f x = let + double : Int -> Int; + double = \y:Int. y+y + in + double (x + 4); diff --git a/sample-programs/basic-9 b/sample-programs/basic-9 index ba9ebdc..d214a8e 100644 --- a/sample-programs/basic-9 +++ b/sample-programs/basic-9 @@ -1,4 +1,5 @@ -main = (\f.\x.\y. f x + f y) (\x. x+x) ((\x. x+1) ((\x. x+3) 2)) 4 +main : Int; +main = (\f:Int -> Int.\x:Int.\y:Int. f x + f y) (\x:Int. x+x) ((\x:Int. x+1) ((\x:Int. x+3) 2)) 4 diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index 2de36a7..735d804 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -1,4 +1,4 @@ - +{-# LANGUAGE LambdaCase #-} module Auxiliary (module Auxiliary) where import Control.Monad.Error.Class (liftEither) import Control.Monad.Except (MonadError) @@ -9,3 +9,13 @@ snoc x xs = xs ++ [x] maybeToRightM :: MonadError l m => l -> Maybe r -> m r maybeToRightM err = liftEither . maybeToRight err + +mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b]) +mapAccumM f = go + where + go acc = \case + [] -> pure (acc, []) + x:xs -> do + (acc', x') <- f acc x + (acc'', xs') <- go acc' xs + pure (acc'', x':xs') diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 3d9595d..eb8845a 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -5,21 +5,20 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where import Auxiliary (snoc) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, evalState) import Data.Foldable.Extra (notNull) -import Data.List (mapAccumL, mapAccumR, partition) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe, mapMaybe) +import Data.List (mapAccumL, partition) import Data.Set (Set, (\\)) import qualified Data.Set as Set -import Data.Tuple.Extra (uncurry3) -import Grammar.Abs import Prelude hiding (exp) +import Renamer hiding (fromBinders) +import TypeCheckerIr -- | Lift lambdas and let expression into supercombinators. lambdaLift :: Program -> Program -lambdaLift = collectScs . rename . abstract . freeVars +lambdaLift = collectScs . abstract . freeVars -- | Annotate free variables @@ -28,25 +27,25 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) | Bind n xs e <- ds ] -freeVarsExp :: Set Ident -> Exp -> AnnExp +freeVarsExp :: Set Id -> Exp -> AnnExp freeVarsExp localVars = \case - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) + EId n | Set.member n localVars -> (Set.singleton n, AId n) + | otherwise -> (mempty, AId n) EInt i -> (mempty, AInt i) - EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2') + EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') where e1' = freeVarsExp localVars e1 e2' = freeVarsExp localVars e2 - EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd e1' e2') + EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') where e1' = freeVarsExp localVars e1 e2' = freeVarsExp localVars e2 - EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e') + EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') where e' = freeVarsExp (Set.insert par localVars) e @@ -66,143 +65,111 @@ freeVarsExp localVars = \case binders' = zipWith3 ABind names parms rhss' e' = freeVarsExp e_localVars e + EAnn e t -> (freeVarsOf e', AAnn e' t) + where + e' = freeVarsExp localVars e -freeVarsOf :: AnnExp -> Set Ident + +freeVarsOf :: AnnExp -> Set Id freeVarsOf = fst -fromBinders :: [Bind] -> ([Ident], [[Ident]], [Exp]) + +fromBinders :: [Bind] -> ([Id], [[Id]], [Exp]) fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ] + -- AST annotated with free variables -type AnnProgram = [(Ident, [Ident], AnnExp)] +type AnnProgram = [(Id, [Id], AnnExp)] -type AnnExp = (Set Ident, AnnExp') +type AnnExp = (Set Id, AnnExp') -data ABind = ABind Ident [Ident] AnnExp deriving Show +data ABind = ABind Id [Id] AnnExp deriving Show -data AnnExp' = AId Ident +data AnnExp' = AId Id | AInt Integer - | AApp AnnExp AnnExp - | AAdd AnnExp AnnExp - | AAbs Ident AnnExp | ALet [ABind] AnnExp + | AApp Type AnnExp AnnExp + | AAdd Type AnnExp AnnExp + | AAbs Type Id AnnExp + | AAnn AnnExp Type deriving Show + -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program -abstract prog = Program $ map go prog +abstract prog = Program $ evalState (mapM go prog) 0 where - go :: (Ident, [Ident], AnnExp) -> Bind - go (name, pars, rhs@(_, e)) = + go :: (Id, [Id], AnnExp) -> State Int Bind + go (name, parms, rhs@(_, e)) = case e of - AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2 + AAbs _ parm e1 -> do + e2' <- abstractExp e2 + pure $ Bind name (snoc parm parms ++ parms2) e2' where - (e2, pars2) = flattenLambdasAnn e1 - _ -> Bind name pars $ abstractExp rhs + (e2, parms2) = flattenLambdasAnn e1 + + _ -> Bind name parms <$> abstractExp rhs -- | Flatten nested lambdas and collect the parameters -- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident]) +flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) flattenLambdasAnn ae = go (ae, []) where - go :: (AnnExp, [Ident]) -> (AnnExp, [Ident]) + go :: (AnnExp, [Id]) -> (AnnExp, [Id]) go ((free, e), acc) = case e of - AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) + AAbs _ par (free1, e1) -> + go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) - -abstractExp :: AnnExp -> Exp +abstractExp :: AnnExp -> State Int Exp abstractExp (free, exp) = case exp of - AId n -> EId n - AInt i -> EInt i - AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2) - AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) - ALet bs e -> ELet (map go bs) $ abstractExp e + AId n -> pure $ EId n + AInt i -> pure $ EInt i + AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) + AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) + ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e) where - go (ABind name parms rhs) = - let - (rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs - in - Bind name (parms ++ parms1) rhs' + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' - skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp + skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp skipLambdas f (free, ae) = case ae of - AAbs name ae1 -> EAbs name $ skipLambdas f ae1 - _ -> f (free, ae) + AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 + _ -> f (free, ae) -- Lift lambda into let and bind free variables - AAbs par e -> foldl EApp sc $ map EId freeList + AAbs t parm e -> do + i <- nextNumber + rhs <- abstractExp e + + let sc_name = Ident ("sc_" ++ show i) + sc = ELet [Bind (sc_name, t_bind) parms rhs] $ EId (sc_name, t) + + pure $ foldl (EApp TInt) sc $ map EId freeList where freeList = Set.toList free - sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc" + t_bind = typeApplyPars (length parm) t + parms = snoc parm freeList --- | Rename all supercombinators and variables -rename :: Program -> Program -rename (Program ds) = Program $ map (uncurry3 Bind) tuples - where - tuples = snd (mapAccumL renameSc 0 ds) - renameSc i (Bind n xs e) = (i2, (n, xs', e')) - where - (i1, xs', env) = newNames i xs - (i2, e') = renameExp env i1 e + AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t -renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) -renameExp env i = \case - - EId n -> (i, EId . fromMaybe n $ Map.lookup n env) - - EInt i1 -> (i, EInt i1) - - EApp e1 e2 -> (i2, EApp e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - EAdd e1 e2 -> (i2, EAdd e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e') - where - (i1, e') = renameExp e_env i e - (names, pars, rhss) = fromBinders bs - (i2, ns', env') = newNames i1 (names ++ concat pars) - pars' = (map . map) renamePar pars - e_env = Map.union env' env - (i3, es') = mapAccumL (renameExp e_env) i2 rhss - - renamePar p = case Map.lookup p env' of - Just p' -> p' - Nothing -> error ("Can't find name for " ++ show p) +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i - EAbs par e -> (i2, EAbs par' e') - where - (i1, par', env') = newName par - (i2, e') = renameExp (Map.union env' env ) i1 e - - -newName :: Ident -> (Int, Ident, Map Ident Ident) -newName old_name = (i, head names, env) - where (i, names, env) = newNames 1 [old_name] - -newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) -newNames i old_names = (i', new_names, env) - where - (i', new_names) = getNames i old_names - env = Map.fromList $ zip old_names new_names - -getNames :: Int -> [Ident] -> (Int, [Ident]) -getNames i ns = (i + length ss, zipWith makeName ss [i..]) - where - ss = map (\(Ident s) -> s) ns - -makeName :: String -> Int -> Ident -makeName prefix i = Ident (prefix ++ "_" ++ show i) +typeApplyPars :: Int -> Type -> Type +typeApplyPars 0 t = t +typeApplyPars i t = + case t of + TFun _ t1 -> typeApplyPars (i-1) t1 + _ -> error "Number of applied pars and type not matching" -- | Collects supercombinators by lifting appropriate let expressions @@ -216,20 +183,20 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) + EId n -> ([], EId n) + EInt i -> ([], EInt i) - EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2') + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 - EAdd e1 e2 -> (scs1 ++ scs2, EAdd e1' e2') + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 - EAbs x e -> (scs, EAbs x e') + EAbs t par e -> (scs, EAbs t par e') where (scs, e') = collectScsExp e @@ -241,28 +208,32 @@ collectScsExp = \case -- > ... -- > in e -- - ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') + ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') where - binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in - Bind n (parms ++ parms1) rhs' - | Bind n parms rhs <- scs' - ] - (rhss_scs, binds') = mapAccumL collectScsRhs [] binds - (e_scs, e') = collectScsExp e + binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in + Bind n (parms ++ parms1) rhs' + | Bind n parms rhs <- scs' + ] + (rhss_scs, binds') = mapAccumL collectScsRhs [] binds + (e_scs, e') = collectScsExp e - (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' + (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') where (rhs_scs, rhs') = collectScsExp rhs + EAnn e t -> (scs, EAnn e' t) + where + (scs, e') = collectScsExp e + -- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: Exp -> (Exp, [Ident]) -flattenLambdas e = go (e, []) +flattenLambdas :: Exp -> (Exp, [Id]) +flattenLambdas = go . (, []) where go (e, acc) = case e of - EAbs par e1 -> go (e1, snoc par acc) - _ -> (e, acc) + EAbs _ par e1 -> go (e1, snoc par acc) + _ -> (e, acc) mkEAbs :: [Bind] -> Exp -> Exp mkEAbs [] e = e diff --git a/src/Main.hs b/src/Main.hs index 41379fc..574fc0c 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -4,10 +4,12 @@ module Main where import Grammar.ErrM (Err) import Grammar.Par (myLexer, pProgram) import Grammar.Print (printTree) -import Interpreter (interpret) +--import Interpreter (interpret) import LambdaLifter (abstract, freeVars, lambdaLift) +import Renamer (rename) import System.Environment (getArgs) import System.Exit (exitFailure, exitSuccess) +import TypeChecker (typecheck) main :: IO () main = getArgs >>= \case @@ -18,12 +20,20 @@ main' :: String -> IO () main' s = do file <- readFile s - putStrLn "\n-- parse" + putStrLn "\n-- Parser" parsed <- fromSyntaxErr . pProgram $ myLexer file putStrLn $ printTree parsed + putStrLn "\n-- Renamer" + let renamed = rename parsed + putStrLn $ printTree renamed + + putStrLn "\n-- TypeChecker" + typechecked <- fromTypeCheckerErr $ typecheck renamed + putStrLn $ printTree typechecked + putStrLn "\n-- Lambda Lifter" - let lifted = lambdaLift parsed + let lifted = lambdaLift typechecked putStrLn $ printTree lifted -- interpred <- fromInterpreterErr $ interpret lifted @@ -41,6 +51,14 @@ fromSyntaxErr = either exitFailure) pure +fromTypeCheckerErr :: Err a -> IO a +fromTypeCheckerErr = either + (\err -> do + putStrLn "\nTYPECHECKER ERROR" + putStrLn err + exitFailure) + pure + fromInterpreterErr :: Err a -> IO a fromInterpreterErr = either (\err -> do diff --git a/src/Renamer.hs b/src/Renamer.hs new file mode 100644 index 0000000..0b2d41e --- /dev/null +++ b/src/Renamer.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE LambdaCase #-} + +module Renamer (module Renamer) where + +import Data.List (mapAccumL, unzip4, zipWith4) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Grammar.Abs + + +-- | Rename all supercombinators and variables +rename :: Program -> Program +rename (Program sc) = Program $ map (renameSc 0) sc + where + renameSc i (Bind n t _ xs e) = Bind n t n xs' e' + where + (i1, xs', env) = newNames i xs + e' = snd $ renameExp env i1 e + +renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) +renameExp env i = \case + + EId n -> (i, EId . fromMaybe n $ Map.lookup n env) + + EInt i1 -> (i, EInt i1) + + EApp e1 e2 -> (i2, EApp e1' e2') + where + (i1, e1') = renameExp env i e1 + (i2, e2') = renameExp env i1 e2 + + EAdd e1 e2 -> (i2, EAdd e1' e2') + where + (i1, e1') = renameExp env i e1 + (i2, e2') = renameExp env i1 e2 + + ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e') + where + mkBind name t = Bind name t name + (i1, e') = renameExp e_env i e + (names, types, pars, rhss) = fromBinders bs + (i2, names', env') = newNames i1 (names ++ concat pars) + pars' = (map . map) renamePar pars + e_env = Map.union env' env + (i3, es') = mapAccumL (renameExp e_env) i2 rhss + + renamePar p = case Map.lookup p env' of + Just p' -> p' + Nothing -> error ("Can't find name for " ++ show p) + + + EAbs par t e -> (i2, EAbs par' t e') + where + (i1, par', env') = newName par + (i2, e') = renameExp (Map.union env' env ) i1 e + + EAnn e t -> (i1, EAnn e' t) + where + (i1, e') = renameExp env i e + + +newName :: Ident -> (Int, Ident, Map Ident Ident) +newName old_name = (i, head names, env) + where (i, names, env) = newNames 1 [old_name] + +newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) +newNames i old_names = (i', new_names, env) + where + (i', new_names) = getNames i old_names + env = Map.fromList $ zip old_names new_names + +getNames :: Int -> [Ident] -> (Int, [Ident]) +getNames i ns = (i + length ss, zipWith makeName ss [i..]) + where + ss = map (\(Ident s) -> s) ns + +makeName :: String -> Int -> Ident +makeName prefix i = Ident (prefix ++ "_" ++ show i) + + +fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp]) +fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ] diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index e69de29..059b375 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -0,0 +1,180 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} + +module TypeChecker (typecheck) where + +import Auxiliary (maybeToRightM, snoc) +import Control.Monad.Except (throwError, unless) +import Data.Map (Map) +import qualified Data.Map as Map +import Grammar.Abs +import Grammar.ErrM (Err) +import Grammar.Print (Print (prt), concatD, doc, printTree, + render) +import Prelude hiding (exp, id) +import qualified TypeCheckerIr as T + + +-- NOTE: this type checker is poorly tested + +-- TODO +-- Coercion +-- Type inference + +data Cxt = Cxt + { env :: Map Ident Type + , sig :: Map Ident Type + } + +initCxt :: [Bind] -> Cxt +initCxt sc = Cxt { env = mempty + , sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc + } + +typecheck :: Program -> Err T.Program +typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc + + +checkBind :: Cxt -> Bind -> Err T.Bind +checkBind cxt b = + case expandLambdas b of + Bind name t _ parms rhs -> do + (rhs', t_rhs) <- infer cxt rhs + + unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs + + pure $ T.Bind (name, t) (zip parms ts_parms) rhs' + + where + ts_parms = fst $ partitionType (length parms) t + +expandLambdas :: Bind -> Bind +expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' + where + rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms + ts_parms = fst $ partitionType (length parms) t + + +infer :: Cxt -> Exp -> Err (T.Exp, Type) +infer cxt = \case + + EId x -> + case lookupEnv x cxt of + Nothing -> + case lookupSig x cxt of + Nothing -> throwError ("Unbound variable:" ++ printTree x) + Just t -> pure (T.EId (x, t), t) + Just t -> pure (T.EId (x, t), t) + + EInt i -> pure (T.EInt i, T.TInt) + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure (T.EApp t2 e' e1', t2) + _ -> do + throwError ("Not a function: " ++ show e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure (T.EAdd T.TInt e' e1', T.TInt) + + EAbs x t e -> do + (e', t1) <- infer (insertEnv x t cxt) e + let t_abs = TFun t t1 + pure (T.EAbs t_abs (x, t) e', t_abs) + + ELet bs e -> do + bs'' <- mapM (checkBind cxt') bs' + (e', t) <- infer cxt' e + pure (T.ELet bs'' e', t) + where + bs' = map expandLambdas bs + cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + + EAnn e t -> do + e' <- check cxt e t + pure (T.EAnn e' t, t) + + +check :: Cxt -> Exp -> Type -> Err T.Exp +check cxt exp typ = case exp of + + EId x -> do + t <- case lookupEnv x cxt of + Nothing -> maybeToRightM + ("Unbound variable:" ++ printTree x) + (lookupSig x cxt) + Just t -> pure t + + unless (typeEq t typ) . throwError $ typeErr x typ t + + pure $ T.EId (x, t) + + EInt i -> do + unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ + pure $ T.EInt i + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure $ T.EApp t2 e' e1' + _ -> throwError ("Not a function 2: " ++ printTree e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure $ T.EAdd T.TInt e' e1' + + EAbs x t e -> do + (e', t_e) <- infer (insertEnv x t cxt) e + let t1 = TFun t t_e + unless (typeEq t1 typ) $ throwError "Wrong lamda type!" + pure $ T.EAbs t1 (x, t) e' + + ELet bs e -> do + bs'' <- mapM (checkBind cxt') bs' + e' <- check cxt' e typ + pure $ T.ELet bs'' e' + where + bs' = map expandLambdas bs + cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + + EAnn e t -> do + unless (typeEq t typ) $ + throwError "Inferred type and type annotation doesn't match" + e' <- check cxt e t + pure $ T.EAnn e' typ + +typeEq :: Type -> Type -> Bool +typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 +typeEq t t1 = t == t1 + +partitionType :: Int -> Type -> ([Type], Type) +partitionType = go [] + where + go acc 0 t = (acc, t) + go acc i t = case t of + TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" + +lookupEnv :: Ident -> Cxt -> Maybe Type +lookupEnv x = Map.lookup x . env + +insertEnv :: Ident -> Type -> Cxt -> Cxt +insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } + +lookupSig :: Ident -> Cxt -> Maybe Type +lookupSig x = Map.lookup x . sig + +typeErr :: Print a => a -> Type -> Type -> String +typeErr p expected actual = render $ concatD + [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" + , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" + , doc $ showString "Actual: " , prt 0 actual + ] diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs new file mode 100644 index 0000000..c8371c5 --- /dev/null +++ b/src/TypeCheckerIr.hs @@ -0,0 +1,108 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeCheckerIr + ( module Grammar.Abs + , module TypeCheckerIr + ) where + +import Grammar.Abs (Ident (..), Type (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) + +newtype Program = Program [Bind] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Exp + = EId Id + | EInt Integer + | ELet [Bind] Exp + | EApp Type Exp Exp + | EAdd Type Exp Exp + | EAbs Type Id Exp + | EAnn Exp Type + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id = (Ident, Type) + +data Bind = Bind Id [Id] Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print Bind where + prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD + [ prtId 0 name + , doc $ showString ";" + , prt 0 n + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) + +prtId :: Int -> Id -> Doc +prtId i (name, t) = prPrec i 0 $ concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] + +prtIdP :: Int -> Id -> Doc +prtIdP i (name, t) = prPrec i 0 $ concatD + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + + +instance Print Exp where + prt i = \case + EId n -> prPrec i 3 $ concatD [prtIdP 0 n] + EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] + ELet bs e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + ] + EApp t e1 e2 -> prPrec i 2 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> prPrec i 1 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs t n e -> prPrec i 0 $ concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtIdP 0 n + , doc $ showString "." + , prt 0 e + ] + EAnn e t -> prPrec i 3 $ concatD + [ doc $ showString "(" + , prt 0 e + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + +