diff --git a/Grammar.cf b/Grammar.cf index 52c2353..0f5a411 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,24 +1,20 @@ Program. Program ::= [Bind]; - -EId. Exp3 ::= Ident; -EInt. Exp3 ::= Integer; -ELet. Exp3 ::= "let" [Bind] "in" Exp; -EApp. Exp2 ::= Exp2 Exp3; -EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident ":" Type "." Exp; -EAnn. Exp3 ::= "(" Exp ":" Type ")"; - +EId. Exp3 ::= Ident; +EInt. Exp3 ::= Integer; +EAnn. Exp3 ::= "(" Exp ":" Type ")"; +ELet. Exp3 ::= "let" Bind "in" Exp; +EApp. Exp2 ::= Exp2 Exp3; +EAdd. Exp1 ::= Exp1 "+" Exp2; +EAbs. Exp ::= "\\" Ident ":" Type "." Exp; ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}"; --- CaseMatch. CaseMatch ::= Case "=>" Exp ; separator CaseMatch ","; ---terminator CaseMatch "."; CInt. Case ::= Integer ; Bind. Bind ::= Ident ":" Type ";" - Ident [Ident] "=" Exp ; + Ident [Ident] "=" Exp; separator Bind ";"; separator Ident ""; @@ -31,4 +27,4 @@ TFun. Type ::= Type1 "->" Type ; coercions Type 1 ; comment "--"; -comment "{-" "-}"; \ No newline at end of file +comment "{-" "-}"; diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index 5f0d670..f0cdcc4 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,20 +1,21 @@ ---tripplemagic : Int -> Int -> Int -> Int; ---tripplemagic x y z = ((\x:Int. x+x) x) + y + z; ---main : Int; ---main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 - ---apply : (Int -> Int) -> Int -> Int; ---apply f x = f x; --- ---main : Int; ---main = (\x : Int . x + 5) 5 +-- tripplemagic : Int -> Int -> Int -> Int; +-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; +-- main : Int; +-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 +-- answer: 22 +-- apply : (Int -> Int) -> Int -> Int; +-- apply f x = f x; +-- main : Int; +-- main = apply (\x : Int . x + 5) 5 +-- answer: 10 apply : (Int -> Int -> Int) -> Int -> Int -> Int; apply f x y = f x y; krimp: Int -> Int -> Int; krimp x y = x + y; main : Int; -main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2; +main = apply (krimp) 2 3; +-- answer: 5 diff --git a/src/Compiler.hs b/src/Compiler.hs index 0820523..8cbeb58 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -1,28 +1,24 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Compiler (compile) where -import Control.Monad.State (StateT, execStateT, gets, modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (second) -import Grammar.ErrM (Err) -import Grammar.Print (printTree) -import LlvmIr ( - LLVMIr (..), - LLVMType (..), - LLVMValue (..), - Visibility (..), - llvmIrToString, - ) -import TypeChecker (partitionType) -import TypeCheckerIr +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (second) +import Grammar.ErrM (Err) +import Grammar.Print (printTree) +import LlvmIr (LLVMIr (..), LLVMType (..), + LLVMValue (..), Visibility (..), + llvmIrToString) +import TypeChecker (partitionType) +import TypeCheckerIr -- | The record used as the code generator state data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo , variableCount :: Integer } @@ -30,7 +26,7 @@ data CodeGenerator = CodeGenerator type CompilerState a = StateT CodeGenerator Err a data FunctionInfo = FunctionInfo - { numArgs :: Int + { numArgs :: Int , arguments :: [Id] } @@ -124,33 +120,29 @@ compile (Program prg) = do t_return = snd $ partitionType (length args) t go :: Exp -> CompilerState () - go (EInt int) = emitInt int - go (EAdd t e1 e2) = emitAdd t e1 e2 + go (EInt int) = emitInt int + go (EAdd t e1 e2) = emitAdd t e1 e2 go (EId (name, _)) = emitIdent name - go (EApp t e1 e2) = emitApp t e1 e2 - go (EAbs t ti e) = emitAbs t ti e - go (ELet binds e) = emitLet binds e - go (EAnn _ _) = emitEAnn + go (EApp t e1 e2) = emitApp t e1 e2 + go (EAbs t ti e) = emitAbs t ti e + go (ELet bind e) = emitLet bind e -- go (ESub e1 e2) = emitSub e1 e2 -- go (EMul e1 e2) = emitMul e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EMod e1 e2) = emitMod e1 e2 --- aux functions --- - emitEAnn :: CompilerState () - emitEAnn = emit . UnsafeRaw $ "why?" - emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs _t tid e = do emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - emitLet :: [Bind] -> Exp -> CompilerState () - emitLet xs e = do + emitLet :: Bind -> Exp -> CompilerState () + emitLet b e = do emit $ Comment $ concat [ "ELet (" - , show xs + , show b , " = " , show e , ") is not implemented!" @@ -170,7 +162,7 @@ compile (Program prg) = do funcs <- gets functions let vis = case Map.lookup id funcs of Nothing -> Local - Just _ -> Global + Just _ -> Global let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) emit $ SetVariable (Ident $ show vs) call x -> do @@ -271,19 +263,18 @@ type2LlvmType = \case where function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) + function2LLVMType x s = (type2LlvmType x, s) getType :: Exp -> LLVMType -getType (EInt _) = I64 +getType (EInt _) = I64 getType (EAdd t _ _) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e -getType (EAnn _ t) = type2LlvmType t +getType (ELet _ e) = getType e valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 valueGetType (VFunction _ _ t) = t diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 51a82e6..44852ec 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -7,12 +7,10 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where import Auxiliary (snoc) import Control.Applicative (Applicative (liftA2)) import Control.Monad.State (MonadState (get, put), State, evalState) -import Data.Foldable.Extra (notNull) -import Data.List (mapAccumL, partition) -import Data.Set (Set, (\\)) +import Data.Set (Set) import qualified Data.Set as Set import Prelude hiding (exp) -import Renamer hiding (fromBinders) +import Renamer import TypeCheckerIr @@ -49,35 +47,22 @@ freeVarsExp localVars = \case where e' = freeVarsExp (Set.insert par localVars) e - -- Sum free variables present in binders and the expression - ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') where - binders_frees = rhss_frees \\ names_set - e_free = freeVarsOf e' \\ names_set + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' - rhss_frees = foldr1 Set.union (map freeVarsOf rhss') - names_set = Set.fromList names + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' - (names, parms, rhss) = fromBinders binders - rhss' = map (freeVarsExp e_localVars) rhss - e_localVars = Set.union localVars names_set - - binders' = zipWith3 ABind names parms rhss' - e' = freeVarsExp e_localVars e - - EAnn e t -> (freeVarsOf e', AAnn e' t) - where - e' = freeVarsExp localVars e + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars freeVarsOf :: AnnExp -> Set Id freeVarsOf = fst - -fromBinders :: [Bind] -> ([Id], [[Id]], [Exp]) -fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ] - - -- AST annotated with free variables type AnnProgram = [(Id, [Id], AnnExp)] @@ -87,14 +72,11 @@ data ABind = ABind Id [Id] AnnExp deriving Show data AnnExp' = AId Id | AInt Integer - | ALet [ABind] AnnExp + | ALet ABind AnnExp | AApp Type AnnExp AnnExp | AAdd Type AnnExp AnnExp | AAbs Type Id AnnExp - | AAnn AnnExp Type deriving Show - - -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program @@ -124,7 +106,7 @@ abstractExp (free, exp) = case exp of AInt i -> pure $ EInt i AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e) + ALet b e -> liftA2 ELet (go b) (abstractExp e) where go (ABind name parms rhs) = do (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs @@ -141,14 +123,13 @@ abstractExp (free, exp) = case exp of rhs <- abstractExp e let sc_name = Ident ("sc_" ++ show i) - sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) pure $ foldl (EApp TInt) sc $ map EId freeList where freeList = Set.toList free parms = snoc parm freeList - AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t nextNumber :: State Int Int nextNumber = do @@ -156,7 +137,7 @@ nextNumber = do put $ succ i pure i --- | Collects supercombinators by lifting appropriate let expressions +-- | Collects supercombinators by lifting non-constant let expressions collectScs :: Program -> Program collectScs (Program scs) = Program $ concatMap collectFromRhs scs where @@ -167,49 +148,35 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) + EId n -> ([], EId n) + EInt i -> ([], EInt i) - EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAbs t par e -> (scs, EAbs t par e') - where - (scs, e') = collectScsExp e + EAbs t par e -> (scs, EAbs t par e') + where + (scs, e') = collectScsExp e - -- Collect supercombinators from binds, the rhss, and the expression. - -- - -- > f = let - -- > sc = rhs - -- > sc1 = rhs1 - -- > ... - -- > in e - -- - ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') - where - binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in - Bind n (parms ++ parms1) rhs' - | Bind n parms rhs <- scs' - ] - (rhss_scs, binds') = mapAccumL collectScsRhs [] binds - (e_scs, e') = collectScsExp e + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ e_scs, ELet bind e') + else (bind : rhs_scs ++ e_scs, e') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (e_scs, e') = collectScsExp e - (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' - - collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') - where - (rhs_scs, rhs') = collectScsExp rhs - - EAnn e t -> (scs, EAnn e' t) - where - (scs, e') = collectScsExp e -- @\x.\y.\z. e → (e, [x,y,z])@ flattenLambdas :: Exp -> (Exp, [Id]) @@ -218,7 +185,3 @@ flattenLambdas = go . (, []) go (e, acc) = case e of EAbs _ par e1 -> go (e1, snoc par acc) _ -> (e, acc) - -mkEAbs :: [Bind] -> Exp -> Exp -mkEAbs [] e = e -mkEAbs bs e = ELet bs e diff --git a/src/Renamer.hs b/src/Renamer.hs index 0b2d41e..2730883 100644 --- a/src/Renamer.hs +++ b/src/Renamer.hs @@ -2,82 +2,84 @@ module Renamer (module Renamer) where -import Data.List (mapAccumL, unzip4, zipWith4) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) +import Auxiliary (mapAccumM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) import Grammar.Abs --- | Rename all supercombinators and variables +-- | Rename all variables and local binds rename :: Program -> Program -rename (Program sc) = Program $ map (renameSc 0) sc +rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 where - renameSc i (Bind n t _ xs e) = Bind n t n xs' e' - where - (i1, xs', env) = newNames i xs - e' = snd $ renameExp env i1 e - -renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) -renameExp env i = \case - - EId n -> (i, EId . fromMaybe n $ Map.lookup n env) - - EInt i1 -> (i, EInt i1) - - EApp e1 e2 -> (i2, EApp e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - EAdd e1 e2 -> (i2, EAdd e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e') - where - mkBind name t = Bind name t name - (i1, e') = renameExp e_env i e - (names, types, pars, rhss) = fromBinders bs - (i2, names', env') = newNames i1 (names ++ concat pars) - pars' = (map . map) renamePar pars - e_env = Map.union env' env - (i3, es') = mapAccumL (renameExp e_env) i2 rhss - - renamePar p = case Map.lookup p env' of - Just p' -> p' - Nothing -> error ("Can't find name for " ++ show p) + initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + renameSc :: Names -> Bind -> Rn Bind + renameSc old_names (Bind name t _ parms rhs) = do + (new_names, parms') <- newNames old_names parms + rhs' <- snd <$> renameExp new_names rhs + pure $ Bind name t name parms' rhs' - EAbs par t e -> (i2, EAbs par' t e') - where - (i1, par', env') = newName par - (i2, e') = renameExp (Map.union env' env ) i1 e +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn { runRn :: State Int a } + deriving (Functor, Applicative, Monad, MonadState Int) - EAnn e t -> (i1, EAnn e' t) - where - (i1, e') = renameExp env i e +-- | Maps old to new name +type Names = Map Ident Ident +renameLocalBind :: Names -> Bind -> Rn (Names, Bind) +renameLocalBind old_names (Bind name t _ parms rhs) = do + (new_names, name') <- newName old_names name + (new_names', parms') <- newNames new_names parms + (new_names'', rhs') <- renameExp new_names' rhs + pure (new_names'', Bind name' t name' parms' rhs') -newName :: Ident -> (Int, Ident, Map Ident Ident) -newName old_name = (i, head names, env) - where (i, names, env) = newNames 1 [old_name] +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case -newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) -newNames i old_names = (i', new_names, env) - where - (i', new_names) = getNames i old_names - env = Map.fromList $ zip old_names new_names + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) -getNames :: Int -> [Ident] -> (Int, [Ident]) -getNames i ns = (i + length ss, zipWith makeName ss [i..]) - where - ss = map (\(Ident s) -> s) ns + EInt i1 -> pure (old_names, EInt i1) -makeName :: String -> Int -> Ident -makeName prefix i = Ident (prefix ++ "_" ++ show i) + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') + + ELet b e -> do + (new_names, b) <- renameLocalBind old_names b + (new_names', e') <- renameExp new_names e + pure (new_names', ELet b e') + + EAbs par t e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' t e') + + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) + +-- | Create a new name and add it to name environment. +newName :: Names -> Ident -> Rn (Names, Ident) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) + +-- | Create multiple names and add them to the name environment +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) +newNames = mapAccumM newName + +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ -fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp]) -fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ] diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index f74c237..b30a360 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -87,18 +87,17 @@ infer cxt = \case let t_abs = TFun t t1 pure (T.EAbs t_abs (x, t) e', t_abs) - ELet bs e -> do - bs'' <- mapM (checkBind cxt') bs' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b (e', t) <- infer cxt' e - pure (T.ELet bs'' e', t) - where - bs' = map expandLambdas bs - cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + pure (T.ELet b' e', t) EAnn e t -> do - e' <- check cxt e t - pure (T.EAnn e' t, t) - + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) check :: Cxt -> Exp -> Type -> Err T.Exp check cxt exp typ = case exp of @@ -137,19 +136,19 @@ check cxt exp typ = case exp of unless (typeEq t1 typ) $ throwError "Wrong lamda type!" pure $ T.EAbs t1 (x, t) e' - ELet bs e -> do - bs'' <- mapM (checkBind cxt') bs' + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b e' <- check cxt' e typ - pure $ T.ELet bs'' e' - where - bs' = map expandLambdas bs - cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs' + pure $ T.ELet b' e' EAnn e t -> do unless (typeEq t typ) $ throwError "Inferred type and type annotation doesn't match" - e' <- check cxt e t - pure $ T.EAnn e' typ + check cxt e t + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t typeEq :: Type -> Type -> Bool typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs index c8371c5..35b3712 100644 --- a/src/TypeCheckerIr.hs +++ b/src/TypeCheckerIr.hs @@ -16,11 +16,10 @@ newtype Program = Program [Bind] data Exp = EId Id | EInt Integer - | ELet [Bind] Exp + | ELet Bind Exp | EApp Type Exp Exp | EAdd Type Exp Exp | EAbs Type Id Exp - | EAnn Exp Type deriving (C.Eq, C.Ord, C.Show, C.Read) type Id = (Ident, Type) @@ -97,12 +96,5 @@ instance Print Exp where , doc $ showString "." , prt 0 e ] - EAnn e t -> prPrec i 3 $ concatD - [ doc $ showString "(" - , prt 0 e - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ]