Started importing Sebastian's new typechecker.

This commit is contained in:
Samuel Hammersberg 2023-03-08 11:01:07 +01:00
parent d5dd7896d8
commit 350cd3b0e9
9 changed files with 1611 additions and 1346 deletions

View file

@ -1,33 +1,51 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident; Program. Program ::= [Def] ;
EInt. Exp3 ::= Integer;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
ESub. Exp1 ::= Exp1 "-" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type;
CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ",";
DBind. Def ::= Bind ;
CInt. Case ::= Integer ; DData. Def ::= Data ;
CatchAll. Case ::= "_" ; separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ; Ident [Ident] "=" Exp ;
separator Bind ";"; Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
Constructor. Constructor ::= Ident ":" Type ;
separator nonempty Constructor "" ;
TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ;
TConstr. Type1 ::= Constr ;
TArr. Type ::= Type1 "->" Type ;
Constr. Constr ::= Ident "(" [Type] ")" ;
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ;
ELit. Exp4 ::= Literal ;
EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ESub. Exp1 ::= Exp1 "-" Exp2 ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
LInt. Literal ::= Integer ;
Inj. Inj ::= Init "=>" Exp ;
separator nonempty Inj ";" ;
InitLit. Init ::= Literal ;
InitConstr. Init ::= Ident [Ident] ;
InitCatch. Init ::= "_" ;
separator Type " " ;
coercions Type 2 ;
separator Ident " "; separator Ident " ";
coercions Exp 3; coercions Exp 5 ;
TInt. Type1 ::= "Int" ;
TPol. Type1 ::= Ident ;
TFun. Type ::= Type1 "->" Type ;
coercions Type 1 ;
comment "--" ; comment "--" ;
comment "{-" "-}" ; comment "{-" "-}" ;

View file

@ -1,87 +1,26 @@
posMul : _Int -> _Int -> _Int;
-- tripplemagic : Int -> Int -> Int -> Int;
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
-- main : Int;
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
-- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
-- apply f x y = f x y;
-- krimp: Int -> Int -> Int;
-- krimp x y = x + y;
-- main : Int;
-- main = apply (krimp) 2 3;
-- answer: 5
-- fibbonaci : Int -> Int;
-- fibbonaci x = case x of {
-- 0 => 0,
-- 1 => 1,
-- -- abusing overflows to represent negatives like a boss
-- _ => (fibbonaci (x - 2))
-- + (fibbonaci (x - 1))
-- } : Int;
-- main : Int;
-- main = fibbonaci 10;
-- answer: 55
-- succ : Int -> Int;
-- succ x = x - 1;
--
-- isZero : Int -> Int;
-- isZero x = case x of {
-- 0 => 1,
-- _ => 0
-- } : Int;
--
-- minimization : (Int -> Int) -> Int -> Int;
-- minimization p x = case p x of {
-- 1 => 0,
-- _ => minimization p (succ x)
-- } : Int;
--
-- main : Int;
-- main = minimization isZero 10;
-- answer: 0
posMul : Int -> Int -> Int;
posMul a b = case b of { posMul a b = case b of {
0 => 0, 0 => 0;
_ => a + posMul a (b - 1) _ => a + posMul a (b - 1)
} : Int; };
facc : Int -> Int; facc : _Int -> _Int;
facc a = case a of { facc a = case a of {
1 => 1, 1 => 1;
_ => posMul a (facc (a - 1)) _ => posMul a (facc (a - 1))
} : Int; };
-- main : Int;
-- main = facc 5
-- answer: 120
-- pow : Int -> Int -> Int; minimization : (_Int -> _Int) -> _Int -> _Int;
-- pow a b = case b of {
-- 0 => 1,
-- _ => posMul a (pow a (b-1))
-- } : Int;
minimization : (Int -> Int) -> Int -> Int;
minimization p x = case p x of { minimization p x = case p x of {
1 => x, 1 => x;
_ => minimization p (x + 1) _ => minimization p (x + 1)
} : Int; };
checkFac : Int -> Int; checkFac : _Int -> _Int;
checkFac x = case facc x of { checkFac x = case facc x of {
0 => 1, 0 => 1;
_ => 0 _ => 0
} : Int; };
main : Int; main : _Int;
main = minimization checkFac 1 main = minimization checkFac 1

View file

@ -1,441 +1,443 @@
{-# LANGUAGE LambdaCase #-} module Codegen.Codegen where
{-# LANGUAGE OverloadedStrings #-} -- {-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Auxiliary (snoc)
import Codegen.LlvmIr (CallingConvention (..),
LLVMComp (..), LLVMIr (..),
LLVMType (..), LLVMValue (..),
Visibility (..), llvmIrToString)
import Control.Monad.State (StateT, execStateT, foldM_, gets,
modify)
import qualified Data.Bifunctor as BI
import Data.List.Extra (trim)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second)
import qualified Grammar.Abs as GA
import Grammar.ErrM (Err)
import System.Process.Extra (readCreateProcess, shell)
import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
Ident (..), Program (..), Type (..))
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, constructors :: Map Id ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
}
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (Bind id args _ : xs) =
(id, FunctionInfo { numArgs=length args, arguments=args })
: go xs
go (DataStructure n cons : xs) = do
map (\(id, xs) -> ((id, TPol n), FunctionInfo {
numArgs=length xs, arguments=createArgs xs
})) cons
<> go xs
createArgs :: [Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getConstructors :: [Bind] -> Map Id ConstructorInfo
getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (DataStructure (Ident n) cons : xs) = do
fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
numArgsCI=length xs,
argumentsCI=createArgs xs,
numCI=i
}) : acc, i+1)) ([],0) cons)
<> go xs
go (_: xs) = go xs
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, constructors = getConstructors scs
, variableCount = 0
, labelCount = 0
}
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "output/llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
DataStructure (Ident "Craig") [
(Ident "Bob", [TInt])--,
--(Ident "Alice", [TInt, TInt])
],
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
Bind (Ident "main", TInt) [] (
EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
)
]
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
generateCode :: Program -> Err String
generateCode (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = do
-- as a last step create all the constructors
c <- gets (Map.toList . constructors)
mapM_ (\((id, t), ci) -> do
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x
top <- Ident . show <$> getNewVar
ptr <- Ident . show <$> getNewVar
-- allocated the primary type
emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor
emit $ SetVariable ptr $
GetElementPtrInbounds t' (Ref t')
(VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
-- get a pointer of the correct type
ptr' <- Ident . show <$> getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
--emit $ UnsafeRaw "\n"
foldM_ (\i (Ident arg_n, arg_t)-> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
elemPtr <- Ident . show <$> getNewVar
emit $ SetVariable elemPtr (
GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
(VIdent ptr' Ptr) I32
(VInteger 0) I32 (VInteger i))
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
-- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
-- store i32 42, i32* %2
pure $ i + 1-- + typeByteSize arg_t'
) 1 (argumentsCI ci)
--emit $ UnsafeRaw "\n"
-- load and return the constructed value
load <- Ident . show <$> getNewVar
emit $ SetVariable load (Load t' Ptr top)
emit $ Ret t' (VIdent load t')
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
) c
compileScs (Bind (name, _t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
emit $ Type id [I8, Array biggest_variant I8]
mapM_ (\(Ident inner_id, fi) -> do
emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
) ts
compileScs xs
-- where
-- _t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
compileExp :: Exp -> CompilerState ()
compileExp (EInt int) = emitInt int
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (EId (name, _)) = emitIdent name
compileExp (EApp t e1 e2) = emitApp t e1 e2
compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ELet binds e) = emitLet binds e
compileExp (ECase t e cs) = emitECased t e cs
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions ---
emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
vs <- exprToValue e
lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
mapM_ (emitCases ty label stackPtr vs) cs
emit $ Label label
res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
where
emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
ns <- getNewVar
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label
emit $ Label lbl_failPos
emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
val <- exprToValue exp
emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do
emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
concat
[ "ELet ("
, show xs
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 []
where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: Type -> Exp -> Exp -> CompilerState ()
emitSub t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
-- --
-- increaseVarCount -- module Codegen.Codegen (generateCode) where
-- vabs <- gets variableCount --
-- emit $ SetVariable $ Ident $ show vabs -- import Auxiliary (snoc)
-- emit $ Call I64 (Ident "llvm.abs.i64") -- import Codegen.LlvmIr (CallingConvention (..),
-- [ (I64, VIdent (Ident $ show vadd)) -- LLVMComp (..), LLVMIr (..),
-- , (I1, VInteger 1) -- LLVMType (..), LLVMValue (..),
-- Visibility (..), llvmIrToString)
-- import Control.Monad.State (StateT, execStateT, foldM_, gets,
-- modify)
-- import qualified Data.Bifunctor as BI
-- import Data.List.Extra (trim)
-- import Data.Map (Map)
-- import qualified Data.Map as Map
-- import Data.Tuple.Extra (dupe, first, second)
-- import qualified Grammar.Abs as GA
-- import Grammar.ErrM (Err)
-- import System.Process.Extra (readCreateProcess, shell)
-- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
-- Ident (..), Program (..), Type (..))
-- -- | The record used as the code generator state
-- data CodeGenerator = CodeGenerator
-- { instructions :: [LLVMIr]
-- , functions :: Map Id FunctionInfo
-- , constructors :: Map Id ConstructorInfo
-- , variableCount :: Integer
-- , labelCount :: Integer
-- }
--
-- -- | A state type synonym
-- type CompilerState a = StateT CodeGenerator Err a
--
-- data FunctionInfo = FunctionInfo
-- { numArgs :: Int
-- , arguments :: [Id]
-- }
-- data ConstructorInfo = ConstructorInfo
-- { numArgsCI :: Int
-- , argumentsCI :: [Id]
-- , numCI :: Integer
-- }
--
--
-- -- | Adds a instruction to the CodeGenerator state
-- emit :: LLVMIr -> CompilerState ()
-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
--
-- -- | Increases the variable counter in the CodeGenerator state
-- increaseVarCount :: CompilerState ()
-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
--
-- -- | Returns the variable count from the CodeGenerator state
-- getVarCount :: CompilerState Integer
-- getVarCount = gets variableCount
--
-- -- | Increases the variable count and returns it from the CodeGenerator state
-- getNewVar :: CompilerState Integer
-- getNewVar = increaseVarCount >> getVarCount
--
-- -- | Increses the label count and returns a label from the CodeGenerator state
-- getNewLabel :: CompilerState Integer
-- getNewLabel = do
-- modify (\t -> t{labelCount = labelCount t + 1})
-- gets labelCount
--
-- -- | Produces a map of functions infos from a list of binds,
-- -- which contains useful data for code generation.
-- getFunctions :: [Bind] -> Map Id FunctionInfo
-- getFunctions bs = Map.fromList $ go bs
-- where
-- go [] = []
-- go (Bind id args _ : xs) =
-- (id, FunctionInfo { numArgs=length args, arguments=args })
-- : go xs
-- go (DataStructure n cons : xs) = do
-- map (\(id, xs) -> ((id, TPol n), FunctionInfo {
-- numArgs=length xs, arguments=createArgs xs
-- })) cons
-- <> go xs
--
-- createArgs :: [Type] -> [Id]
-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
--
-- -- | Produces a map of functions infos from a list of binds,
-- -- which contains useful data for code generation.
-- getConstructors :: [Bind] -> Map Id ConstructorInfo
-- getConstructors bs = Map.fromList $ go bs
-- where
-- go [] = []
-- go (DataStructure (Ident n) cons : xs) = do
-- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
-- numArgsCI=length xs,
-- argumentsCI=createArgs xs,
-- numCI=i
-- }) : acc, i+1)) ([],0) cons)
-- <> go xs
-- go (_: xs) = go xs
--
-- initCodeGenerator :: [Bind] -> CodeGenerator
-- initCodeGenerator scs = CodeGenerator { instructions = defaultStart
-- , functions = getFunctions scs
-- , constructors = getConstructors scs
-- , variableCount = 0
-- , labelCount = 0
-- }
--
-- run :: Err String -> IO ()
-- run s = do
-- let s' = case s of
-- Right s -> s
-- Left _ -> error "yo"
-- writeFile "output/llvm.ll" s'
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
--
-- test :: Integer -> Program
-- test v = Program [
-- DataStructure (Ident "Craig") [
-- (Ident "Bob", [TInt])--,
-- --(Ident "Alice", [TInt, TInt])
-- ],
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
-- Bind (Ident "main", TInt) [] (
-- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
-- )
-- ] -- ]
-- increaseVarCount --
-- v <- gets variableCount -- {- | Compiles an AST and produces a LLVM Ir string.
-- emit $ SetVariable $ Ident $ show v -- An easy way to actually "compile" this output is to
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 -- Simply pipe it to LLI
-- -}
-- emitDiv :: Exp -> Exp -> CompilerState () -- generateCode :: Program -> Err String
-- emitDiv e1 e2 = do -- generateCode (Program scs) = do
-- (v1,v2) <- binExprToValues e1 e2 -- let codegen = initCodeGenerator scs
-- increaseVarCount -- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
-- v <- gets variableCount --
-- emit $ SetVariable $ Ident $ show v -- compileScs :: [Bind] -> CompilerState ()
-- emit $ Div I64 v1 v2 -- compileScs [] = do
-- -- as a last step create all the constructors
exprToValue :: Exp -> CompilerState LLVMValue -- c <- gets (Map.toList . constructors)
exprToValue = \case -- mapM_ (\((id, t), ci) -> do
EInt i -> pure $ VInteger i -- let t' = type2LlvmType t
-- let x = BI.second type2LlvmType <$> argumentsCI ci
EId id@(name, t) -> do -- emit $ Define FastCC t' id x
funcs <- gets functions -- top <- Ident . show <$> getNewVar
case Map.lookup id funcs of -- ptr <- Ident . show <$> getNewVar
Just fi -> do -- -- allocated the primary type
if numArgs fi == 0 -- emit $ SetVariable top (Alloca t')
then do --
vc <- getNewVar -- -- set the first byte to the index of the constructor
emit $ SetVariable (Ident $ show vc) -- emit $ SetVariable ptr $
(Call FastCC (type2LlvmType t) Global name []) -- GetElementPtrInbounds t' (Ref t')
pure $ VIdent (Ident $ show vc) (type2LlvmType t) -- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
else pure $ VFunction name Global (type2LlvmType t) -- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
Nothing -> pure $ VIdent name (type2LlvmType t) --
-- -- get a pointer of the correct type
e -> do -- ptr' <- Ident . show <$> getNewVar
compileExp e -- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
v <- getVarCount --
pure $ VIdent (Ident $ show v) (getType e) -- --emit $ UnsafeRaw "\n"
--
type2LlvmType :: Type -> LLVMType -- foldM_ (\i (Ident arg_n, arg_t)-> do
type2LlvmType = \case -- let arg_t' = type2LlvmType arg_t
TInt -> I64 -- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
TFun t xs -> do -- elemPtr <- Ident . show <$> getNewVar
let (t', xs') = function2LLVMType xs [type2LlvmType t] -- emit $ SetVariable elemPtr (
Function t' xs' -- GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
TPol t -> CustomType t -- (VIdent ptr' Ptr) I32
where -- (VInteger 0) I32 (VInteger i))
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) -- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) -- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
function2LLVMType x s = (type2LlvmType x, s) -- -- store i32 42, i32* %2
-- pure $ i + 1-- + typeByteSize arg_t'
getType :: Exp -> LLVMType -- ) 1 (argumentsCI ci)
getType (EInt _) = I64 --
getType (EAdd t _ _) = type2LlvmType t -- --emit $ UnsafeRaw "\n"
getType (ESub t _ _) = type2LlvmType t --
getType (EId (_, t)) = type2LlvmType t -- -- load and return the constructed value
getType (EApp t _ _) = type2LlvmType t -- load <- Ident . show <$> getNewVar
getType (EAbs t _ _) = type2LlvmType t -- emit $ SetVariable load (Load t' Ptr top)
getType (ELet _ e) = getType e -- emit $ Ret t' (VIdent load t')
getType (ECase t _ _) = type2LlvmType t -- emit DefineEnd
--
valueGetType :: LLVMValue -> LLVMType -- modify $ \s -> s { variableCount = 0 }
valueGetType (VInteger _) = I64 -- ) c
valueGetType (VIdent _ t) = t -- compileScs (Bind (name, _t) args exp : xs) = do
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 -- emit $ UnsafeRaw "\n"
valueGetType (VFunction _ _ t) = t -- emit . Comment $ show name <> ": " <> show exp
-- let args' = map (second type2LlvmType) args
typeByteSize :: LLVMType -> Integer -- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
typeByteSize I1 = 1 -- functionBody <- exprToValue exp
typeByteSize I8 = 1 -- if name == "main"
typeByteSize I32 = 4 -- then mapM_ emit $ mainContent functionBody
typeByteSize I64 = 8 -- else emit $ Ret I64 functionBody
typeByteSize Ptr = 8 -- emit DefineEnd
typeByteSize (Ref _) = 8 -- modify $ \s -> s { variableCount = 0 }
typeByteSize (Function _ _) = 8 -- compileScs xs
typeByteSize (Array n t) = n * typeByteSize t -- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
typeByteSize (CustomType _) = 8 -- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
-- emit $ Type id [I8, Array biggest_variant I8]
-- mapM_ (\(Ident inner_id, fi) -> do
-- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
-- ) ts
-- compileScs xs
--
-- -- where
-- -- _t_return = snd $ partitionType (length args) t
--
-- mainContent :: LLVMValue -> [LLVMIr]
-- mainContent var =
-- [ UnsafeRaw $
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- -- , Label (Ident "b_1")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "b_2")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "end")
-- Ret I64 (VInteger 0)
-- ]
--
-- defaultStart :: [LLVMIr]
-- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
-- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
-- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
-- ]
--
-- compileExp :: Exp -> CompilerState ()
-- compileExp (EInt int) = emitInt int
-- compileExp (EAdd t e1 e2) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
-- compileExp (EId (name, _)) = emitIdent name
-- compileExp (EApp t e1 e2) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e
-- compileExp (ELet binds e) = emitLet binds e
-- compileExp (ECase t e cs) = emitECased t e cs
-- -- go (EMul e1 e2) = emitMul e1 e2
-- -- go (EDiv e1 e2) = emitDiv e1 e2
-- -- go (EMod e1 e2) = emitMod e1 e2
--
-- --- aux functions ---
-- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
-- emitECased t e cases = do
-- let cs = snd <$> cases
-- let ty = type2LlvmType t
-- vs <- exprToValue e
-- lbl <- getNewLabel
-- let label = Ident $ "escape_" <> show lbl
-- stackPtr <- getNewVar
-- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
-- mapM_ (emitCases ty label stackPtr vs) cs
-- emit $ Label label
-- res <- getNewVar
-- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
-- where
-- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
-- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
-- ns <- getNewVar
-- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
-- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
-- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
-- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
-- emit $ Label lbl_succPos
-- val <- exprToValue exp
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
-- emit $ Br label
-- emit $ Label lbl_failPos
-- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
-- val <- exprToValue exp
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
-- emit $ Br label
--
--
-- emitAbs :: Type -> Id -> Exp -> CompilerState ()
-- emitAbs _t tid e = do
-- emit . Comment $
-- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
-- emitLet :: Bind -> Exp -> CompilerState ()
-- emitLet xs e = do
-- emit $
-- Comment $
-- concat
-- [ "ELet ("
-- , show xs
-- , " = "
-- , show e
-- , ") is not implemented!"
-- ]
--
-- emitApp :: Type -> Exp -> Exp -> CompilerState ()
-- emitApp t e1 e2 = appEmitter t e1 e2 []
-- where
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
-- appEmitter t e1 e2 stack = do
-- let newStack = e2 : stack
-- case e1 of
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
-- EId id@(name, _) -> do
-- args <- traverse exprToValue newStack
-- vs <- getNewVar
-- funcs <- gets functions
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
-- args' = map (first valueGetType . dupe) args
-- call = Call FastCC (type2LlvmType t) visibility name args'
-- emit $ SetVariable (Ident $ show vs) call
-- x -> do
-- emit . Comment $ "The unspeakable happened: "
-- emit . Comment $ show x
--
-- emitIdent :: Ident -> CompilerState ()
-- emitIdent id = do
-- -- !!this should never happen!!
-- emit $ Comment "This should not have happened!"
-- emit $ Variable id
-- emit $ UnsafeRaw "\n"
--
-- emitInt :: Integer -> CompilerState ()
-- emitInt i = do
-- -- !!this should never happen!!
-- varCount <- getNewVar
-- emit $ Comment "This should not have happened!"
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
--
-- emitAdd :: Type -> Exp -> Exp -> CompilerState ()
-- emitAdd t e1 e2 = do
-- v1 <- exprToValue e1
-- v2 <- exprToValue e2
-- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
--
-- emitSub :: Type -> Exp -> Exp -> CompilerState ()
-- emitSub t e1 e2 = do
-- v1 <- exprToValue e1
-- v2 <- exprToValue e2
-- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
--
-- -- emitMul :: Exp -> Exp -> CompilerState ()
-- -- emitMul e1 e2 = do
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Mul I64 v1 v2
--
-- -- emitMod :: Exp -> Exp -> CompilerState ()
-- -- emitMod e1 e2 = do
-- -- -- `let m a b = rem (abs $ b + a) b`
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- vadd <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show vadd
-- -- emit $ Add I64 v1 v2
-- --
-- -- increaseVarCount
-- -- vabs <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show vabs
-- -- emit $ Call I64 (Ident "llvm.abs.i64")
-- -- [ (I64, VIdent (Ident $ show vadd))
-- -- , (I1, VInteger 1)
-- -- ]
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
--
-- -- emitDiv :: Exp -> Exp -> CompilerState ()
-- -- emitDiv e1 e2 = do
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Div I64 v1 v2
--
-- exprToValue :: Exp -> CompilerState LLVMValue
-- exprToValue = \case
-- EInt i -> pure $ VInteger i
--
-- EId id@(name, t) -> do
-- funcs <- gets functions
-- case Map.lookup id funcs of
-- Just fi -> do
-- if numArgs fi == 0
-- then do
-- vc <- getNewVar
-- emit $ SetVariable (Ident $ show vc)
-- (Call FastCC (type2LlvmType t) Global name [])
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
-- else pure $ VFunction name Global (type2LlvmType t)
-- Nothing -> pure $ VIdent name (type2LlvmType t)
--
-- e -> do
-- compileExp e
-- v <- getVarCount
-- pure $ VIdent (Ident $ show v) (getType e)
--
-- type2LlvmType :: Type -> LLVMType
-- type2LlvmType = \case
-- TInt -> I64
-- TFun t xs -> do
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
-- Function t' xs'
-- TPol t -> CustomType t
-- where
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
--
-- getType :: Exp -> LLVMType
-- getType (EInt _) = I64
-- getType (EAdd t _ _) = type2LlvmType t
-- getType (ESub t _ _) = type2LlvmType t
-- getType (EId (_, t)) = type2LlvmType t
-- getType (EApp t _ _) = type2LlvmType t
-- getType (EAbs t _ _) = type2LlvmType t
-- getType (ELet _ e) = getType e
-- getType (ECase t _ _) = type2LlvmType t
--
-- valueGetType :: LLVMValue -> LLVMType
-- valueGetType (VInteger _) = I64
-- valueGetType (VIdent _ t) = t
-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
-- valueGetType (VFunction _ _ t) = t
--
-- typeByteSize :: LLVMType -> Integer
-- typeByteSize I1 = 1
-- typeByteSize I8 = 1
-- typeByteSize I32 = 4
-- typeByteSize I64 = 8
-- typeByteSize Ptr = 8
-- typeByteSize (Ref _) = 8
-- typeByteSize (Function _ _) = 8
-- typeByteSize (Array n t) = n * typeByteSize t
-- typeByteSize (CustomType _) = 8
--

View file

@ -1,239 +1,241 @@
{-# LANGUAGE LambdaCase #-} module Codegen.LlvmIr where
-- {-# LANGUAGE LambdaCase #-}
module Codegen.LlvmIr ( --
LLVMType (..), -- module Codegen.LlvmIr (
LLVMIr (..), -- LLVMType (..),
llvmIrToString, -- LLVMIr (..),
LLVMValue (..), -- llvmIrToString,
LLVMComp (..), -- LLVMValue (..),
Visibility (..), -- LLVMComp (..),
CallingConvention (..) -- Visibility (..),
) where -- CallingConvention (..)
-- ) where
import Data.List (intercalate) --
import TypeChecker.TypeCheckerIr -- import Data.List (intercalate)
-- import TypeChecker.TypeCheckerIr
data CallingConvention = TailCC | FastCC | CCC | ColdCC --
instance Show CallingConvention where -- data CallingConvention = TailCC | FastCC | CCC | ColdCC
show :: CallingConvention -> String -- instance Show CallingConvention where
show TailCC = "tailcc" -- show :: CallingConvention -> String
show FastCC = "fastcc" -- show TailCC = "tailcc"
show CCC = "ccc" -- show FastCC = "fastcc"
show ColdCC = "coldcc" -- show CCC = "ccc"
-- show ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types --
data LLVMType -- -- | A datatype which represents some basic LLVM types
= I1 -- data LLVMType
| I8 -- = I1
| I32 -- | I8
| I64 -- | I32
| Ptr -- | I64
| Ref LLVMType -- | Ptr
| Function LLVMType [LLVMType] -- | Ref LLVMType
| Array Integer LLVMType -- | Function LLVMType [LLVMType]
| CustomType Ident -- | Array Integer LLVMType
-- | CustomType Ident
instance Show LLVMType where --
show :: LLVMType -> String -- instance Show LLVMType where
show = \case -- show :: LLVMType -> String
I1 -> "i1" -- show = \case
I8 -> "i8" -- I1 -> "i1"
I32 -> "i32" -- I8 -> "i8"
I64 -> "i64" -- I32 -> "i32"
Ptr -> "ptr" -- I64 -> "i64"
Ref ty -> show ty <> "*" -- Ptr -> "ptr"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" -- Ref ty -> show ty <> "*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"] -- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
CustomType (Ident ty) -> "%" <> ty -- Array n ty -> concat ["[", show n, " x ", show ty, "]"]
-- CustomType (Ident ty) -> "%" <> ty
data LLVMComp --
= LLEq -- data LLVMComp
| LLNe -- = LLEq
| LLUgt -- | LLNe
| LLUge -- | LLUgt
| LLUlt -- | LLUge
| LLUle -- | LLUlt
| LLSgt -- | LLUle
| LLSge -- | LLSgt
| LLSlt -- | LLSge
| LLSle -- | LLSlt
instance Show LLVMComp where -- | LLSle
show :: LLVMComp -> String -- instance Show LLVMComp where
show = \case -- show :: LLVMComp -> String
LLEq -> "eq" -- show = \case
LLNe -> "ne" -- LLEq -> "eq"
LLUgt -> "ugt" -- LLNe -> "ne"
LLUge -> "uge" -- LLUgt -> "ugt"
LLUlt -> "ult" -- LLUge -> "uge"
LLUle -> "ule" -- LLUlt -> "ult"
LLSgt -> "sgt" -- LLUle -> "ule"
LLSge -> "sge" -- LLSgt -> "sgt"
LLSlt -> "slt" -- LLSge -> "sge"
LLSle -> "sle" -- LLSlt -> "slt"
-- LLSle -> "sle"
data Visibility = Local | Global --
instance Show Visibility where -- data Visibility = Local | Global
show :: Visibility -> String -- instance Show Visibility where
show Local = "%" -- show :: Visibility -> String
show Global = "@" -- show Local = "%"
-- show Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable, --
-- or a string contstant -- -- | Represents a LLVM "value", as in an integer, a register variable,
data LLVMValue -- -- or a string contstant
= VInteger Integer -- data LLVMValue
| VIdent Ident LLVMType -- = VInteger Integer
| VConstant String -- | VIdent Ident LLVMType
| VFunction Ident Visibility LLVMType -- | VConstant String
-- | VFunction Ident Visibility LLVMType
instance Show LLVMValue where --
show :: LLVMValue -> String -- instance Show LLVMValue where
show v = case v of -- show :: LLVMValue -> String
VInteger i -> show i -- show v = case v of
VIdent (Ident n) _ -> "%" <> n -- VInteger i -> show i
VFunction (Ident n) vis _ -> show vis <> n -- VIdent (Ident n) _ -> "%" <> n
VConstant s -> "c" <> show s -- VFunction (Ident n) vis _ -> show vis <> n
-- VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] --
type Args = [(LLVMType, LLVMValue)] -- type Params = [(Ident, LLVMType)]
-- type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM --
data LLVMIr -- -- | A datatype which represents different instructions in LLVM
= Type Ident [LLVMType] -- data LLVMIr
| Define CallingConvention LLVMType Ident Params -- = Type Ident [LLVMType]
| DefineEnd -- | Define CallingConvention LLVMType Ident Params
| Declare LLVMType Ident Params -- | DefineEnd
| SetVariable Ident LLVMIr -- | Declare LLVMType Ident Params
| Variable Ident -- | SetVariable Ident LLVMIr
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue -- | Variable Ident
| Add LLVMType LLVMValue LLVMValue -- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Sub LLVMType LLVMValue LLVMValue -- | Add LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue -- | Sub LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue -- | Div LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue -- | Mul LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue -- | Srem LLVMType LLVMValue LLVMValue
| Br Ident -- | Icmp LLVMComp LLVMType LLVMValue LLVMValue
| BrCond LLVMValue Ident Ident -- | Br Ident
| Label Ident -- | BrCond LLVMValue Ident Ident
| Call CallingConvention LLVMType Visibility Ident Args -- | Label Ident
| Alloca LLVMType -- | Call CallingConvention LLVMType Visibility Ident Args
| Store LLVMType LLVMValue LLVMType Ident -- | Alloca LLVMType
| Load LLVMType LLVMType Ident -- | Store LLVMType LLVMValue LLVMType Ident
| Bitcast LLVMType Ident LLVMType -- | Load LLVMType LLVMType Ident
| Ret LLVMType LLVMValue -- | Bitcast LLVMType Ident LLVMType
| Comment String -- | Ret LLVMType LLVMValue
| UnsafeRaw String -- This should generally be avoided, and proper -- | Comment String
-- instructions should be used in its place -- | UnsafeRaw String -- This should generally be avoided, and proper
deriving (Show) -- -- instructions should be used in its place
-- deriving (Show)
-- | Converts a list of LLVMIr instructions to a string --
llvmIrToString :: [LLVMIr] -> String -- -- | Converts a list of LLVMIr instructions to a string
llvmIrToString = go 0 -- llvmIrToString :: [LLVMIr] -> String
where -- llvmIrToString = go 0
go :: Int -> [LLVMIr] -> String -- where
go _ [] = mempty -- go :: Int -> [LLVMIr] -> String
go i (x : xs) = do -- go _ [] = mempty
let (i', n) = case x of -- go i (x : xs) = do
Define{} -> (i + 1, 0) -- let (i', n) = case x of
DefineEnd -> (i - 1, 0) -- Define{} -> (i + 1, 0)
_ -> (i, i) -- DefineEnd -> (i - 1, 0)
insToString n x <> go i' xs -- _ -> (i, i)
-- insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc. --
The integer represents the indentation -- {- | Converts a LLVM inststruction to a String, allowing for printing etc.
-} -- The integer represents the indentation
{- FOURMOLU_DISABLE -} -- -}
insToString :: Int -> LLVMIr -> String -- {- FOURMOLU_DISABLE -}
insToString i l = -- insToString :: Int -> LLVMIr -> String
replicate i '\t' <> case l of -- insToString i l =
(GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do -- replicate i '\t' <> case l of
-- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 -- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
concat -- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
[ "getelementptr inbounds ", show t1, ", " , show t2 -- concat
, " ", show p, ", ", show t3, " ", show v1, -- [ "getelementptr inbounds ", show t1, ", " , show t2
", ", show t4, " ", show v2, "\n" ] -- , " ", show p, ", ", show t3, " ", show v1,
(Type (Ident n) types) -> -- ", ", show t4, " ", show v2, "\n" ]
concat -- (Type (Ident n) types) ->
[ "%", n, " = type { " -- concat
, intercalate ", " (map show types) -- [ "%", n, " = type { "
, " }\n" -- , intercalate ", " (map show types)
] -- , " }\n"
(Define c t (Ident i) params) -> -- ]
concat -- (Define c t (Ident i) params) ->
[ "define ", show c, " ", show t, " @", i -- concat
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) -- [ "define ", show c, " ", show t, " @", i
, ") {\n" -- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
] -- , ") {\n"
DefineEnd -> "}\n" -- ]
(Declare _t (Ident _i) _params) -> undefined -- DefineEnd -> "}\n"
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] -- (Declare _t (Ident _i) _params) -> undefined
(Add t v1 v2) -> -- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
concat -- (Add t v1 v2) ->
[ "add ", show t, " ", show v1 -- concat
, ", ", show v2, "\n" -- [ "add ", show t, " ", show v1
] -- , ", ", show v2, "\n"
(Sub t v1 v2) -> -- ]
concat -- (Sub t v1 v2) ->
[ "sub ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "sub ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Div t v1 v2) -> -- ]
concat -- (Div t v1 v2) ->
[ "sdiv ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "sdiv ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Mul t v1 v2) -> -- ]
concat -- (Mul t v1 v2) ->
[ "mul ", show t, " ", show v1 -- concat
, ", ", show v2, "\n" -- [ "mul ", show t, " ", show v1
] -- , ", ", show v2, "\n"
(Srem t v1 v2) -> -- ]
concat -- (Srem t v1 v2) ->
[ "srem ", show t, " ", show v1, ", " -- concat
, show v2, "\n" -- [ "srem ", show t, " ", show v1, ", "
] -- , show v2, "\n"
(Call c t vis (Ident i) arg) -> -- ]
concat -- (Call c t vis (Ident i) arg) ->
[ "call ", show c, " ", show t, " ", show vis, i, "(" -- concat
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg -- [ "call ", show c, " ", show t, " ", show vis, i, "("
, ")\n" -- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
] -- , ")\n"
(Alloca t) -> unwords ["alloca", show t, "\n"] -- ]
(Store t1 val t2 (Ident id2)) -> -- (Alloca t) -> unwords ["alloca", show t, "\n"]
concat -- (Store t1 val t2 (Ident id2)) ->
[ "store ", show t1, " ", show val -- concat
, ", ", show t2 , " %", id2, "\n" -- [ "store ", show t1, " ", show val
] -- , ", ", show t2 , " %", id2, "\n"
(Load t1 t2 (Ident addr)) -> -- ]
concat -- (Load t1 t2 (Ident addr)) ->
[ "load ", show t1, ", " -- concat
, show t2, " %", addr, "\n" -- [ "load ", show t1, ", "
] -- , show t2, " %", addr, "\n"
(Bitcast t1 (Ident i) t2) -> -- ]
concat -- (Bitcast t1 (Ident i) t2) ->
[ "bitcast ", show t1, " %" -- concat
, i, " to ", show t2, "\n" -- [ "bitcast ", show t1, " %"
] -- , i, " to ", show t2, "\n"
(Icmp comp t v1 v2) -> -- ]
concat -- (Icmp comp t v1 v2) ->
[ "icmp ", show comp, " ", show t -- concat
, " ", show v1, ", ", show v2, "\n" -- [ "icmp ", show comp, " ", show t
] -- , " ", show v1, ", ", show v2, "\n"
(Ret t v) -> -- ]
concat -- (Ret t v) ->
[ "ret ", show t, " " -- concat
, show v, "\n" -- [ "ret ", show t, " "
] -- , show v, "\n"
(UnsafeRaw s) -> s -- ]
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" -- (UnsafeRaw s) -> s
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" -- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(BrCond val (Ident s1) (Ident s2)) -> -- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
concat -- (BrCond val (Ident s1) (Ident s2)) ->
[ "br i1 ", show val, ", ", "label %" -- concat
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" -- [ "br i1 ", show val, ", ", "label %"
] -- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
(Comment s) -> "; " <> s <> "\n" -- ]
(Variable (Ident id)) -> "%" <> id -- (Comment s) -> "; " <> s <> "\n"
{- FOURMOLU_ENABLE -} -- (Variable (Ident id)) -> "%" <> id
-- {- FOURMOLU_ENABLE -}
lblPfx :: String --
lblPfx = "lbl_" -- lblPfx :: String
-- lblPfx = "lbl_"
--

View file

@ -1,235 +1,192 @@
{-# LANGUAGE LambdaCase #-} --{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} --{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where module LambdaLifter.LambdaLifter where
import Auxiliary (snoc) --import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2)) --import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, --import Control.Monad.State (MonadState (get, put), State,
evalState) -- evalState)
import Data.Set (Set) --import Data.Set (Set)
import qualified Data.Set as Set --import qualified Data.Set as Set
import Debug.Trace (trace) --import Prelude hiding (exp)
import qualified Grammar.Abs as GA --import Renamer.Renamer
import Prelude hiding (exp) --import TypeChecker.TypeCheckerIr
import Renamer.Renamer
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators. ---- | Lift lambdas and let expression into supercombinators.
-- Three phases: ---- Three phases:
-- @freeVars@ annotatss all the free variables. ---- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions. ---- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. ---- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program --lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars --lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') ---- | Annotate free variables
where --freeVars :: Program -> AnnProgram
e' = freeVarsExp (Set.insert par localVars) e --freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
-- | Bind n xs e <- ds
-- ]
-- Sum free variables present in bind and the expression --freeVarsExp :: Set Id -> Exp -> AnnExp
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') --freeVarsExp localVars = \case
where -- EId n | Set.member n localVars -> (Set.singleton n, AId n)
binders_frees = Set.delete name $ freeVarsOf rhs' -- | otherwise -> (mempty, AId n)
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs -- ELit _ (LInt i) -> (mempty, AInt i)
new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e -- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
e_localVars = Set.insert name localVars -- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
(ECase t e cs) -> do -- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
let e' = freeVarsExp localVars e -- where
let vars = freeVarsOf e' -- e1' = freeVarsExp localVars e1
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do -- e2' = freeVarsExp localVars e2
let e' = freeVarsExp vars e
let vars' = freeVarsOf e' -- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
(Set.union vars vars', AnnCase c e' : acc) -- where
) (vars, []) cs -- e' = freeVarsExp (Set.insert par localVars) e
(vars', ACase t e' (reverse cs'))
-- -- Sum free variables present in bind and the expression
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
-- where
-- binders_frees = Set.delete name $ freeVarsOf rhs'
-- e_free = Set.delete name $ freeVarsOf e'
-- rhs' = freeVarsExp e_localVars rhs
-- new_bind = ABind name parms rhs'
-- e' = freeVarsExp e_localVars e
-- e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id --freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst --freeVarsOf = fst
-- AST annotated with free variables ---- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)] --type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp') --type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show --data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id --data AnnExp' = AId Id
| AInt Integer -- | AInt Integer
| ALet ABind AnnExp -- | ALet ABind AnnExp
| AApp Type AnnExp AnnExp -- | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp -- | AAdd Type AnnExp AnnExp
| ASub Type AnnExp AnnExp -- | AAbs Type Id AnnExp
| AAbs Type Id AnnExp -- deriving Show
| ACase Type AnnExp [AnnCase] ---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
deriving Show ---- Free variables are @v₁ v₂ .. vₙ@ are bound.
data AnnCase = AnnCase GA.Case AnnExp --abstract :: AnnProgram -> Program
deriving Show --abstract prog = Program $ evalState (mapM go prog) 0
-- where
-- go :: (Id, [Id], AnnExp) -> State Int Bind
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
-- where
-- (rhs', parms1) = flattenLambdasAnn rhs
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. ---- | Flatten nested lambdas and collect the parameters
-- Free variables are @v₁ v₂ .. vₙ@ are bound. ---- @\x.\y.\z. ae → (ae, [x,y,z])@
abstract :: AnnProgram -> Program --flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
abstract prog = Program $ evalState (mapM go prog) 0 --flattenLambdasAnn ae = go (ae, [])
where -- where
go :: (Id, [Id], AnnExp) -> State Int Bind -- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' -- go ((free, e), acc) =
where -- case e of
(rhs', parms1) = flattenLambdasAnn rhs -- AAbs _ par (free1, e1) ->
-- go ((Set.delete par free1, e1), snoc par acc)
-- _ -> ((free, e), acc)
--abstractExp :: AnnExp -> State Int Exp
--abstractExp (free, exp) = case exp of
-- AId n -> pure $ EId n
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
-- where
-- go (ABind name parms rhs) = do
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
-- pure $ Bind name (parms ++ parms1) rhs'
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
-- skipLambdas f (free, ae) = case ae of
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
-- _ -> f (free, ae)
-- -- Lift lambda into let and bind free variables
-- AAbs t parm e -> do
-- i <- nextNumber
-- rhs <- abstractExp e
-- let sc_name = Ident ("sc_" ++ show i)
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
-- where
-- freeList = Set.toList free
-- parms = snoc parm freeList
-- | Flatten nested lambdas and collect the parameters --nextNumber :: State Int Int
-- @\x.\y.\z. ae → (ae, [x,y,z])@ --nextNumber = do
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) -- i <- get
flattenLambdasAnn ae = go (ae, []) -- put $ succ i
where -- pure i
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp ---- | Collects supercombinators by lifting non-constant let expressions
abstractExp (free, exp) = case exp of --collectScs :: Program -> Program
AId n -> pure $ EId n --collectScs (Program scs) = Program $ concatMap collectFromRhs scs
AInt i -> pure $ EInt i -- where
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) -- collectFromRhs (Bind name parms rhs) =
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) -- let (rhs_scs, rhs') = collectScsExp rhs
ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2) -- in Bind name parms rhs' : rhs_scs
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
ACase t e cs -> do --collectScsExp :: Exp -> ([Bind], Exp)
e' <- abstractExp e --collectScsExp = \case
cs' <- mapM (\(AnnCase c e) -> do -- EId n -> ([], EId n)
e' <- abstractExp e -- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
pure (t,Case c e')) cs
pure $ ECase t e' cs'
-- Lift lambda into let and bind free variables -- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
AAbs t parm e -> do -- where
i <- nextNumber -- (scs1, e1') = collectScsExp e1
rhs <- abstractExp e -- (scs2, e2') = collectScsExp e2
let sc_name = Ident ("sc_" ++ show i) -- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) -- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
pure $ foldl (EApp TInt) sc $ map EId freeList -- EAbs t par e -> (scs, EAbs t par e')
where -- where
freeList = Set.toList free -- (scs, e') = collectScsExp e
parms = snoc parm freeList
-- -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- -- > f = let sc x y = rhs in e
-- --
-- ELet (Bind name parms rhs) e -> if null parms
-- then ( rhs_scs ++ e_scs, ELet bind e')
-- else (bind : rhs_scs ++ e_scs, e')
-- where
-- bind = Bind name parms rhs'
-- (rhs_scs, rhs') = collectScsExp rhs
-- (e_scs, e') = collectScsExp e
nextNumber :: State Int Int ---- @\x.\y.\z. e → (e, [x,y,z])@
nextNumber = do --flattenLambdas :: Exp -> (Exp, [Id])
i <- get --flattenLambdas = go . (, [])
put $ succ i -- where
pure i -- go (e, acc) = case e of
-- EAbs _ par e1 -> go (e1, snoc par acc)
-- _ -> (e, acc)
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
ECase t e cs -> do
let (scs, e') = collectScsExp e
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
let (scs', e') = collectScsExp e
(scs ++ scs', (t,Case c e') : acc)
) (scs,[]) cs
(scs', ECase t e' cs')
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)

View file

@ -2,7 +2,7 @@
module Main where module Main where
import Codegen.Codegen (generateCode) --import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn) import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
@ -11,7 +11,7 @@ import Grammar.Print (printTree)
-- import Interpreter (interpret) -- import Interpreter (interpret)
import Control.Monad (when) import Control.Monad (when)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
import LambdaLifter.LambdaLifter (lambdaLift) --import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import System.Directory (createDirectory, doesPathExist, import System.Directory (createDirectory, doesPathExist,
getDirectoryContents, getDirectoryContents,
@ -46,19 +46,19 @@ main' debug s = do
typechecked <- fromTypeCheckerErr $ typecheck renamed typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --" -- printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
printToErr $ printTree lifted -- printToErr $ printTree lifted
--
printToErr "\n -- Printing compiler output to stdout --" -- printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ generateCode lifted -- compiled <- fromCompilerErr $ generateCode lifted
--putStrLn compiled --putStrLn compiled
check <- doesPathExist "output" -- check <- doesPathExist "output"
when check (removeDirectoryRecursive "output") -- when check (removeDirectoryRecursive "output")
createDirectory "output" -- createDirectory "output"
writeFile "output/llvm.ll" compiled -- writeFile "output/llvm.ll" compiled
if debug then debugDotViz else putStrLn compiled -- if debug then debugDotViz else putStrLn compiled
-- interpred <- fromInterpreterErr $ interpret lifted -- interpred <- fromInterpreterErr $ interpret lifted

View file

@ -1,29 +1,31 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Renamer.Renamer (module Renamer.Renamer) where module Renamer.Renamer where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad (foldM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (MonadState, State, evalState, gets,
modify) modify)
import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe) import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where where
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs -- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind initNames = Map.fromList $ foldl' saveIfBind [] bs
renameSc old_names (Bind name t _ parms rhs) = do saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms (new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs' pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: State Int a} newtype Rn a = Rn {runRn :: State Int a}
@ -42,45 +44,44 @@ renameLocalBind old_names (Bind name t _ parms rhs) = do
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EInt i1 -> pure (old_names, EInt i1)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, ESub e1' e2') pure (Map.union env1 env2, ESub e1' e2')
ELet i e1 e2 -> do
ELet b e -> do (new_names, e1') <- renameExp old_names e1
(new_names, b) <- renameLocalBind old_names b (new_names', e2') <- renameExp new_names e2
(new_names', e') <- renameExp new_names e pure (new_names', ELet i e1' e2')
pure (new_names', ELet b e') EAbs par e -> do
EAbs par t e -> do
(new_names, par') <- newName old_names par (new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase e injs -> do
(_, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
ECase e cs t -> do renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
(new_names, e') <- renameExp old_names e renameInjs ns xs = do
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do (new_names, xs') <- unzip <$> mapM (renameInj ns) xs
(nm,exp') <- renameExp names exp if null new_names then return (mempty, xs') else return (head new_names, xs')
pure (nm,CaseMatch c exp' : stack)
) (new_names, []) cs renameInj :: Names -> Inj -> Rn (Names, Inj)
pure (new_names', ECase e' cs' t) renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)
@ -95,4 +96,3 @@ newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -1,215 +1,517 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker (typecheck, partitionType) where -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Auxiliary (maybeToRightM, snoc) import Control.Monad.Except
import Control.Monad.Except (throwError, unless) import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM (Err) import Grammar.Print (printTree)
import Grammar.Print (Print (prt), concatD, doc,
printTree, render)
import Prelude hiding (exp, id)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
-- NOTE: this type checker is poorly tested initCtx = Ctx mempty
-- TODO initEnv = Env 0 mempty mempty
-- Coercion
-- Type inference
data Cxt = Cxt runPretty :: Exp -> Either Error String
{ env :: Map Ident Type -- ^ Local scope signature runPretty = fmap (printTree . fst) . run . inferExp
, sig :: Map Ident Type -- ^ Top-level signatures
}
initCxt :: [Bind] -> Cxt run :: Infer a -> Either Error a
initCxt sc = Cxt { env = mempty run = runC initEnv initCtx
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
}
typecheck :: Program -> Err T.Program runC :: Env -> Ctx -> Infer a -> Either Error a
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
-- | Check if infered rhs type matches type signature. typecheck :: Program -> Either Error T.Program
checkBind :: Cxt -> Bind -> Err T.Bind typecheck = run . checkPrg
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ {- | Start by freshening the type variable of data types to avoid clash with
expandLambdas :: Bind -> Bind other user defined polymorphic types
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' This might be wrong for type constructors that work over several variables
where -}
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms freshenData :: Data -> Infer Data
ts_parms = fst $ partitionType (length parms) t freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
-- | Infer type of expression. {- | Freshen all polymorphic variables, regardless of name
infer :: Cxt -> Exp -> Err (T.Exp, Type) | freshenType "d" (a -> b -> c) becomes (d -> d -> d)
infer cxt = \case -}
EId x -> freshenType :: Ident -> Type -> Type
case lookupEnv x cxt of freshenType iden = \case
Nothing -> (TPol _) -> TPol iden
case lookupSig x cxt of (TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
Nothing -> throwError ("Unbound variable:" ++ printTree x) (TConstr (Constr a ts)) ->
Just t -> pure (T.EId (x, t), t) TConstr (Constr a (map (freshenType iden) ts))
Just t -> pure (T.EId (x, t), t) rest -> rest
EInt i -> pure (T.EInt i, T.TInt) freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
EApp e e1 -> do checkData :: Data -> Infer ()
(e', t) <- infer cxt e checkData d = do
case t of d' <- freshenData d
TFun t1 t2 -> do case d' of
e1' <- check cxt e1 t1 (Data typ@(Constr name ts) constrs) -> do
pure (T.EApp t2 e' e1', t2) unless
_ -> do (all isPoly ts)
throwError ("Not a function: " ++ show e) (throwError $ unwords ["Data type incorrectly declared"])
traverse_
EAdd e e1 -> do ( \(Constructor name' t') ->
e' <- check cxt e T.TInt if TConstr typ == retType t'
e1' <- check cxt e1 T.TInt then insertConstr name' t'
pure (T.EAdd T.TInt e' e1', T.TInt) else
throwError $
ESub e e1 -> do unwords
e' <- check cxt e T.TInt [ "return type of constructor:"
e1' <- check cxt e1 T.TInt , printTree name
pure (T.ESub T.TInt e' e1', T.TInt) , "with type:"
, printTree (retType t')
EAbs x t e -> do , "does not match data: "
(e', t1) <- infer (insertEnv x t cxt) e , printTree typ
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e
pure (T.ELet b' e', t)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure (T.ECase t1 e' cs,t1)
Left e -> throwError e
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
EId x -> do
t <- case lookupEnv x cxt of
Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
ESub e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.ESub T.TInt e' e1'
EAbs x t e -> do
(e', t_e) <- infer (insertEnv x t cxt) e
let t1 = TFun t t_e
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure $ T.ECase t1 e' cs
Left e -> throwError e
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env
insertEnv :: Ident -> Type -> Cxt -> Cxt
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
lookupSig :: Ident -> Cxt -> Maybe Type
lookupSig x = Map.lookup x . sig
typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual
] ]
)
constrs
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
T.Program <$> checkDef bs
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless
(t `typeEq` t'')
( throwError $
unwords
[ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e'
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
{- | Check if two types are considered equal
For the purpose of the algorithm two polymorphic types are always considered
equal
-}
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TPol _) = True
isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do
(s1, t', e') <- algoW e
unless
(t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t'
]
)
applySt s1 $ do
s2 <- unify t t'
return (s2 `compose` s1, t, e')
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit (LInt n) ->
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EId i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing ->
throwError $
"Unbound variable: " ++ show i
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- \| ------------------------------------------
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else
throwError $
unwords
[ "Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case
InitLit lit ->
let returnType = litType lit
in if expected == returnType
then return (mempty, expected)
else
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
Nothing ->
throwError $
unwords
[ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do
let flat = flattenType t
let returnType = last flat
case ( length (init flat) == length args
, returnType `isMoreSpecificOrEq` expected
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt _) = TMono "Int"

View file

@ -1,74 +1,105 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr module TypeChecker.TypeCheckerIr where
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
import Grammar.Abs (Ident (..), Type (..)) import Control.Monad.Except
import qualified Grammar.Abs as GA import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind] -- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env
{ count :: Int
, sigs :: Map Ident Type
, constructors :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp data Exp
= EId Id = EId Id
| EInt Integer | ELit Type Literal
| ELet Bind Exp | ELet Bind Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| ESub Type Exp Exp | ESub Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| ECase Type Exp [(Type, Case)] | ECase Type Exp [Inj]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Case = Case GA.Case Exp data Inj = Inj (Init, Type) Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])] data Bind = Bind Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD prt i (Bind (t, name) rhs) =
[ prtId 0 name prPrec i 0 $
, doc $ showString ";" concatD
, prt 0 n [ prt 0 name
, prtIdPs 0 parms , doc $ showString ":"
, prt 0 t
, doc $ showString "\n"
, prt 0 name
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD
[ prt 0 n
, doc $ showString "{"
, doc . showString . show $ xs
, doc $ showString "}"
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD prtId i (name, t) =
prPrec i 0 $
concatD
[ prt 0 name [ prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
] ]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD prtIdP i (name, t) =
prPrec i 0 $
concatD
[ doc $ showString "(" [ doc $ showString "("
, prt 0 name , prt 0 name
, doc $ showString ":" , doc $ showString ":"
@ -76,64 +107,78 @@ prtIdP i (name, t) = prPrec i 0 $ concatD
, doc $ showString ")" , doc $ showString ")"
] ]
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n] EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
ELet bs e -> prPrec i 3 $ concatD ELet bs e ->
prPrec i 3 $
concatD
[ doc $ showString "let" [ doc $ showString "let"
, prt 0 bs , prt 0 bs
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
EApp t e1 e2 -> prPrec i 2 $ concatD EApp _ e1 e2 ->
[ doc $ showString "@" prPrec i 2 $
, prt 0 t concatD
, prt 2 e1 [ prt 2 e1
, prt 3 e2 , prt 3 e2
] ]
EAdd t e1 e2 -> prPrec i 1 $ concatD EAdd t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, prt 1 e1 , prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
, doc $ showString "\n"
] ]
ESub t e1 e2 -> prPrec i 1 $ concatD ESub t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, prt 1 e1 , prt 1 e1
, doc $ showString "-" , doc $ showString "-"
, prt 2 e2 , prt 2 e2
, doc $ showString "\n"
] ]
EAbs t n e -> prPrec i 0 $ concatD EAbs t n e ->
prPrec i 0 $
concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, doc $ showString "\\" , doc $ showString "\\"
, prtIdP 0 n , prtId 0 n
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
]
ECase t e cs -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, prt 0 e
, doc $ showString ")"
, prPrec i 0 $ concatD . printCases $ cs
]
where
printCases :: [(Type, Case)] -> [Doc]
printCases [] = []
printCases ((t, Case c e):xs) = concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, doc . showString . show $ c
, doc $ showString ")"
, doc $ showString "=>"
, prt 0 e
, doc $ showString "\n" , doc $ showString "\n"
] : printCases xs ]
ECase t exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
instance Print Inj where
prt i = \case
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]