diff --git a/Grammar.cf b/Grammar.cf index dddab37..a55e8c4 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,33 +1,51 @@ -Program. Program ::= [Bind]; -EId. Exp3 ::= Ident; -EInt. Exp3 ::= Integer; -EAnn. Exp3 ::= "(" Exp ":" Type ")"; -ELet. Exp3 ::= "let" Bind "in" Exp; -EApp. Exp2 ::= Exp2 Exp3; -EAdd. Exp1 ::= Exp1 "+" Exp2; -ESub. Exp1 ::= Exp1 "-" Exp2; -EAbs. Exp ::= "\\" Ident ":" Type "." Exp; -ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type; -CaseMatch. CaseMatch ::= Case "=>" Exp ; -separator CaseMatch ","; +Program. Program ::= [Def] ; - -CInt. Case ::= Integer ; -CatchAll. Case ::= "_" ; +DBind. Def ::= Bind ; +DData. Def ::= Data ; +separator Def ";" ; Bind. Bind ::= Ident ":" Type ";" - Ident [Ident] "=" Exp; + Ident [Ident] "=" Exp ; -separator Bind ";"; -separator Ident ""; +Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ; -coercions Exp 3; +Constructor. Constructor ::= Ident ":" Type ; +separator nonempty Constructor "" ; -TInt. Type1 ::= "Int" ; -TPol. Type1 ::= Ident ; -TFun. Type ::= Type1 "->" Type ; -coercions Type 1 ; +TMono. Type1 ::= "_" Ident ; +TPol. Type1 ::= "'" Ident ; +TConstr. Type1 ::= Constr ; +TArr. Type ::= Type1 "->" Type ; -comment "--"; -comment "{-" "-}"; \ No newline at end of file +Constr. Constr ::= Ident "(" [Type] ")" ; + +-- TODO: Move literal to its own thing since it's reused in Init as well. +EAnn. Exp5 ::= "(" Exp ":" Type ")" ; +EId. Exp4 ::= Ident ; +ELit. Exp4 ::= Literal ; +EApp. Exp3 ::= Exp3 Exp4 ; +EAdd. Exp1 ::= Exp1 "+" Exp2 ; +ESub. Exp1 ::= Exp1 "-" Exp2 ; +ELet. Exp ::= "let" Ident "=" Exp "in" Exp ; +EAbs. Exp ::= "\\" Ident "." Exp ; +ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; + +LInt. Literal ::= Integer ; + +Inj. Inj ::= Init "=>" Exp ; +separator nonempty Inj ";" ; + +InitLit. Init ::= Literal ; +InitConstr. Init ::= Ident [Ident] ; +InitCatch. Init ::= "_" ; + +separator Type " " ; +coercions Type 2 ; + +separator Ident " "; + +coercions Exp 5 ; + +comment "--" ; +comment "{-" "-}" ; diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index 113c8b7..57ce1d9 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,87 +1,26 @@ - --- tripplemagic : Int -> Int -> Int -> Int; --- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; --- main : Int; --- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 --- answer: 22 - --- apply : (Int -> Int) -> Int -> Int; --- apply f x = f x; --- main : Int; --- main = apply (\x : Int . x + 5) 5 --- answer: 10 - --- apply : (Int -> Int -> Int) -> Int -> Int -> Int; --- apply f x y = f x y; --- krimp: Int -> Int -> Int; --- krimp x y = x + y; --- main : Int; --- main = apply (krimp) 2 3; --- answer: 5 - --- fibbonaci : Int -> Int; --- fibbonaci x = case x of { --- 0 => 0, --- 1 => 1, --- -- abusing overflows to represent negatives like a boss --- _ => (fibbonaci (x - 2)) --- + (fibbonaci (x - 1)) --- } : Int; --- main : Int; --- main = fibbonaci 10; --- answer: 55 - --- succ : Int -> Int; --- succ x = x - 1; --- --- isZero : Int -> Int; --- isZero x = case x of { --- 0 => 1, --- _ => 0 --- } : Int; --- --- minimization : (Int -> Int) -> Int -> Int; --- minimization p x = case p x of { --- 1 => 0, --- _ => minimization p (succ x) --- } : Int; --- --- main : Int; --- main = minimization isZero 10; --- answer: 0 - -posMul : Int -> Int -> Int; +posMul : _Int -> _Int -> _Int; posMul a b = case b of { - 0 => 0, + 0 => 0; _ => a + posMul a (b - 1) -} : Int; +}; -facc : Int -> Int; +facc : _Int -> _Int; facc a = case a of { - 1 => 1, + 1 => 1; _ => posMul a (facc (a - 1)) -} : Int; --- main : Int; --- main = facc 5 --- answer: 120 +}; --- pow : Int -> Int -> Int; --- pow a b = case b of { --- 0 => 1, --- _ => posMul a (pow a (b-1)) --- } : Int; - -minimization : (Int -> Int) -> Int -> Int; +minimization : (_Int -> _Int) -> _Int -> _Int; minimization p x = case p x of { - 1 => x, + 1 => x; _ => minimization p (x + 1) -} : Int; +}; -checkFac : Int -> Int; +checkFac : _Int -> _Int; checkFac x = case facc x of { - 0 => 1, + 0 => 1; _ => 0 -} : Int; +}; -main : Int; +main : _Int; main = minimization checkFac 1 \ No newline at end of file diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 174d0b1..9d3b034 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,441 +1,443 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} - -module Codegen.Codegen (generateCode) where - -import Auxiliary (snoc) -import Codegen.LlvmIr (CallingConvention (..), - LLVMComp (..), LLVMIr (..), - LLVMType (..), LLVMValue (..), - Visibility (..), llvmIrToString) -import Control.Monad.State (StateT, execStateT, foldM_, gets, - modify) -import qualified Data.Bifunctor as BI -import Data.List.Extra (trim) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (dupe, first, second) -import qualified Grammar.Abs as GA -import Grammar.ErrM (Err) -import System.Process.Extra (readCreateProcess, shell) -import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id, - Ident (..), Program (..), Type (..)) --- | The record used as the code generator state -data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo - , constructors :: Map Id ConstructorInfo - , variableCount :: Integer - , labelCount :: Integer - } - --- | A state type synonym -type CompilerState a = StateT CodeGenerator Err a - -data FunctionInfo = FunctionInfo - { numArgs :: Int - , arguments :: [Id] - } -data ConstructorInfo = ConstructorInfo - { numArgsCI :: Int - , argumentsCI :: [Id] - , numCI :: Integer - } - - --- | Adds a instruction to the CodeGenerator state -emit :: LLVMIr -> CompilerState () -emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } - --- | Increases the variable counter in the CodeGenerator state -increaseVarCount :: CompilerState () -increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } - --- | Returns the variable count from the CodeGenerator state -getVarCount :: CompilerState Integer -getVarCount = gets variableCount - --- | Increases the variable count and returns it from the CodeGenerator state -getNewVar :: CompilerState Integer -getNewVar = increaseVarCount >> getVarCount - --- | Increses the label count and returns a label from the CodeGenerator state -getNewLabel :: CompilerState Integer -getNewLabel = do - modify (\t -> t{labelCount = labelCount t + 1}) - gets labelCount - --- | Produces a map of functions infos from a list of binds, --- which contains useful data for code generation. -getFunctions :: [Bind] -> Map Id FunctionInfo -getFunctions bs = Map.fromList $ go bs - where - go [] = [] - go (Bind id args _ : xs) = - (id, FunctionInfo { numArgs=length args, arguments=args }) - : go xs - go (DataStructure n cons : xs) = do - map (\(id, xs) -> ((id, TPol n), FunctionInfo { - numArgs=length xs, arguments=createArgs xs - })) cons - <> go xs - -createArgs :: [Type] -> [Id] -createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs - --- | Produces a map of functions infos from a list of binds, --- which contains useful data for code generation. -getConstructors :: [Bind] -> Map Id ConstructorInfo -getConstructors bs = Map.fromList $ go bs - where - go [] = [] - go (DataStructure (Ident n) cons : xs) = do - fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo { - numArgsCI=length xs, - argumentsCI=createArgs xs, - numCI=i - }) : acc, i+1)) ([],0) cons) - <> go xs - go (_: xs) = go xs - -initCodeGenerator :: [Bind] -> CodeGenerator -initCodeGenerator scs = CodeGenerator { instructions = defaultStart - , functions = getFunctions scs - , constructors = getConstructors scs - , variableCount = 0 - , labelCount = 0 - } - -run :: Err String -> IO () -run s = do - let s' = case s of - Right s -> s - Left _ -> error "yo" - writeFile "output/llvm.ll" s' - putStrLn . trim =<< readCreateProcess (shell "lli") s' - -test :: Integer -> Program -test v = Program [ - DataStructure (Ident "Craig") [ - (Ident "Bob", [TInt])--, - --(Ident "Alice", [TInt, TInt]) - ], - Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)), - Bind (Ident "main", TInt) [] ( - EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92) - ) - ] - -{- | Compiles an AST and produces a LLVM Ir string. - An easy way to actually "compile" this output is to - Simply pipe it to LLI --} -generateCode :: Program -> Err String -generateCode (Program scs) = do - let codegen = initCodeGenerator scs - llvmIrToString . instructions <$> execStateT (compileScs scs) codegen - -compileScs :: [Bind] -> CompilerState () -compileScs [] = do - -- as a last step create all the constructors - c <- gets (Map.toList . constructors) - mapM_ (\((id, t), ci) -> do - let t' = type2LlvmType t - let x = BI.second type2LlvmType <$> argumentsCI ci - emit $ Define FastCC t' id x - top <- Ident . show <$> getNewVar - ptr <- Ident . show <$> getNewVar - -- allocated the primary type - emit $ SetVariable top (Alloca t') - - -- set the first byte to the index of the constructor - emit $ SetVariable ptr $ - GetElementPtrInbounds t' (Ref t') - (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0) - emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr - - -- get a pointer of the correct type - ptr' <- Ident . show <$> getNewVar - emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id)) - - --emit $ UnsafeRaw "\n" - - foldM_ (\i (Ident arg_n, arg_t)-> do - let arg_t' = type2LlvmType arg_t - emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i ) - elemPtr <- Ident . show <$> getNewVar - emit $ SetVariable elemPtr ( - GetElementPtrInbounds (CustomType id) (Ref (CustomType id)) - (VIdent ptr' Ptr) I32 - (VInteger 0) I32 (VInteger i)) - emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr - -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1 - -- store i32 42, i32* %2 - pure $ i + 1-- + typeByteSize arg_t' - ) 1 (argumentsCI ci) - - --emit $ UnsafeRaw "\n" - - -- load and return the constructed value - load <- Ident . show <$> getNewVar - emit $ SetVariable load (Load t' Ptr top) - emit $ Ret t' (VIdent load t') - emit DefineEnd - - modify $ \s -> s { variableCount = 0 } - ) c -compileScs (Bind (name, _t) args exp : xs) = do - emit $ UnsafeRaw "\n" - emit . Comment $ show name <> ": " <> show exp - let args' = map (second type2LlvmType) args - emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit $ mainContent functionBody - else emit $ Ret I64 functionBody - emit DefineEnd - modify $ \s -> s { variableCount = 0 } - compileScs xs -compileScs (DataStructure id@(Ident outer_id) ts : xs) = do - let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) - emit $ Type id [I8, Array biggest_variant I8] - mapM_ (\(Ident inner_id, fi) -> do - emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) - ) ts - compileScs xs - - -- where - -- _t_return = snd $ partitionType (length args) t - -mainContent :: LLVMValue -> [LLVMIr] -mainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" - , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") - -- , Label (Ident "b_1") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (Ident "end") - -- , Label (Ident "b_2") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (Ident "end") - -- , Label (Ident "end") - Ret I64 (VInteger 0) - ] - -defaultStart :: [LLVMIr] -defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" - , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" - , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - ] - -compileExp :: Exp -> CompilerState () -compileExp (EInt int) = emitInt int -compileExp (EAdd t e1 e2) = emitAdd t e1 e2 -compileExp (ESub t e1 e2) = emitSub t e1 e2 -compileExp (EId (name, _)) = emitIdent name -compileExp (EApp t e1 e2) = emitApp t e1 e2 -compileExp (EAbs t ti e) = emitAbs t ti e -compileExp (ELet binds e) = emitLet binds e -compileExp (ECase t e cs) = emitECased t e cs - -- go (EMul e1 e2) = emitMul e1 e2 - -- go (EDiv e1 e2) = emitDiv e1 e2 - -- go (EMod e1 e2) = emitMod e1 e2 - ---- aux functions --- -emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState () -emitECased t e cases = do - let cs = snd <$> cases - let ty = type2LlvmType t - vs <- exprToValue e - lbl <- getNewLabel - let label = Ident $ "escape_" <> show lbl - stackPtr <- getNewVar - emit $ SetVariable (Ident $ show stackPtr) (Alloca ty) - mapM_ (emitCases ty label stackPtr vs) cs - emit $ Label label - res <- getNewVar - emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr)) - where - emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState () - emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do - ns <- getNewVar - lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel - emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i)) - emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos - emit $ Label lbl_succPos - val <- exprToValue exp - emit $ Store ty val Ptr (Ident . show $ stackPtr) - emit $ Br label - emit $ Label lbl_failPos - emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do - val <- exprToValue exp - emit $ Store ty val Ptr (Ident . show $ stackPtr) - emit $ Br label - - -emitAbs :: Type -> Id -> Exp -> CompilerState () -emitAbs _t tid e = do - emit . Comment $ - "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e -emitLet :: Bind -> Exp -> CompilerState () -emitLet xs e = do - emit $ - Comment $ - concat - [ "ELet (" - , show xs - , " = " - , show e - , ") is not implemented!" - ] - -emitApp :: Type -> Exp -> Exp -> CompilerState () -emitApp t e1 e2 = appEmitter t e1 e2 [] - where - appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter t e1 e2 stack = do - let newStack = e2 : stack - case e1 of - EApp _ e1' e2' -> appEmitter t e1' e2' newStack - EId id@(name, _) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - let visibility = maybe Local (const Global) $ Map.lookup id funcs - args' = map (first valueGetType . dupe) args - call = Call FastCC (type2LlvmType t) visibility name args' - emit $ SetVariable (Ident $ show vs) call - x -> do - emit . Comment $ "The unspeakable happened: " - emit . Comment $ show x - -emitIdent :: Ident -> CompilerState () -emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" - -emitInt :: Integer -> CompilerState () -emitInt i = do - -- !!this should never happen!! - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) - -emitAdd :: Type -> Exp -> Exp -> CompilerState () -emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) - -emitSub :: Type -> Exp -> Exp -> CompilerState () -emitSub t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2) - --- emitMul :: Exp -> Exp -> CompilerState () --- emitMul e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Mul I64 v1 v2 - --- emitMod :: Exp -> Exp -> CompilerState () --- emitMod e1 e2 = do --- -- `let m a b = rem (abs $ b + a) b` --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- vadd <- gets variableCount --- emit $ SetVariable $ Ident $ show vadd --- emit $ Add I64 v1 v2 +module Codegen.Codegen where +-- {-# LANGUAGE LambdaCase #-} +-- {-# LANGUAGE OverloadedStrings #-} +-- +-- module Codegen.Codegen (generateCode) where +-- +-- import Auxiliary (snoc) +-- import Codegen.LlvmIr (CallingConvention (..), +-- LLVMComp (..), LLVMIr (..), +-- LLVMType (..), LLVMValue (..), +-- Visibility (..), llvmIrToString) +-- import Control.Monad.State (StateT, execStateT, foldM_, gets, +-- modify) +-- import qualified Data.Bifunctor as BI +-- import Data.List.Extra (trim) +-- import Data.Map (Map) +-- import qualified Data.Map as Map +-- import Data.Tuple.Extra (dupe, first, second) +-- import qualified Grammar.Abs as GA +-- import Grammar.ErrM (Err) +-- import System.Process.Extra (readCreateProcess, shell) +-- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id, +-- Ident (..), Program (..), Type (..)) +-- -- | The record used as the code generator state +-- data CodeGenerator = CodeGenerator +-- { instructions :: [LLVMIr] +-- , functions :: Map Id FunctionInfo +-- , constructors :: Map Id ConstructorInfo +-- , variableCount :: Integer +-- , labelCount :: Integer +-- } +-- +-- -- | A state type synonym +-- type CompilerState a = StateT CodeGenerator Err a +-- +-- data FunctionInfo = FunctionInfo +-- { numArgs :: Int +-- , arguments :: [Id] +-- } +-- data ConstructorInfo = ConstructorInfo +-- { numArgsCI :: Int +-- , argumentsCI :: [Id] +-- , numCI :: Integer +-- } +-- +-- +-- -- | Adds a instruction to the CodeGenerator state +-- emit :: LLVMIr -> CompilerState () +-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } +-- +-- -- | Increases the variable counter in the CodeGenerator state +-- increaseVarCount :: CompilerState () +-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } +-- +-- -- | Returns the variable count from the CodeGenerator state +-- getVarCount :: CompilerState Integer +-- getVarCount = gets variableCount +-- +-- -- | Increases the variable count and returns it from the CodeGenerator state +-- getNewVar :: CompilerState Integer +-- getNewVar = increaseVarCount >> getVarCount +-- +-- -- | Increses the label count and returns a label from the CodeGenerator state +-- getNewLabel :: CompilerState Integer +-- getNewLabel = do +-- modify (\t -> t{labelCount = labelCount t + 1}) +-- gets labelCount +-- +-- -- | Produces a map of functions infos from a list of binds, +-- -- which contains useful data for code generation. +-- getFunctions :: [Bind] -> Map Id FunctionInfo +-- getFunctions bs = Map.fromList $ go bs +-- where +-- go [] = [] +-- go (Bind id args _ : xs) = +-- (id, FunctionInfo { numArgs=length args, arguments=args }) +-- : go xs +-- go (DataStructure n cons : xs) = do +-- map (\(id, xs) -> ((id, TPol n), FunctionInfo { +-- numArgs=length xs, arguments=createArgs xs +-- })) cons +-- <> go xs +-- +-- createArgs :: [Type] -> [Id] +-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs +-- +-- -- | Produces a map of functions infos from a list of binds, +-- -- which contains useful data for code generation. +-- getConstructors :: [Bind] -> Map Id ConstructorInfo +-- getConstructors bs = Map.fromList $ go bs +-- where +-- go [] = [] +-- go (DataStructure (Ident n) cons : xs) = do +-- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo { +-- numArgsCI=length xs, +-- argumentsCI=createArgs xs, +-- numCI=i +-- }) : acc, i+1)) ([],0) cons) +-- <> go xs +-- go (_: xs) = go xs +-- +-- initCodeGenerator :: [Bind] -> CodeGenerator +-- initCodeGenerator scs = CodeGenerator { instructions = defaultStart +-- , functions = getFunctions scs +-- , constructors = getConstructors scs +-- , variableCount = 0 +-- , labelCount = 0 +-- } +-- +-- run :: Err String -> IO () +-- run s = do +-- let s' = case s of +-- Right s -> s +-- Left _ -> error "yo" +-- writeFile "output/llvm.ll" s' +-- putStrLn . trim =<< readCreateProcess (shell "lli") s' +-- +-- test :: Integer -> Program +-- test v = Program [ +-- DataStructure (Ident "Craig") [ +-- (Ident "Bob", [TInt])--, +-- --(Ident "Alice", [TInt, TInt]) +-- ], +-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)), +-- Bind (Ident "main", TInt) [] ( +-- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92) +-- ) +-- ] +-- +-- {- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to LLI +-- -} +-- generateCode :: Program -> Err String +-- generateCode (Program scs) = do +-- let codegen = initCodeGenerator scs +-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen +-- +-- compileScs :: [Bind] -> CompilerState () +-- compileScs [] = do +-- -- as a last step create all the constructors +-- c <- gets (Map.toList . constructors) +-- mapM_ (\((id, t), ci) -> do +-- let t' = type2LlvmType t +-- let x = BI.second type2LlvmType <$> argumentsCI ci +-- emit $ Define FastCC t' id x +-- top <- Ident . show <$> getNewVar +-- ptr <- Ident . show <$> getNewVar +-- -- allocated the primary type +-- emit $ SetVariable top (Alloca t') +-- +-- -- set the first byte to the index of the constructor +-- emit $ SetVariable ptr $ +-- GetElementPtrInbounds t' (Ref t') +-- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0) +-- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr +-- +-- -- get a pointer of the correct type +-- ptr' <- Ident . show <$> getNewVar +-- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id)) +-- +-- --emit $ UnsafeRaw "\n" +-- +-- foldM_ (\i (Ident arg_n, arg_t)-> do +-- let arg_t' = type2LlvmType arg_t +-- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i ) +-- elemPtr <- Ident . show <$> getNewVar +-- emit $ SetVariable elemPtr ( +-- GetElementPtrInbounds (CustomType id) (Ref (CustomType id)) +-- (VIdent ptr' Ptr) I32 +-- (VInteger 0) I32 (VInteger i)) +-- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr +-- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1 +-- -- store i32 42, i32* %2 +-- pure $ i + 1-- + typeByteSize arg_t' +-- ) 1 (argumentsCI ci) +-- +-- --emit $ UnsafeRaw "\n" +-- +-- -- load and return the constructed value +-- load <- Ident . show <$> getNewVar +-- emit $ SetVariable load (Load t' Ptr top) +-- emit $ Ret t' (VIdent load t') +-- emit DefineEnd +-- +-- modify $ \s -> s { variableCount = 0 } +-- ) c +-- compileScs (Bind (name, _t) args exp : xs) = do +-- emit $ UnsafeRaw "\n" +-- emit . Comment $ show name <> ": " <> show exp +-- let args' = map (second type2LlvmType) args +-- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' +-- functionBody <- exprToValue exp +-- if name == "main" +-- then mapM_ emit $ mainContent functionBody +-- else emit $ Ret I64 functionBody +-- emit DefineEnd +-- modify $ \s -> s { variableCount = 0 } +-- compileScs xs +-- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do +-- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) +-- emit $ Type id [I8, Array biggest_variant I8] +-- mapM_ (\(Ident inner_id, fi) -> do +-- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) +-- ) ts +-- compileScs xs +-- +-- -- where +-- -- _t_return = snd $ partitionType (length args) t +-- +-- mainContent :: LLVMValue -> [LLVMIr] +-- mainContent var = +-- [ UnsafeRaw $ +-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" +-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) +-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") +-- -- , Label (Ident "b_1") +-- -- , UnsafeRaw +-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" +-- -- , Br (Ident "end") +-- -- , Label (Ident "b_2") +-- -- , UnsafeRaw +-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" +-- -- , Br (Ident "end") +-- -- , Label (Ident "end") +-- Ret I64 (VInteger 0) +-- ] +-- +-- defaultStart :: [LLVMIr] +-- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" +-- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" +-- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" +-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" +-- ] +-- +-- compileExp :: Exp -> CompilerState () +-- compileExp (EInt int) = emitInt int +-- compileExp (EAdd t e1 e2) = emitAdd t e1 e2 +-- compileExp (ESub t e1 e2) = emitSub t e1 e2 +-- compileExp (EId (name, _)) = emitIdent name +-- compileExp (EApp t e1 e2) = emitApp t e1 e2 +-- compileExp (EAbs t ti e) = emitAbs t ti e +-- compileExp (ELet binds e) = emitLet binds e +-- compileExp (ECase t e cs) = emitECased t e cs +-- -- go (EMul e1 e2) = emitMul e1 e2 +-- -- go (EDiv e1 e2) = emitDiv e1 e2 +-- -- go (EMod e1 e2) = emitMod e1 e2 +-- +-- --- aux functions --- +-- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState () +-- emitECased t e cases = do +-- let cs = snd <$> cases +-- let ty = type2LlvmType t +-- vs <- exprToValue e +-- lbl <- getNewLabel +-- let label = Ident $ "escape_" <> show lbl +-- stackPtr <- getNewVar +-- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty) +-- mapM_ (emitCases ty label stackPtr vs) cs +-- emit $ Label label +-- res <- getNewVar +-- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr)) +-- where +-- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState () +-- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do +-- ns <- getNewVar +-- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel +-- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel +-- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i)) +-- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos +-- emit $ Label lbl_succPos +-- val <- exprToValue exp +-- emit $ Store ty val Ptr (Ident . show $ stackPtr) +-- emit $ Br label +-- emit $ Label lbl_failPos +-- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do +-- val <- exprToValue exp +-- emit $ Store ty val Ptr (Ident . show $ stackPtr) +-- emit $ Br label +-- +-- +-- emitAbs :: Type -> Id -> Exp -> CompilerState () +-- emitAbs _t tid e = do +-- emit . Comment $ +-- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e +-- emitLet :: Bind -> Exp -> CompilerState () +-- emitLet xs e = do +-- emit $ +-- Comment $ +-- concat +-- [ "ELet (" +-- , show xs +-- , " = " +-- , show e +-- , ") is not implemented!" +-- ] +-- +-- emitApp :: Type -> Exp -> Exp -> CompilerState () +-- emitApp t e1 e2 = appEmitter t e1 e2 [] +-- where +-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () +-- appEmitter t e1 e2 stack = do +-- let newStack = e2 : stack +-- case e1 of +-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack +-- EId id@(name, _) -> do +-- args <- traverse exprToValue newStack +-- vs <- getNewVar +-- funcs <- gets functions +-- let visibility = maybe Local (const Global) $ Map.lookup id funcs +-- args' = map (first valueGetType . dupe) args +-- call = Call FastCC (type2LlvmType t) visibility name args' +-- emit $ SetVariable (Ident $ show vs) call +-- x -> do +-- emit . Comment $ "The unspeakable happened: " +-- emit . Comment $ show x +-- +-- emitIdent :: Ident -> CompilerState () +-- emitIdent id = do +-- -- !!this should never happen!! +-- emit $ Comment "This should not have happened!" +-- emit $ Variable id +-- emit $ UnsafeRaw "\n" +-- +-- emitInt :: Integer -> CompilerState () +-- emitInt i = do +-- -- !!this should never happen!! +-- varCount <- getNewVar +-- emit $ Comment "This should not have happened!" +-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) +-- +-- emitAdd :: Type -> Exp -> Exp -> CompilerState () +-- emitAdd t e1 e2 = do +-- v1 <- exprToValue e1 +-- v2 <- exprToValue e2 +-- v <- getNewVar +-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) +-- +-- emitSub :: Type -> Exp -> Exp -> CompilerState () +-- emitSub t e1 e2 = do +-- v1 <- exprToValue e1 +-- v2 <- exprToValue e2 +-- v <- getNewVar +-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2) +-- +-- -- emitMul :: Exp -> Exp -> CompilerState () +-- -- emitMul e1 e2 = do +-- -- (v1,v2) <- binExprToValues e1 e2 +-- -- increaseVarCount +-- -- v <- gets variableCount +-- -- emit $ SetVariable $ Ident $ show v +-- -- emit $ Mul I64 v1 v2 +-- +-- -- emitMod :: Exp -> Exp -> CompilerState () +-- -- emitMod e1 e2 = do +-- -- -- `let m a b = rem (abs $ b + a) b` +-- -- (v1,v2) <- binExprToValues e1 e2 +-- -- increaseVarCount +-- -- vadd <- gets variableCount +-- -- emit $ SetVariable $ Ident $ show vadd +-- -- emit $ Add I64 v1 v2 +-- -- +-- -- increaseVarCount +-- -- vabs <- gets variableCount +-- -- emit $ SetVariable $ Ident $ show vabs +-- -- emit $ Call I64 (Ident "llvm.abs.i64") +-- -- [ (I64, VIdent (Ident $ show vadd)) +-- -- , (I1, VInteger 1) +-- -- ] +-- -- increaseVarCount +-- -- v <- gets variableCount +-- -- emit $ SetVariable $ Ident $ show v +-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 +-- +-- -- emitDiv :: Exp -> Exp -> CompilerState () +-- -- emitDiv e1 e2 = do +-- -- (v1,v2) <- binExprToValues e1 e2 +-- -- increaseVarCount +-- -- v <- gets variableCount +-- -- emit $ SetVariable $ Ident $ show v +-- -- emit $ Div I64 v1 v2 +-- +-- exprToValue :: Exp -> CompilerState LLVMValue +-- exprToValue = \case +-- EInt i -> pure $ VInteger i +-- +-- EId id@(name, t) -> do +-- funcs <- gets functions +-- case Map.lookup id funcs of +-- Just fi -> do +-- if numArgs fi == 0 +-- then do +-- vc <- getNewVar +-- emit $ SetVariable (Ident $ show vc) +-- (Call FastCC (type2LlvmType t) Global name []) +-- pure $ VIdent (Ident $ show vc) (type2LlvmType t) +-- else pure $ VFunction name Global (type2LlvmType t) +-- Nothing -> pure $ VIdent name (type2LlvmType t) +-- +-- e -> do +-- compileExp e +-- v <- getVarCount +-- pure $ VIdent (Ident $ show v) (getType e) +-- +-- type2LlvmType :: Type -> LLVMType +-- type2LlvmType = \case +-- TInt -> I64 +-- TFun t xs -> do +-- let (t', xs') = function2LLVMType xs [type2LlvmType t] +-- Function t' xs' +-- TPol t -> CustomType t +-- where +-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) +-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) +-- function2LLVMType x s = (type2LlvmType x, s) +-- +-- getType :: Exp -> LLVMType +-- getType (EInt _) = I64 +-- getType (EAdd t _ _) = type2LlvmType t +-- getType (ESub t _ _) = type2LlvmType t +-- getType (EId (_, t)) = type2LlvmType t +-- getType (EApp t _ _) = type2LlvmType t +-- getType (EAbs t _ _) = type2LlvmType t +-- getType (ELet _ e) = getType e +-- getType (ECase t _ _) = type2LlvmType t +-- +-- valueGetType :: LLVMValue -> LLVMType +-- valueGetType (VInteger _) = I64 +-- valueGetType (VIdent _ t) = t +-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +-- valueGetType (VFunction _ _ t) = t +-- +-- typeByteSize :: LLVMType -> Integer +-- typeByteSize I1 = 1 +-- typeByteSize I8 = 1 +-- typeByteSize I32 = 4 +-- typeByteSize I64 = 8 +-- typeByteSize Ptr = 8 +-- typeByteSize (Ref _) = 8 +-- typeByteSize (Function _ _) = 8 +-- typeByteSize (Array n t) = n * typeByteSize t +-- typeByteSize (CustomType _) = 8 -- --- increaseVarCount --- vabs <- gets variableCount --- emit $ SetVariable $ Ident $ show vabs --- emit $ Call I64 (Ident "llvm.abs.i64") --- [ (I64, VIdent (Ident $ show vadd)) --- , (I1, VInteger 1) --- ] --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 - --- emitDiv :: Exp -> Exp -> CompilerState () --- emitDiv e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Div I64 v1 v2 - -exprToValue :: Exp -> CompilerState LLVMValue -exprToValue = \case - EInt i -> pure $ VInteger i - - EId id@(name, t) -> do - funcs <- gets functions - case Map.lookup id funcs of - Just fi -> do - if numArgs fi == 0 - then do - vc <- getNewVar - emit $ SetVariable (Ident $ show vc) - (Call FastCC (type2LlvmType t) Global name []) - pure $ VIdent (Ident $ show vc) (type2LlvmType t) - else pure $ VFunction name Global (type2LlvmType t) - Nothing -> pure $ VIdent name (type2LlvmType t) - - e -> do - compileExp e - v <- getVarCount - pure $ VIdent (Ident $ show v) (getType e) - -type2LlvmType :: Type -> LLVMType -type2LlvmType = \case - TInt -> I64 - TFun t xs -> do - let (t', xs') = function2LLVMType xs [type2LlvmType t] - Function t' xs' - TPol t -> CustomType t - where - function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) - function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) - -getType :: Exp -> LLVMType -getType (EInt _) = I64 -getType (EAdd t _ _) = type2LlvmType t -getType (ESub t _ _) = type2LlvmType t -getType (EId (_, t)) = type2LlvmType t -getType (EApp t _ _) = type2LlvmType t -getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e -getType (ECase t _ _) = type2LlvmType t - -valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 -valueGetType (VFunction _ _ t) = t - -typeByteSize :: LLVMType -> Integer -typeByteSize I1 = 1 -typeByteSize I8 = 1 -typeByteSize I32 = 4 -typeByteSize I64 = 8 -typeByteSize Ptr = 8 -typeByteSize (Ref _) = 8 -typeByteSize (Function _ _) = 8 -typeByteSize (Array n t) = n * typeByteSize t -typeByteSize (CustomType _) = 8 diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index 08cd69d..4a649c3 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -1,239 +1,241 @@ -{-# LANGUAGE LambdaCase #-} - -module Codegen.LlvmIr ( - LLVMType (..), - LLVMIr (..), - llvmIrToString, - LLVMValue (..), - LLVMComp (..), - Visibility (..), - CallingConvention (..) -) where - -import Data.List (intercalate) -import TypeChecker.TypeCheckerIr - -data CallingConvention = TailCC | FastCC | CCC | ColdCC -instance Show CallingConvention where - show :: CallingConvention -> String - show TailCC = "tailcc" - show FastCC = "fastcc" - show CCC = "ccc" - show ColdCC = "coldcc" - --- | A datatype which represents some basic LLVM types -data LLVMType - = I1 - | I8 - | I32 - | I64 - | Ptr - | Ref LLVMType - | Function LLVMType [LLVMType] - | Array Integer LLVMType - | CustomType Ident - -instance Show LLVMType where - show :: LLVMType -> String - show = \case - I1 -> "i1" - I8 -> "i8" - I32 -> "i32" - I64 -> "i64" - Ptr -> "ptr" - Ref ty -> show ty <> "*" - Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" - Array n ty -> concat ["[", show n, " x ", show ty, "]"] - CustomType (Ident ty) -> "%" <> ty - -data LLVMComp - = LLEq - | LLNe - | LLUgt - | LLUge - | LLUlt - | LLUle - | LLSgt - | LLSge - | LLSlt - | LLSle -instance Show LLVMComp where - show :: LLVMComp -> String - show = \case - LLEq -> "eq" - LLNe -> "ne" - LLUgt -> "ugt" - LLUge -> "uge" - LLUlt -> "ult" - LLUle -> "ule" - LLSgt -> "sgt" - LLSge -> "sge" - LLSlt -> "slt" - LLSle -> "sle" - -data Visibility = Local | Global -instance Show Visibility where - show :: Visibility -> String - show Local = "%" - show Global = "@" - --- | Represents a LLVM "value", as in an integer, a register variable, --- or a string contstant -data LLVMValue - = VInteger Integer - | VIdent Ident LLVMType - | VConstant String - | VFunction Ident Visibility LLVMType - -instance Show LLVMValue where - show :: LLVMValue -> String - show v = case v of - VInteger i -> show i - VIdent (Ident n) _ -> "%" <> n - VFunction (Ident n) vis _ -> show vis <> n - VConstant s -> "c" <> show s - -type Params = [(Ident, LLVMType)] -type Args = [(LLVMType, LLVMValue)] - --- | A datatype which represents different instructions in LLVM -data LLVMIr - = Type Ident [LLVMType] - | Define CallingConvention LLVMType Ident Params - | DefineEnd - | Declare LLVMType Ident Params - | SetVariable Ident LLVMIr - | Variable Ident - | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue - | Add LLVMType LLVMValue LLVMValue - | Sub LLVMType LLVMValue LLVMValue - | Div LLVMType LLVMValue LLVMValue - | Mul LLVMType LLVMValue LLVMValue - | Srem LLVMType LLVMValue LLVMValue - | Icmp LLVMComp LLVMType LLVMValue LLVMValue - | Br Ident - | BrCond LLVMValue Ident Ident - | Label Ident - | Call CallingConvention LLVMType Visibility Ident Args - | Alloca LLVMType - | Store LLVMType LLVMValue LLVMType Ident - | Load LLVMType LLVMType Ident - | Bitcast LLVMType Ident LLVMType - | Ret LLVMType LLVMValue - | Comment String - | UnsafeRaw String -- This should generally be avoided, and proper - -- instructions should be used in its place - deriving (Show) - --- | Converts a list of LLVMIr instructions to a string -llvmIrToString :: [LLVMIr] -> String -llvmIrToString = go 0 - where - go :: Int -> [LLVMIr] -> String - go _ [] = mempty - go i (x : xs) = do - let (i', n) = case x of - Define{} -> (i + 1, 0) - DefineEnd -> (i - 1, 0) - _ -> (i, i) - insToString n x <> go i' xs - -{- | Converts a LLVM inststruction to a String, allowing for printing etc. - The integer represents the indentation --} -{- FOURMOLU_DISABLE -} - insToString :: Int -> LLVMIr -> String - insToString i l = - replicate i '\t' <> case l of - (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do - -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 - concat - [ "getelementptr inbounds ", show t1, ", " , show t2 - , " ", show p, ", ", show t3, " ", show v1, - ", ", show t4, " ", show v2, "\n" ] - (Type (Ident n) types) -> - concat - [ "%", n, " = type { " - , intercalate ", " (map show types) - , " }\n" - ] - (Define c t (Ident i) params) -> - concat - [ "define ", show c, " ", show t, " @", i - , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) - , ") {\n" - ] - DefineEnd -> "}\n" - (Declare _t (Ident _i) _params) -> undefined - (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] - (Add t v1 v2) -> - concat - [ "add ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Sub t v1 v2) -> - concat - [ "sub ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Div t v1 v2) -> - concat - [ "sdiv ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Mul t v1 v2) -> - concat - [ "mul ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Srem t v1 v2) -> - concat - [ "srem ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Call c t vis (Ident i) arg) -> - concat - [ "call ", show c, " ", show t, " ", show vis, i, "(" - , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg - , ")\n" - ] - (Alloca t) -> unwords ["alloca", show t, "\n"] - (Store t1 val t2 (Ident id2)) -> - concat - [ "store ", show t1, " ", show val - , ", ", show t2 , " %", id2, "\n" - ] - (Load t1 t2 (Ident addr)) -> - concat - [ "load ", show t1, ", " - , show t2, " %", addr, "\n" - ] - (Bitcast t1 (Ident i) t2) -> - concat - [ "bitcast ", show t1, " %" - , i, " to ", show t2, "\n" - ] - (Icmp comp t v1 v2) -> - concat - [ "icmp ", show comp, " ", show t - , " ", show v1, ", ", show v2, "\n" - ] - (Ret t v) -> - concat - [ "ret ", show t, " " - , show v, "\n" - ] - (UnsafeRaw s) -> s - (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" - (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" - (BrCond val (Ident s1) (Ident s2)) -> - concat - [ "br i1 ", show val, ", ", "label %" - , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" - ] - (Comment s) -> "; " <> s <> "\n" - (Variable (Ident id)) -> "%" <> id -{- FOURMOLU_ENABLE -} - -lblPfx :: String -lblPfx = "lbl_" +module Codegen.LlvmIr where +-- {-# LANGUAGE LambdaCase #-} +-- +-- module Codegen.LlvmIr ( +-- LLVMType (..), +-- LLVMIr (..), +-- llvmIrToString, +-- LLVMValue (..), +-- LLVMComp (..), +-- Visibility (..), +-- CallingConvention (..) +-- ) where +-- +-- import Data.List (intercalate) +-- import TypeChecker.TypeCheckerIr +-- +-- data CallingConvention = TailCC | FastCC | CCC | ColdCC +-- instance Show CallingConvention where +-- show :: CallingConvention -> String +-- show TailCC = "tailcc" +-- show FastCC = "fastcc" +-- show CCC = "ccc" +-- show ColdCC = "coldcc" +-- +-- -- | A datatype which represents some basic LLVM types +-- data LLVMType +-- = I1 +-- | I8 +-- | I32 +-- | I64 +-- | Ptr +-- | Ref LLVMType +-- | Function LLVMType [LLVMType] +-- | Array Integer LLVMType +-- | CustomType Ident +-- +-- instance Show LLVMType where +-- show :: LLVMType -> String +-- show = \case +-- I1 -> "i1" +-- I8 -> "i8" +-- I32 -> "i32" +-- I64 -> "i64" +-- Ptr -> "ptr" +-- Ref ty -> show ty <> "*" +-- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" +-- Array n ty -> concat ["[", show n, " x ", show ty, "]"] +-- CustomType (Ident ty) -> "%" <> ty +-- +-- data LLVMComp +-- = LLEq +-- | LLNe +-- | LLUgt +-- | LLUge +-- | LLUlt +-- | LLUle +-- | LLSgt +-- | LLSge +-- | LLSlt +-- | LLSle +-- instance Show LLVMComp where +-- show :: LLVMComp -> String +-- show = \case +-- LLEq -> "eq" +-- LLNe -> "ne" +-- LLUgt -> "ugt" +-- LLUge -> "uge" +-- LLUlt -> "ult" +-- LLUle -> "ule" +-- LLSgt -> "sgt" +-- LLSge -> "sge" +-- LLSlt -> "slt" +-- LLSle -> "sle" +-- +-- data Visibility = Local | Global +-- instance Show Visibility where +-- show :: Visibility -> String +-- show Local = "%" +-- show Global = "@" +-- +-- -- | Represents a LLVM "value", as in an integer, a register variable, +-- -- or a string contstant +-- data LLVMValue +-- = VInteger Integer +-- | VIdent Ident LLVMType +-- | VConstant String +-- | VFunction Ident Visibility LLVMType +-- +-- instance Show LLVMValue where +-- show :: LLVMValue -> String +-- show v = case v of +-- VInteger i -> show i +-- VIdent (Ident n) _ -> "%" <> n +-- VFunction (Ident n) vis _ -> show vis <> n +-- VConstant s -> "c" <> show s +-- +-- type Params = [(Ident, LLVMType)] +-- type Args = [(LLVMType, LLVMValue)] +-- +-- -- | A datatype which represents different instructions in LLVM +-- data LLVMIr +-- = Type Ident [LLVMType] +-- | Define CallingConvention LLVMType Ident Params +-- | DefineEnd +-- | Declare LLVMType Ident Params +-- | SetVariable Ident LLVMIr +-- | Variable Ident +-- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue +-- | Add LLVMType LLVMValue LLVMValue +-- | Sub LLVMType LLVMValue LLVMValue +-- | Div LLVMType LLVMValue LLVMValue +-- | Mul LLVMType LLVMValue LLVMValue +-- | Srem LLVMType LLVMValue LLVMValue +-- | Icmp LLVMComp LLVMType LLVMValue LLVMValue +-- | Br Ident +-- | BrCond LLVMValue Ident Ident +-- | Label Ident +-- | Call CallingConvention LLVMType Visibility Ident Args +-- | Alloca LLVMType +-- | Store LLVMType LLVMValue LLVMType Ident +-- | Load LLVMType LLVMType Ident +-- | Bitcast LLVMType Ident LLVMType +-- | Ret LLVMType LLVMValue +-- | Comment String +-- | UnsafeRaw String -- This should generally be avoided, and proper +-- -- instructions should be used in its place +-- deriving (Show) +-- +-- -- | Converts a list of LLVMIr instructions to a string +-- llvmIrToString :: [LLVMIr] -> String +-- llvmIrToString = go 0 +-- where +-- go :: Int -> [LLVMIr] -> String +-- go _ [] = mempty +-- go i (x : xs) = do +-- let (i', n) = case x of +-- Define{} -> (i + 1, 0) +-- DefineEnd -> (i - 1, 0) +-- _ -> (i, i) +-- insToString n x <> go i' xs +-- +-- {- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +-- -} +-- {- FOURMOLU_DISABLE -} +-- insToString :: Int -> LLVMIr -> String +-- insToString i l = +-- replicate i '\t' <> case l of +-- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do +-- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 +-- concat +-- [ "getelementptr inbounds ", show t1, ", " , show t2 +-- , " ", show p, ", ", show t3, " ", show v1, +-- ", ", show t4, " ", show v2, "\n" ] +-- (Type (Ident n) types) -> +-- concat +-- [ "%", n, " = type { " +-- , intercalate ", " (map show types) +-- , " }\n" +-- ] +-- (Define c t (Ident i) params) -> +-- concat +-- [ "define ", show c, " ", show t, " @", i +-- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) +-- , ") {\n" +-- ] +-- DefineEnd -> "}\n" +-- (Declare _t (Ident _i) _params) -> undefined +-- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] +-- (Add t v1 v2) -> +-- concat +-- [ "add ", show t, " ", show v1 +-- , ", ", show v2, "\n" +-- ] +-- (Sub t v1 v2) -> +-- concat +-- [ "sub ", show t, " ", show v1, ", " +-- , show v2, "\n" +-- ] +-- (Div t v1 v2) -> +-- concat +-- [ "sdiv ", show t, " ", show v1, ", " +-- , show v2, "\n" +-- ] +-- (Mul t v1 v2) -> +-- concat +-- [ "mul ", show t, " ", show v1 +-- , ", ", show v2, "\n" +-- ] +-- (Srem t v1 v2) -> +-- concat +-- [ "srem ", show t, " ", show v1, ", " +-- , show v2, "\n" +-- ] +-- (Call c t vis (Ident i) arg) -> +-- concat +-- [ "call ", show c, " ", show t, " ", show vis, i, "(" +-- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg +-- , ")\n" +-- ] +-- (Alloca t) -> unwords ["alloca", show t, "\n"] +-- (Store t1 val t2 (Ident id2)) -> +-- concat +-- [ "store ", show t1, " ", show val +-- , ", ", show t2 , " %", id2, "\n" +-- ] +-- (Load t1 t2 (Ident addr)) -> +-- concat +-- [ "load ", show t1, ", " +-- , show t2, " %", addr, "\n" +-- ] +-- (Bitcast t1 (Ident i) t2) -> +-- concat +-- [ "bitcast ", show t1, " %" +-- , i, " to ", show t2, "\n" +-- ] +-- (Icmp comp t v1 v2) -> +-- concat +-- [ "icmp ", show comp, " ", show t +-- , " ", show v1, ", ", show v2, "\n" +-- ] +-- (Ret t v) -> +-- concat +-- [ "ret ", show t, " " +-- , show v, "\n" +-- ] +-- (UnsafeRaw s) -> s +-- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" +-- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" +-- (BrCond val (Ident s1) (Ident s2)) -> +-- concat +-- [ "br i1 ", show val, ", ", "label %" +-- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" +-- ] +-- (Comment s) -> "; " <> s <> "\n" +-- (Variable (Ident id)) -> "%" <> id +-- {- FOURMOLU_ENABLE -} +-- +-- lblPfx :: String +-- lblPfx = "lbl_" +-- diff --git a/src/LambdaLifter/LambdaLifter.hs b/src/LambdaLifter/LambdaLifter.hs index 661b95a..271cc70 100644 --- a/src/LambdaLifter/LambdaLifter.hs +++ b/src/LambdaLifter/LambdaLifter.hs @@ -1,235 +1,192 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +--{-# LANGUAGE LambdaCase #-} +--{-# LANGUAGE OverloadedStrings #-} -module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where +module LambdaLifter.LambdaLifter where -import Auxiliary (snoc) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad.State (MonadState (get, put), State, - evalState) -import Data.Set (Set) -import qualified Data.Set as Set -import Debug.Trace (trace) -import qualified Grammar.Abs as GA -import Prelude hiding (exp) -import Renamer.Renamer -import TypeChecker.TypeCheckerIr +--import Auxiliary (snoc) +--import Control.Applicative (Applicative (liftA2)) +--import Control.Monad.State (MonadState (get, put), State, +-- evalState) +--import Data.Set (Set) +--import qualified Data.Set as Set +--import Prelude hiding (exp) +--import Renamer.Renamer +--import TypeChecker.TypeCheckerIr --- | Lift lambdas and let expression into supercombinators. --- Three phases: --- @freeVars@ annotatss all the free variables. --- @abstract@ converts lambdas into let expressions. --- @collectScs@ moves every non-constant let expression to a top-level function. -lambdaLift :: Program -> Program -lambdaLift = collectScs . abstract . freeVars - --- | Annotate free variables -freeVars :: Program -> AnnProgram -freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) - | Bind n xs e <- ds - ] - -freeVarsExp :: Set Id -> Exp -> AnnExp -freeVarsExp localVars = \case - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) - - EInt i -> (mempty, AInt i) - - EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 - - EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 - - ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 +---- | Lift lambdas and let expression into supercombinators. +---- Three phases: +---- @freeVars@ annotatss all the free variables. +---- @abstract@ converts lambdas into let expressions. +---- @collectScs@ moves every non-constant let expression to a top-level function. +--lambdaLift :: Program -> Program +--lambdaLift = collectScs . abstract . freeVars - EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') - where - e' = freeVarsExp (Set.insert par localVars) e +---- | Annotate free variables +--freeVars :: Program -> AnnProgram +--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) +-- | Bind n xs e <- ds +-- ] - -- Sum free variables present in bind and the expression - ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') - where - binders_frees = Set.delete name $ freeVarsOf rhs' - e_free = Set.delete name $ freeVarsOf e' +--freeVarsExp :: Set Id -> Exp -> AnnExp +--freeVarsExp localVars = \case +-- EId n | Set.member n localVars -> (Set.singleton n, AId n) +-- | otherwise -> (mempty, AId n) - rhs' = freeVarsExp e_localVars rhs - new_bind = ABind name parms rhs' +-- ELit _ (LInt i) -> (mempty, AInt i) - e' = freeVarsExp e_localVars e - e_localVars = Set.insert name localVars +-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') +-- where +-- e1' = freeVarsExp localVars e1 +-- e2' = freeVarsExp localVars e2 - (ECase t e cs) -> do - let e' = freeVarsExp localVars e - let vars = freeVarsOf e' - let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do - let e' = freeVarsExp vars e - let vars' = freeVarsOf e' - (Set.union vars vars', AnnCase c e' : acc) - ) (vars, []) cs - (vars', ACase t e' (reverse cs')) +-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') +-- where +-- e1' = freeVarsExp localVars e1 +-- e2' = freeVarsExp localVars e2 + +-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') +-- where +-- e' = freeVarsExp (Set.insert par localVars) e + +-- -- Sum free variables present in bind and the expression +-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') +-- where +-- binders_frees = Set.delete name $ freeVarsOf rhs' +-- e_free = Set.delete name $ freeVarsOf e' + +-- rhs' = freeVarsExp e_localVars rhs +-- new_bind = ABind name parms rhs' + +-- e' = freeVarsExp e_localVars e +-- e_localVars = Set.insert name localVars -freeVarsOf :: AnnExp -> Set Id -freeVarsOf = fst +--freeVarsOf :: AnnExp -> Set Id +--freeVarsOf = fst --- AST annotated with free variables -type AnnProgram = [(Id, [Id], AnnExp)] +---- AST annotated with free variables +--type AnnProgram = [(Id, [Id], AnnExp)] -type AnnExp = (Set Id, AnnExp') +--type AnnExp = (Set Id, AnnExp') -data ABind = ABind Id [Id] AnnExp deriving Show +--data ABind = ABind Id [Id] AnnExp deriving Show -data AnnExp' = AId Id - | AInt Integer - | ALet ABind AnnExp - | AApp Type AnnExp AnnExp - | AAdd Type AnnExp AnnExp - | ASub Type AnnExp AnnExp - | AAbs Type Id AnnExp - | ACase Type AnnExp [AnnCase] - deriving Show -data AnnCase = AnnCase GA.Case AnnExp - deriving Show +--data AnnExp' = AId Id +-- | AInt Integer +-- | ALet ABind AnnExp +-- | AApp Type AnnExp AnnExp +-- | AAdd Type AnnExp AnnExp +-- | AAbs Type Id AnnExp +-- deriving Show +---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +---- Free variables are @v₁ v₂ .. vₙ@ are bound. +--abstract :: AnnProgram -> Program +--abstract prog = Program $ evalState (mapM go prog) 0 +-- where +-- go :: (Id, [Id], AnnExp) -> State Int Bind +-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' +-- where +-- (rhs', parms1) = flattenLambdasAnn rhs --- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. --- Free variables are @v₁ v₂ .. vₙ@ are bound. -abstract :: AnnProgram -> Program -abstract prog = Program $ evalState (mapM go prog) 0 - where - go :: (Id, [Id], AnnExp) -> State Int Bind - go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' - where - (rhs', parms1) = flattenLambdasAnn rhs +---- | Flatten nested lambdas and collect the parameters +---- @\x.\y.\z. ae → (ae, [x,y,z])@ +--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) +--flattenLambdasAnn ae = go (ae, []) +-- where +-- go :: (AnnExp, [Id]) -> (AnnExp, [Id]) +-- go ((free, e), acc) = +-- case e of +-- AAbs _ par (free1, e1) -> +-- go ((Set.delete par free1, e1), snoc par acc) +-- _ -> ((free, e), acc) + +--abstractExp :: AnnExp -> State Int Exp +--abstractExp (free, exp) = case exp of +-- AId n -> pure $ EId n +-- AInt i -> pure $ ELit (TMono "Int") (LInt i) +-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) +-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) +-- ALet b e -> liftA2 ELet (go b) (abstractExp e) +-- where +-- go (ABind name parms rhs) = do +-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs +-- pure $ Bind name (parms ++ parms1) rhs' + +-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp +-- skipLambdas f (free, ae) = case ae of +-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 +-- _ -> f (free, ae) + +-- -- Lift lambda into let and bind free variables +-- AAbs t parm e -> do +-- i <- nextNumber +-- rhs <- abstractExp e + +-- let sc_name = Ident ("sc_" ++ show i) +-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) + +-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList +-- where +-- freeList = Set.toList free +-- parms = snoc parm freeList --- | Flatten nested lambdas and collect the parameters --- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) -flattenLambdasAnn ae = go (ae, []) - where - go :: (AnnExp, [Id]) -> (AnnExp, [Id]) - go ((free, e), acc) = - case e of - AAbs _ par (free1, e1) -> - go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) +--nextNumber :: State Int Int +--nextNumber = do +-- i <- get +-- put $ succ i +-- pure i -abstractExp :: AnnExp -> State Int Exp -abstractExp (free, exp) = case exp of - AId n -> pure $ EId n - AInt i -> pure $ EInt i - AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) - AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2) - ALet b e -> liftA2 ELet (go b) (abstractExp e) - where - go (ABind name parms rhs) = do - (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs - pure $ Bind name (parms ++ parms1) rhs' - - skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp - skipLambdas f (free, ae) = case ae of - AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 - _ -> f (free, ae) +---- | Collects supercombinators by lifting non-constant let expressions +--collectScs :: Program -> Program +--collectScs (Program scs) = Program $ concatMap collectFromRhs scs +-- where +-- collectFromRhs (Bind name parms rhs) = +-- let (rhs_scs, rhs') = collectScsExp rhs +-- in Bind name parms rhs' : rhs_scs - ACase t e cs -> do - e' <- abstractExp e - cs' <- mapM (\(AnnCase c e) -> do - e' <- abstractExp e - pure (t,Case c e')) cs - pure $ ECase t e' cs' +--collectScsExp :: Exp -> ([Bind], Exp) +--collectScsExp = \case +-- EId n -> ([], EId n) +-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i)) - -- Lift lambda into let and bind free variables - AAbs t parm e -> do - i <- nextNumber - rhs <- abstractExp e +-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') +-- where +-- (scs1, e1') = collectScsExp e1 +-- (scs2, e2') = collectScsExp e2 - let sc_name = Ident ("sc_" ++ show i) - sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) +-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') +-- where +-- (scs1, e1') = collectScsExp e1 +-- (scs2, e2') = collectScsExp e2 - pure $ foldl (EApp TInt) sc $ map EId freeList - where - freeList = Set.toList free - parms = snoc parm freeList +-- EAbs t par e -> (scs, EAbs t par e') +-- where +-- (scs, e') = collectScsExp e + +-- -- Collect supercombinators from bind, the rhss, and the expression. +-- -- +-- -- > f = let sc x y = rhs in e +-- -- +-- ELet (Bind name parms rhs) e -> if null parms +-- then ( rhs_scs ++ e_scs, ELet bind e') +-- else (bind : rhs_scs ++ e_scs, e') +-- where +-- bind = Bind name parms rhs' +-- (rhs_scs, rhs') = collectScsExp rhs +-- (e_scs, e') = collectScsExp e -nextNumber :: State Int Int -nextNumber = do - i <- get - put $ succ i - pure i +---- @\x.\y.\z. e → (e, [x,y,z])@ +--flattenLambdas :: Exp -> (Exp, [Id]) +--flattenLambdas = go . (, []) +-- where +-- go (e, acc) = case e of +-- EAbs _ par e1 -> go (e1, snoc par acc) +-- _ -> (e, acc) --- | Collects supercombinators by lifting non-constant let expressions -collectScs :: Program -> Program -collectScs (Program scs) = Program $ concatMap collectFromRhs scs - where - collectFromRhs (Bind name parms rhs) = - let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs - - -collectScsExp :: Exp -> ([Bind], Exp) -collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) - - EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 - - EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 - - ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 - - EAbs t par e -> (scs, EAbs t par e') - where - (scs, e') = collectScsExp e - - -- Collect supercombinators from bind, the rhss, and the expression. - -- - -- > f = let sc x y = rhs in e - -- - ELet (Bind name parms rhs) e -> if null parms - then ( rhs_scs ++ e_scs, ELet bind e') - else (bind : rhs_scs ++ e_scs, e') - where - bind = Bind name parms rhs' - (rhs_scs, rhs') = collectScsExp rhs - (e_scs, e') = collectScsExp e - ECase t e cs -> do - let (scs, e') = collectScsExp e - let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do - let (scs', e') = collectScsExp e - (scs ++ scs', (t,Case c e') : acc) - ) (scs,[]) cs - (scs', ECase t e' cs') - - --- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: Exp -> (Exp, [Id]) -flattenLambdas = go . (, []) - where - go (e, acc) = case e of - EAbs _ par e1 -> go (e1, snoc par acc) - _ -> (e, acc) diff --git a/src/Main.hs b/src/Main.hs index 7390341..c82f6a5 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,26 +2,26 @@ module Main where -import Codegen.Codegen (generateCode) -import GHC.IO.Handle.Text (hPutStrLn) -import Grammar.ErrM (Err) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) +--import Codegen.Codegen (generateCode) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) -- import Interpreter (interpret) -import Control.Monad (when) -import Data.List.Extra (isSuffixOf) -import LambdaLifter.LambdaLifter (lambdaLift) -import Renamer.Renamer (rename) -import System.Directory (createDirectory, doesPathExist, - getDirectoryContents, - removeDirectoryRecursive, - setCurrentDirectory) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) -import System.IO (stderr) -import System.Process.Extra (spawnCommand, waitForProcess) -import TypeChecker.TypeChecker (typecheck) +import Control.Monad (when) +import Data.List.Extra (isSuffixOf) +--import LambdaLifter.LambdaLifter (lambdaLift) +import Renamer.Renamer (rename) +import System.Directory (createDirectory, doesPathExist, + getDirectoryContents, + removeDirectoryRecursive, + setCurrentDirectory) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import System.Process.Extra (spawnCommand, waitForProcess) +import TypeChecker.TypeChecker (typecheck) main :: IO () main = @@ -46,19 +46,19 @@ main' debug s = do typechecked <- fromTypeCheckerErr $ typecheck renamed printToErr $ printTree typechecked - printToErr "\n-- Lambda Lifter --" - let lifted = lambdaLift typechecked - printToErr $ printTree lifted - - printToErr "\n -- Printing compiler output to stdout --" - compiled <- fromCompilerErr $ generateCode lifted + -- printToErr "\n-- Lambda Lifter --" + -- let lifted = lambdaLift typechecked + -- printToErr $ printTree lifted +-- + -- printToErr "\n -- Printing compiler output to stdout --" + -- compiled <- fromCompilerErr $ generateCode lifted --putStrLn compiled - check <- doesPathExist "output" - when check (removeDirectoryRecursive "output") - createDirectory "output" - writeFile "output/llvm.ll" compiled - if debug then debugDotViz else putStrLn compiled + -- check <- doesPathExist "output" + -- when check (removeDirectoryRecursive "output") + -- createDirectory "output" + -- writeFile "output/llvm.ll" compiled + -- if debug then debugDotViz else putStrLn compiled -- interpred <- fromInterpreterErr $ interpret lifted diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index 3c426b4..1def35e 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -1,86 +1,87 @@ {-# LANGUAGE LambdaCase #-} -module Renamer.Renamer (module Renamer.Renamer) where +module Renamer.Renamer where import Auxiliary (mapAccumM) -import Control.Monad (foldM) import Control.Monad.State (MonadState, State, evalState, gets, modify) +import Data.List (foldl') import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe (fromMaybe) import Data.Tuple.Extra (dupe) import Grammar.Abs - -- | Rename all variables and local binds rename :: Program -> Program rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 where - initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs - renameSc :: Names -> Bind -> Rn Bind - renameSc old_names (Bind name t _ parms rhs) = do + -- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + initNames = Map.fromList $ foldl' saveIfBind [] bs + saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc + saveIfBind acc _ = acc + renameSc :: Names -> Def -> Rn Def + renameSc old_names (DBind (Bind name t _ parms rhs)) = do (new_names, parms') <- newNames old_names parms - rhs' <- snd <$> renameExp new_names rhs - pure $ Bind name t name parms' rhs' - + rhs' <- snd <$> renameExp new_names rhs + pure . DBind $ Bind name t name parms' rhs' + renameSc _ def = pure def -- | Rename monad. State holds the number of renamed names. -newtype Rn a = Rn { runRn :: State Int a } - deriving (Functor, Applicative, Monad, MonadState Int) +newtype Rn a = Rn {runRn :: State Int a} + deriving (Functor, Applicative, Monad, MonadState Int) -- | Maps old to new name type Names = Map Ident Ident renameLocalBind :: Names -> Bind -> Rn (Names, Bind) renameLocalBind old_names (Bind name t _ parms rhs) = do - (new_names, name') <- newName old_names name + (new_names, name') <- newName old_names name (new_names', parms') <- newNames new_names parms - (new_names'', rhs') <- renameExp new_names' rhs + (new_names'', rhs') <- renameExp new_names' rhs pure (new_names'', Bind name' t name' parms' rhs') renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp old_names = \case - EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) - - EInt i1 -> pure (old_names, EInt i1) - + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) + ELit (LInt i1) -> pure (old_names, ELit (LInt i1)) EApp e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, EApp e1' e2') - EAdd e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, EAdd e1' e2') - ESub e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, ESub e1' e2') - - ELet b e -> do - (new_names, b) <- renameLocalBind old_names b - (new_names', e') <- renameExp new_names e - pure (new_names', ELet b e') - - EAbs par t e -> do + ELet i e1 e2 -> do + (new_names, e1') <- renameExp old_names e1 + (new_names', e2') <- renameExp new_names e2 + pure (new_names', ELet i e1' e2') + EAbs par e -> do (new_names, par') <- newName old_names par - (new_names', e') <- renameExp new_names e - pure (new_names', EAbs par' t e') - + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' e') EAnn e t -> do (new_names, e') <- renameExp old_names e pure (new_names, EAnn e' t) + ECase e injs -> do + (_, e') <- renameExp old_names e + (new_names, injs') <- renameInjs old_names injs + pure (new_names, ECase e' injs') - ECase e cs t -> do - (new_names, e') <- renameExp old_names e - (new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do - (nm,exp') <- renameExp names exp - pure (nm,CaseMatch c exp' : stack) - ) (new_names, []) cs - pure (new_names', ECase e' cs' t) +renameInjs :: Names -> [Inj] -> Rn (Names, [Inj]) +renameInjs ns xs = do + (new_names, xs') <- unzip <$> mapM (renameInj ns) xs + if null new_names then return (mempty, xs') else return (head new_names, xs') + +renameInj :: Names -> Inj -> Rn (Names, Inj) +renameInj ns (Inj init e) = do + (new_names, e') <- renameExp ns e + return (new_names, Inj init e') -- | Create a new name and add it to name environment. newName :: Names -> Ident -> Rn (Names, Ident) @@ -95,4 +96,3 @@ newNames = mapAccumM newName -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. makeName :: Ident -> Rn Ident makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ - diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 3d6bba8..c9a4ac4 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -1,215 +1,517 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} -module TypeChecker.TypeChecker (typecheck, partitionType) where +-- | A module for type checking and inference using algorithm W, Hindley-Milner +module TypeChecker.TypeChecker where -import Auxiliary (maybeToRightM, snoc) -import Control.Monad.Except (throwError, unless) +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Foldable (traverse_) +import Data.Functor.Identity (runIdentity) +import Data.List (foldl') import Data.Map (Map) -import qualified Data.Map as Map +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S +import Debug.Trace (trace) import Grammar.Abs -import Grammar.ErrM (Err) -import Grammar.Print (Print (prt), concatD, doc, - printTree, render) -import Prelude hiding (exp, id) +import Grammar.Print (printTree) import qualified TypeChecker.TypeCheckerIr as T +import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, + Poly (..), Subst) --- NOTE: this type checker is poorly tested +initCtx = Ctx mempty --- TODO --- Coercion --- Type inference +initEnv = Env 0 mempty mempty -data Cxt = Cxt - { env :: Map Ident Type -- ^ Local scope signature - , sig :: Map Ident Type -- ^ Top-level signatures - } +runPretty :: Exp -> Either Error String +runPretty = fmap (printTree . fst) . run . inferExp -initCxt :: [Bind] -> Cxt -initCxt sc = Cxt { env = mempty - , sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc - } +run :: Infer a -> Either Error a +run = runC initEnv initCtx -typecheck :: Program -> Err T.Program -typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc +runC :: Env -> Ctx -> Infer a -> Either Error a +runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e --- | Check if infered rhs type matches type signature. -checkBind :: Cxt -> Bind -> Err T.Bind -checkBind cxt b = - case expandLambdas b of - Bind name t _ parms rhs -> do - (rhs', t_rhs) <- infer cxt rhs - unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs - pure $ T.Bind (name, t) (zip parms ts_parms) rhs' - where - ts_parms = fst $ partitionType (length parms) t +typecheck :: Program -> Either Error T.Program +typecheck = run . checkPrg --- | @ f x y = rhs ⇒ f = \x.\y. rhs @ -expandLambdas :: Bind -> Bind -expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' +{- | Start by freshening the type variable of data types to avoid clash with +other user defined polymorphic types +This might be wrong for type constructors that work over several variables +-} +freshenData :: Data -> Infer Data +freshenData (Data (Constr name ts) constrs) = do + fr <- fresh + let fr' = case fr of + TPol a -> a + -- Meh, this part assumes fresh generates a polymorphic type + _ -> + error + "Bug: implementation of \ + \ fresh and freshenData are not compatible" + let new_ts = map (freshenType fr') ts + let new_constrs = map (freshenConstr fr') constrs + return $ Data (Constr name new_ts) new_constrs + +{- | Freshen all polymorphic variables, regardless of name +| freshenType "d" (a -> b -> c) becomes (d -> d -> d) +-} +freshenType :: Ident -> Type -> Type +freshenType iden = \case + (TPol _) -> TPol iden + (TArr a b) -> TArr (freshenType iden a) (freshenType iden b) + (TConstr (Constr a ts)) -> + TConstr (Constr a (map (freshenType iden) ts)) + rest -> rest + +freshenConstr :: Ident -> Constructor -> Constructor +freshenConstr iden (Constructor name t) = + Constructor name (freshenType iden t) + +checkData :: Data -> Infer () +checkData d = do + d' <- freshenData d + case d' of + (Data typ@(Constr name ts) constrs) -> do + unless + (all isPoly ts) + (throwError $ unwords ["Data type incorrectly declared"]) + traverse_ + ( \(Constructor name' t') -> + if TConstr typ == retType t' + then insertConstr name' t' + else + throwError $ + unwords + [ "return type of constructor:" + , printTree name + , "with type:" + , printTree (retType t') + , "does not match data: " + , printTree typ + ] + ) + constrs + +retType :: Type -> Type +retType (TArr _ t2) = retType t2 +retType a = a + +checkPrg :: Program -> Infer T.Program +checkPrg (Program bs) = do + preRun bs + T.Program <$> checkDef bs where - rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms - ts_parms = fst $ partitionType (length parms) t + preRun :: [Def] -> Infer () + preRun [] = return () + preRun (x : xs) = case x of + DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs + DData d@(Data _ _) -> checkData d >> preRun xs --- | Infer type of expression. -infer :: Cxt -> Exp -> Err (T.Exp, Type) -infer cxt = \case - EId x -> - case lookupEnv x cxt of - Nothing -> - case lookupSig x cxt of - Nothing -> throwError ("Unbound variable:" ++ printTree x) - Just t -> pure (T.EId (x, t), t) - Just t -> pure (T.EId (x, t), t) + checkDef :: [Def] -> Infer [T.Def] + checkDef [] = return [] + checkDef (x : xs) = case x of + (DBind b) -> do + b' <- checkBind b + fmap (T.DBind b' :) (checkDef xs) + (DData d) -> fmap (T.DData d :) (checkDef xs) - EInt i -> pure (T.EInt i, T.TInt) +checkBind :: Bind -> Infer T.Bind +checkBind (Bind n t _ args e) = do + (t', e') <- inferExp $ makeLambda e (reverse args) + s <- unify t t' + let t'' = apply s t + unless + (t `typeEq` t'') + ( throwError $ + unwords + [ "Top level signature" + , printTree t + , "does not match body with inferred type:" + , printTree t'' + ] + ) + return $ T.Bind (n, t) e' + where + makeLambda :: Exp -> [Ident] -> Exp + makeLambda = foldl (flip EAbs) - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure (T.EApp t2 e' e1', t2) - _ -> do - throwError ("Not a function: " ++ show e) - - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure (T.EAdd T.TInt e' e1', T.TInt) - - ESub e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure (T.ESub T.TInt e' e1', T.TInt) - - EAbs x t e -> do - (e', t1) <- infer (insertEnv x t cxt) e - let t_abs = TFun t t1 - pure (T.EAbs t_abs (x, t) e', t_abs) - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - (e', t) <- infer cxt' e - pure (T.ELet b' e', t) - - EAnn e t -> do - (e', t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (e', t1) - - ECase e cs t -> do - (e',t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - case traverse (\(CaseMatch c e) -> do - -- //TODO check c as well - e' <- check cxt e t - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (t1, T.Case c e') - ) cs of - Right cs -> pure (T.ECase t1 e' cs,t1) - Left e -> throwError e - --- | Check infered type matches the supplied type. -check :: Cxt -> Exp -> Type -> Err T.Exp -check cxt exp typ = case exp of - EId x -> do - t <- case lookupEnv x cxt of - Nothing -> maybeToRightM - ("Unbound variable:" ++ printTree x) - (lookupSig x cxt) - Just t -> pure t - unless (typeEq t typ) . throwError $ typeErr x typ t - pure $ T.EId (x, t) - - EInt i -> do - unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ - pure $ T.EInt i - - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure $ T.EApp t2 e' e1' - _ -> throwError ("Not a function 2: " ++ printTree e) - - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure $ T.EAdd T.TInt e' e1' - - ESub e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure $ T.ESub T.TInt e' e1' - - EAbs x t e -> do - (e', t_e) <- infer (insertEnv x t cxt) e - let t1 = TFun t t_e - unless (typeEq t1 typ) $ throwError "Wrong lamda type!" - pure $ T.EAbs t1 (x, t) e' - - ECase e cs t -> do - (e',t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - case traverse (\(CaseMatch c e) -> do - -- //TODO check c as well - e' <- check cxt e t - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (t1, T.Case c e') - ) cs of - Right cs -> pure $ T.ECase t1 e' cs - Left e -> throwError e - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - e' <- check cxt' e typ - pure $ T.ELet b' e' - - EAnn e t -> do - unless (typeEq t typ) $ - throwError "Inferred type and type annotation doesn't match" - check cxt e t - --- | Check if types are equivalent. Doesn't handle coercion or polymorphism. +{- | Check if two types are considered equal + For the purpose of the algorithm two polymorphic types are always considered + equal +-} typeEq :: Type -> Type -> Bool -typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 -typeEq t t1 = t == t1 +typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' +typeEq (TMono a) (TMono b) = a == b +typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = + length a == length b + && name == name' + && and (zipWith typeEq a b) +typeEq (TPol _) (TPol _) = True +typeEq _ _ = False --- | Partion type into types of parameters and return type. -partitionType :: Int -- Number of parameters to apply - -> Type - -> ([Type], Type) -partitionType = go [] - where - go acc 0 t = (acc, t) - go acc i t = case t of - TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" +isMoreSpecificOrEq :: Type -> Type -> Bool +isMoreSpecificOrEq _ (TPol _) = True +isMoreSpecificOrEq (TArr a b) (TArr c d) = + isMoreSpecificOrEq a c && isMoreSpecificOrEq b d +isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) = + n1 == n2 + && length ts1 == length ts2 + && and (zipWith isMoreSpecificOrEq ts1 ts2) +isMoreSpecificOrEq a b = a == b -insertBind :: Bind -> Cxt -> Cxt -insertBind (Bind n t _ _ _) = insertEnv n t +isPoly :: Type -> Bool +isPoly (TPol _) = True +isPoly _ = False -lookupEnv :: Ident -> Cxt -> Maybe Type -lookupEnv x = Map.lookup x . env +inferExp :: Exp -> Infer (Type, T.Exp) +inferExp e = do + (s, t, e') <- algoW e + let subbed = apply s t + return (subbed, replace subbed e') -insertEnv :: Ident -> Type -> Cxt -> Cxt -insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } +replace :: Type -> T.Exp -> T.Exp +replace t = \case + T.ELit _ e -> T.ELit t e + T.EId (n, _) -> T.EId (n, t) + T.EAbs _ name e -> T.EAbs t name e + T.EApp _ e1 e2 -> T.EApp t e1 e2 + T.EAdd _ e1 e2 -> T.EAdd t e1 e2 + T.ESub _ e1 e2 -> T.ESub t e1 e2 + T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2 + T.ECase _ expr injs -> T.ECase t expr injs -lookupSig :: Ident -> Cxt -> Maybe Type -lookupSig x = Map.lookup x . sig +algoW :: Exp -> Infer (Subst, Type, T.Exp) +algoW = \case + -- \| TODO: More testing need to be done. Unsure of the correctness of this + EAnn e t -> do + (s1, t', e') <- algoW e + unless + (t `isMoreSpecificOrEq` t') + ( throwError $ + unwords + [ "Annotated type:" + , printTree t + , "does not match inferred type:" + , printTree t' + ] + ) + applySt s1 $ do + s2 <- unify t t' + return (s2 `compose` s1, t, e') -typeErr :: Print a => a -> Type -> Type -> String -typeErr p expected actual = render $ concatD - [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" - , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" - , doc $ showString "Actual: " , prt 0 actual - ] + -- \| ------------------ + -- \| Γ ⊢ i : Int, ∅ + + ELit (LInt n) -> + return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) + ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a + -- \| x : σ ∈ Γ   τ = inst(σ) + -- \| ---------------------- + -- \| Γ ⊢ x : τ, ∅ + + EId i -> do + var <- asks vars + case M.lookup i var of + Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x)) + Nothing -> do + sig <- gets sigs + case M.lookup i sig of + Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> do + constr <- gets constructors + case M.lookup i constr of + Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> + throwError $ + "Unbound variable: " ++ show i + + -- \| τ = newvar Γ, x : τ ⊢ e : τ', S + -- \| --------------------------------- + -- \| Γ ⊢ w λx. e : Sτ → τ', S + + EAbs name e -> do + fr <- fresh + withBinding name (Forall [] fr) $ do + (s1, t', e') <- algoW e + let varType = apply s1 fr + let newArr = TArr varType t' + return (s1, newArr, T.EAbs newArr (name, varType) e') + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ + -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) + -- \| ------------------------------------------ + -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ + -- This might be wrong + + EAdd e0 e1 -> do + (s1, t0, e0') <- algoW e0 + applySt s1 $ do + (s2, t1, e1') <- algoW e1 + -- applySt s2 $ do + s3 <- unify (apply s2 t0) (TMono "Int") + s4 <- unify (apply s3 t1) (TMono "Int") + return + ( s4 `compose` s3 `compose` s2 `compose` s1 + , TMono "Int" + , T.EAdd (TMono "Int") e0' e1' + ) + + ESub e0 e1 -> do + (s1, t0, e0') <- algoW e0 + applySt s1 $ do + (s2, t1, e1') <- algoW e1 + -- applySt s2 $ do + s3 <- unify (apply s2 t0) (TMono "Int") + s4 <- unify (apply s3 t1) (TMono "Int") + return + ( s4 `compose` s3 `compose` s2 `compose` s1 + , TMono "Int" + , T.ESub (TMono "Int") e0' e1' + ) + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 + -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') + -- \| -------------------------------------- + -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ + + EApp e0 e1 -> do + fr <- fresh + (s0, t0, e0') <- algoW e0 + applySt s0 $ do + (s1, t1, e1') <- algoW e1 + -- applySt s1 $ do + s2 <- unify (apply s1 t0) (TArr t1 fr) + let t = apply s2 fr + return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') + + -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ + -- \| ---------------------------------------------- + -- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ + + -- The bar over S₀ and Γ means "generalize" + + ELet name e0 e1 -> do + (s1, t1, e0') <- algoW e0 + env <- asks vars + let t' = generalize (apply s1 env) t1 + withBinding name t' $ do + (s2, t2, e1') <- algoW e1 + return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') + ECase caseExpr injs -> do + (_, t0, e0') <- algoW caseExpr + (injs', ts) <- mapAndUnzipM (checkInj t0) injs + case ts of + [] -> throwError "Case expression missing any matches" + ts -> do + unified <- zipWithM unify ts (tail ts) + let unified' = foldl' compose mempty unified + let typ = apply unified' (head ts) + return (unified', typ, T.ECase typ e0' injs') + +-- | Unify two types producing a new substitution +unify :: Type -> Type -> Infer Subst +unify t0 t1 = do + trace ("t0: " ++ show t0) return () + trace ("t1: " ++ show t1) return () + case (t0, t1) of + (TArr a b, TArr c d) -> do + s1 <- unify a c + s2 <- unify (apply s1 b) (apply s1 d) + return $ s1 `compose` s2 + (TPol a, b) -> occurs a b + (a, TPol b) -> occurs b a + (TMono a, TMono b) -> + if a == b then return M.empty else throwError "Types do not unify" + -- \| TODO: Figure out a cleaner way to express the same thing + (TConstr (Constr name t), TConstr (Constr name' t')) -> + if name == name' && length t == length t' + then do + xs <- zipWithM unify t t' + return $ foldr compose nullSubst xs + else + throwError $ + unwords + [ "Type constructor:" + , printTree name + , "(" ++ printTree t ++ ")" + , "does not match with:" + , printTree name' + , "(" ++ printTree t' ++ ")" + ] + (a, b) -> + throwError . unwords $ + [ "Type:" + , printTree a + , "can't be unified with:" + , printTree b + ] + +{- | Check if a type is contained in another type. +I.E. { a = a -> b } is an unsolvable constraint since there is no substitution +such that these are equal +-} +occurs :: Ident -> Type -> Infer Subst +occurs _ (TPol _) = return nullSubst +occurs i t = + if S.member i (free t) + then + throwError $ + unwords + [ "Occurs check failed, can't unify" + , printTree (TPol i) + , "with" + , printTree t + ] + else return $ M.singleton i t + +-- | Generalize a type over all free variables in the substitution set +generalize :: Map Ident Poly -> Type -> Poly +generalize env t = Forall (S.toList $ free t S.\\ free env) t + +{- | Instantiate a polymorphic type. The free type variables are substituted +with fresh ones. +-} +inst :: Poly -> Infer Type +inst (Forall xs t) = do + xs' <- mapM (const fresh) xs + let s = M.fromList $ zip xs xs' + return $ apply s t + +-- | Compose two substitution sets +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 + +-- | A class representing free variables functions +class FreeVars t where + -- | Get all free variables from t + free :: t -> Set Ident + + -- | Apply a substitution to t + apply :: Subst -> t -> t + +instance FreeVars Type where + free :: Type -> Set Ident + free (TPol a) = S.singleton a + free (TMono _) = mempty + free (TArr a b) = free a `S.union` free b + -- \| Not guaranteed to be correct + free (TConstr (Constr _ a)) = + foldl' (\acc x -> free x `S.union` acc) S.empty a + + apply :: Subst -> Type -> Type + apply sub t = do + case t of + TMono a -> TMono a + TPol a -> case M.lookup a sub of + Nothing -> TPol a + Just t -> t + TArr a b -> TArr (apply sub a) (apply sub b) + TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a)) + +instance FreeVars Poly where + free :: Poly -> Set Ident + free (Forall xs t) = free t S.\\ S.fromList xs + apply :: Subst -> Poly -> Poly + apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) + +instance FreeVars (Map Ident Poly) where + free :: Map Ident Poly -> Set Ident + free m = foldl' S.union S.empty (map free $ M.elems m) + apply :: Subst -> Map Ident Poly -> Map Ident Poly + apply s = M.map (apply s) + +-- | Apply substitutions to the environment. +applySt :: Subst -> Infer a -> Infer a +applySt s = local (\st -> st{vars = apply s (vars st)}) + +-- | Represents the empty substition set +nullSubst :: Subst +nullSubst = M.empty + +-- | Generate a new fresh variable and increment the state counter +fresh :: Infer Type +fresh = do + n <- gets count + modify (\st -> st{count = n + 1}) + return . TPol . Ident $ show n + +-- | Run the monadic action with an additional binding +withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a +withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) + +-- | Insert a function signature into the environment +insertSig :: Ident -> Type -> Infer () +insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) + +-- | Insert a constructor with its data type +insertConstr :: Ident -> Type -> Infer () +insertConstr i t = + modify (\st -> st{constructors = M.insert i t (constructors st)}) + +-------- PATTERN MATCHING --------- + +-- "case expr of", the type of 'expr' is caseType +checkInj :: Type -> Inj -> Infer (T.Inj, Type) +checkInj caseType (Inj it expr) = do + (args, t') <- initType caseType it + (_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr) + return (T.Inj (it, t') e', t) + +initType :: Type -> Init -> Infer (Map Ident Poly, Type) +initType expected = \case + InitLit lit -> + let returnType = litType lit + in if expected == returnType + then return (mempty, expected) + else + throwError $ + unwords + [ "Inferred type" + , printTree returnType + , "does not match expected type:" + , printTree expected + ] + InitConstr c args -> do + st <- gets constructors + case M.lookup c st of + Nothing -> + throwError $ + unwords + [ "Constructor:" + , printTree c + , "does not exist" + ] + Just t -> do + let flat = flattenType t + let returnType = last flat + case ( length (init flat) == length args + , returnType `isMoreSpecificOrEq` expected + ) of + (True, True) -> + return + ( M.fromList $ zip args (map (Forall []) flat) + , expected + ) + (False, _) -> + throwError $ + "Can't partially match on the constructor: " + ++ printTree c + (_, False) -> + throwError $ + unwords + [ "Inferred type" + , printTree returnType + , "does not match expected type:" + , printTree expected + ] + InitCatch -> return (mempty, expected) + +flattenType :: Type -> [Type] +flattenType (TArr a b) = flattenType a ++ flattenType b +flattenType a = [a] + +litType :: Literal -> Type +litType (LInt _) = TMono "Int" diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 7dfe3be..31d89b4 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,139 +1,184 @@ {-# LANGUAGE LambdaCase #-} -module TypeChecker.TypeCheckerIr - ( module Grammar.Abs - , module TypeChecker.TypeCheckerIr - ) where +module TypeChecker.TypeCheckerIr where -import Grammar.Abs (Ident (..), Type (..)) -import qualified Grammar.Abs as GA +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Functor.Identity (Identity) +import Data.Map (Map) +import Grammar.Abs (Data (..), Ident (..), Init (..), + Literal (..), Type (..)) import Grammar.Print import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) +import qualified Prelude as C (Eq, Ord, Read, Show) -newtype Program = Program [Bind] - deriving (C.Eq, C.Ord, C.Show, C.Read) +-- | A data type representing type variables +data Poly = Forall [Ident] Type + deriving (Show) + +newtype Ctx = Ctx {vars :: Map Ident Poly} + +data Env = Env + { count :: Int + , sigs :: Map Ident Type + , constructors :: Map Ident Type + } + +type Error = String +type Subst = Map Ident Type + +type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) + +newtype Program = Program [Def] + deriving (C.Eq, C.Ord, C.Show, C.Read) data Exp - = EId Id - | EInt Integer + = EId Id + | ELit Type Literal | ELet Bind Exp | EApp Type Exp Exp | EAdd Type Exp Exp | ESub Type Exp Exp - | EAbs Type Id Exp - | ECase Type Exp [(Type, Case)] - deriving (C.Eq, C.Ord, C.Show, C.Read) + | EAbs Type Id Exp + | ECase Type Exp [Inj] + deriving (C.Eq, C.Ord, C.Read, C.Show) -data Case = Case GA.Case Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) +data Inj = Inj (Init, Type) Exp + deriving (C.Eq, C.Ord, C.Read, C.Show) + +data Def = DBind Bind | DData Data + deriving (C.Eq, C.Ord, C.Read, C.Show) type Id = (Ident, Type) -data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])] +data Bind = Bind Id Exp deriving (C.Eq, C.Ord, C.Show, C.Read) +instance Print [Def] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs] + +instance Print Def where + prt i (DBind bind) = prt i bind + prt i (DData d) = prt i d + instance Print Program where prt i (Program sc) = prPrec i 0 $ prt 0 sc instance Print Bind where - prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD - [ prtId 0 name - , doc $ showString ";" - , prt 0 n - , prtIdPs 0 parms - , doc $ showString "=" - , prt 0 rhs - ] - prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD - [ prt 0 n - , doc $ showString "{" - , doc . showString . show $ xs - , doc $ showString "}" - ] + prt i (Bind (t, name) rhs) = + prPrec i 0 $ + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString "\n" + , prt 0 name + , doc $ showString "=" + , prt 0 rhs + ] instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs] prtIdPs :: Int -> [Id] -> Doc prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtId :: Int -> Id -> Doc -prtId i (name, t) = prPrec i 0 $ concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - ] +prtId i (name, t) = + prPrec i 0 $ + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] prtIdP :: Int -> Id -> Doc -prtIdP i (name, t) = prPrec i 0 $ concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] - +prtIdP i (name, t) = + prPrec i 0 $ + concatD + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] instance Print Exp where - prt i = \case - EId n -> prPrec i 3 $ concatD [prtIdP 0 n] - EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] - ELet bs e -> prPrec i 3 $ concatD - [ doc $ showString "let" - , prt 0 bs - , doc $ showString "in" - , prt 0 e - ] - EApp t e1 e2 -> prPrec i 2 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 2 e1 - , prt 3 e2 - ] - EAdd t e1 e2 -> prPrec i 1 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - ESub t e1 e2 -> prPrec i 1 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 1 e1 - , doc $ showString "-" - , prt 2 e2 - ] - EAbs t n e -> prPrec i 0 $ concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "\\" - , prtIdP 0 n - , doc $ showString "." - , prt 0 e - ] - ECase t e cs -> prPrec i 0 $ concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "(" - , prt 0 e - , doc $ showString ")" - , prPrec i 0 $ concatD . printCases $ cs - ] + prt i = \case + EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"] + ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"] + ELet bs e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + , doc $ showString "\n" + ] + EApp _ e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> + prPrec i 1 $ + concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + , doc $ showString "\n" + ] + ESub t e1 e2 -> + prPrec i 1 $ + concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "-" + , prt 2 e2 + , doc $ showString "\n" + ] + EAbs t n e -> + prPrec i 0 $ + concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtId 0 n + , doc $ showString "." + , prt 0 e + , doc $ showString "\n" + ] + ECase t exp injs -> + prPrec + i + 0 + ( concatD + [ doc (showString "case") + , prt 0 exp + , doc (showString "of") + , doc (showString "{") + , prt 0 injs + , doc (showString "}") + , doc (showString ":") + , prt 0 t + , doc $ showString "\n" + ] + ) - where - printCases :: [(Type, Case)] -> [Doc] - printCases [] = [] - printCases ((t, Case c e):xs) = concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "(" - , doc . showString . show $ c - , doc $ showString ")" - , doc $ showString "=>" - , prt 0 e - , doc $ showString "\n" - ] : printCases xs +instance Print Inj where + prt i = \case + Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp]) + +instance Print [Inj] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]