Started importing Sebastian's new typechecker.

This commit is contained in:
Samuel Hammersberg 2023-03-08 11:01:07 +01:00
parent d5dd7896d8
commit 350cd3b0e9
9 changed files with 1611 additions and 1346 deletions

View file

@ -1,33 +1,51 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident; Program. Program ::= [Def] ;
EInt. Exp3 ::= Integer;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
ESub. Exp1 ::= Exp1 "-" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type;
CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ",";
DBind. Def ::= Bind ;
CInt. Case ::= Integer ; DData. Def ::= Data ;
CatchAll. Case ::= "_" ; separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp; Ident [Ident] "=" Exp ;
separator Bind ";"; Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
separator Ident "";
coercions Exp 3; Constructor. Constructor ::= Ident ":" Type ;
separator nonempty Constructor "" ;
TInt. Type1 ::= "Int" ; TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= Ident ; TPol. Type1 ::= "'" Ident ;
TFun. Type ::= Type1 "->" Type ; TConstr. Type1 ::= Constr ;
coercions Type 1 ; TArr. Type ::= Type1 "->" Type ;
comment "--"; Constr. Constr ::= Ident "(" [Type] ")" ;
comment "{-" "-}";
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ;
ELit. Exp4 ::= Literal ;
EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ESub. Exp1 ::= Exp1 "-" Exp2 ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
LInt. Literal ::= Integer ;
Inj. Inj ::= Init "=>" Exp ;
separator nonempty Inj ";" ;
InitLit. Init ::= Literal ;
InitConstr. Init ::= Ident [Ident] ;
InitCatch. Init ::= "_" ;
separator Type " " ;
coercions Type 2 ;
separator Ident " ";
coercions Exp 5 ;
comment "--" ;
comment "{-" "-}" ;

View file

@ -1,87 +1,26 @@
posMul : _Int -> _Int -> _Int;
-- tripplemagic : Int -> Int -> Int -> Int;
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
-- main : Int;
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
-- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
-- apply f x y = f x y;
-- krimp: Int -> Int -> Int;
-- krimp x y = x + y;
-- main : Int;
-- main = apply (krimp) 2 3;
-- answer: 5
-- fibbonaci : Int -> Int;
-- fibbonaci x = case x of {
-- 0 => 0,
-- 1 => 1,
-- -- abusing overflows to represent negatives like a boss
-- _ => (fibbonaci (x - 2))
-- + (fibbonaci (x - 1))
-- } : Int;
-- main : Int;
-- main = fibbonaci 10;
-- answer: 55
-- succ : Int -> Int;
-- succ x = x - 1;
--
-- isZero : Int -> Int;
-- isZero x = case x of {
-- 0 => 1,
-- _ => 0
-- } : Int;
--
-- minimization : (Int -> Int) -> Int -> Int;
-- minimization p x = case p x of {
-- 1 => 0,
-- _ => minimization p (succ x)
-- } : Int;
--
-- main : Int;
-- main = minimization isZero 10;
-- answer: 0
posMul : Int -> Int -> Int;
posMul a b = case b of { posMul a b = case b of {
0 => 0, 0 => 0;
_ => a + posMul a (b - 1) _ => a + posMul a (b - 1)
} : Int; };
facc : Int -> Int; facc : _Int -> _Int;
facc a = case a of { facc a = case a of {
1 => 1, 1 => 1;
_ => posMul a (facc (a - 1)) _ => posMul a (facc (a - 1))
} : Int; };
-- main : Int;
-- main = facc 5
-- answer: 120
-- pow : Int -> Int -> Int; minimization : (_Int -> _Int) -> _Int -> _Int;
-- pow a b = case b of {
-- 0 => 1,
-- _ => posMul a (pow a (b-1))
-- } : Int;
minimization : (Int -> Int) -> Int -> Int;
minimization p x = case p x of { minimization p x = case p x of {
1 => x, 1 => x;
_ => minimization p (x + 1) _ => minimization p (x + 1)
} : Int; };
checkFac : Int -> Int; checkFac : _Int -> _Int;
checkFac x = case facc x of { checkFac x = case facc x of {
0 => 1, 0 => 1;
_ => 0 _ => 0
} : Int; };
main : Int; main : _Int;
main = minimization checkFac 1 main = minimization checkFac 1

View file

@ -1,441 +1,443 @@
{-# LANGUAGE LambdaCase #-} module Codegen.Codegen where
{-# LANGUAGE OverloadedStrings #-} -- {-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where --
-- module Codegen.Codegen (generateCode) where
import Auxiliary (snoc) --
import Codegen.LlvmIr (CallingConvention (..), -- import Auxiliary (snoc)
LLVMComp (..), LLVMIr (..), -- import Codegen.LlvmIr (CallingConvention (..),
LLVMType (..), LLVMValue (..), -- LLVMComp (..), LLVMIr (..),
Visibility (..), llvmIrToString) -- LLVMType (..), LLVMValue (..),
import Control.Monad.State (StateT, execStateT, foldM_, gets, -- Visibility (..), llvmIrToString)
modify) -- import Control.Monad.State (StateT, execStateT, foldM_, gets,
import qualified Data.Bifunctor as BI -- modify)
import Data.List.Extra (trim) -- import qualified Data.Bifunctor as BI
import Data.Map (Map) -- import Data.List.Extra (trim)
import qualified Data.Map as Map -- import Data.Map (Map)
import Data.Tuple.Extra (dupe, first, second) -- import qualified Data.Map as Map
import qualified Grammar.Abs as GA -- import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err) -- import qualified Grammar.Abs as GA
import System.Process.Extra (readCreateProcess, shell) -- import Grammar.ErrM (Err)
import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id, -- import System.Process.Extra (readCreateProcess, shell)
Ident (..), Program (..), Type (..)) -- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
-- | The record used as the code generator state -- Ident (..), Program (..), Type (..))
data CodeGenerator = CodeGenerator -- -- | The record used as the code generator state
{ instructions :: [LLVMIr] -- data CodeGenerator = CodeGenerator
, functions :: Map Id FunctionInfo -- { instructions :: [LLVMIr]
, constructors :: Map Id ConstructorInfo -- , functions :: Map Id FunctionInfo
, variableCount :: Integer -- , constructors :: Map Id ConstructorInfo
, labelCount :: Integer -- , variableCount :: Integer
} -- , labelCount :: Integer
-- }
-- | A state type synonym --
type CompilerState a = StateT CodeGenerator Err a -- -- | A state type synonym
-- type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo --
{ numArgs :: Int -- data FunctionInfo = FunctionInfo
, arguments :: [Id] -- { numArgs :: Int
} -- , arguments :: [Id]
data ConstructorInfo = ConstructorInfo -- }
{ numArgsCI :: Int -- data ConstructorInfo = ConstructorInfo
, argumentsCI :: [Id] -- { numArgsCI :: Int
, numCI :: Integer -- , argumentsCI :: [Id]
} -- , numCI :: Integer
-- }
--
-- | Adds a instruction to the CodeGenerator state --
emit :: LLVMIr -> CompilerState () -- -- | Adds a instruction to the CodeGenerator state
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } -- emit :: LLVMIr -> CompilerState ()
-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state --
increaseVarCount :: CompilerState () -- -- | Increases the variable counter in the CodeGenerator state
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } -- increaseVarCount :: CompilerState ()
-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state --
getVarCount :: CompilerState Integer -- -- | Returns the variable count from the CodeGenerator state
getVarCount = gets variableCount -- getVarCount :: CompilerState Integer
-- getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state --
getNewVar :: CompilerState Integer -- -- | Increases the variable count and returns it from the CodeGenerator state
getNewVar = increaseVarCount >> getVarCount -- getNewVar :: CompilerState Integer
-- getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state --
getNewLabel :: CompilerState Integer -- -- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel = do -- getNewLabel :: CompilerState Integer
modify (\t -> t{labelCount = labelCount t + 1}) -- getNewLabel = do
gets labelCount -- modify (\t -> t{labelCount = labelCount t + 1})
-- gets labelCount
-- | Produces a map of functions infos from a list of binds, --
-- which contains useful data for code generation. -- -- | Produces a map of functions infos from a list of binds,
getFunctions :: [Bind] -> Map Id FunctionInfo -- -- which contains useful data for code generation.
getFunctions bs = Map.fromList $ go bs -- getFunctions :: [Bind] -> Map Id FunctionInfo
where -- getFunctions bs = Map.fromList $ go bs
go [] = [] -- where
go (Bind id args _ : xs) = -- go [] = []
(id, FunctionInfo { numArgs=length args, arguments=args }) -- go (Bind id args _ : xs) =
: go xs -- (id, FunctionInfo { numArgs=length args, arguments=args })
go (DataStructure n cons : xs) = do -- : go xs
map (\(id, xs) -> ((id, TPol n), FunctionInfo { -- go (DataStructure n cons : xs) = do
numArgs=length xs, arguments=createArgs xs -- map (\(id, xs) -> ((id, TPol n), FunctionInfo {
})) cons -- numArgs=length xs, arguments=createArgs xs
<> go xs -- })) cons
-- <> go xs
createArgs :: [Type] -> [Id] --
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs -- createArgs :: [Type] -> [Id]
-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
-- | Produces a map of functions infos from a list of binds, --
-- which contains useful data for code generation. -- -- | Produces a map of functions infos from a list of binds,
getConstructors :: [Bind] -> Map Id ConstructorInfo -- -- which contains useful data for code generation.
getConstructors bs = Map.fromList $ go bs -- getConstructors :: [Bind] -> Map Id ConstructorInfo
where -- getConstructors bs = Map.fromList $ go bs
go [] = [] -- where
go (DataStructure (Ident n) cons : xs) = do -- go [] = []
fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo { -- go (DataStructure (Ident n) cons : xs) = do
numArgsCI=length xs, -- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
argumentsCI=createArgs xs, -- numArgsCI=length xs,
numCI=i -- argumentsCI=createArgs xs,
}) : acc, i+1)) ([],0) cons) -- numCI=i
<> go xs -- }) : acc, i+1)) ([],0) cons)
go (_: xs) = go xs -- <> go xs
-- go (_: xs) = go xs
initCodeGenerator :: [Bind] -> CodeGenerator --
initCodeGenerator scs = CodeGenerator { instructions = defaultStart -- initCodeGenerator :: [Bind] -> CodeGenerator
, functions = getFunctions scs -- initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, constructors = getConstructors scs -- , functions = getFunctions scs
, variableCount = 0 -- , constructors = getConstructors scs
, labelCount = 0 -- , variableCount = 0
} -- , labelCount = 0
-- }
run :: Err String -> IO () --
run s = do -- run :: Err String -> IO ()
let s' = case s of -- run s = do
Right s -> s -- let s' = case s of
Left _ -> error "yo" -- Right s -> s
writeFile "output/llvm.ll" s' -- Left _ -> error "yo"
putStrLn . trim =<< readCreateProcess (shell "lli") s' -- writeFile "output/llvm.ll" s'
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program --
test v = Program [ -- test :: Integer -> Program
DataStructure (Ident "Craig") [ -- test v = Program [
(Ident "Bob", [TInt])--, -- DataStructure (Ident "Craig") [
--(Ident "Alice", [TInt, TInt]) -- (Ident "Bob", [TInt])--,
], -- --(Ident "Alice", [TInt, TInt])
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)), -- ],
Bind (Ident "main", TInt) [] ( -- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92) -- Bind (Ident "main", TInt) [] (
) -- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
] -- )
-- ]
{- | Compiles an AST and produces a LLVM Ir string. --
An easy way to actually "compile" this output is to -- {- | Compiles an AST and produces a LLVM Ir string.
Simply pipe it to LLI -- An easy way to actually "compile" this output is to
-} -- Simply pipe it to LLI
generateCode :: Program -> Err String -- -}
generateCode (Program scs) = do -- generateCode :: Program -> Err String
let codegen = initCodeGenerator scs -- generateCode (Program scs) = do
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen -- let codegen = initCodeGenerator scs
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState () --
compileScs [] = do -- compileScs :: [Bind] -> CompilerState ()
-- as a last step create all the constructors -- compileScs [] = do
c <- gets (Map.toList . constructors) -- -- as a last step create all the constructors
mapM_ (\((id, t), ci) -> do -- c <- gets (Map.toList . constructors)
let t' = type2LlvmType t -- mapM_ (\((id, t), ci) -> do
let x = BI.second type2LlvmType <$> argumentsCI ci -- let t' = type2LlvmType t
emit $ Define FastCC t' id x -- let x = BI.second type2LlvmType <$> argumentsCI ci
top <- Ident . show <$> getNewVar -- emit $ Define FastCC t' id x
ptr <- Ident . show <$> getNewVar -- top <- Ident . show <$> getNewVar
-- allocated the primary type -- ptr <- Ident . show <$> getNewVar
emit $ SetVariable top (Alloca t') -- -- allocated the primary type
-- emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor --
emit $ SetVariable ptr $ -- -- set the first byte to the index of the constructor
GetElementPtrInbounds t' (Ref t') -- emit $ SetVariable ptr $
(VIdent top I8) I32 (VInteger 0) I32 (VInteger 0) -- GetElementPtrInbounds t' (Ref t')
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr -- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
-- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
-- get a pointer of the correct type --
ptr' <- Ident . show <$> getNewVar -- -- get a pointer of the correct type
emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id)) -- ptr' <- Ident . show <$> getNewVar
-- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
--emit $ UnsafeRaw "\n" --
-- --emit $ UnsafeRaw "\n"
foldM_ (\i (Ident arg_n, arg_t)-> do --
let arg_t' = type2LlvmType arg_t -- foldM_ (\i (Ident arg_n, arg_t)-> do
emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i ) -- let arg_t' = type2LlvmType arg_t
elemPtr <- Ident . show <$> getNewVar -- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
emit $ SetVariable elemPtr ( -- elemPtr <- Ident . show <$> getNewVar
GetElementPtrInbounds (CustomType id) (Ref (CustomType id)) -- emit $ SetVariable elemPtr (
(VIdent ptr' Ptr) I32 -- GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
(VInteger 0) I32 (VInteger i)) -- (VIdent ptr' Ptr) I32
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr -- (VInteger 0) I32 (VInteger i))
-- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1 -- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
-- store i32 42, i32* %2 -- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
pure $ i + 1-- + typeByteSize arg_t' -- -- store i32 42, i32* %2
) 1 (argumentsCI ci) -- pure $ i + 1-- + typeByteSize arg_t'
-- ) 1 (argumentsCI ci)
--emit $ UnsafeRaw "\n" --
-- --emit $ UnsafeRaw "\n"
-- load and return the constructed value --
load <- Ident . show <$> getNewVar -- -- load and return the constructed value
emit $ SetVariable load (Load t' Ptr top) -- load <- Ident . show <$> getNewVar
emit $ Ret t' (VIdent load t') -- emit $ SetVariable load (Load t' Ptr top)
emit DefineEnd -- emit $ Ret t' (VIdent load t')
-- emit DefineEnd
modify $ \s -> s { variableCount = 0 } --
) c -- modify $ \s -> s { variableCount = 0 }
compileScs (Bind (name, _t) args exp : xs) = do -- ) c
emit $ UnsafeRaw "\n" -- compileScs (Bind (name, _t) args exp : xs) = do
emit . Comment $ show name <> ": " <> show exp -- emit $ UnsafeRaw "\n"
let args' = map (second type2LlvmType) args -- emit . Comment $ show name <> ": " <> show exp
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' -- let args' = map (second type2LlvmType) args
functionBody <- exprToValue exp -- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
if name == "main" -- functionBody <- exprToValue exp
then mapM_ emit $ mainContent functionBody -- if name == "main"
else emit $ Ret I64 functionBody -- then mapM_ emit $ mainContent functionBody
emit DefineEnd -- else emit $ Ret I64 functionBody
modify $ \s -> s { variableCount = 0 } -- emit DefineEnd
compileScs xs -- modify $ \s -> s { variableCount = 0 }
compileScs (DataStructure id@(Ident outer_id) ts : xs) = do -- compileScs xs
let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) -- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
emit $ Type id [I8, Array biggest_variant I8] -- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
mapM_ (\(Ident inner_id, fi) -> do -- emit $ Type id [I8, Array biggest_variant I8]
emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) -- mapM_ (\(Ident inner_id, fi) -> do
) ts -- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
compileScs xs -- ) ts
-- compileScs xs
-- where --
-- _t_return = snd $ partitionType (length args) t -- -- where
-- -- _t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr] --
mainContent var = -- mainContent :: LLVMValue -> [LLVMIr]
[ UnsafeRaw $ -- mainContent var =
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" -- [ UnsafeRaw $
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") -- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , Label (Ident "b_1") -- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , UnsafeRaw -- -- , Label (Ident "b_1")
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" -- -- , UnsafeRaw
-- , Br (Ident "end") -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Label (Ident "b_2") -- -- , Br (Ident "end")
-- , UnsafeRaw -- -- , Label (Ident "b_2")
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" -- -- , UnsafeRaw
-- , Br (Ident "end") -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Label (Ident "end") -- -- , Br (Ident "end")
Ret I64 (VInteger 0) -- -- , Label (Ident "end")
] -- Ret I64 (VInteger 0)
-- ]
defaultStart :: [LLVMIr] --
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" -- defaultStart :: [LLVMIr]
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" -- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" -- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" -- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
] -- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
-- ]
compileExp :: Exp -> CompilerState () --
compileExp (EInt int) = emitInt int -- compileExp :: Exp -> CompilerState ()
compileExp (EAdd t e1 e2) = emitAdd t e1 e2 -- compileExp (EInt int) = emitInt int
compileExp (ESub t e1 e2) = emitSub t e1 e2 -- compileExp (EAdd t e1 e2) = emitAdd t e1 e2
compileExp (EId (name, _)) = emitIdent name -- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (EApp t e1 e2) = emitApp t e1 e2 -- compileExp (EId (name, _)) = emitIdent name
compileExp (EAbs t ti e) = emitAbs t ti e -- compileExp (EApp t e1 e2) = emitApp t e1 e2
compileExp (ELet binds e) = emitLet binds e -- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ECase t e cs) = emitECased t e cs -- compileExp (ELet binds e) = emitLet binds e
-- go (EMul e1 e2) = emitMul e1 e2 -- compileExp (ECase t e cs) = emitECased t e cs
-- go (EDiv e1 e2) = emitDiv e1 e2 -- -- go (EMul e1 e2) = emitMul e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- -- go (EDiv e1 e2) = emitDiv e1 e2
-- -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --
emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState () -- --- aux functions ---
emitECased t e cases = do -- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
let cs = snd <$> cases -- emitECased t e cases = do
let ty = type2LlvmType t -- let cs = snd <$> cases
vs <- exprToValue e -- let ty = type2LlvmType t
lbl <- getNewLabel -- vs <- exprToValue e
let label = Ident $ "escape_" <> show lbl -- lbl <- getNewLabel
stackPtr <- getNewVar -- let label = Ident $ "escape_" <> show lbl
emit $ SetVariable (Ident $ show stackPtr) (Alloca ty) -- stackPtr <- getNewVar
mapM_ (emitCases ty label stackPtr vs) cs -- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
emit $ Label label -- mapM_ (emitCases ty label stackPtr vs) cs
res <- getNewVar -- emit $ Label label
emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr)) -- res <- getNewVar
where -- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState () -- where
emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do -- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
ns <- getNewVar -- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel -- ns <- getNewVar
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel -- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i)) -- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos -- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
emit $ Label lbl_succPos -- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
val <- exprToValue exp -- emit $ Label lbl_succPos
emit $ Store ty val Ptr (Ident . show $ stackPtr) -- val <- exprToValue exp
emit $ Br label -- emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Label lbl_failPos -- emit $ Br label
emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do -- emit $ Label lbl_failPos
val <- exprToValue exp -- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
emit $ Store ty val Ptr (Ident . show $ stackPtr) -- val <- exprToValue exp
emit $ Br label -- emit $ Store ty val Ptr (Ident . show $ stackPtr)
-- emit $ Br label
--
emitAbs :: Type -> Id -> Exp -> CompilerState () --
emitAbs _t tid e = do -- emitAbs :: Type -> Id -> Exp -> CompilerState ()
emit . Comment $ -- emitAbs _t tid e = do
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e -- emit . Comment $
emitLet :: Bind -> Exp -> CompilerState () -- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet xs e = do -- emitLet :: Bind -> Exp -> CompilerState ()
emit $ -- emitLet xs e = do
Comment $ -- emit $
concat -- Comment $
[ "ELet (" -- concat
, show xs -- [ "ELet ("
, " = " -- , show xs
, show e -- , " = "
, ") is not implemented!" -- , show e
] -- , ") is not implemented!"
-- ]
emitApp :: Type -> Exp -> Exp -> CompilerState () --
emitApp t e1 e2 = appEmitter t e1 e2 [] -- emitApp :: Type -> Exp -> Exp -> CompilerState ()
where -- emitApp t e1 e2 = appEmitter t e1 e2 []
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () -- where
appEmitter t e1 e2 stack = do -- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
let newStack = e2 : stack -- appEmitter t e1 e2 stack = do
case e1 of -- let newStack = e2 : stack
EApp _ e1' e2' -> appEmitter t e1' e2' newStack -- case e1 of
EId id@(name, _) -> do -- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
args <- traverse exprToValue newStack -- EId id@(name, _) -> do
vs <- getNewVar -- args <- traverse exprToValue newStack
funcs <- gets functions -- vs <- getNewVar
let visibility = maybe Local (const Global) $ Map.lookup id funcs -- funcs <- gets functions
args' = map (first valueGetType . dupe) args -- let visibility = maybe Local (const Global) $ Map.lookup id funcs
call = Call FastCC (type2LlvmType t) visibility name args' -- args' = map (first valueGetType . dupe) args
emit $ SetVariable (Ident $ show vs) call -- call = Call FastCC (type2LlvmType t) visibility name args'
x -> do -- emit $ SetVariable (Ident $ show vs) call
emit . Comment $ "The unspeakable happened: " -- x -> do
emit . Comment $ show x -- emit . Comment $ "The unspeakable happened: "
-- emit . Comment $ show x
emitIdent :: Ident -> CompilerState () --
emitIdent id = do -- emitIdent :: Ident -> CompilerState ()
-- !!this should never happen!! -- emitIdent id = do
emit $ Comment "This should not have happened!" -- -- !!this should never happen!!
emit $ Variable id -- emit $ Comment "This should not have happened!"
emit $ UnsafeRaw "\n" -- emit $ Variable id
-- emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState () --
emitInt i = do -- emitInt :: Integer -> CompilerState ()
-- !!this should never happen!! -- emitInt i = do
varCount <- getNewVar -- -- !!this should never happen!!
emit $ Comment "This should not have happened!" -- varCount <- getNewVar
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) -- emit $ Comment "This should not have happened!"
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState () --
emitAdd t e1 e2 = do -- emitAdd :: Type -> Exp -> Exp -> CompilerState ()
v1 <- exprToValue e1 -- emitAdd t e1 e2 = do
v2 <- exprToValue e2 -- v1 <- exprToValue e1
v <- getNewVar -- v2 <- exprToValue e2
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) -- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: Type -> Exp -> Exp -> CompilerState () --
emitSub t e1 e2 = do -- emitSub :: Type -> Exp -> Exp -> CompilerState ()
v1 <- exprToValue e1 -- emitSub t e1 e2 = do
v2 <- exprToValue e2 -- v1 <- exprToValue e1
v <- getNewVar -- v2 <- exprToValue e2
emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2) -- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState () --
-- emitMul e1 e2 = do -- -- emitMul :: Exp -> Exp -> CompilerState ()
-- (v1,v2) <- binExprToValues e1 e2 -- -- emitMul e1 e2 = do
-- increaseVarCount -- -- (v1,v2) <- binExprToValues e1 e2
-- v <- gets variableCount -- -- increaseVarCount
-- emit $ SetVariable $ Ident $ show v -- -- v <- gets variableCount
-- emit $ Mul I64 v1 v2 -- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState () --
-- emitMod e1 e2 = do -- -- emitMod :: Exp -> Exp -> CompilerState ()
-- -- `let m a b = rem (abs $ b + a) b` -- -- emitMod e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- -- -- `let m a b = rem (abs $ b + a) b`
-- increaseVarCount -- -- (v1,v2) <- binExprToValues e1 e2
-- vadd <- gets variableCount -- -- increaseVarCount
-- emit $ SetVariable $ Ident $ show vadd -- -- vadd <- gets variableCount
-- emit $ Add I64 v1 v2 -- -- emit $ SetVariable $ Ident $ show vadd
-- -- emit $ Add I64 v1 v2
-- --
-- -- increaseVarCount
-- -- vabs <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show vabs
-- -- emit $ Call I64 (Ident "llvm.abs.i64")
-- -- [ (I64, VIdent (Ident $ show vadd))
-- -- , (I1, VInteger 1)
-- -- ]
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
--
-- -- emitDiv :: Exp -> Exp -> CompilerState ()
-- -- emitDiv e1 e2 = do
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Div I64 v1 v2
--
-- exprToValue :: Exp -> CompilerState LLVMValue
-- exprToValue = \case
-- EInt i -> pure $ VInteger i
--
-- EId id@(name, t) -> do
-- funcs <- gets functions
-- case Map.lookup id funcs of
-- Just fi -> do
-- if numArgs fi == 0
-- then do
-- vc <- getNewVar
-- emit $ SetVariable (Ident $ show vc)
-- (Call FastCC (type2LlvmType t) Global name [])
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
-- else pure $ VFunction name Global (type2LlvmType t)
-- Nothing -> pure $ VIdent name (type2LlvmType t)
--
-- e -> do
-- compileExp e
-- v <- getVarCount
-- pure $ VIdent (Ident $ show v) (getType e)
--
-- type2LlvmType :: Type -> LLVMType
-- type2LlvmType = \case
-- TInt -> I64
-- TFun t xs -> do
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
-- Function t' xs'
-- TPol t -> CustomType t
-- where
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
--
-- getType :: Exp -> LLVMType
-- getType (EInt _) = I64
-- getType (EAdd t _ _) = type2LlvmType t
-- getType (ESub t _ _) = type2LlvmType t
-- getType (EId (_, t)) = type2LlvmType t
-- getType (EApp t _ _) = type2LlvmType t
-- getType (EAbs t _ _) = type2LlvmType t
-- getType (ELet _ e) = getType e
-- getType (ECase t _ _) = type2LlvmType t
--
-- valueGetType :: LLVMValue -> LLVMType
-- valueGetType (VInteger _) = I64
-- valueGetType (VIdent _ t) = t
-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
-- valueGetType (VFunction _ _ t) = t
--
-- typeByteSize :: LLVMType -> Integer
-- typeByteSize I1 = 1
-- typeByteSize I8 = 1
-- typeByteSize I32 = 4
-- typeByteSize I64 = 8
-- typeByteSize Ptr = 8
-- typeByteSize (Ref _) = 8
-- typeByteSize (Function _ _) = 8
-- typeByteSize (Array n t) = n * typeByteSize t
-- typeByteSize (CustomType _) = 8
-- --
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc)
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
TInt -> I64
TFun t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
TPol t -> CustomType t
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (ESub t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
getType (ECase t _ _) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8

View file

@ -1,239 +1,241 @@
{-# LANGUAGE LambdaCase #-} module Codegen.LlvmIr where
-- {-# LANGUAGE LambdaCase #-}
module Codegen.LlvmIr ( --
LLVMType (..), -- module Codegen.LlvmIr (
LLVMIr (..), -- LLVMType (..),
llvmIrToString, -- LLVMIr (..),
LLVMValue (..), -- llvmIrToString,
LLVMComp (..), -- LLVMValue (..),
Visibility (..), -- LLVMComp (..),
CallingConvention (..) -- Visibility (..),
) where -- CallingConvention (..)
-- ) where
import Data.List (intercalate) --
import TypeChecker.TypeCheckerIr -- import Data.List (intercalate)
-- import TypeChecker.TypeCheckerIr
data CallingConvention = TailCC | FastCC | CCC | ColdCC --
instance Show CallingConvention where -- data CallingConvention = TailCC | FastCC | CCC | ColdCC
show :: CallingConvention -> String -- instance Show CallingConvention where
show TailCC = "tailcc" -- show :: CallingConvention -> String
show FastCC = "fastcc" -- show TailCC = "tailcc"
show CCC = "ccc" -- show FastCC = "fastcc"
show ColdCC = "coldcc" -- show CCC = "ccc"
-- show ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types --
data LLVMType -- -- | A datatype which represents some basic LLVM types
= I1 -- data LLVMType
| I8 -- = I1
| I32 -- | I8
| I64 -- | I32
| Ptr -- | I64
| Ref LLVMType -- | Ptr
| Function LLVMType [LLVMType] -- | Ref LLVMType
| Array Integer LLVMType -- | Function LLVMType [LLVMType]
| CustomType Ident -- | Array Integer LLVMType
-- | CustomType Ident
instance Show LLVMType where --
show :: LLVMType -> String -- instance Show LLVMType where
show = \case -- show :: LLVMType -> String
I1 -> "i1" -- show = \case
I8 -> "i8" -- I1 -> "i1"
I32 -> "i32" -- I8 -> "i8"
I64 -> "i64" -- I32 -> "i32"
Ptr -> "ptr" -- I64 -> "i64"
Ref ty -> show ty <> "*" -- Ptr -> "ptr"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" -- Ref ty -> show ty <> "*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"] -- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
CustomType (Ident ty) -> "%" <> ty -- Array n ty -> concat ["[", show n, " x ", show ty, "]"]
-- CustomType (Ident ty) -> "%" <> ty
data LLVMComp --
= LLEq -- data LLVMComp
| LLNe -- = LLEq
| LLUgt -- | LLNe
| LLUge -- | LLUgt
| LLUlt -- | LLUge
| LLUle -- | LLUlt
| LLSgt -- | LLUle
| LLSge -- | LLSgt
| LLSlt -- | LLSge
| LLSle -- | LLSlt
instance Show LLVMComp where -- | LLSle
show :: LLVMComp -> String -- instance Show LLVMComp where
show = \case -- show :: LLVMComp -> String
LLEq -> "eq" -- show = \case
LLNe -> "ne" -- LLEq -> "eq"
LLUgt -> "ugt" -- LLNe -> "ne"
LLUge -> "uge" -- LLUgt -> "ugt"
LLUlt -> "ult" -- LLUge -> "uge"
LLUle -> "ule" -- LLUlt -> "ult"
LLSgt -> "sgt" -- LLUle -> "ule"
LLSge -> "sge" -- LLSgt -> "sgt"
LLSlt -> "slt" -- LLSge -> "sge"
LLSle -> "sle" -- LLSlt -> "slt"
-- LLSle -> "sle"
data Visibility = Local | Global --
instance Show Visibility where -- data Visibility = Local | Global
show :: Visibility -> String -- instance Show Visibility where
show Local = "%" -- show :: Visibility -> String
show Global = "@" -- show Local = "%"
-- show Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable, --
-- or a string contstant -- -- | Represents a LLVM "value", as in an integer, a register variable,
data LLVMValue -- -- or a string contstant
= VInteger Integer -- data LLVMValue
| VIdent Ident LLVMType -- = VInteger Integer
| VConstant String -- | VIdent Ident LLVMType
| VFunction Ident Visibility LLVMType -- | VConstant String
-- | VFunction Ident Visibility LLVMType
instance Show LLVMValue where --
show :: LLVMValue -> String -- instance Show LLVMValue where
show v = case v of -- show :: LLVMValue -> String
VInteger i -> show i -- show v = case v of
VIdent (Ident n) _ -> "%" <> n -- VInteger i -> show i
VFunction (Ident n) vis _ -> show vis <> n -- VIdent (Ident n) _ -> "%" <> n
VConstant s -> "c" <> show s -- VFunction (Ident n) vis _ -> show vis <> n
-- VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] --
type Args = [(LLVMType, LLVMValue)] -- type Params = [(Ident, LLVMType)]
-- type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM --
data LLVMIr -- -- | A datatype which represents different instructions in LLVM
= Type Ident [LLVMType] -- data LLVMIr
| Define CallingConvention LLVMType Ident Params -- = Type Ident [LLVMType]
| DefineEnd -- | Define CallingConvention LLVMType Ident Params
| Declare LLVMType Ident Params -- | DefineEnd
| SetVariable Ident LLVMIr -- | Declare LLVMType Ident Params
| Variable Ident -- | SetVariable Ident LLVMIr
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue -- | Variable Ident
| Add LLVMType LLVMValue LLVMValue -- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Sub LLVMType LLVMValue LLVMValue -- | Add LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue -- | Sub LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue -- | Div LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue -- | Mul LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue -- | Srem LLVMType LLVMValue LLVMValue
| Br Ident -- | Icmp LLVMComp LLVMType LLVMValue LLVMValue
| BrCond LLVMValue Ident Ident -- | Br Ident
| Label Ident -- | BrCond LLVMValue Ident Ident
| Call CallingConvention LLVMType Visibility Ident Args -- | Label Ident
| Alloca LLVMType -- | Call CallingConvention LLVMType Visibility Ident Args
| Store LLVMType LLVMValue LLVMType Ident -- | Alloca LLVMType
| Load LLVMType LLVMType Ident -- | Store LLVMType LLVMValue LLVMType Ident
| Bitcast LLVMType Ident LLVMType -- | Load LLVMType LLVMType Ident
| Ret LLVMType LLVMValue -- | Bitcast LLVMType Ident LLVMType
| Comment String -- | Ret LLVMType LLVMValue
| UnsafeRaw String -- This should generally be avoided, and proper -- | Comment String
-- instructions should be used in its place -- | UnsafeRaw String -- This should generally be avoided, and proper
deriving (Show) -- -- instructions should be used in its place
-- deriving (Show)
-- | Converts a list of LLVMIr instructions to a string --
llvmIrToString :: [LLVMIr] -> String -- -- | Converts a list of LLVMIr instructions to a string
llvmIrToString = go 0 -- llvmIrToString :: [LLVMIr] -> String
where -- llvmIrToString = go 0
go :: Int -> [LLVMIr] -> String -- where
go _ [] = mempty -- go :: Int -> [LLVMIr] -> String
go i (x : xs) = do -- go _ [] = mempty
let (i', n) = case x of -- go i (x : xs) = do
Define{} -> (i + 1, 0) -- let (i', n) = case x of
DefineEnd -> (i - 1, 0) -- Define{} -> (i + 1, 0)
_ -> (i, i) -- DefineEnd -> (i - 1, 0)
insToString n x <> go i' xs -- _ -> (i, i)
-- insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc. --
The integer represents the indentation -- {- | Converts a LLVM inststruction to a String, allowing for printing etc.
-} -- The integer represents the indentation
{- FOURMOLU_DISABLE -} -- -}
insToString :: Int -> LLVMIr -> String -- {- FOURMOLU_DISABLE -}
insToString i l = -- insToString :: Int -> LLVMIr -> String
replicate i '\t' <> case l of -- insToString i l =
(GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do -- replicate i '\t' <> case l of
-- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 -- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
concat -- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
[ "getelementptr inbounds ", show t1, ", " , show t2 -- concat
, " ", show p, ", ", show t3, " ", show v1, -- [ "getelementptr inbounds ", show t1, ", " , show t2
", ", show t4, " ", show v2, "\n" ] -- , " ", show p, ", ", show t3, " ", show v1,
(Type (Ident n) types) -> -- ", ", show t4, " ", show v2, "\n" ]
concat -- (Type (Ident n) types) ->
[ "%", n, " = type { " -- concat
, intercalate ", " (map show types) -- [ "%", n, " = type { "
, " }\n" -- , intercalate ", " (map show types)
] -- , " }\n"
(Define c t (Ident i) params) -> -- ]
concat -- (Define c t (Ident i) params) ->
[ "define ", show c, " ", show t, " @", i -- concat
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) -- [ "define ", show c, " ", show t, " @", i
, ") {\n" -- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
] -- , ") {\n"
DefineEnd -> "}\n" -- ]
(Declare _t (Ident _i) _params) -> undefined -- DefineEnd -> "}\n"
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] -- (Declare _t (Ident _i) _params) -> undefined
(Add t v1 v2) -> -- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
concat -- (Add t v1 v2) ->
[ "add ", show t, " ", show v1 -- concat
, ", ", show v2, "\n" -- [ "add ", show t, " ", show v1
] -- , ", ", show v2, "\n"
(Sub t v1 v2) -> -- ]
concat -- (Sub t v1 v2) ->
[ "sub ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "sub ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Div t v1 v2) -> -- ]
concat -- (Div t v1 v2) ->
[ "sdiv ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "sdiv ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Mul t v1 v2) -> -- ]
concat -- (Mul t v1 v2) ->
[ "mul ", show t, " ", show v1 -- concat
, ", ", show v2, "\n" -- [ "mul ", show t, " ", show v1
] -- , ", ", show v2, "\n"
(Srem t v1 v2) -> -- ]
concat -- (Srem t v1 v2) ->
[ "srem ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "srem ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Call c t vis (Ident i) arg) -> -- ]
concat -- (Call c t vis (Ident i) arg) ->
[ "call ", show c, " ", show t, " ", show vis, i, "(" -- concat
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg -- [ "call ", show c, " ", show t, " ", show vis, i, "("
, ")\n" -- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
] -- , ")\n"
(Alloca t) -> unwords ["alloca", show t, "\n"] -- ]
(Store t1 val t2 (Ident id2)) -> -- (Alloca t) -> unwords ["alloca", show t, "\n"]
concat -- (Store t1 val t2 (Ident id2)) ->
[ "store ", show t1, " ", show val -- concat
, ", ", show t2 , " %", id2, "\n" -- [ "store ", show t1, " ", show val
] -- , ", ", show t2 , " %", id2, "\n"
(Load t1 t2 (Ident addr)) -> -- ]
concat -- (Load t1 t2 (Ident addr)) ->
[ "load ", show t1, ", " -- concat
, show t2, " %", addr, "\n" -- [ "load ", show t1, ", "
] -- , show t2, " %", addr, "\n"
(Bitcast t1 (Ident i) t2) -> -- ]
concat -- (Bitcast t1 (Ident i) t2) ->
[ "bitcast ", show t1, " %" -- concat
, i, " to ", show t2, "\n" -- [ "bitcast ", show t1, " %"
] -- , i, " to ", show t2, "\n"
(Icmp comp t v1 v2) -> -- ]
concat -- (Icmp comp t v1 v2) ->
[ "icmp ", show comp, " ", show t -- concat
, " ", show v1, ", ", show v2, "\n" -- [ "icmp ", show comp, " ", show t
] -- , " ", show v1, ", ", show v2, "\n"
(Ret t v) -> -- ]
concat -- (Ret t v) ->
[ "ret ", show t, " " -- concat
, show v, "\n" -- [ "ret ", show t, " "
] -- , show v, "\n"
(UnsafeRaw s) -> s -- ]
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" -- (UnsafeRaw s) -> s
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" -- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(BrCond val (Ident s1) (Ident s2)) -> -- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
concat -- (BrCond val (Ident s1) (Ident s2)) ->
[ "br i1 ", show val, ", ", "label %" -- concat
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" -- [ "br i1 ", show val, ", ", "label %"
] -- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
(Comment s) -> "; " <> s <> "\n" -- ]
(Variable (Ident id)) -> "%" <> id -- (Comment s) -> "; " <> s <> "\n"
{- FOURMOLU_ENABLE -} -- (Variable (Ident id)) -> "%" <> id
-- {- FOURMOLU_ENABLE -}
lblPfx :: String --
lblPfx = "lbl_" -- lblPfx :: String
-- lblPfx = "lbl_"
--

View file

@ -1,235 +1,192 @@
{-# LANGUAGE LambdaCase #-} --{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} --{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where module LambdaLifter.LambdaLifter where
import Auxiliary (snoc) --import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2)) --import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, --import Control.Monad.State (MonadState (get, put), State,
evalState) -- evalState)
import Data.Set (Set) --import Data.Set (Set)
import qualified Data.Set as Set --import qualified Data.Set as Set
import Debug.Trace (trace) --import Prelude hiding (exp)
import qualified Grammar.Abs as GA --import Renamer.Renamer
import Prelude hiding (exp) --import TypeChecker.TypeCheckerIr
import Renamer.Renamer
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators. ---- | Lift lambdas and let expression into supercombinators.
-- Three phases: ---- Three phases:
-- @freeVars@ annotatss all the free variables. ---- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions. ---- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. ---- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program --lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars --lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') ---- | Annotate free variables
where --freeVars :: Program -> AnnProgram
e' = freeVarsExp (Set.insert par localVars) e --freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
-- | Bind n xs e <- ds
-- ]
-- Sum free variables present in bind and the expression --freeVarsExp :: Set Id -> Exp -> AnnExp
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') --freeVarsExp localVars = \case
where -- EId n | Set.member n localVars -> (Set.singleton n, AId n)
binders_frees = Set.delete name $ freeVarsOf rhs' -- | otherwise -> (mempty, AId n)
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs -- ELit _ (LInt i) -> (mempty, AInt i)
new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e -- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
e_localVars = Set.insert name localVars -- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
(ECase t e cs) -> do -- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
let e' = freeVarsExp localVars e -- where
let vars = freeVarsOf e' -- e1' = freeVarsExp localVars e1
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do -- e2' = freeVarsExp localVars e2
let e' = freeVarsExp vars e
let vars' = freeVarsOf e' -- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
(Set.union vars vars', AnnCase c e' : acc) -- where
) (vars, []) cs -- e' = freeVarsExp (Set.insert par localVars) e
(vars', ACase t e' (reverse cs'))
-- -- Sum free variables present in bind and the expression
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
-- where
-- binders_frees = Set.delete name $ freeVarsOf rhs'
-- e_free = Set.delete name $ freeVarsOf e'
-- rhs' = freeVarsExp e_localVars rhs
-- new_bind = ABind name parms rhs'
-- e' = freeVarsExp e_localVars e
-- e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id --freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst --freeVarsOf = fst
-- AST annotated with free variables ---- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)] --type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp') --type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show --data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id --data AnnExp' = AId Id
| AInt Integer -- | AInt Integer
| ALet ABind AnnExp -- | ALet ABind AnnExp
| AApp Type AnnExp AnnExp -- | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp -- | AAdd Type AnnExp AnnExp
| ASub Type AnnExp AnnExp -- | AAbs Type Id AnnExp
| AAbs Type Id AnnExp -- deriving Show
| ACase Type AnnExp [AnnCase] ---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
deriving Show ---- Free variables are @v₁ v₂ .. vₙ@ are bound.
data AnnCase = AnnCase GA.Case AnnExp --abstract :: AnnProgram -> Program
deriving Show --abstract prog = Program $ evalState (mapM go prog) 0
-- where
-- go :: (Id, [Id], AnnExp) -> State Int Bind
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
-- where
-- (rhs', parms1) = flattenLambdasAnn rhs
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. ---- | Flatten nested lambdas and collect the parameters
-- Free variables are @v₁ v₂ .. vₙ@ are bound. ---- @\x.\y.\z. ae → (ae, [x,y,z])@
abstract :: AnnProgram -> Program --flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
abstract prog = Program $ evalState (mapM go prog) 0 --flattenLambdasAnn ae = go (ae, [])
where -- where
go :: (Id, [Id], AnnExp) -> State Int Bind -- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' -- go ((free, e), acc) =
where -- case e of
(rhs', parms1) = flattenLambdasAnn rhs -- AAbs _ par (free1, e1) ->
-- go ((Set.delete par free1, e1), snoc par acc)
-- _ -> ((free, e), acc)
--abstractExp :: AnnExp -> State Int Exp
--abstractExp (free, exp) = case exp of
-- AId n -> pure $ EId n
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
-- where
-- go (ABind name parms rhs) = do
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
-- pure $ Bind name (parms ++ parms1) rhs'
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
-- skipLambdas f (free, ae) = case ae of
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
-- _ -> f (free, ae)
-- -- Lift lambda into let and bind free variables
-- AAbs t parm e -> do
-- i <- nextNumber
-- rhs <- abstractExp e
-- let sc_name = Ident ("sc_" ++ show i)
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
-- where
-- freeList = Set.toList free
-- parms = snoc parm freeList
-- | Flatten nested lambdas and collect the parameters --nextNumber :: State Int Int
-- @\x.\y.\z. ae → (ae, [x,y,z])@ --nextNumber = do
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) -- i <- get
flattenLambdasAnn ae = go (ae, []) -- put $ succ i
where -- pure i
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp ---- | Collects supercombinators by lifting non-constant let expressions
abstractExp (free, exp) = case exp of --collectScs :: Program -> Program
AId n -> pure $ EId n --collectScs (Program scs) = Program $ concatMap collectFromRhs scs
AInt i -> pure $ EInt i -- where
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) -- collectFromRhs (Bind name parms rhs) =
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) -- let (rhs_scs, rhs') = collectScsExp rhs
ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2) -- in Bind name parms rhs' : rhs_scs
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
ACase t e cs -> do --collectScsExp :: Exp -> ([Bind], Exp)
e' <- abstractExp e --collectScsExp = \case
cs' <- mapM (\(AnnCase c e) -> do -- EId n -> ([], EId n)
e' <- abstractExp e -- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
pure (t,Case c e')) cs
pure $ ECase t e' cs'
-- Lift lambda into let and bind free variables -- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
AAbs t parm e -> do -- where
i <- nextNumber -- (scs1, e1') = collectScsExp e1
rhs <- abstractExp e -- (scs2, e2') = collectScsExp e2
let sc_name = Ident ("sc_" ++ show i) -- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) -- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
pure $ foldl (EApp TInt) sc $ map EId freeList -- EAbs t par e -> (scs, EAbs t par e')
where -- where
freeList = Set.toList free -- (scs, e') = collectScsExp e
parms = snoc parm freeList
-- -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- -- > f = let sc x y = rhs in e
-- --
-- ELet (Bind name parms rhs) e -> if null parms
-- then ( rhs_scs ++ e_scs, ELet bind e')
-- else (bind : rhs_scs ++ e_scs, e')
-- where
-- bind = Bind name parms rhs'
-- (rhs_scs, rhs') = collectScsExp rhs
-- (e_scs, e') = collectScsExp e
nextNumber :: State Int Int ---- @\x.\y.\z. e → (e, [x,y,z])@
nextNumber = do --flattenLambdas :: Exp -> (Exp, [Id])
i <- get --flattenLambdas = go . (, [])
put $ succ i -- where
pure i -- go (e, acc) = case e of
-- EAbs _ par e1 -> go (e1, snoc par acc)
-- _ -> (e, acc)
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
ECase t e cs -> do
let (scs, e') = collectScsExp e
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
let (scs', e') = collectScsExp e
(scs ++ scs', (t,Case c e') : acc)
) (scs,[]) cs
(scs', ECase t e' cs')
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)

View file

@ -2,26 +2,26 @@
module Main where module Main where
import Codegen.Codegen (generateCode) --import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn) import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
-- import Interpreter (interpret) -- import Interpreter (interpret)
import Control.Monad (when) import Control.Monad (when)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
import LambdaLifter.LambdaLifter (lambdaLift) --import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import System.Directory (createDirectory, doesPathExist, import System.Directory (createDirectory, doesPathExist,
getDirectoryContents, getDirectoryContents,
removeDirectoryRecursive, removeDirectoryRecursive,
setCurrentDirectory) setCurrentDirectory)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess) import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr) import System.IO (stderr)
import System.Process.Extra (spawnCommand, waitForProcess) import System.Process.Extra (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (typecheck) import TypeChecker.TypeChecker (typecheck)
main :: IO () main :: IO ()
main = main =
@ -46,19 +46,19 @@ main' debug s = do
typechecked <- fromTypeCheckerErr $ typecheck renamed typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --" -- printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
printToErr $ printTree lifted -- printToErr $ printTree lifted
--
printToErr "\n -- Printing compiler output to stdout --" -- printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ generateCode lifted -- compiled <- fromCompilerErr $ generateCode lifted
--putStrLn compiled --putStrLn compiled
check <- doesPathExist "output" -- check <- doesPathExist "output"
when check (removeDirectoryRecursive "output") -- when check (removeDirectoryRecursive "output")
createDirectory "output" -- createDirectory "output"
writeFile "output/llvm.ll" compiled -- writeFile "output/llvm.ll" compiled
if debug then debugDotViz else putStrLn compiled -- if debug then debugDotViz else putStrLn compiled
-- interpred <- fromInterpreterErr $ interpret lifted -- interpred <- fromInterpreterErr $ interpret lifted

View file

@ -1,86 +1,87 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Renamer.Renamer (module Renamer.Renamer) where module Renamer.Renamer where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad (foldM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (MonadState, State, evalState, gets,
modify) modify)
import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe) import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where where
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs -- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind initNames = Map.fromList $ foldl' saveIfBind [] bs
renameSc old_names (Bind name t _ parms rhs) = do saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms (new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs' pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a } newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int) deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name -- | Maps old to new name
type Names = Map Ident Ident type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind) renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name (new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms (new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs (new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs') pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EInt i1 -> pure (old_names, EInt i1)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, ESub e1' e2') pure (Map.union env1 env2, ESub e1' e2')
ELet i e1 e2 -> do
ELet b e -> do (new_names, e1') <- renameExp old_names e1
(new_names, b) <- renameLocalBind old_names b (new_names', e2') <- renameExp new_names e2
(new_names', e') <- renameExp new_names e pure (new_names', ELet i e1' e2')
pure (new_names', ELet b e') EAbs par e -> do
EAbs par t e -> do
(new_names, par') <- newName old_names par (new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase e injs -> do
(_, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
ECase e cs t -> do renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
(new_names, e') <- renameExp old_names e renameInjs ns xs = do
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do (new_names, xs') <- unzip <$> mapM (renameInj ns) xs
(nm,exp') <- renameExp names exp if null new_names then return (mempty, xs') else return (head new_names, xs')
pure (nm,CaseMatch c exp' : stack)
) (new_names, []) cs renameInj :: Names -> Inj -> Rn (Names, Inj)
pure (new_names', ECase e' cs' t) renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)
@ -95,4 +96,3 @@ newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -1,215 +1,517 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker (typecheck, partitionType) where -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Auxiliary (maybeToRightM, snoc) import Control.Monad.Except
import Control.Monad.Except (throwError, unless) import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM (Err) import Grammar.Print (printTree)
import Grammar.Print (Print (prt), concatD, doc,
printTree, render)
import Prelude hiding (exp, id)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
-- NOTE: this type checker is poorly tested initCtx = Ctx mempty
-- TODO initEnv = Env 0 mempty mempty
-- Coercion
-- Type inference
data Cxt = Cxt runPretty :: Exp -> Either Error String
{ env :: Map Ident Type -- ^ Local scope signature runPretty = fmap (printTree . fst) . run . inferExp
, sig :: Map Ident Type -- ^ Top-level signatures
}
initCxt :: [Bind] -> Cxt run :: Infer a -> Either Error a
initCxt sc = Cxt { env = mempty run = runC initEnv initCtx
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
}
typecheck :: Program -> Err T.Program runC :: Env -> Ctx -> Infer a -> Either Error a
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
-- | Check if infered rhs type matches type signature. typecheck :: Program -> Either Error T.Program
checkBind :: Cxt -> Bind -> Err T.Bind typecheck = run . checkPrg
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ {- | Start by freshening the type variable of data types to avoid clash with
expandLambdas :: Bind -> Bind other user defined polymorphic types
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' This might be wrong for type constructors that work over several variables
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
{- | Freshen all polymorphic variables, regardless of name
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
-}
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol _) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) ->
TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
checkData :: Data -> Infer ()
checkData d = do
d' <- freshenData d
case d' of
(Data typ@(Constr name ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
if TConstr typ == retType t'
then insertConstr name' t'
else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ
]
)
constrs
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
T.Program <$> checkDef bs
where where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms preRun :: [Def] -> Infer ()
ts_parms = fst $ partitionType (length parms) t preRun [] = return ()
preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
-- | Infer type of expression. checkDef :: [Def] -> Infer [T.Def]
infer :: Cxt -> Exp -> Err (T.Exp, Type) checkDef [] = return []
infer cxt = \case checkDef (x : xs) = case x of
EId x -> (DBind b) -> do
case lookupEnv x cxt of b' <- checkBind b
Nothing -> fmap (T.DBind b' :) (checkDef xs)
case lookupSig x cxt of (DData d) -> fmap (T.DData d :) (checkDef xs)
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EInt i -> pure (T.EInt i, T.TInt) checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless
(t `typeEq` t'')
( throwError $
unwords
[ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e'
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
EApp e e1 -> do {- | Check if two types are considered equal
(e', t) <- infer cxt e For the purpose of the algorithm two polymorphic types are always considered
case t of equal
TFun t1 t2 -> do -}
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.EAdd T.TInt e' e1', T.TInt)
ESub e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.ESub T.TInt e' e1', T.TInt)
EAbs x t e -> do
(e', t1) <- infer (insertEnv x t cxt) e
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e
pure (T.ELet b' e', t)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure (T.ECase t1 e' cs,t1)
Left e -> throwError e
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
EId x -> do
t <- case lookupEnv x cxt of
Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
ESub e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.ESub T.TInt e' e1'
EAbs x t e -> do
(e', t_e) <- infer (insertEnv x t cxt) e
let t1 = TFun t t_e
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure $ T.ECase t1 e' cs
Left e -> throwError e
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq t t1 = t == t1 typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
-- | Partion type into types of parameters and return type. isMoreSpecificOrEq :: Type -> Type -> Bool
partitionType :: Int -- Number of parameters to apply isMoreSpecificOrEq _ (TPol _) = True
-> Type isMoreSpecificOrEq (TArr a b) (TArr c d) =
-> ([Type], Type) isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
partitionType = go [] isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
where n1 == n2
go acc 0 t = (acc, t) && length ts1 == length ts2
go acc i t = case t of && and (zipWith isMoreSpecificOrEq ts1 ts2)
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 isMoreSpecificOrEq a b = a == b
_ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt isPoly :: Type -> Bool
insertBind (Bind n t _ _ _) = insertEnv n t isPoly (TPol _) = True
isPoly _ = False
lookupEnv :: Ident -> Cxt -> Maybe Type inferExp :: Exp -> Infer (Type, T.Exp)
lookupEnv x = Map.lookup x . env inferExp e = do
(s, t, e') <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
insertEnv :: Ident -> Type -> Cxt -> Cxt replace :: Type -> T.Exp -> T.Exp
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
lookupSig :: Ident -> Cxt -> Maybe Type algoW :: Exp -> Infer (Subst, Type, T.Exp)
lookupSig x = Map.lookup x . sig algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do
(s1, t', e') <- algoW e
unless
(t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t'
]
)
applySt s1 $ do
s2 <- unify t t'
return (s2 `compose` s1, t, e')
typeErr :: Print a => a -> Type -> Type -> String -- \| ------------------
typeErr p expected actual = render $ concatD -- \| Γ ⊢ i : Int, ∅
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" ELit (LInt n) ->
, doc $ showString "Actual: " , prt 0 actual return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
] ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EId i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing ->
throwError $
"Unbound variable: " ++ show i
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- \| ------------------------------------------
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else
throwError $
unwords
[ "Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case
InitLit lit ->
let returnType = litType lit
in if expected == returnType
then return (mempty, expected)
else
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
Nothing ->
throwError $
unwords
[ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do
let flat = flattenType t
let returnType = last flat
case ( length (init flat) == length args
, returnType `isMoreSpecificOrEq` expected
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt _) = TMono "Int"

View file

@ -1,139 +1,184 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr module TypeChecker.TypeCheckerIr where
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
import Grammar.Abs (Ident (..), Type (..)) import Control.Monad.Except
import qualified Grammar.Abs as GA import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind] -- | A data type representing type variables
deriving (C.Eq, C.Ord, C.Show, C.Read) data Poly = Forall [Ident] Type
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env
{ count :: Int
, sigs :: Map Ident Type
, constructors :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp data Exp
= EId Id = EId Id
| EInt Integer | ELit Type Literal
| ELet Bind Exp | ELet Bind Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| ESub Type Exp Exp | ESub Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| ECase Type Exp [(Type, Case)] | ECase Type Exp [Inj]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Case = Case GA.Case Exp data Inj = Inj (Init, Type) Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])] data Bind = Bind Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD prt i (Bind (t, name) rhs) =
[ prtId 0 name prPrec i 0 $
, doc $ showString ";" concatD
, prt 0 n [ prt 0 name
, prtIdPs 0 parms , doc $ showString ":"
, doc $ showString "=" , prt 0 t
, prt 0 rhs , doc $ showString "\n"
] , prt 0 name
prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD , doc $ showString "="
[ prt 0 n , prt 0 rhs
, doc $ showString "{" ]
, doc . showString . show $ xs
, doc $ showString "}"
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD prtId i (name, t) =
[ prt 0 name prPrec i 0 $
, doc $ showString ":" concatD
, prt 0 t [ prt 0 name
] , doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD prtIdP i (name, t) =
[ doc $ showString "(" prPrec i 0 $
, prt 0 name concatD
, doc $ showString ":" [ doc $ showString "("
, prt 0 t , prt 0 name
, doc $ showString ")" , doc $ showString ":"
] , prt 0 t
, doc $ showString ")"
]
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n] EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
ELet bs e -> prPrec i 3 $ concatD ELet bs e ->
[ doc $ showString "let" prPrec i 3 $
, prt 0 bs concatD
, doc $ showString "in" [ doc $ showString "let"
, prt 0 e , prt 0 bs
] , doc $ showString "in"
EApp t e1 e2 -> prPrec i 2 $ concatD , prt 0 e
[ doc $ showString "@" , doc $ showString "\n"
, prt 0 t ]
, prt 2 e1 EApp _ e1 e2 ->
, prt 3 e2 prPrec i 2 $
] concatD
EAdd t e1 e2 -> prPrec i 1 $ concatD [ prt 2 e1
[ doc $ showString "@" , prt 3 e2
, prt 0 t ]
, prt 1 e1 EAdd t e1 e2 ->
, doc $ showString "+" prPrec i 1 $
, prt 2 e2 concatD
] [ doc $ showString "@"
ESub t e1 e2 -> prPrec i 1 $ concatD , prt 0 t
[ doc $ showString "@" , prt 1 e1
, prt 0 t , doc $ showString "+"
, prt 1 e1 , prt 2 e2
, doc $ showString "-" , doc $ showString "\n"
, prt 2 e2 ]
] ESub t e1 e2 ->
EAbs t n e -> prPrec i 0 $ concatD prPrec i 1 $
[ doc $ showString "@" concatD
, prt 0 t [ doc $ showString "@"
, doc $ showString "\\" , prt 0 t
, prtIdP 0 n , prt 1 e1
, doc $ showString "." , doc $ showString "-"
, prt 0 e , prt 2 e2
] , doc $ showString "\n"
ECase t e cs -> prPrec i 0 $ concatD ]
[ doc $ showString "@" EAbs t n e ->
, prt 0 t prPrec i 0 $
, doc $ showString "(" concatD
, prt 0 e [ doc $ showString "@"
, doc $ showString ")" , prt 0 t
, prPrec i 0 $ concatD . printCases $ cs , doc $ showString "\\"
] , prtId 0 n
, doc $ showString "."
, prt 0 e
, doc $ showString "\n"
]
ECase t exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
where instance Print Inj where
printCases :: [(Type, Case)] -> [Doc] prt i = \case
printCases [] = [] Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
printCases ((t, Case c e):xs) = concatD
[ doc $ showString "@" instance Print [Inj] where
, prt 0 t prt _ [] = concatD []
, doc $ showString "(" prt _ [x] = concatD [prt 0 x]
, doc . showString . show $ c prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
, doc $ showString ")"
, doc $ showString "=>"
, prt 0 e
, doc $ showString "\n"
] : printCases xs