From a3e57dde7b87509a7fe08ecb5f499ebbd9ec55dd Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 12:57:23 +0100 Subject: [PATCH 1/6] Change grammar: only one bind in let and no EAnn for typed syntax --- Grammar.cf | 22 +++---- sample-programs/basic-1 | 23 +++---- src/Compiler.hs | 69 ++++++++++----------- src/LambdaLifter.hs | 113 ++++++++++++---------------------- src/Renamer.hs | 130 ++++++++++++++++++++-------------------- src/TypeChecker.hs | 33 +++++----- src/TypeCheckerIr.hs | 10 +--- 7 files changed, 172 insertions(+), 228 deletions(-) diff --git a/Grammar.cf b/Grammar.cf index 52c2353..0f5a411 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,24 +1,20 @@ Program. Program ::= [Bind]; - -EId. Exp3 ::= Ident; -EInt. Exp3 ::= Integer; -ELet. Exp3 ::= "let" [Bind] "in" Exp; -EApp. Exp2 ::= Exp2 Exp3; -EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident ":" Type "." Exp; -EAnn. Exp3 ::= "(" Exp ":" Type ")"; - +EId. Exp3 ::= Ident; +EInt. Exp3 ::= Integer; +EAnn. Exp3 ::= "(" Exp ":" Type ")"; +ELet. Exp3 ::= "let" Bind "in" Exp; +EApp. Exp2 ::= Exp2 Exp3; +EAdd. Exp1 ::= Exp1 "+" Exp2; +EAbs. Exp ::= "\\" Ident ":" Type "." Exp; ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}"; --- CaseMatch. CaseMatch ::= Case "=>" Exp ; separator CaseMatch ","; ---terminator CaseMatch "."; CInt. Case ::= Integer ; Bind. Bind ::= Ident ":" Type ";" - Ident [Ident] "=" Exp ; + Ident [Ident] "=" Exp; separator Bind ";"; separator Ident ""; @@ -31,4 +27,4 @@ TFun. Type ::= Type1 "->" Type ; coercions Type 1 ; comment "--"; -comment "{-" "-}"; \ No newline at end of file +comment "{-" "-}"; diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index 5f0d670..f0cdcc4 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,20 +1,21 @@ ---tripplemagic : Int -> Int -> Int -> Int; ---tripplemagic x y z = ((\x:Int. x+x) x) + y + z; ---main : Int; ---main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 - ---apply : (Int -> Int) -> Int -> Int; ---apply f x = f x; --- ---main : Int; ---main = (\x : Int . x + 5) 5 +-- tripplemagic : Int -> Int -> Int -> Int; +-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; +-- main : Int; +-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 +-- answer: 22 +-- apply : (Int -> Int) -> Int -> Int; +-- apply f x = f x; +-- main : Int; +-- main = apply (\x : Int . x + 5) 5 +-- answer: 10 apply : (Int -> Int -> Int) -> Int -> Int -> Int; apply f x y = f x y; krimp: Int -> Int -> Int; krimp x y = x + y; main : Int; -main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2; +main = apply (krimp) 2 3; +-- answer: 5 diff --git a/src/Compiler.hs b/src/Compiler.hs index 0820523..8cbeb58 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -1,28 +1,24 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Compiler (compile) where -import Control.Monad.State (StateT, execStateT, gets, modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (second) -import Grammar.ErrM (Err) -import Grammar.Print (printTree) -import LlvmIr ( - LLVMIr (..), - LLVMType (..), - LLVMValue (..), - Visibility (..), - llvmIrToString, - ) -import TypeChecker (partitionType) -import TypeCheckerIr +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (second) +import Grammar.ErrM (Err) +import Grammar.Print (printTree) +import LlvmIr (LLVMIr (..), LLVMType (..), + LLVMValue (..), Visibility (..), + llvmIrToString) +import TypeChecker (partitionType) +import TypeCheckerIr -- | The record used as the code generator state data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo , variableCount :: Integer } @@ -30,7 +26,7 @@ data CodeGenerator = CodeGenerator type CompilerState a = StateT CodeGenerator Err a data FunctionInfo = FunctionInfo - { numArgs :: Int + { numArgs :: Int , arguments :: [Id] } @@ -124,33 +120,29 @@ compile (Program prg) = do t_return = snd $ partitionType (length args) t go :: Exp -> CompilerState () - go (EInt int) = emitInt int - go (EAdd t e1 e2) = emitAdd t e1 e2 + go (EInt int) = emitInt int + go (EAdd t e1 e2) = emitAdd t e1 e2 go (EId (name, _)) = emitIdent name - go (EApp t e1 e2) = emitApp t e1 e2 - go (EAbs t ti e) = emitAbs t ti e - go (ELet binds e) = emitLet binds e - go (EAnn _ _) = emitEAnn + go (EApp t e1 e2) = emitApp t e1 e2 + go (EAbs t ti e) = emitAbs t ti e + go (ELet bind e) = emitLet bind e -- go (ESub e1 e2) = emitSub e1 e2 -- go (EMul e1 e2) = emitMul e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EMod e1 e2) = emitMod e1 e2 --- aux functions --- - emitEAnn :: CompilerState () - emitEAnn = emit . UnsafeRaw $ "why?" - emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs _t tid e = do emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - emitLet :: [Bind] -> Exp -> CompilerState () - emitLet xs e = do + emitLet :: Bind -> Exp -> CompilerState () + emitLet b e = do emit $ Comment $ concat [ "ELet (" - , show xs + , show b , " = " , show e , ") is not implemented!" @@ -170,7 +162,7 @@ compile (Program prg) = do funcs <- gets functions let vis = case Map.lookup id funcs of Nothing -> Local - Just _ -> Global + Just _ -> Global let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) emit $ SetVariable (Ident $ show vs) call x -> do @@ -271,19 +263,18 @@ type2LlvmType = \case where function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) + function2LLVMType x s = (type2LlvmType x, s) getType :: Exp -> LLVMType -getType (EInt _) = I64 +getType (EInt _) = I64 getType (EAdd t _ _) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e -getType (EAnn _ t) = type2LlvmType t +getType (ELet _ e) = getType e valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 valueGetType (VFunction _ _ t) = t diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 51a82e6..44852ec 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -7,12 +7,10 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where import Auxiliary (snoc) import Control.Applicative (Applicative (liftA2)) import Control.Monad.State (MonadState (get, put), State, evalState) -import Data.Foldable.Extra (notNull) -import Data.List (mapAccumL, partition) -import Data.Set (Set, (\\)) +import Data.Set (Set) import qualified Data.Set as Set import Prelude hiding (exp) -import Renamer hiding (fromBinders) +import Renamer import TypeCheckerIr @@ -49,35 +47,22 @@ freeVarsExp localVars = \case where e' = freeVarsExp (Set.insert par localVars) e - -- Sum free variables present in binders and the expression - ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') where - binders_frees = rhss_frees \\ names_set - e_free = freeVarsOf e' \\ names_set + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' - rhss_frees = foldr1 Set.union (map freeVarsOf rhss') - names_set = Set.fromList names + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' - (names, parms, rhss) = fromBinders binders - rhss' = map (freeVarsExp e_localVars) rhss - e_localVars = Set.union localVars names_set - - binders' = zipWith3 ABind names parms rhss' - e' = freeVarsExp e_localVars e - - EAnn e t -> (freeVarsOf e', AAnn e' t) - where - e' = freeVarsExp localVars e + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars freeVarsOf :: AnnExp -> Set Id freeVarsOf = fst - -fromBinders :: [Bind] -> ([Id], [[Id]], [Exp]) -fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ] - - -- AST annotated with free variables type AnnProgram = [(Id, [Id], AnnExp)] @@ -87,14 +72,11 @@ data ABind = ABind Id [Id] AnnExp deriving Show data AnnExp' = AId Id | AInt Integer - | ALet [ABind] AnnExp + | ALet ABind AnnExp | AApp Type AnnExp AnnExp | AAdd Type AnnExp AnnExp | AAbs Type Id AnnExp - | AAnn AnnExp Type deriving Show - - -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program @@ -124,7 +106,7 @@ abstractExp (free, exp) = case exp of AInt i -> pure $ EInt i AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e) + ALet b e -> liftA2 ELet (go b) (abstractExp e) where go (ABind name parms rhs) = do (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs @@ -141,14 +123,13 @@ abstractExp (free, exp) = case exp of rhs <- abstractExp e let sc_name = Ident ("sc_" ++ show i) - sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) pure $ foldl (EApp TInt) sc $ map EId freeList where freeList = Set.toList free parms = snoc parm freeList - AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t nextNumber :: State Int Int nextNumber = do @@ -156,7 +137,7 @@ nextNumber = do put $ succ i pure i --- | Collects supercombinators by lifting appropriate let expressions +-- | Collects supercombinators by lifting non-constant let expressions collectScs :: Program -> Program collectScs (Program scs) = Program $ concatMap collectFromRhs scs where @@ -167,49 +148,35 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) + EId n -> ([], EId n) + EInt i -> ([], EInt i) - EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAbs t par e -> (scs, EAbs t par e') - where - (scs, e') = collectScsExp e + EAbs t par e -> (scs, EAbs t par e') + where + (scs, e') = collectScsExp e - -- Collect supercombinators from binds, the rhss, and the expression. - -- - -- > f = let - -- > sc = rhs - -- > sc1 = rhs1 - -- > ... - -- > in e - -- - ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') - where - binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in - Bind n (parms ++ parms1) rhs' - | Bind n parms rhs <- scs' - ] - (rhss_scs, binds') = mapAccumL collectScsRhs [] binds - (e_scs, e') = collectScsExp e + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ e_scs, ELet bind e') + else (bind : rhs_scs ++ e_scs, e') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (e_scs, e') = collectScsExp e - (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' - - collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') - where - (rhs_scs, rhs') = collectScsExp rhs - - EAnn e t -> (scs, EAnn e' t) - where - (scs, e') = collectScsExp e -- @\x.\y.\z. e → (e, [x,y,z])@ flattenLambdas :: Exp -> (Exp, [Id]) @@ -218,7 +185,3 @@ flattenLambdas = go . (, []) go (e, acc) = case e of EAbs _ par e1 -> go (e1, snoc par acc) _ -> (e, acc) - -mkEAbs :: [Bind] -> Exp -> Exp -mkEAbs [] e = e -mkEAbs bs e = ELet bs e diff --git a/src/Renamer.hs b/src/Renamer.hs index 0b2d41e..2730883 100644 --- a/src/Renamer.hs +++ b/src/Renamer.hs @@ -2,82 +2,84 @@ module Renamer (module Renamer) where -import Data.List (mapAccumL, unzip4, zipWith4) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) +import Auxiliary (mapAccumM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) import Grammar.Abs --- | Rename all supercombinators and variables +-- | Rename all variables and local binds rename :: Program -> Program -rename (Program sc) = Program $ map (renameSc 0) sc +rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 where - renameSc i (Bind n t _ xs e) = Bind n t n xs' e' - where - (i1, xs', env) = newNames i xs - e' = snd $ renameExp env i1 e - -renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) -renameExp env i = \case - - EId n -> (i, EId . fromMaybe n $ Map.lookup n env) - - EInt i1 -> (i, EInt i1) - - EApp e1 e2 -> (i2, EApp e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - EAdd e1 e2 -> (i2, EAdd e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e') - where - mkBind name t = Bind name t name - (i1, e') = renameExp e_env i e - (names, types, pars, rhss) = fromBinders bs - (i2, names', env') = newNames i1 (names ++ concat pars) - pars' = (map . map) renamePar pars - e_env = Map.union env' env - (i3, es') = mapAccumL (renameExp e_env) i2 rhss - - renamePar p = case Map.lookup p env' of - Just p' -> p' - Nothing -> error ("Can't find name for " ++ show p) + initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + renameSc :: Names -> Bind -> Rn Bind + renameSc old_names (Bind name t _ parms rhs) = do + (new_names, parms') <- newNames old_names parms + rhs' <- snd <$> renameExp new_names rhs + pure $ Bind name t name parms' rhs' - EAbs par t e -> (i2, EAbs par' t e') - where - (i1, par', env') = newName par - (i2, e') = renameExp (Map.union env' env ) i1 e +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn { runRn :: State Int a } + deriving (Functor, Applicative, Monad, MonadState Int) - EAnn e t -> (i1, EAnn e' t) - where - (i1, e') = renameExp env i e +-- | Maps old to new name +type Names = Map Ident Ident +renameLocalBind :: Names -> Bind -> Rn (Names, Bind) +renameLocalBind old_names (Bind name t _ parms rhs) = do + (new_names, name') <- newName old_names name + (new_names', parms') <- newNames new_names parms + (new_names'', rhs') <- renameExp new_names' rhs + pure (new_names'', Bind name' t name' parms' rhs') -newName :: Ident -> (Int, Ident, Map Ident Ident) -newName old_name = (i, head names, env) - where (i, names, env) = newNames 1 [old_name] +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case -newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) -newNames i old_names = (i', new_names, env) - where - (i', new_names) = getNames i old_names - env = Map.fromList $ zip old_names new_names + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) -getNames :: Int -> [Ident] -> (Int, [Ident]) -getNames i ns = (i + length ss, zipWith makeName ss [i..]) - where - ss = map (\(Ident s) -> s) ns + EInt i1 -> pure (old_names, EInt i1) -makeName :: String -> Int -> Ident -makeName prefix i = Ident (prefix ++ "_" ++ show i) + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') + + ELet b e -> do + (new_names, b) <- renameLocalBind old_names b + (new_names', e') <- renameExp new_names e + pure (new_names', ELet b e') + + EAbs par t e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' t e') + + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) + +-- | Create a new name and add it to name environment. +newName :: Names -> Ident -> Rn (Names, Ident) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) + +-- | Create multiple names and add them to the name environment +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) +newNames = mapAccumM newName + +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ -fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp]) -fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ] diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index f74c237..b30a360 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -87,18 +87,17 @@ infer cxt = \case let t_abs = TFun t t1 pure (T.EAbs t_abs (x, t) e', t_abs) - ELet bs e -> do - bs'' <- mapM (checkBind cxt') bs' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b (e', t) <- infer cxt' e - pure (T.ELet bs'' e', t) - where - bs' = map expandLambdas bs - cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + pure (T.ELet b' e', t) EAnn e t -> do - e' <- check cxt e t - pure (T.EAnn e' t, t) - + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) check :: Cxt -> Exp -> Type -> Err T.Exp check cxt exp typ = case exp of @@ -137,19 +136,19 @@ check cxt exp typ = case exp of unless (typeEq t1 typ) $ throwError "Wrong lamda type!" pure $ T.EAbs t1 (x, t) e' - ELet bs e -> do - bs'' <- mapM (checkBind cxt') bs' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b e' <- check cxt' e typ - pure $ T.ELet bs'' e' - where - bs' = map expandLambdas bs - cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + pure $ T.ELet b' e' EAnn e t -> do unless (typeEq t typ) $ throwError "Inferred type and type annotation doesn't match" - e' <- check cxt e t - pure $ T.EAnn e' typ + check cxt e t + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t typeEq :: Type -> Type -> Bool typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs index c8371c5..35b3712 100644 --- a/src/TypeCheckerIr.hs +++ b/src/TypeCheckerIr.hs @@ -16,11 +16,10 @@ newtype Program = Program [Bind] data Exp = EId Id | EInt Integer - | ELet [Bind] Exp + | ELet Bind Exp | EApp Type Exp Exp | EAdd Type Exp Exp | EAbs Type Id Exp - | EAnn Exp Type deriving (C.Eq, C.Ord, C.Show, C.Read) type Id = (Ident, Type) @@ -97,12 +96,5 @@ instance Print Exp where , doc $ showString "." , prt 0 e ] - EAnn e t -> prPrec i 3 $ concatD - [ doc $ showString "(" - , prt 0 e - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] From ad615cc9d89ce1e4ed66430c3a4084bcf4945df7 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 13:26:41 +0100 Subject: [PATCH 2/6] Document and fix code style --- src/TypeChecker.hs | 197 ++++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 99 deletions(-) diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index b30a360..380b009 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree, import Prelude hiding (exp, id) import qualified TypeCheckerIr as T - -- NOTE: this type checker is poorly tested -- TODO @@ -22,9 +21,9 @@ import qualified TypeCheckerIr as T -- Type inference data Cxt = Cxt - { env :: Map Ident Type - , sig :: Map Ident Type - } + { env :: Map Ident Type -- ^ Local scope signature + , sig :: Map Ident Type -- ^ Top-level signatures + } initCxt :: [Bind] -> Cxt initCxt sc = Cxt { env = mempty @@ -34,133 +33,133 @@ initCxt sc = Cxt { env = mempty typecheck :: Program -> Err T.Program typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc - +-- | Check if infered rhs type matches type signature. checkBind :: Cxt -> Bind -> Err T.Bind checkBind cxt b = - case expandLambdas b of - Bind name t _ parms rhs -> do - (rhs', t_rhs) <- infer cxt rhs - - unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs - - pure $ T.Bind (name, t) (zip parms ts_parms) rhs' - - where - ts_parms = fst $ partitionType (length parms) t + case expandLambdas b of + Bind name t _ parms rhs -> do + (rhs', t_rhs) <- infer cxt rhs + unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs + pure $ T.Bind (name, t) (zip parms ts_parms) rhs' + where + ts_parms = fst $ partitionType (length parms) t +-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ expandLambdas :: Bind -> Bind expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' where rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms ts_parms = fst $ partitionType (length parms) t - +-- | Infer type of expression. infer :: Cxt -> Exp -> Err (T.Exp, Type) infer cxt = \case + EId x -> + case lookupEnv x cxt of + Nothing -> + case lookupSig x cxt of + Nothing -> throwError ("Unbound variable:" ++ printTree x) + Just t -> pure (T.EId (x, t), t) + Just t -> pure (T.EId (x, t), t) - EId x -> - case lookupEnv x cxt of - Nothing -> - case lookupSig x cxt of - Nothing -> throwError ("Unbound variable:" ++ printTree x) - Just t -> pure (T.EId (x, t), t) - Just t -> pure (T.EId (x, t), t) + EInt i -> pure (T.EInt i, T.TInt) - EInt i -> pure (T.EInt i, T.TInt) + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure (T.EApp t2 e' e1', t2) + _ -> do + throwError ("Not a function: " ++ show e) - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure (T.EApp t2 e' e1', t2) - _ -> do - throwError ("Not a function: " ++ show e) + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure (T.EAdd T.TInt e' e1', T.TInt) - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure (T.EAdd T.TInt e' e1', T.TInt) + EAbs x t e -> do + (e', t1) <- infer (insertEnv x t cxt) e + let t_abs = TFun t t1 + pure (T.EAbs t_abs (x, t) e', t_abs) - EAbs x t e -> do - (e', t1) <- infer (insertEnv x t cxt) e - let t_abs = TFun t t1 - pure (T.EAbs t_abs (x, t) e', t_abs) + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + (e', t) <- infer cxt' e + pure (T.ELet b' e', t) - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - (e', t) <- infer cxt' e - pure (T.ELet b' e', t) - - EAnn e t -> do - (e', t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (e', t1) + EAnn e t -> do + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) +-- | Check infered type matches the supplied type. check :: Cxt -> Exp -> Type -> Err T.Exp check cxt exp typ = case exp of - EId x -> do - t <- case lookupEnv x cxt of - Nothing -> maybeToRightM - ("Unbound variable:" ++ printTree x) - (lookupSig x cxt) - Just t -> pure t + EId x -> do + t <- case lookupEnv x cxt of + Nothing -> maybeToRightM + ("Unbound variable:" ++ printTree x) + (lookupSig x cxt) + Just t -> pure t + unless (typeEq t typ) . throwError $ typeErr x typ t + pure $ T.EId (x, t) - unless (typeEq t typ) . throwError $ typeErr x typ t + EInt i -> do + unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ + pure $ T.EInt i - pure $ T.EId (x, t) + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure $ T.EApp t2 e' e1' + _ -> throwError ("Not a function 2: " ++ printTree e) - EInt i -> do - unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ - pure $ T.EInt i + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure $ T.EAdd T.TInt e' e1' - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure $ T.EApp t2 e' e1' - _ -> throwError ("Not a function 2: " ++ printTree e) + EAbs x t e -> do + (e', t_e) <- infer (insertEnv x t cxt) e + let t1 = TFun t t_e + unless (typeEq t1 typ) $ throwError "Wrong lamda type!" + pure $ T.EAbs t1 (x, t) e' - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure $ T.EAdd T.TInt e' e1' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + e' <- check cxt' e typ + pure $ T.ELet b' e' - EAbs x t e -> do - (e', t_e) <- infer (insertEnv x t cxt) e - let t1 = TFun t t_e - unless (typeEq t1 typ) $ throwError "Wrong lamda type!" - pure $ T.EAbs t1 (x, t) e' - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - e' <- check cxt' e typ - pure $ T.ELet b' e' - - EAnn e t -> do - unless (typeEq t typ) $ - throwError "Inferred type and type annotation doesn't match" - check cxt e t - -insertBind :: Bind -> Cxt -> Cxt -insertBind (Bind n t _ _ _) = insertEnv n t + EAnn e t -> do + unless (typeEq t typ) $ + throwError "Inferred type and type annotation doesn't match" + check cxt e t +-- | Check if types are equivalent. Doesn't handle coercion or polymorphism. typeEq :: Type -> Type -> Bool typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq t t1 = t == t1 -partitionType :: Int -> Type -> ([Type], Type) +-- | Partion type into types of parameters and return type. +partitionType :: Int -- Number of parameters to apply + -> Type + -> ([Type], Type) partitionType = go [] where go acc 0 t = (acc, t) go acc i t = case t of - TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" + TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t lookupEnv :: Ident -> Cxt -> Maybe Type lookupEnv x = Map.lookup x . env @@ -173,7 +172,7 @@ lookupSig x = Map.lookup x . sig typeErr :: Print a => a -> Type -> Type -> String typeErr p expected actual = render $ concatD - [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" - , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" - , doc $ showString "Actual: " , prt 0 actual - ] + [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" + , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" + , doc $ showString "Actual: " , prt 0 actual + ] From 21fb6bf5ed80199f5552398eb2cb73ac3defad63 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 13:27:58 +0100 Subject: [PATCH 3/6] Fix indentation --- src/Renamer.hs | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/src/Renamer.hs b/src/Renamer.hs index 2730883..744d9ad 100644 --- a/src/Renamer.hs +++ b/src/Renamer.hs @@ -40,34 +40,33 @@ renameLocalBind old_names (Bind name t _ parms rhs) = do renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp old_names = \case + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) - EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) + EInt i1 -> pure (old_names, EInt i1) - EInt i1 -> pure (old_names, EInt i1) + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') - EApp e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EApp e1' e2') + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') - EAdd e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EAdd e1' e2') + ELet b e -> do + (new_names, b) <- renameLocalBind old_names b + (new_names', e') <- renameExp new_names e + pure (new_names', ELet b e') - ELet b e -> do - (new_names, b) <- renameLocalBind old_names b - (new_names', e') <- renameExp new_names e - pure (new_names', ELet b e') + EAbs par t e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' t e') - EAbs par t e -> do - (new_names, par') <- newName old_names par - (new_names', e') <- renameExp new_names e - pure (new_names', EAbs par' t e') - - EAnn e t -> do - (new_names, e') <- renameExp old_names e - pure (new_names, EAnn e' t) + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) -- | Create a new name and add it to name environment. newName :: Names -> Ident -> Rn (Names, Ident) From b8aedd541d8df0ddd925a58bac64c4ddeac751fd Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 13:35:33 +0100 Subject: [PATCH 4/6] Document and fix code style --- src/LambdaLifter.hs | 127 +++++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 62 deletions(-) diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 44852ec..015e7f3 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -15,6 +15,10 @@ import TypeCheckerIr -- | Lift lambdas and let expression into supercombinators. +-- Three phases: +-- @freeVars@ annotatss all the free variables. +-- @abstract@ converts lambdas into let expressions. +-- @collectScs@ moves every non-constant let expression to a top-level function. lambdaLift :: Program -> Program lambdaLift = collectScs . abstract . freeVars @@ -27,37 +31,36 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) freeVarsExp :: Set Id -> Exp -> AnnExp freeVarsExp localVars = \case + EId n | Set.member n localVars -> (Set.singleton n, AId n) + | otherwise -> (mempty, AId n) - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) + EInt i -> (mempty, AInt i) - EInt i -> (mempty, AInt i) + EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 - EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 - EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') + where + e' = freeVarsExp (Set.insert par localVars) e - EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') - where - e' = freeVarsExp (Set.insert par localVars) e + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') + where + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' - -- Sum free variables present in bind and the expression - ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') - where - binders_frees = Set.delete name $ freeVarsOf rhs' - e_free = Set.delete name $ freeVarsOf e' + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' - rhs' = freeVarsExp e_localVars rhs - new_bind = ABind name parms rhs' - - e' = freeVarsExp e_localVars e - e_localVars = Set.insert name localVars + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars freeVarsOf :: AnnExp -> Set Id @@ -82,10 +85,10 @@ data AnnExp' = AId Id abstract :: AnnProgram -> Program abstract prog = Program $ evalState (mapM go prog) 0 where - go :: (Id, [Id], AnnExp) -> State Int Bind - go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' - where - (rhs', parms1) = flattenLambdasAnn rhs + go :: (Id, [Id], AnnExp) -> State Int Bind + go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + where + (rhs', parms1) = flattenLambdasAnn rhs -- | Flatten nested lambdas and collect the parameters @@ -95,55 +98,55 @@ flattenLambdasAnn ae = go (ae, []) where go :: (AnnExp, [Id]) -> (AnnExp, [Id]) go ((free, e), acc) = - case e of - AAbs _ par (free1, e1) -> - go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) + case e of + AAbs _ par (free1, e1) -> + go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) abstractExp :: AnnExp -> State Int Exp abstractExp (free, exp) = case exp of - AId n -> pure $ EId n - AInt i -> pure $ EInt i - AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) - AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ALet b e -> liftA2 ELet (go b) (abstractExp e) - where - go (ABind name parms rhs) = do - (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs - pure $ Bind name (parms ++ parms1) rhs' + AId n -> pure $ EId n + AInt i -> pure $ EInt i + AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) + AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) + ALet b e -> liftA2 ELet (go b) (abstractExp e) + where + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' - skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp - skipLambdas f (free, ae) = case ae of - AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 - _ -> f (free, ae) + skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp + skipLambdas f (free, ae) = case ae of + AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 + _ -> f (free, ae) - -- Lift lambda into let and bind free variables - AAbs t parm e -> do - i <- nextNumber - rhs <- abstractExp e + -- Lift lambda into let and bind free variables + AAbs t parm e -> do + i <- nextNumber + rhs <- abstractExp e - let sc_name = Ident ("sc_" ++ show i) - sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) + let sc_name = Ident ("sc_" ++ show i) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) - pure $ foldl (EApp TInt) sc $ map EId freeList - where - freeList = Set.toList free - parms = snoc parm freeList + pure $ foldl (EApp TInt) sc $ map EId freeList + where + freeList = Set.toList free + parms = snoc parm freeList nextNumber :: State Int Int nextNumber = do - i <- get - put $ succ i - pure i + i <- get + put $ succ i + pure i -- | Collects supercombinators by lifting non-constant let expressions collectScs :: Program -> Program collectScs (Program scs) = Program $ concatMap collectFromRhs scs where collectFromRhs (Bind name parms rhs) = - let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs collectScsExp :: Exp -> ([Bind], Exp) @@ -183,5 +186,5 @@ flattenLambdas :: Exp -> (Exp, [Id]) flattenLambdas = go . (, []) where go (e, acc) = case e of - EAbs _ par e1 -> go (e1, snoc par acc) - _ -> (e, acc) + EAbs _ par e1 -> go (e1, snoc par acc) + _ -> (e, acc) From 3efb27ac0c50924a446df98c2f398a7c8ea5cae8 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 13:41:38 +0100 Subject: [PATCH 5/6] Document and fix code style --- src/Renamer.hs | 2 +- src/TypeChecker.hs | 4 +- src/TypeCheckerIr.hs | 100 +++++++++++++++++++++---------------------- 3 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/Renamer.hs b/src/Renamer.hs index 744d9ad..b284e92 100644 --- a/src/Renamer.hs +++ b/src/Renamer.hs @@ -55,7 +55,7 @@ renameExp old_names = \case pure (Map.union env1 env2, EAdd e1' e2') ELet b e -> do - (new_names, b) <- renameLocalBind old_names b + (new_names, b) <- renameLocalBind old_names b (new_names', e') <- renameExp new_names e pure (new_names', ELet b e') diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index 380b009..1e44888 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -41,8 +41,8 @@ checkBind cxt b = (rhs', t_rhs) <- infer cxt rhs unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs pure $ T.Bind (name, t) (zip parms ts_parms) rhs' - where - ts_parms = fst $ partitionType (length parms) t + where + ts_parms = fst $ partitionType (length parms) t -- | @ f x y = rhs ⇒ f = \x.\y. rhs @ expandLambdas :: Bind -> Bind diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs index 35b3712..f6e3ec6 100644 --- a/src/TypeCheckerIr.hs +++ b/src/TypeCheckerIr.hs @@ -20,25 +20,25 @@ data Exp | EApp Type Exp Exp | EAdd Type Exp Exp | EAbs Type Id Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) type Id = (Ident, Type) data Bind = Bind Id [Id] Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) instance Print Program where - prt i (Program sc) = prPrec i 0 $ prt 0 sc + prt i (Program sc) = prPrec i 0 $ prt 0 sc instance Print Bind where - prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD - [ prtId 0 name - , doc $ showString ";" - , prt 0 n - , prtIdPs 0 parms - , doc $ showString "=" - , prt 0 rhs - ] + prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD + [ prtId 0 name + , doc $ showString ";" + , prt 0 n + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] instance Print [Bind] where prt _ [] = concatD [] @@ -50,51 +50,51 @@ prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtId :: Int -> Id -> Doc prtId i (name, t) = prPrec i 0 $ concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - ] + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] prtIdP :: Int -> Id -> Doc prtIdP i (name, t) = prPrec i 0 $ concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] instance Print Exp where prt i = \case - EId n -> prPrec i 3 $ concatD [prtIdP 0 n] - EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] - ELet bs e -> prPrec i 3 $ concatD - [ doc $ showString "let" - , prt 0 bs - , doc $ showString "in" - , prt 0 e - ] - EApp t e1 e2 -> prPrec i 2 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 2 e1 - , prt 3 e2 - ] - EAdd t e1 e2 -> prPrec i 1 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - EAbs t n e -> prPrec i 0 $ concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "\\" - , prtIdP 0 n - , doc $ showString "." - , prt 0 e - ] + EId n -> prPrec i 3 $ concatD [prtIdP 0 n] + EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] + ELet bs e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + ] + EApp t e1 e2 -> prPrec i 2 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> prPrec i 1 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs t n e -> prPrec i 0 $ concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtIdP 0 n + , doc $ showString "." + , prt 0 e + ] From 4ab6681f681df9d5671106e8c5a43e971bb9f4f2 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 14:36:59 +0100 Subject: [PATCH 6/6] Rearrange code --- src/Compiler.hs | 362 +++++++++++++++++++++++------------------------- src/LlvmIr.hs | 180 ++++++++++++------------ 2 files changed, 262 insertions(+), 280 deletions(-) diff --git a/src/Compiler.hs b/src/Compiler.hs index 8cbeb58..fd6b6bc 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -3,12 +3,12 @@ module Compiler (compile) where +import Auxiliary (snoc) import Control.Monad.State (StateT, execStateT, gets, modify) import Data.Map (Map) import qualified Data.Map as Map -import Data.Tuple.Extra (second) +import Data.Tuple.Extra (dupe, first, second) import Grammar.ErrM (Err) -import Grammar.Print (printTree) import LlvmIr (LLVMIr (..), LLVMType (..), LLVMValue (..), Visibility (..), llvmIrToString) @@ -32,11 +32,11 @@ data FunctionInfo = FunctionInfo -- | Adds a instruction to the CodeGenerator state emit :: LLVMIr -> CompilerState () -emit l = modify (\t -> t{instructions = instructions t ++ [l]}) +emit l = modify $ \t -> t { instructions = snoc l $ instructions t } -- | Increases the variable counter in the CodeGenerator state increaseVarCount :: CompilerState () -increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1}) +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } -- | Returns the variable count from the CodeGenerator state getVarCount :: CompilerState Integer @@ -46,212 +46,198 @@ getVarCount = gets variableCount getNewVar :: CompilerState Integer getNewVar = increaseVarCount >> getVarCount -{- | Produces a map of functions infos from a list of binds, - which contains useful data for code generation. --} +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. getFunctions :: [Bind] -> Map Id FunctionInfo -getFunctions xs = - Map.fromList $ - map - ( \(Bind id args _) -> - ( id - , FunctionInfo - { numArgs = length args - , arguments = args - } - ) - ) - xs - -{- | Compiles an AST and produces a LLVM Ir string. - An easy way to actually "compile" this output is to - Simply pipe it to LLI --} -compile :: Program -> Err String -compile (Program prg) = do - let s = - CodeGenerator - { instructions = defaultStart - , functions = getFunctions prg - , variableCount = 0 - } - ins <- instructions <$> execStateT (goDef prg) s - pure $ llvmIrToString ins +getFunctions bs = Map.fromList $ map go bs where - mainContent :: LLVMValue -> [LLVMIr] - mainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" - , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") - -- , Label (Ident "b_1") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (Ident "end") - -- , Label (Ident "b_2") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (Ident "end") - -- , Label (Ident "end") - Ret I64 (VInteger 0) - ] + go (Bind id args _) = + (id, FunctionInfo { numArgs=length args, arguments=args }) - defaultStart :: [LLVMIr] - defaultStart = - [ Comment (show $ printTree (Program prg)) - , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - ] - goDef :: [Bind] -> CompilerState () - goDef [] = return () - goDef (Bind (name, t) args exp : xs) = do - emit $ UnsafeRaw "\n" - emit $ Comment $ show name <> ": " <> show exp - emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit (mainContent functionBody) - else emit $ Ret I64 functionBody - emit DefineEnd - modify (\s -> s{variableCount = 0}) - goDef xs - where - t_return = snd $ partitionType (length args) t - go :: Exp -> CompilerState () - go (EInt int) = emitInt int - go (EAdd t e1 e2) = emitAdd t e1 e2 - go (EId (name, _)) = emitIdent name - go (EApp t e1 e2) = emitApp t e1 e2 - go (EAbs t ti e) = emitAbs t ti e - go (ELet bind e) = emitLet bind e - -- go (ESub e1 e2) = emitSub e1 e2 - -- go (EMul e1 e2) = emitMul e1 e2 - -- go (EDiv e1 e2) = emitDiv e1 e2 - -- go (EMod e1 e2) = emitMod e1 e2 +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , variableCount = 0 + } - --- aux functions --- - emitAbs :: Type -> Id -> Exp -> CompilerState () - emitAbs _t tid e = do - emit . Comment $ - "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - emitLet :: Bind -> Exp -> CompilerState () - emitLet b e = do - emit $ - Comment $ - concat - [ "ELet (" - , show b - , " = " - , show e - , ") is not implemented!" - ] +-- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to lli +compile :: Program -> Err String +compile (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen - emitApp :: Type -> Exp -> Exp -> CompilerState () - emitApp t e1 e2 = appEmitter t e1 e2 [] - where - appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter t e1 e2 stack = do - let newStack = e2 : stack - case e1 of - EApp _ e1' e2' -> appEmitter t e1' e2' newStack - EId id@(name, _) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - let vis = case Map.lookup id funcs of - Nothing -> Local - Just _ -> Global - let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) - emit $ SetVariable (Ident $ show vs) call - x -> do - emit . Comment $ "The unspeakable happened: " - emit . Comment $ show x +compileScs :: [Bind] -> CompilerState () +compileScs [] = pure () +compileScs (Bind (name, t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define (type2LlvmType t_return) name args' + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs + where + t_return = snd $ partitionType (length args) t - emitIdent :: Ident -> CompilerState () - emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (Ident "end") + -- , Label (Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (Ident "end") + -- , Label (Ident "end") + Ret I64 (VInteger 0) + ] - emitInt :: Integer -> CompilerState () - emitInt i = do - -- !!this should never happen!! - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] - emitAdd :: Type -> Exp -> Exp -> CompilerState () - emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) +compileExp :: Exp -> CompilerState () +compileExp = \case + EInt i -> emitInt i + EAdd t e1 e2 -> emitAdd t e1 e2 + EId (name, _) -> emitIdent name + EApp t e1 e2 -> emitApp t e1 e2 + EAbs t ti e -> emitAbs t ti e + ELet bind e -> emitLet bind e - -- emitMul :: Exp -> Exp -> CompilerState () - -- emitMul e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Mul I64 v1 v2 +--- aux functions --- +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - -- emitMod :: Exp -> Exp -> CompilerState () - -- emitMod e1 e2 = do - -- -- `let m a b = rem (abs $ b + a) b` - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- vadd <- gets variableCount - -- emit $ SetVariable $ Ident $ show vadd - -- emit $ Add I64 v1 v2 - -- - -- increaseVarCount - -- vabs <- gets variableCount - -- emit $ SetVariable $ Ident $ show vabs - -- emit $ Call I64 (Ident "llvm.abs.i64") - -- [ (I64, VIdent (Ident $ show vadd)) - -- , (I1, VInteger 1) - -- ] - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 +emitLet :: Bind -> Exp -> CompilerState () +emitLet b e = emit . Comment $ concat [ "ELet (" + , show b + , " = " + , show e + , ") is not implemented!" + ] - -- emitDiv :: Exp -> Exp -> CompilerState () - -- emitDiv e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Div I64 v1 v2 +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call (type2LlvmType t) visibility name args' + emit $ SetVariable (Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x - -- emitSub :: Exp -> Exp -> CompilerState () - -- emitSub e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Sub I64 v1 v2 +emitIdent :: Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" - exprToValue :: Exp -> CompilerState LLVMValue - exprToValue (EInt i) = return $ VInteger i - exprToValue (EId id@(name, t)) = do +emitInt :: Integer -> CompilerState () +emitInt i = do + -- !!this should never happen!! + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) + +-- emitMul :: Exp -> Exp -> CompilerState () +-- emitMul e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Mul I64 v1 v2 + +-- emitMod :: Exp -> Exp -> CompilerState () +-- emitMod e1 e2 = do +-- -- `let m a b = rem (abs $ b + a) b` +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- vadd <- gets variableCount +-- emit $ SetVariable $ Ident $ show vadd +-- emit $ Add I64 v1 v2 +-- +-- increaseVarCount +-- vabs <- gets variableCount +-- emit $ SetVariable $ Ident $ show vabs +-- emit $ Call I64 (Ident "llvm.abs.i64") +-- [ (I64, VIdent (Ident $ show vadd)) +-- , (I1, VInteger 1) +-- ] +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 + +-- emitDiv :: Exp -> Exp -> CompilerState () +-- emitDiv e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Div I64 v1 v2 + +-- emitSub :: Exp -> Exp -> CompilerState () +-- emitSub e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Sub I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + EInt i -> pure $ VInteger i + + EId id@(name, t) -> do funcs <- gets functions case Map.lookup id funcs of Just fi -> do if numArgs fi == 0 then do vc <- getNewVar - emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) - return $ VIdent (Ident $ show vc) (type2LlvmType t) - else return $ VFunction name Global (type2LlvmType t) - Nothing -> return $ VIdent name (type2LlvmType t) - exprToValue e = do - go e + emit $ SetVariable (Ident $ show vc) + (Call (type2LlvmType t) Global name []) + pure $ VIdent (Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + + e -> do + compileExp e v <- getVarCount - return $ VIdent (Ident $ show v) (getType e) + pure $ VIdent (Ident $ show v) (getType e) type2LlvmType :: Type -> LLVMType type2LlvmType = \case diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs index b29f296..d340ddc 100644 --- a/src/LlvmIr.hs +++ b/src/LlvmIr.hs @@ -9,8 +9,8 @@ module LlvmIr ( Visibility (..), ) where -import Data.List (intercalate) -import TypeCheckerIr +import Data.List (intercalate) +import TypeCheckerIr -- | A datatype which represents some basic LLVM types data LLVMType @@ -51,8 +51,8 @@ data LLVMComp instance Show LLVMComp where show :: LLVMComp -> String show = \case - LLEq -> "eq" - LLNe -> "ne" + LLEq -> "eq" + LLNe -> "ne" LLUgt -> "ugt" LLUge -> "uge" LLUlt -> "ult" @@ -65,12 +65,11 @@ instance Show LLVMComp where data Visibility = Local | Global instance Show Visibility where show :: Visibility -> String - show Local = "%" + show Local = "%" show Global = "@" -{- | Represents a LLVM "value", as in an integer, a register variable, - or a string contstant --} +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant data LLVMValue = VInteger Integer | VIdent Ident LLVMType @@ -80,10 +79,10 @@ data LLVMValue instance Show LLVMValue where show :: LLVMValue -> String show v = case v of - VInteger i -> show i - VIdent (Ident n) _ -> "%" <> n + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n VFunction (Ident n) vis _ -> show vis <> n - VConstant s -> "c" <> show s + VConstant s -> "c" <> show s type Params = [(Ident, LLVMType)] type Args = [(LLVMType, LLVMValue)] @@ -122,87 +121,84 @@ llvmIrToString = go 0 go _ [] = mempty go i (x : xs) = do let (i', n) = case x of - Define{} -> (i + 1, 0) + Define{} -> (i + 1, 0) DefineEnd -> (i - 1, 0) - _ -> (i, i) + _ -> (i, i) insToString n x <> go i' xs -{- | Converts a LLVM inststruction to a String, allowing for printing etc. - The integer represents the indentation --} -{- FOURMOLU_DISABLE -} - insToString :: Int -> LLVMIr -> String - insToString i l = - replicate i '\t' <> case l of - (Define t (Ident i) params) -> - concat - [ "define ", show t, " @", i - , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) - , ") {\n" - ] - DefineEnd -> "}\n" - (Declare _t (Ident _i) _params) -> undefined - (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] - (Add t v1 v2) -> - concat - [ "add ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Sub t v1 v2) -> - concat - [ "sub ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Div t v1 v2) -> - concat - [ "sdiv ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Mul t v1 v2) -> - concat - [ "mul ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Srem t v1 v2) -> - concat - [ "srem ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Call t vis (Ident i) arg) -> - concat - [ "call ", show t, " ", show vis, i, "(" - , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg - , ")\n" - ] - (Alloca t) -> unwords ["alloca", show t, "\n"] - (Store t1 (Ident id1) t2 (Ident id2)) -> - concat - [ "store ", show t1, " %", id1 - , ", ", show t2 , " %", id2, "\n" - ] - (Bitcast t1 (Ident i) t2) -> - concat - [ "bitcast ", show t1, " %" - , i, " to ", show t2, "\n" - ] - (Icmp comp t v1 v2) -> - concat - [ "icmp ", show comp, " ", show t - , " ", show v1, ", ", show v2, "\n" - ] - (Ret t v) -> - concat - [ "ret ", show t, " " - , show v, "\n" - ] - (UnsafeRaw s) -> s - (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" - (Br (Ident s)) -> "br label %label_" <> s <> "\n" - (BrCond val (Ident s1) (Ident s2)) -> - concat - [ "br i1 ", show val, ", ", "label %" - , "label_", s1, ", ", "label %", "label_", s2, "\n" - ] - (Comment s) -> "; " <> s <> "\n" - (Variable (Ident id)) -> "%" <> id -{- FOURMOLU_ENABLE -} +-- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +insToString :: Int -> LLVMIr -> String +insToString i l = + replicate i '\t' <> case l of + (Define t (Ident i) params) -> + concat + [ "define ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call t vis (Ident i) arg) -> + concat + [ "call ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 (Ident id1) t2 (Ident id2)) -> + concat + [ "store ", show t1, " %", id1 + , ", ", show t2 , " %", id2, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" + (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , "label_", s1, ", ", "label %", "label_", s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id