Started importing Sebastian's new typechecker.

This commit is contained in:
Samuel Hammersberg 2023-03-08 11:01:07 +01:00
parent d5dd7896d8
commit 350cd3b0e9
9 changed files with 1611 additions and 1346 deletions

View file

@ -1,33 +1,51 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
EInt. Exp3 ::= Integer;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
ESub. Exp1 ::= Exp1 "-" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type;
CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ",";
Program. Program ::= [Def] ;
CInt. Case ::= Integer ;
CatchAll. Case ::= "_" ;
DBind. Def ::= Bind ;
DData. Def ::= Data ;
separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp;
Ident [Ident] "=" Exp ;
separator Bind ";";
separator Ident "";
Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
coercions Exp 3;
Constructor. Constructor ::= Ident ":" Type ;
separator nonempty Constructor "" ;
TInt. Type1 ::= "Int" ;
TPol. Type1 ::= Ident ;
TFun. Type ::= Type1 "->" Type ;
coercions Type 1 ;
TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ;
TConstr. Type1 ::= Constr ;
TArr. Type ::= Type1 "->" Type ;
comment "--";
comment "{-" "-}";
Constr. Constr ::= Ident "(" [Type] ")" ;
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ;
ELit. Exp4 ::= Literal ;
EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ESub. Exp1 ::= Exp1 "-" Exp2 ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
LInt. Literal ::= Integer ;
Inj. Inj ::= Init "=>" Exp ;
separator nonempty Inj ";" ;
InitLit. Init ::= Literal ;
InitConstr. Init ::= Ident [Ident] ;
InitCatch. Init ::= "_" ;
separator Type " " ;
coercions Type 2 ;
separator Ident " ";
coercions Exp 5 ;
comment "--" ;
comment "{-" "-}" ;

View file

@ -1,87 +1,26 @@
-- tripplemagic : Int -> Int -> Int -> Int;
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
-- main : Int;
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
-- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
-- apply f x y = f x y;
-- krimp: Int -> Int -> Int;
-- krimp x y = x + y;
-- main : Int;
-- main = apply (krimp) 2 3;
-- answer: 5
-- fibbonaci : Int -> Int;
-- fibbonaci x = case x of {
-- 0 => 0,
-- 1 => 1,
-- -- abusing overflows to represent negatives like a boss
-- _ => (fibbonaci (x - 2))
-- + (fibbonaci (x - 1))
-- } : Int;
-- main : Int;
-- main = fibbonaci 10;
-- answer: 55
-- succ : Int -> Int;
-- succ x = x - 1;
--
-- isZero : Int -> Int;
-- isZero x = case x of {
-- 0 => 1,
-- _ => 0
-- } : Int;
--
-- minimization : (Int -> Int) -> Int -> Int;
-- minimization p x = case p x of {
-- 1 => 0,
-- _ => minimization p (succ x)
-- } : Int;
--
-- main : Int;
-- main = minimization isZero 10;
-- answer: 0
posMul : Int -> Int -> Int;
posMul : _Int -> _Int -> _Int;
posMul a b = case b of {
0 => 0,
0 => 0;
_ => a + posMul a (b - 1)
} : Int;
};
facc : Int -> Int;
facc : _Int -> _Int;
facc a = case a of {
1 => 1,
1 => 1;
_ => posMul a (facc (a - 1))
} : Int;
-- main : Int;
-- main = facc 5
-- answer: 120
};
-- pow : Int -> Int -> Int;
-- pow a b = case b of {
-- 0 => 1,
-- _ => posMul a (pow a (b-1))
-- } : Int;
minimization : (Int -> Int) -> Int -> Int;
minimization : (_Int -> _Int) -> _Int -> _Int;
minimization p x = case p x of {
1 => x,
1 => x;
_ => minimization p (x + 1)
} : Int;
};
checkFac : Int -> Int;
checkFac : _Int -> _Int;
checkFac x = case facc x of {
0 => 1,
0 => 1;
_ => 0
} : Int;
};
main : Int;
main : _Int;
main = minimization checkFac 1

View file

@ -1,441 +1,443 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Auxiliary (snoc)
import Codegen.LlvmIr (CallingConvention (..),
LLVMComp (..), LLVMIr (..),
LLVMType (..), LLVMValue (..),
Visibility (..), llvmIrToString)
import Control.Monad.State (StateT, execStateT, foldM_, gets,
modify)
import qualified Data.Bifunctor as BI
import Data.List.Extra (trim)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second)
import qualified Grammar.Abs as GA
import Grammar.ErrM (Err)
import System.Process.Extra (readCreateProcess, shell)
import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
Ident (..), Program (..), Type (..))
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, constructors :: Map Id ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
}
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (Bind id args _ : xs) =
(id, FunctionInfo { numArgs=length args, arguments=args })
: go xs
go (DataStructure n cons : xs) = do
map (\(id, xs) -> ((id, TPol n), FunctionInfo {
numArgs=length xs, arguments=createArgs xs
})) cons
<> go xs
createArgs :: [Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getConstructors :: [Bind] -> Map Id ConstructorInfo
getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (DataStructure (Ident n) cons : xs) = do
fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
numArgsCI=length xs,
argumentsCI=createArgs xs,
numCI=i
}) : acc, i+1)) ([],0) cons)
<> go xs
go (_: xs) = go xs
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, constructors = getConstructors scs
, variableCount = 0
, labelCount = 0
}
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "output/llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
DataStructure (Ident "Craig") [
(Ident "Bob", [TInt])--,
--(Ident "Alice", [TInt, TInt])
],
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
Bind (Ident "main", TInt) [] (
EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
)
]
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
generateCode :: Program -> Err String
generateCode (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = do
-- as a last step create all the constructors
c <- gets (Map.toList . constructors)
mapM_ (\((id, t), ci) -> do
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x
top <- Ident . show <$> getNewVar
ptr <- Ident . show <$> getNewVar
-- allocated the primary type
emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor
emit $ SetVariable ptr $
GetElementPtrInbounds t' (Ref t')
(VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
-- get a pointer of the correct type
ptr' <- Ident . show <$> getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
--emit $ UnsafeRaw "\n"
foldM_ (\i (Ident arg_n, arg_t)-> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
elemPtr <- Ident . show <$> getNewVar
emit $ SetVariable elemPtr (
GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
(VIdent ptr' Ptr) I32
(VInteger 0) I32 (VInteger i))
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
-- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
-- store i32 42, i32* %2
pure $ i + 1-- + typeByteSize arg_t'
) 1 (argumentsCI ci)
--emit $ UnsafeRaw "\n"
-- load and return the constructed value
load <- Ident . show <$> getNewVar
emit $ SetVariable load (Load t' Ptr top)
emit $ Ret t' (VIdent load t')
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
) c
compileScs (Bind (name, _t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
emit $ Type id [I8, Array biggest_variant I8]
mapM_ (\(Ident inner_id, fi) -> do
emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
) ts
compileScs xs
-- where
-- _t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
compileExp :: Exp -> CompilerState ()
compileExp (EInt int) = emitInt int
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (EId (name, _)) = emitIdent name
compileExp (EApp t e1 e2) = emitApp t e1 e2
compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ELet binds e) = emitLet binds e
compileExp (ECase t e cs) = emitECased t e cs
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions ---
emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
vs <- exprToValue e
lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
mapM_ (emitCases ty label stackPtr vs) cs
emit $ Label label
res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
where
emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
ns <- getNewVar
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label
emit $ Label lbl_failPos
emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
val <- exprToValue exp
emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do
emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
concat
[ "ELet ("
, show xs
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 []
where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: Type -> Exp -> Exp -> CompilerState ()
emitSub t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
module Codegen.Codegen where
-- {-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE OverloadedStrings #-}
--
-- module Codegen.Codegen (generateCode) where
--
-- import Auxiliary (snoc)
-- import Codegen.LlvmIr (CallingConvention (..),
-- LLVMComp (..), LLVMIr (..),
-- LLVMType (..), LLVMValue (..),
-- Visibility (..), llvmIrToString)
-- import Control.Monad.State (StateT, execStateT, foldM_, gets,
-- modify)
-- import qualified Data.Bifunctor as BI
-- import Data.List.Extra (trim)
-- import Data.Map (Map)
-- import qualified Data.Map as Map
-- import Data.Tuple.Extra (dupe, first, second)
-- import qualified Grammar.Abs as GA
-- import Grammar.ErrM (Err)
-- import System.Process.Extra (readCreateProcess, shell)
-- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
-- Ident (..), Program (..), Type (..))
-- -- | The record used as the code generator state
-- data CodeGenerator = CodeGenerator
-- { instructions :: [LLVMIr]
-- , functions :: Map Id FunctionInfo
-- , constructors :: Map Id ConstructorInfo
-- , variableCount :: Integer
-- , labelCount :: Integer
-- }
--
-- -- | A state type synonym
-- type CompilerState a = StateT CodeGenerator Err a
--
-- data FunctionInfo = FunctionInfo
-- { numArgs :: Int
-- , arguments :: [Id]
-- }
-- data ConstructorInfo = ConstructorInfo
-- { numArgsCI :: Int
-- , argumentsCI :: [Id]
-- , numCI :: Integer
-- }
--
--
-- -- | Adds a instruction to the CodeGenerator state
-- emit :: LLVMIr -> CompilerState ()
-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
--
-- -- | Increases the variable counter in the CodeGenerator state
-- increaseVarCount :: CompilerState ()
-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
--
-- -- | Returns the variable count from the CodeGenerator state
-- getVarCount :: CompilerState Integer
-- getVarCount = gets variableCount
--
-- -- | Increases the variable count and returns it from the CodeGenerator state
-- getNewVar :: CompilerState Integer
-- getNewVar = increaseVarCount >> getVarCount
--
-- -- | Increses the label count and returns a label from the CodeGenerator state
-- getNewLabel :: CompilerState Integer
-- getNewLabel = do
-- modify (\t -> t{labelCount = labelCount t + 1})
-- gets labelCount
--
-- -- | Produces a map of functions infos from a list of binds,
-- -- which contains useful data for code generation.
-- getFunctions :: [Bind] -> Map Id FunctionInfo
-- getFunctions bs = Map.fromList $ go bs
-- where
-- go [] = []
-- go (Bind id args _ : xs) =
-- (id, FunctionInfo { numArgs=length args, arguments=args })
-- : go xs
-- go (DataStructure n cons : xs) = do
-- map (\(id, xs) -> ((id, TPol n), FunctionInfo {
-- numArgs=length xs, arguments=createArgs xs
-- })) cons
-- <> go xs
--
-- createArgs :: [Type] -> [Id]
-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
--
-- -- | Produces a map of functions infos from a list of binds,
-- -- which contains useful data for code generation.
-- getConstructors :: [Bind] -> Map Id ConstructorInfo
-- getConstructors bs = Map.fromList $ go bs
-- where
-- go [] = []
-- go (DataStructure (Ident n) cons : xs) = do
-- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
-- numArgsCI=length xs,
-- argumentsCI=createArgs xs,
-- numCI=i
-- }) : acc, i+1)) ([],0) cons)
-- <> go xs
-- go (_: xs) = go xs
--
-- initCodeGenerator :: [Bind] -> CodeGenerator
-- initCodeGenerator scs = CodeGenerator { instructions = defaultStart
-- , functions = getFunctions scs
-- , constructors = getConstructors scs
-- , variableCount = 0
-- , labelCount = 0
-- }
--
-- run :: Err String -> IO ()
-- run s = do
-- let s' = case s of
-- Right s -> s
-- Left _ -> error "yo"
-- writeFile "output/llvm.ll" s'
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
--
-- test :: Integer -> Program
-- test v = Program [
-- DataStructure (Ident "Craig") [
-- (Ident "Bob", [TInt])--,
-- --(Ident "Alice", [TInt, TInt])
-- ],
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
-- Bind (Ident "main", TInt) [] (
-- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
-- )
-- ]
--
-- {- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to
-- Simply pipe it to LLI
-- -}
-- generateCode :: Program -> Err String
-- generateCode (Program scs) = do
-- let codegen = initCodeGenerator scs
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
--
-- compileScs :: [Bind] -> CompilerState ()
-- compileScs [] = do
-- -- as a last step create all the constructors
-- c <- gets (Map.toList . constructors)
-- mapM_ (\((id, t), ci) -> do
-- let t' = type2LlvmType t
-- let x = BI.second type2LlvmType <$> argumentsCI ci
-- emit $ Define FastCC t' id x
-- top <- Ident . show <$> getNewVar
-- ptr <- Ident . show <$> getNewVar
-- -- allocated the primary type
-- emit $ SetVariable top (Alloca t')
--
-- -- set the first byte to the index of the constructor
-- emit $ SetVariable ptr $
-- GetElementPtrInbounds t' (Ref t')
-- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
-- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
--
-- -- get a pointer of the correct type
-- ptr' <- Ident . show <$> getNewVar
-- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
--
-- --emit $ UnsafeRaw "\n"
--
-- foldM_ (\i (Ident arg_n, arg_t)-> do
-- let arg_t' = type2LlvmType arg_t
-- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
-- elemPtr <- Ident . show <$> getNewVar
-- emit $ SetVariable elemPtr (
-- GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
-- (VIdent ptr' Ptr) I32
-- (VInteger 0) I32 (VInteger i))
-- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
-- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
-- -- store i32 42, i32* %2
-- pure $ i + 1-- + typeByteSize arg_t'
-- ) 1 (argumentsCI ci)
--
-- --emit $ UnsafeRaw "\n"
--
-- -- load and return the constructed value
-- load <- Ident . show <$> getNewVar
-- emit $ SetVariable load (Load t' Ptr top)
-- emit $ Ret t' (VIdent load t')
-- emit DefineEnd
--
-- modify $ \s -> s { variableCount = 0 }
-- ) c
-- compileScs (Bind (name, _t) args exp : xs) = do
-- emit $ UnsafeRaw "\n"
-- emit . Comment $ show name <> ": " <> show exp
-- let args' = map (second type2LlvmType) args
-- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
-- functionBody <- exprToValue exp
-- if name == "main"
-- then mapM_ emit $ mainContent functionBody
-- else emit $ Ret I64 functionBody
-- emit DefineEnd
-- modify $ \s -> s { variableCount = 0 }
-- compileScs xs
-- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
-- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
-- emit $ Type id [I8, Array biggest_variant I8]
-- mapM_ (\(Ident inner_id, fi) -> do
-- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
-- ) ts
-- compileScs xs
--
-- -- where
-- -- _t_return = snd $ partitionType (length args) t
--
-- mainContent :: LLVMValue -> [LLVMIr]
-- mainContent var =
-- [ UnsafeRaw $
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- -- , Label (Ident "b_1")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "b_2")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "end")
-- Ret I64 (VInteger 0)
-- ]
--
-- defaultStart :: [LLVMIr]
-- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
-- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
-- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
-- ]
--
-- compileExp :: Exp -> CompilerState ()
-- compileExp (EInt int) = emitInt int
-- compileExp (EAdd t e1 e2) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
-- compileExp (EId (name, _)) = emitIdent name
-- compileExp (EApp t e1 e2) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e
-- compileExp (ELet binds e) = emitLet binds e
-- compileExp (ECase t e cs) = emitECased t e cs
-- -- go (EMul e1 e2) = emitMul e1 e2
-- -- go (EDiv e1 e2) = emitDiv e1 e2
-- -- go (EMod e1 e2) = emitMod e1 e2
--
-- --- aux functions ---
-- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
-- emitECased t e cases = do
-- let cs = snd <$> cases
-- let ty = type2LlvmType t
-- vs <- exprToValue e
-- lbl <- getNewLabel
-- let label = Ident $ "escape_" <> show lbl
-- stackPtr <- getNewVar
-- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
-- mapM_ (emitCases ty label stackPtr vs) cs
-- emit $ Label label
-- res <- getNewVar
-- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
-- where
-- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
-- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
-- ns <- getNewVar
-- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
-- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
-- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
-- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
-- emit $ Label lbl_succPos
-- val <- exprToValue exp
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
-- emit $ Br label
-- emit $ Label lbl_failPos
-- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
-- val <- exprToValue exp
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
-- emit $ Br label
--
--
-- emitAbs :: Type -> Id -> Exp -> CompilerState ()
-- emitAbs _t tid e = do
-- emit . Comment $
-- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
-- emitLet :: Bind -> Exp -> CompilerState ()
-- emitLet xs e = do
-- emit $
-- Comment $
-- concat
-- [ "ELet ("
-- , show xs
-- , " = "
-- , show e
-- , ") is not implemented!"
-- ]
--
-- emitApp :: Type -> Exp -> Exp -> CompilerState ()
-- emitApp t e1 e2 = appEmitter t e1 e2 []
-- where
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
-- appEmitter t e1 e2 stack = do
-- let newStack = e2 : stack
-- case e1 of
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
-- EId id@(name, _) -> do
-- args <- traverse exprToValue newStack
-- vs <- getNewVar
-- funcs <- gets functions
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
-- args' = map (first valueGetType . dupe) args
-- call = Call FastCC (type2LlvmType t) visibility name args'
-- emit $ SetVariable (Ident $ show vs) call
-- x -> do
-- emit . Comment $ "The unspeakable happened: "
-- emit . Comment $ show x
--
-- emitIdent :: Ident -> CompilerState ()
-- emitIdent id = do
-- -- !!this should never happen!!
-- emit $ Comment "This should not have happened!"
-- emit $ Variable id
-- emit $ UnsafeRaw "\n"
--
-- emitInt :: Integer -> CompilerState ()
-- emitInt i = do
-- -- !!this should never happen!!
-- varCount <- getNewVar
-- emit $ Comment "This should not have happened!"
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
--
-- emitAdd :: Type -> Exp -> Exp -> CompilerState ()
-- emitAdd t e1 e2 = do
-- v1 <- exprToValue e1
-- v2 <- exprToValue e2
-- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
--
-- emitSub :: Type -> Exp -> Exp -> CompilerState ()
-- emitSub t e1 e2 = do
-- v1 <- exprToValue e1
-- v2 <- exprToValue e2
-- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
--
-- -- emitMul :: Exp -> Exp -> CompilerState ()
-- -- emitMul e1 e2 = do
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Mul I64 v1 v2
--
-- -- emitMod :: Exp -> Exp -> CompilerState ()
-- -- emitMod e1 e2 = do
-- -- -- `let m a b = rem (abs $ b + a) b`
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- vadd <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show vadd
-- -- emit $ Add I64 v1 v2
-- --
-- -- increaseVarCount
-- -- vabs <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show vabs
-- -- emit $ Call I64 (Ident "llvm.abs.i64")
-- -- [ (I64, VIdent (Ident $ show vadd))
-- -- , (I1, VInteger 1)
-- -- ]
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
--
-- -- emitDiv :: Exp -> Exp -> CompilerState ()
-- -- emitDiv e1 e2 = do
-- -- (v1,v2) <- binExprToValues e1 e2
-- -- increaseVarCount
-- -- v <- gets variableCount
-- -- emit $ SetVariable $ Ident $ show v
-- -- emit $ Div I64 v1 v2
--
-- exprToValue :: Exp -> CompilerState LLVMValue
-- exprToValue = \case
-- EInt i -> pure $ VInteger i
--
-- EId id@(name, t) -> do
-- funcs <- gets functions
-- case Map.lookup id funcs of
-- Just fi -> do
-- if numArgs fi == 0
-- then do
-- vc <- getNewVar
-- emit $ SetVariable (Ident $ show vc)
-- (Call FastCC (type2LlvmType t) Global name [])
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
-- else pure $ VFunction name Global (type2LlvmType t)
-- Nothing -> pure $ VIdent name (type2LlvmType t)
--
-- e -> do
-- compileExp e
-- v <- getVarCount
-- pure $ VIdent (Ident $ show v) (getType e)
--
-- type2LlvmType :: Type -> LLVMType
-- type2LlvmType = \case
-- TInt -> I64
-- TFun t xs -> do
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
-- Function t' xs'
-- TPol t -> CustomType t
-- where
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
--
-- getType :: Exp -> LLVMType
-- getType (EInt _) = I64
-- getType (EAdd t _ _) = type2LlvmType t
-- getType (ESub t _ _) = type2LlvmType t
-- getType (EId (_, t)) = type2LlvmType t
-- getType (EApp t _ _) = type2LlvmType t
-- getType (EAbs t _ _) = type2LlvmType t
-- getType (ELet _ e) = getType e
-- getType (ECase t _ _) = type2LlvmType t
--
-- valueGetType :: LLVMValue -> LLVMType
-- valueGetType (VInteger _) = I64
-- valueGetType (VIdent _ t) = t
-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
-- valueGetType (VFunction _ _ t) = t
--
-- typeByteSize :: LLVMType -> Integer
-- typeByteSize I1 = 1
-- typeByteSize I8 = 1
-- typeByteSize I32 = 4
-- typeByteSize I64 = 8
-- typeByteSize Ptr = 8
-- typeByteSize (Ref _) = 8
-- typeByteSize (Function _ _) = 8
-- typeByteSize (Array n t) = n * typeByteSize t
-- typeByteSize (CustomType _) = 8
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc)
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
TInt -> I64
TFun t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
TPol t -> CustomType t
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (ESub t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
getType (ECase t _ _) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8

View file

@ -1,239 +1,241 @@
{-# LANGUAGE LambdaCase #-}
module Codegen.LlvmIr (
LLVMType (..),
LLVMIr (..),
llvmIrToString,
LLVMValue (..),
LLVMComp (..),
Visibility (..),
CallingConvention (..)
) where
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr
data CallingConvention = TailCC | FastCC | CCC | ColdCC
instance Show CallingConvention where
show :: CallingConvention -> String
show TailCC = "tailcc"
show FastCC = "fastcc"
show CCC = "ccc"
show ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types
data LLVMType
= I1
| I8
| I32
| I64
| Ptr
| Ref LLVMType
| Function LLVMType [LLVMType]
| Array Integer LLVMType
| CustomType Ident
instance Show LLVMType where
show :: LLVMType -> String
show = \case
I1 -> "i1"
I8 -> "i8"
I32 -> "i32"
I64 -> "i64"
Ptr -> "ptr"
Ref ty -> show ty <> "*"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
CustomType (Ident ty) -> "%" <> ty
data LLVMComp
= LLEq
| LLNe
| LLUgt
| LLUge
| LLUlt
| LLUle
| LLSgt
| LLSge
| LLSlt
| LLSle
instance Show LLVMComp where
show :: LLVMComp -> String
show = \case
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
LLUle -> "ule"
LLSgt -> "sgt"
LLSge -> "sge"
LLSlt -> "slt"
LLSle -> "sle"
data Visibility = Local | Global
instance Show Visibility where
show :: Visibility -> String
show Local = "%"
show Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant
data LLVMValue
= VInteger Integer
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
instance Show LLVMValue where
show :: LLVMValue -> String
show v = case v of
VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM
data LLVMIr
= Type Ident [LLVMType]
| Define CallingConvention LLVMType Ident Params
| DefineEnd
| Declare LLVMType Ident Params
| SetVariable Ident LLVMIr
| Variable Ident
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Add LLVMType LLVMValue LLVMValue
| Sub LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
| Br Ident
| BrCond LLVMValue Ident Ident
| Label Ident
| Call CallingConvention LLVMType Visibility Ident Args
| Alloca LLVMType
| Store LLVMType LLVMValue LLVMType Ident
| Load LLVMType LLVMType Ident
| Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue
| Comment String
| UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place
deriving (Show)
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
where
go :: Int -> [LLVMIr] -> String
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation
-}
{- FOURMOLU_DISABLE -}
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of
(GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
-- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
concat
[ "getelementptr inbounds ", show t1, ", " , show t2
, " ", show p, ", ", show t3, " ", show v1,
", ", show t4, " ", show v2, "\n" ]
(Type (Ident n) types) ->
concat
[ "%", n, " = type { "
, intercalate ", " (map show types)
, " }\n"
]
(Define c t (Ident i) params) ->
concat
[ "define ", show c, " ", show t, " @", i
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
, ") {\n"
]
DefineEnd -> "}\n"
(Declare _t (Ident _i) _params) -> undefined
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
(Add t v1 v2) ->
concat
[ "add ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Sub t v1 v2) ->
concat
[ "sub ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Div t v1 v2) ->
concat
[ "sdiv ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Mul t v1 v2) ->
concat
[ "mul ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Srem t v1 v2) ->
concat
[ "srem ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Call c t vis (Ident i) arg) ->
concat
[ "call ", show c, " ", show t, " ", show vis, i, "("
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
, ")\n"
]
(Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 val t2 (Ident id2)) ->
concat
[ "store ", show t1, " ", show val
, ", ", show t2 , " %", id2, "\n"
]
(Load t1 t2 (Ident addr)) ->
concat
[ "load ", show t1, ", "
, show t2, " %", addr, "\n"
]
(Bitcast t1 (Ident i) t2) ->
concat
[ "bitcast ", show t1, " %"
, i, " to ", show t2, "\n"
]
(Icmp comp t v1 v2) ->
concat
[ "icmp ", show comp, " ", show t
, " ", show v1, ", ", show v2, "\n"
]
(Ret t v) ->
concat
[ "ret ", show t, " "
, show v, "\n"
]
(UnsafeRaw s) -> s
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) ->
concat
[ "br i1 ", show val, ", ", "label %"
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
]
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -}
lblPfx :: String
lblPfx = "lbl_"
module Codegen.LlvmIr where
-- {-# LANGUAGE LambdaCase #-}
--
-- module Codegen.LlvmIr (
-- LLVMType (..),
-- LLVMIr (..),
-- llvmIrToString,
-- LLVMValue (..),
-- LLVMComp (..),
-- Visibility (..),
-- CallingConvention (..)
-- ) where
--
-- import Data.List (intercalate)
-- import TypeChecker.TypeCheckerIr
--
-- data CallingConvention = TailCC | FastCC | CCC | ColdCC
-- instance Show CallingConvention where
-- show :: CallingConvention -> String
-- show TailCC = "tailcc"
-- show FastCC = "fastcc"
-- show CCC = "ccc"
-- show ColdCC = "coldcc"
--
-- -- | A datatype which represents some basic LLVM types
-- data LLVMType
-- = I1
-- | I8
-- | I32
-- | I64
-- | Ptr
-- | Ref LLVMType
-- | Function LLVMType [LLVMType]
-- | Array Integer LLVMType
-- | CustomType Ident
--
-- instance Show LLVMType where
-- show :: LLVMType -> String
-- show = \case
-- I1 -> "i1"
-- I8 -> "i8"
-- I32 -> "i32"
-- I64 -> "i64"
-- Ptr -> "ptr"
-- Ref ty -> show ty <> "*"
-- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
-- Array n ty -> concat ["[", show n, " x ", show ty, "]"]
-- CustomType (Ident ty) -> "%" <> ty
--
-- data LLVMComp
-- = LLEq
-- | LLNe
-- | LLUgt
-- | LLUge
-- | LLUlt
-- | LLUle
-- | LLSgt
-- | LLSge
-- | LLSlt
-- | LLSle
-- instance Show LLVMComp where
-- show :: LLVMComp -> String
-- show = \case
-- LLEq -> "eq"
-- LLNe -> "ne"
-- LLUgt -> "ugt"
-- LLUge -> "uge"
-- LLUlt -> "ult"
-- LLUle -> "ule"
-- LLSgt -> "sgt"
-- LLSge -> "sge"
-- LLSlt -> "slt"
-- LLSle -> "sle"
--
-- data Visibility = Local | Global
-- instance Show Visibility where
-- show :: Visibility -> String
-- show Local = "%"
-- show Global = "@"
--
-- -- | Represents a LLVM "value", as in an integer, a register variable,
-- -- or a string contstant
-- data LLVMValue
-- = VInteger Integer
-- | VIdent Ident LLVMType
-- | VConstant String
-- | VFunction Ident Visibility LLVMType
--
-- instance Show LLVMValue where
-- show :: LLVMValue -> String
-- show v = case v of
-- VInteger i -> show i
-- VIdent (Ident n) _ -> "%" <> n
-- VFunction (Ident n) vis _ -> show vis <> n
-- VConstant s -> "c" <> show s
--
-- type Params = [(Ident, LLVMType)]
-- type Args = [(LLVMType, LLVMValue)]
--
-- -- | A datatype which represents different instructions in LLVM
-- data LLVMIr
-- = Type Ident [LLVMType]
-- | Define CallingConvention LLVMType Ident Params
-- | DefineEnd
-- | Declare LLVMType Ident Params
-- | SetVariable Ident LLVMIr
-- | Variable Ident
-- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
-- | Add LLVMType LLVMValue LLVMValue
-- | Sub LLVMType LLVMValue LLVMValue
-- | Div LLVMType LLVMValue LLVMValue
-- | Mul LLVMType LLVMValue LLVMValue
-- | Srem LLVMType LLVMValue LLVMValue
-- | Icmp LLVMComp LLVMType LLVMValue LLVMValue
-- | Br Ident
-- | BrCond LLVMValue Ident Ident
-- | Label Ident
-- | Call CallingConvention LLVMType Visibility Ident Args
-- | Alloca LLVMType
-- | Store LLVMType LLVMValue LLVMType Ident
-- | Load LLVMType LLVMType Ident
-- | Bitcast LLVMType Ident LLVMType
-- | Ret LLVMType LLVMValue
-- | Comment String
-- | UnsafeRaw String -- This should generally be avoided, and proper
-- -- instructions should be used in its place
-- deriving (Show)
--
-- -- | Converts a list of LLVMIr instructions to a string
-- llvmIrToString :: [LLVMIr] -> String
-- llvmIrToString = go 0
-- where
-- go :: Int -> [LLVMIr] -> String
-- go _ [] = mempty
-- go i (x : xs) = do
-- let (i', n) = case x of
-- Define{} -> (i + 1, 0)
-- DefineEnd -> (i - 1, 0)
-- _ -> (i, i)
-- insToString n x <> go i' xs
--
-- {- | Converts a LLVM inststruction to a String, allowing for printing etc.
-- The integer represents the indentation
-- -}
-- {- FOURMOLU_DISABLE -}
-- insToString :: Int -> LLVMIr -> String
-- insToString i l =
-- replicate i '\t' <> case l of
-- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
-- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
-- concat
-- [ "getelementptr inbounds ", show t1, ", " , show t2
-- , " ", show p, ", ", show t3, " ", show v1,
-- ", ", show t4, " ", show v2, "\n" ]
-- (Type (Ident n) types) ->
-- concat
-- [ "%", n, " = type { "
-- , intercalate ", " (map show types)
-- , " }\n"
-- ]
-- (Define c t (Ident i) params) ->
-- concat
-- [ "define ", show c, " ", show t, " @", i
-- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
-- , ") {\n"
-- ]
-- DefineEnd -> "}\n"
-- (Declare _t (Ident _i) _params) -> undefined
-- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
-- (Add t v1 v2) ->
-- concat
-- [ "add ", show t, " ", show v1
-- , ", ", show v2, "\n"
-- ]
-- (Sub t v1 v2) ->
-- concat
-- [ "sub ", show t, " ", show v1, ", "
-- , show v2, "\n"
-- ]
-- (Div t v1 v2) ->
-- concat
-- [ "sdiv ", show t, " ", show v1, ", "
-- , show v2, "\n"
-- ]
-- (Mul t v1 v2) ->
-- concat
-- [ "mul ", show t, " ", show v1
-- , ", ", show v2, "\n"
-- ]
-- (Srem t v1 v2) ->
-- concat
-- [ "srem ", show t, " ", show v1, ", "
-- , show v2, "\n"
-- ]
-- (Call c t vis (Ident i) arg) ->
-- concat
-- [ "call ", show c, " ", show t, " ", show vis, i, "("
-- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
-- , ")\n"
-- ]
-- (Alloca t) -> unwords ["alloca", show t, "\n"]
-- (Store t1 val t2 (Ident id2)) ->
-- concat
-- [ "store ", show t1, " ", show val
-- , ", ", show t2 , " %", id2, "\n"
-- ]
-- (Load t1 t2 (Ident addr)) ->
-- concat
-- [ "load ", show t1, ", "
-- , show t2, " %", addr, "\n"
-- ]
-- (Bitcast t1 (Ident i) t2) ->
-- concat
-- [ "bitcast ", show t1, " %"
-- , i, " to ", show t2, "\n"
-- ]
-- (Icmp comp t v1 v2) ->
-- concat
-- [ "icmp ", show comp, " ", show t
-- , " ", show v1, ", ", show v2, "\n"
-- ]
-- (Ret t v) ->
-- concat
-- [ "ret ", show t, " "
-- , show v, "\n"
-- ]
-- (UnsafeRaw s) -> s
-- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
-- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
-- (BrCond val (Ident s1) (Ident s2)) ->
-- concat
-- [ "br i1 ", show val, ", ", "label %"
-- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
-- ]
-- (Comment s) -> "; " <> s <> "\n"
-- (Variable (Ident id)) -> "%" <> id
-- {- FOURMOLU_ENABLE -}
--
-- lblPfx :: String
-- lblPfx = "lbl_"
--

View file

@ -1,235 +1,192 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
--{-# LANGUAGE LambdaCase #-}
--{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
module LambdaLifter.LambdaLifter where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Trace (trace)
import qualified Grammar.Abs as GA
import Prelude hiding (exp)
import Renamer.Renamer
import TypeChecker.TypeCheckerIr
--import Auxiliary (snoc)
--import Control.Applicative (Applicative (liftA2))
--import Control.Monad.State (MonadState (get, put), State,
-- evalState)
--import Data.Set (Set)
--import qualified Data.Set as Set
--import Prelude hiding (exp)
--import Renamer.Renamer
--import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
---- | Lift lambdas and let expression into supercombinators.
---- Three phases:
---- @freeVars@ annotatss all the free variables.
---- @abstract@ converts lambdas into let expressions.
---- @collectScs@ moves every non-constant let expression to a top-level function.
--lambdaLift :: Program -> Program
--lambdaLift = collectScs . abstract . freeVars
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where
e' = freeVarsExp (Set.insert par localVars) e
---- | Annotate free variables
--freeVars :: Program -> AnnProgram
--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
-- | Bind n xs e <- ds
-- ]
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
--freeVarsExp :: Set Id -> Exp -> AnnExp
--freeVarsExp localVars = \case
-- EId n | Set.member n localVars -> (Set.singleton n, AId n)
-- | otherwise -> (mempty, AId n)
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
-- ELit _ (LInt i) -> (mempty, AInt i)
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
-- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
(ECase t e cs) -> do
let e' = freeVarsExp localVars e
let vars = freeVarsOf e'
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do
let e' = freeVarsExp vars e
let vars' = freeVarsOf e'
(Set.union vars vars', AnnCase c e' : acc)
) (vars, []) cs
(vars', ACase t e' (reverse cs'))
-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
-- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
-- where
-- e' = freeVarsExp (Set.insert par localVars) e
-- -- Sum free variables present in bind and the expression
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
-- where
-- binders_frees = Set.delete name $ freeVarsOf rhs'
-- e_free = Set.delete name $ freeVarsOf e'
-- rhs' = freeVarsExp e_localVars rhs
-- new_bind = ABind name parms rhs'
-- e' = freeVarsExp e_localVars e
-- e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
--freeVarsOf :: AnnExp -> Set Id
--freeVarsOf = fst
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
---- AST annotated with free variables
--type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp')
--type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show
--data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| ASub Type AnnExp AnnExp
| AAbs Type Id AnnExp
| ACase Type AnnExp [AnnCase]
deriving Show
data AnnCase = AnnCase GA.Case AnnExp
deriving Show
--data AnnExp' = AId Id
-- | AInt Integer
-- | ALet ABind AnnExp
-- | AApp Type AnnExp AnnExp
-- | AAdd Type AnnExp AnnExp
-- | AAbs Type Id AnnExp
-- deriving Show
---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
---- Free variables are @v₁ v₂ .. vₙ@ are bound.
--abstract :: AnnProgram -> Program
--abstract prog = Program $ evalState (mapM go prog) 0
-- where
-- go :: (Id, [Id], AnnExp) -> State Int Bind
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
-- where
-- (rhs', parms1) = flattenLambdasAnn rhs
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
---- | Flatten nested lambdas and collect the parameters
---- @\x.\y.\z. ae → (ae, [x,y,z])@
--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
--flattenLambdasAnn ae = go (ae, [])
-- where
-- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
-- go ((free, e), acc) =
-- case e of
-- AAbs _ par (free1, e1) ->
-- go ((Set.delete par free1, e1), snoc par acc)
-- _ -> ((free, e), acc)
--abstractExp :: AnnExp -> State Int Exp
--abstractExp (free, exp) = case exp of
-- AId n -> pure $ EId n
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
-- where
-- go (ABind name parms rhs) = do
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
-- pure $ Bind name (parms ++ parms1) rhs'
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
-- skipLambdas f (free, ae) = case ae of
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
-- _ -> f (free, ae)
-- -- Lift lambda into let and bind free variables
-- AAbs t parm e -> do
-- i <- nextNumber
-- rhs <- abstractExp e
-- let sc_name = Ident ("sc_" ++ show i)
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
-- where
-- freeList = Set.toList free
-- parms = snoc parm freeList
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
--nextNumber :: State Int Int
--nextNumber = do
-- i <- get
-- put $ succ i
-- pure i
abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of
AId n -> pure $ EId n
AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
---- | Collects supercombinators by lifting non-constant let expressions
--collectScs :: Program -> Program
--collectScs (Program scs) = Program $ concatMap collectFromRhs scs
-- where
-- collectFromRhs (Bind name parms rhs) =
-- let (rhs_scs, rhs') = collectScsExp rhs
-- in Bind name parms rhs' : rhs_scs
ACase t e cs -> do
e' <- abstractExp e
cs' <- mapM (\(AnnCase c e) -> do
e' <- abstractExp e
pure (t,Case c e')) cs
pure $ ECase t e' cs'
--collectScsExp :: Exp -> ([Bind], Exp)
--collectScsExp = \case
-- EId n -> ([], EId n)
-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
-- Lift lambda into let and bind free variables
AAbs t parm e -> do
i <- nextNumber
rhs <- abstractExp e
-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
-- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
-- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
-- EAbs t par e -> (scs, EAbs t par e')
-- where
-- (scs, e') = collectScsExp e
-- -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- -- > f = let sc x y = rhs in e
-- --
-- ELet (Bind name parms rhs) e -> if null parms
-- then ( rhs_scs ++ e_scs, ELet bind e')
-- else (bind : rhs_scs ++ e_scs, e')
-- where
-- bind = Bind name parms rhs'
-- (rhs_scs, rhs') = collectScsExp rhs
-- (e_scs, e') = collectScsExp e
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
---- @\x.\y.\z. e → (e, [x,y,z])@
--flattenLambdas :: Exp -> (Exp, [Id])
--flattenLambdas = go . (, [])
-- where
-- go (e, acc) = case e of
-- EAbs _ par e1 -> go (e1, snoc par acc)
-- _ -> (e, acc)
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
ECase t e cs -> do
let (scs, e') = collectScsExp e
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
let (scs', e') = collectScsExp e
(scs ++ scs', (t,Case c e') : acc)
) (scs,[]) cs
(scs', ECase t e' cs')
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)

View file

@ -2,26 +2,26 @@
module Main where
import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
--import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
-- import Interpreter (interpret)
import Control.Monad (when)
import Data.List.Extra (isSuffixOf)
import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename)
import System.Directory (createDirectory, doesPathExist,
getDirectoryContents,
removeDirectoryRecursive,
setCurrentDirectory)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import System.Process.Extra (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (typecheck)
import Control.Monad (when)
import Data.List.Extra (isSuffixOf)
--import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename)
import System.Directory (createDirectory, doesPathExist,
getDirectoryContents,
removeDirectoryRecursive,
setCurrentDirectory)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import System.Process.Extra (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main =
@ -46,19 +46,19 @@ main' debug s = do
typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ generateCode lifted
-- printToErr "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted
--
-- printToErr "\n -- Printing compiler output to stdout --"
-- compiled <- fromCompilerErr $ generateCode lifted
--putStrLn compiled
check <- doesPathExist "output"
when check (removeDirectoryRecursive "output")
createDirectory "output"
writeFile "output/llvm.ll" compiled
if debug then debugDotViz else putStrLn compiled
-- check <- doesPathExist "output"
-- when check (removeDirectoryRecursive "output")
-- createDirectory "output"
-- writeFile "output/llvm.ll" compiled
-- if debug then debugDotViz else putStrLn compiled
-- interpred <- fromInterpreterErr $ interpret lifted

View file

@ -1,86 +1,87 @@
{-# LANGUAGE LambdaCase #-}
module Renamer.Renamer (module Renamer.Renamer) where
module Renamer.Renamer where
import Auxiliary (mapAccumM)
import Control.Monad (foldM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind
renameSc old_names (Bind name t _ parms rhs) = do
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs'
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }
deriving (Functor, Applicative, Monad, MonadState Int)
newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
EInt i1 -> pure (old_names, EInt i1)
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, ESub e1' e2')
ELet b e -> do
(new_names, b) <- renameLocalBind old_names b
(new_names', e') <- renameExp new_names e
pure (new_names', ELet b e')
EAbs par t e -> do
ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2')
EAbs par e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e')
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
ECase e injs -> do
(_, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
ECase e cs t -> do
(new_names, e') <- renameExp old_names e
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do
(nm,exp') <- renameExp names exp
pure (nm,CaseMatch c exp' : stack)
) (new_names, []) cs
pure (new_names', ECase e' cs' t)
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do
(new_names, xs') <- unzip <$> mapM (renameInj ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
@ -95,4 +96,3 @@ newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -1,215 +1,517 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker (typecheck, partitionType) where
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Auxiliary (maybeToRightM, snoc)
import Control.Monad.Except (throwError, unless)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (Print (prt), concatD, doc,
printTree, render)
import Prelude hiding (exp, id)
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
-- NOTE: this type checker is poorly tested
initCtx = Ctx mempty
-- TODO
-- Coercion
-- Type inference
initEnv = Env 0 mempty mempty
data Cxt = Cxt
{ env :: Map Ident Type -- ^ Local scope signature
, sig :: Map Ident Type -- ^ Top-level signatures
}
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp
initCxt :: [Bind] -> Cxt
initCxt sc = Cxt { env = mempty
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
}
run :: Infer a -> Either Error a
run = runC initEnv initCtx
typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
-- | Check if infered rhs type matches type signature.
checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
{- | Start by freshening the type variable of data types to avoid clash with
other user defined polymorphic types
This might be wrong for type constructors that work over several variables
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
{- | Freshen all polymorphic variables, regardless of name
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
-}
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol _) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) ->
TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
checkData :: Data -> Infer ()
checkData d = do
d' <- freshenData d
case d' of
(Data typ@(Constr name ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
if TConstr typ == retType t'
then insertConstr name' t'
else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ
]
)
constrs
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
T.Program <$> checkDef bs
where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
-- | Infer type of expression.
infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
EInt i -> pure (T.EInt i, T.TInt)
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless
(t `typeEq` t'')
( throwError $
unwords
[ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e'
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.EAdd T.TInt e' e1', T.TInt)
ESub e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.ESub T.TInt e' e1', T.TInt)
EAbs x t e -> do
(e', t1) <- infer (insertEnv x t cxt) e
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e
pure (T.ELet b' e', t)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure (T.ECase t1 e' cs,t1)
Left e -> throwError e
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
EId x -> do
t <- case lookupEnv x cxt of
Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
ESub e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.ESub T.TInt e' e1'
EAbs x t e -> do
(e', t_e) <- infer (insertEnv x t cxt) e
let t1 = TFun t t_e
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure $ T.ECase t1 e' cs
Left e -> throwError e
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
{- | Check if two types are considered equal
For the purpose of the algorithm two polymorphic types are always considered
equal
-}
typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
isPoly :: Type -> Bool
isPoly (TPol _) = True
isPoly _ = False
lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
insertEnv :: Ident -> Type -> Cxt -> Cxt
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
lookupSig :: Ident -> Cxt -> Maybe Type
lookupSig x = Map.lookup x . sig
algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do
(s1, t', e') <- algoW e
unless
(t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t'
]
)
applySt s1 $ do
s2 <- unify t t'
return (s2 `compose` s1, t, e')
typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual
]
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit (LInt n) ->
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EId i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing ->
throwError $
"Unbound variable: " ++ show i
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- \| ------------------------------------------
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else
throwError $
unwords
[ "Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case
InitLit lit ->
let returnType = litType lit
in if expected == returnType
then return (mempty, expected)
else
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
Nothing ->
throwError $
unwords
[ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do
let flat = flattenType t
let returnType = last flat
case ( length (init flat) == length args
, returnType `isMoreSpecificOrEq` expected
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt _) = TMono "Int"

View file

@ -1,139 +1,184 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
module TypeChecker.TypeCheckerIr where
import Grammar.Abs (Ident (..), Type (..))
import qualified Grammar.Abs as GA
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env
{ count :: Int
, sigs :: Map Ident Type
, constructors :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Id
| EInt Integer
= EId Id
| ELit Type Literal
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| ESub Type Exp Exp
| EAbs Type Id Exp
| ECase Type Exp [(Type, Case)]
deriving (C.Eq, C.Ord, C.Show, C.Read)
| EAbs Type Id Exp
| ECase Type Exp [Inj]
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Case = Case GA.Case Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Inj = Inj (Init, Type) Exp
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])]
data Bind = Bind Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name
, doc $ showString ";"
, prt 0 n
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD
[ prt 0 n
, doc $ showString "{"
, doc . showString . show $ xs
, doc $ showString "}"
]
prt i (Bind (t, name) rhs) =
prPrec i 0 $
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString "\n"
, prt 0 name
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtId i (name, t) =
prPrec i 0 $
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
prtIdP i (name, t) =
prPrec i 0 $
concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet bs e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
ESub t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "-"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtIdP 0 n
, doc $ showString "."
, prt 0 e
]
ECase t e cs -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, prt 0 e
, doc $ showString ")"
, prPrec i 0 $ concatD . printCases $ cs
]
prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
ELet bs e ->
prPrec i 3 $
concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
, doc $ showString "\n"
]
EApp _ e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
, doc $ showString "\n"
]
ESub t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "-"
, prt 2 e2
, doc $ showString "\n"
]
EAbs t n e ->
prPrec i 0 $
concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtId 0 n
, doc $ showString "."
, prt 0 e
, doc $ showString "\n"
]
ECase t exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
where
printCases :: [(Type, Case)] -> [Doc]
printCases [] = []
printCases ((t, Case c e):xs) = concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, doc . showString . show $ c
, doc $ showString ")"
, doc $ showString "=>"
, prt 0 e
, doc $ showString "\n"
] : printCases xs
instance Print Inj where
prt i = \case
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]