Started importing Sebastian's new typechecker.
This commit is contained in:
parent
d5dd7896d8
commit
350cd3b0e9
9 changed files with 1611 additions and 1346 deletions
68
Grammar.cf
68
Grammar.cf
|
|
@ -1,33 +1,51 @@
|
|||
Program. Program ::= [Bind];
|
||||
|
||||
EId. Exp3 ::= Ident;
|
||||
EInt. Exp3 ::= Integer;
|
||||
EAnn. Exp3 ::= "(" Exp ":" Type ")";
|
||||
ELet. Exp3 ::= "let" Bind "in" Exp;
|
||||
EApp. Exp2 ::= Exp2 Exp3;
|
||||
EAdd. Exp1 ::= Exp1 "+" Exp2;
|
||||
ESub. Exp1 ::= Exp1 "-" Exp2;
|
||||
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
|
||||
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type;
|
||||
CaseMatch. CaseMatch ::= Case "=>" Exp ;
|
||||
separator CaseMatch ",";
|
||||
Program. Program ::= [Def] ;
|
||||
|
||||
|
||||
CInt. Case ::= Integer ;
|
||||
CatchAll. Case ::= "_" ;
|
||||
DBind. Def ::= Bind ;
|
||||
DData. Def ::= Data ;
|
||||
separator Def ";" ;
|
||||
|
||||
Bind. Bind ::= Ident ":" Type ";"
|
||||
Ident [Ident] "=" Exp;
|
||||
Ident [Ident] "=" Exp ;
|
||||
|
||||
separator Bind ";";
|
||||
separator Ident "";
|
||||
Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
|
||||
|
||||
coercions Exp 3;
|
||||
Constructor. Constructor ::= Ident ":" Type ;
|
||||
separator nonempty Constructor "" ;
|
||||
|
||||
TInt. Type1 ::= "Int" ;
|
||||
TPol. Type1 ::= Ident ;
|
||||
TFun. Type ::= Type1 "->" Type ;
|
||||
coercions Type 1 ;
|
||||
TMono. Type1 ::= "_" Ident ;
|
||||
TPol. Type1 ::= "'" Ident ;
|
||||
TConstr. Type1 ::= Constr ;
|
||||
TArr. Type ::= Type1 "->" Type ;
|
||||
|
||||
comment "--";
|
||||
comment "{-" "-}";
|
||||
Constr. Constr ::= Ident "(" [Type] ")" ;
|
||||
|
||||
-- TODO: Move literal to its own thing since it's reused in Init as well.
|
||||
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
|
||||
EId. Exp4 ::= Ident ;
|
||||
ELit. Exp4 ::= Literal ;
|
||||
EApp. Exp3 ::= Exp3 Exp4 ;
|
||||
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
|
||||
ESub. Exp1 ::= Exp1 "-" Exp2 ;
|
||||
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
|
||||
EAbs. Exp ::= "\\" Ident "." Exp ;
|
||||
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
|
||||
|
||||
LInt. Literal ::= Integer ;
|
||||
|
||||
Inj. Inj ::= Init "=>" Exp ;
|
||||
separator nonempty Inj ";" ;
|
||||
|
||||
InitLit. Init ::= Literal ;
|
||||
InitConstr. Init ::= Ident [Ident] ;
|
||||
InitCatch. Init ::= "_" ;
|
||||
|
||||
separator Type " " ;
|
||||
coercions Type 2 ;
|
||||
|
||||
separator Ident " ";
|
||||
|
||||
coercions Exp 5 ;
|
||||
|
||||
comment "--" ;
|
||||
comment "{-" "-}" ;
|
||||
|
|
|
|||
|
|
@ -1,87 +1,26 @@
|
|||
|
||||
-- tripplemagic : Int -> Int -> Int -> Int;
|
||||
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
|
||||
-- main : Int;
|
||||
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
|
||||
-- answer: 22
|
||||
|
||||
-- apply : (Int -> Int) -> Int -> Int;
|
||||
-- apply f x = f x;
|
||||
-- main : Int;
|
||||
-- main = apply (\x : Int . x + 5) 5
|
||||
-- answer: 10
|
||||
|
||||
-- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
|
||||
-- apply f x y = f x y;
|
||||
-- krimp: Int -> Int -> Int;
|
||||
-- krimp x y = x + y;
|
||||
-- main : Int;
|
||||
-- main = apply (krimp) 2 3;
|
||||
-- answer: 5
|
||||
|
||||
-- fibbonaci : Int -> Int;
|
||||
-- fibbonaci x = case x of {
|
||||
-- 0 => 0,
|
||||
-- 1 => 1,
|
||||
-- -- abusing overflows to represent negatives like a boss
|
||||
-- _ => (fibbonaci (x - 2))
|
||||
-- + (fibbonaci (x - 1))
|
||||
-- } : Int;
|
||||
-- main : Int;
|
||||
-- main = fibbonaci 10;
|
||||
-- answer: 55
|
||||
|
||||
-- succ : Int -> Int;
|
||||
-- succ x = x - 1;
|
||||
--
|
||||
-- isZero : Int -> Int;
|
||||
-- isZero x = case x of {
|
||||
-- 0 => 1,
|
||||
-- _ => 0
|
||||
-- } : Int;
|
||||
--
|
||||
-- minimization : (Int -> Int) -> Int -> Int;
|
||||
-- minimization p x = case p x of {
|
||||
-- 1 => 0,
|
||||
-- _ => minimization p (succ x)
|
||||
-- } : Int;
|
||||
--
|
||||
-- main : Int;
|
||||
-- main = minimization isZero 10;
|
||||
-- answer: 0
|
||||
|
||||
posMul : Int -> Int -> Int;
|
||||
posMul : _Int -> _Int -> _Int;
|
||||
posMul a b = case b of {
|
||||
0 => 0,
|
||||
0 => 0;
|
||||
_ => a + posMul a (b - 1)
|
||||
} : Int;
|
||||
};
|
||||
|
||||
facc : Int -> Int;
|
||||
facc : _Int -> _Int;
|
||||
facc a = case a of {
|
||||
1 => 1,
|
||||
1 => 1;
|
||||
_ => posMul a (facc (a - 1))
|
||||
} : Int;
|
||||
-- main : Int;
|
||||
-- main = facc 5
|
||||
-- answer: 120
|
||||
};
|
||||
|
||||
-- pow : Int -> Int -> Int;
|
||||
-- pow a b = case b of {
|
||||
-- 0 => 1,
|
||||
-- _ => posMul a (pow a (b-1))
|
||||
-- } : Int;
|
||||
|
||||
minimization : (Int -> Int) -> Int -> Int;
|
||||
minimization : (_Int -> _Int) -> _Int -> _Int;
|
||||
minimization p x = case p x of {
|
||||
1 => x,
|
||||
1 => x;
|
||||
_ => minimization p (x + 1)
|
||||
} : Int;
|
||||
};
|
||||
|
||||
checkFac : Int -> Int;
|
||||
checkFac : _Int -> _Int;
|
||||
checkFac x = case facc x of {
|
||||
0 => 1,
|
||||
0 => 1;
|
||||
_ => 0
|
||||
} : Int;
|
||||
};
|
||||
|
||||
main : Int;
|
||||
main : _Int;
|
||||
main = minimization checkFac 1
|
||||
|
|
@ -1,441 +1,443 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Codegen.Codegen (generateCode) where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Codegen.LlvmIr (CallingConvention (..),
|
||||
LLVMComp (..), LLVMIr (..),
|
||||
LLVMType (..), LLVMValue (..),
|
||||
Visibility (..), llvmIrToString)
|
||||
import Control.Monad.State (StateT, execStateT, foldM_, gets,
|
||||
modify)
|
||||
import qualified Data.Bifunctor as BI
|
||||
import Data.List.Extra (trim)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Tuple.Extra (dupe, first, second)
|
||||
import qualified Grammar.Abs as GA
|
||||
import Grammar.ErrM (Err)
|
||||
import System.Process.Extra (readCreateProcess, shell)
|
||||
import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
|
||||
Ident (..), Program (..), Type (..))
|
||||
-- | The record used as the code generator state
|
||||
data CodeGenerator = CodeGenerator
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map Id FunctionInfo
|
||||
, constructors :: Map Id ConstructorInfo
|
||||
, variableCount :: Integer
|
||||
, labelCount :: Integer
|
||||
}
|
||||
|
||||
-- | A state type synonym
|
||||
type CompilerState a = StateT CodeGenerator Err a
|
||||
|
||||
data FunctionInfo = FunctionInfo
|
||||
{ numArgs :: Int
|
||||
, arguments :: [Id]
|
||||
}
|
||||
data ConstructorInfo = ConstructorInfo
|
||||
{ numArgsCI :: Int
|
||||
, argumentsCI :: [Id]
|
||||
, numCI :: Integer
|
||||
}
|
||||
|
||||
|
||||
-- | Adds a instruction to the CodeGenerator state
|
||||
emit :: LLVMIr -> CompilerState ()
|
||||
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
|
||||
|
||||
-- | Increases the variable counter in the CodeGenerator state
|
||||
increaseVarCount :: CompilerState ()
|
||||
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||
|
||||
-- | Returns the variable count from the CodeGenerator state
|
||||
getVarCount :: CompilerState Integer
|
||||
getVarCount = gets variableCount
|
||||
|
||||
-- | Increases the variable count and returns it from the CodeGenerator state
|
||||
getNewVar :: CompilerState Integer
|
||||
getNewVar = increaseVarCount >> getVarCount
|
||||
|
||||
-- | Increses the label count and returns a label from the CodeGenerator state
|
||||
getNewLabel :: CompilerState Integer
|
||||
getNewLabel = do
|
||||
modify (\t -> t{labelCount = labelCount t + 1})
|
||||
gets labelCount
|
||||
|
||||
-- | Produces a map of functions infos from a list of binds,
|
||||
-- which contains useful data for code generation.
|
||||
getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||
getFunctions bs = Map.fromList $ go bs
|
||||
where
|
||||
go [] = []
|
||||
go (Bind id args _ : xs) =
|
||||
(id, FunctionInfo { numArgs=length args, arguments=args })
|
||||
: go xs
|
||||
go (DataStructure n cons : xs) = do
|
||||
map (\(id, xs) -> ((id, TPol n), FunctionInfo {
|
||||
numArgs=length xs, arguments=createArgs xs
|
||||
})) cons
|
||||
<> go xs
|
||||
|
||||
createArgs :: [Type] -> [Id]
|
||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
|
||||
|
||||
-- | Produces a map of functions infos from a list of binds,
|
||||
-- which contains useful data for code generation.
|
||||
getConstructors :: [Bind] -> Map Id ConstructorInfo
|
||||
getConstructors bs = Map.fromList $ go bs
|
||||
where
|
||||
go [] = []
|
||||
go (DataStructure (Ident n) cons : xs) = do
|
||||
fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
|
||||
numArgsCI=length xs,
|
||||
argumentsCI=createArgs xs,
|
||||
numCI=i
|
||||
}) : acc, i+1)) ([],0) cons)
|
||||
<> go xs
|
||||
go (_: xs) = go xs
|
||||
|
||||
initCodeGenerator :: [Bind] -> CodeGenerator
|
||||
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||
, functions = getFunctions scs
|
||||
, constructors = getConstructors scs
|
||||
, variableCount = 0
|
||||
, labelCount = 0
|
||||
}
|
||||
|
||||
run :: Err String -> IO ()
|
||||
run s = do
|
||||
let s' = case s of
|
||||
Right s -> s
|
||||
Left _ -> error "yo"
|
||||
writeFile "output/llvm.ll" s'
|
||||
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||
|
||||
test :: Integer -> Program
|
||||
test v = Program [
|
||||
DataStructure (Ident "Craig") [
|
||||
(Ident "Bob", [TInt])--,
|
||||
--(Ident "Alice", [TInt, TInt])
|
||||
],
|
||||
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
|
||||
Bind (Ident "main", TInt) [] (
|
||||
EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
|
||||
)
|
||||
]
|
||||
|
||||
{- | Compiles an AST and produces a LLVM Ir string.
|
||||
An easy way to actually "compile" this output is to
|
||||
Simply pipe it to LLI
|
||||
-}
|
||||
generateCode :: Program -> Err String
|
||||
generateCode (Program scs) = do
|
||||
let codegen = initCodeGenerator scs
|
||||
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||
|
||||
compileScs :: [Bind] -> CompilerState ()
|
||||
compileScs [] = do
|
||||
-- as a last step create all the constructors
|
||||
c <- gets (Map.toList . constructors)
|
||||
mapM_ (\((id, t), ci) -> do
|
||||
let t' = type2LlvmType t
|
||||
let x = BI.second type2LlvmType <$> argumentsCI ci
|
||||
emit $ Define FastCC t' id x
|
||||
top <- Ident . show <$> getNewVar
|
||||
ptr <- Ident . show <$> getNewVar
|
||||
-- allocated the primary type
|
||||
emit $ SetVariable top (Alloca t')
|
||||
|
||||
-- set the first byte to the index of the constructor
|
||||
emit $ SetVariable ptr $
|
||||
GetElementPtrInbounds t' (Ref t')
|
||||
(VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
|
||||
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
|
||||
|
||||
-- get a pointer of the correct type
|
||||
ptr' <- Ident . show <$> getNewVar
|
||||
emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
|
||||
|
||||
--emit $ UnsafeRaw "\n"
|
||||
|
||||
foldM_ (\i (Ident arg_n, arg_t)-> do
|
||||
let arg_t' = type2LlvmType arg_t
|
||||
emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
|
||||
elemPtr <- Ident . show <$> getNewVar
|
||||
emit $ SetVariable elemPtr (
|
||||
GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
|
||||
(VIdent ptr' Ptr) I32
|
||||
(VInteger 0) I32 (VInteger i))
|
||||
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||
-- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
|
||||
-- store i32 42, i32* %2
|
||||
pure $ i + 1-- + typeByteSize arg_t'
|
||||
) 1 (argumentsCI ci)
|
||||
|
||||
--emit $ UnsafeRaw "\n"
|
||||
|
||||
-- load and return the constructed value
|
||||
load <- Ident . show <$> getNewVar
|
||||
emit $ SetVariable load (Load t' Ptr top)
|
||||
emit $ Ret t' (VIdent load t')
|
||||
emit DefineEnd
|
||||
|
||||
modify $ \s -> s { variableCount = 0 }
|
||||
) c
|
||||
compileScs (Bind (name, _t) args exp : xs) = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
emit . Comment $ show name <> ": " <> show exp
|
||||
let args' = map (second type2LlvmType) args
|
||||
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
|
||||
functionBody <- exprToValue exp
|
||||
if name == "main"
|
||||
then mapM_ emit $ mainContent functionBody
|
||||
else emit $ Ret I64 functionBody
|
||||
emit DefineEnd
|
||||
modify $ \s -> s { variableCount = 0 }
|
||||
compileScs xs
|
||||
compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
|
||||
let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
|
||||
emit $ Type id [I8, Array biggest_variant I8]
|
||||
mapM_ (\(Ident inner_id, fi) -> do
|
||||
emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
|
||||
) ts
|
||||
compileScs xs
|
||||
|
||||
-- where
|
||||
-- _t_return = snd $ partitionType (length args) t
|
||||
|
||||
mainContent :: LLVMValue -> [LLVMIr]
|
||||
mainContent var =
|
||||
[ UnsafeRaw $
|
||||
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||
-- , Label (Ident "b_1")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "b_2")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "end")
|
||||
Ret I64 (VInteger 0)
|
||||
]
|
||||
|
||||
defaultStart :: [LLVMIr]
|
||||
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
]
|
||||
|
||||
compileExp :: Exp -> CompilerState ()
|
||||
compileExp (EInt int) = emitInt int
|
||||
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
|
||||
compileExp (ESub t e1 e2) = emitSub t e1 e2
|
||||
compileExp (EId (name, _)) = emitIdent name
|
||||
compileExp (EApp t e1 e2) = emitApp t e1 e2
|
||||
compileExp (EAbs t ti e) = emitAbs t ti e
|
||||
compileExp (ELet binds e) = emitLet binds e
|
||||
compileExp (ECase t e cs) = emitECased t e cs
|
||||
-- go (EMul e1 e2) = emitMul e1 e2
|
||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
||||
-- go (EMod e1 e2) = emitMod e1 e2
|
||||
|
||||
--- aux functions ---
|
||||
emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
|
||||
emitECased t e cases = do
|
||||
let cs = snd <$> cases
|
||||
let ty = type2LlvmType t
|
||||
vs <- exprToValue e
|
||||
lbl <- getNewLabel
|
||||
let label = Ident $ "escape_" <> show lbl
|
||||
stackPtr <- getNewVar
|
||||
emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
|
||||
mapM_ (emitCases ty label stackPtr vs) cs
|
||||
emit $ Label label
|
||||
res <- getNewVar
|
||||
emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
|
||||
where
|
||||
emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||
emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
|
||||
ns <- getNewVar
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
|
||||
emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
|
||||
emit $ Label lbl_succPos
|
||||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
emit $ Label lbl_failPos
|
||||
emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
|
||||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
|
||||
|
||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
emitAbs _t tid e = do
|
||||
emit . Comment $
|
||||
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||
emitLet :: Bind -> Exp -> CompilerState ()
|
||||
emitLet xs e = do
|
||||
emit $
|
||||
Comment $
|
||||
concat
|
||||
[ "ELet ("
|
||||
, show xs
|
||||
, " = "
|
||||
, show e
|
||||
, ") is not implemented!"
|
||||
]
|
||||
|
||||
emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||
emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||
where
|
||||
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||
appEmitter t e1 e2 stack = do
|
||||
let newStack = e2 : stack
|
||||
case e1 of
|
||||
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||
EId id@(name, _) -> do
|
||||
args <- traverse exprToValue newStack
|
||||
vs <- getNewVar
|
||||
funcs <- gets functions
|
||||
let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||
args' = map (first valueGetType . dupe) args
|
||||
call = Call FastCC (type2LlvmType t) visibility name args'
|
||||
emit $ SetVariable (Ident $ show vs) call
|
||||
x -> do
|
||||
emit . Comment $ "The unspeakable happened: "
|
||||
emit . Comment $ show x
|
||||
|
||||
emitIdent :: Ident -> CompilerState ()
|
||||
emitIdent id = do
|
||||
-- !!this should never happen!!
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ Variable id
|
||||
emit $ UnsafeRaw "\n"
|
||||
|
||||
emitInt :: Integer -> CompilerState ()
|
||||
emitInt i = do
|
||||
-- !!this should never happen!!
|
||||
varCount <- getNewVar
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||
|
||||
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||
emitAdd t e1 e2 = do
|
||||
v1 <- exprToValue e1
|
||||
v2 <- exprToValue e2
|
||||
v <- getNewVar
|
||||
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
|
||||
emitSub :: Type -> Exp -> Exp -> CompilerState ()
|
||||
emitSub t e1 e2 = do
|
||||
v1 <- exprToValue e1
|
||||
v2 <- exprToValue e2
|
||||
v <- getNewVar
|
||||
emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
|
||||
|
||||
-- emitMul :: Exp -> Exp -> CompilerState ()
|
||||
-- emitMul e1 e2 = do
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Mul I64 v1 v2
|
||||
|
||||
-- emitMod :: Exp -> Exp -> CompilerState ()
|
||||
-- emitMod e1 e2 = do
|
||||
-- -- `let m a b = rem (abs $ b + a) b`
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- vadd <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show vadd
|
||||
-- emit $ Add I64 v1 v2
|
||||
module Codegen.Codegen where
|
||||
-- {-# LANGUAGE LambdaCase #-}
|
||||
-- {-# LANGUAGE OverloadedStrings #-}
|
||||
--
|
||||
-- module Codegen.Codegen (generateCode) where
|
||||
--
|
||||
-- import Auxiliary (snoc)
|
||||
-- import Codegen.LlvmIr (CallingConvention (..),
|
||||
-- LLVMComp (..), LLVMIr (..),
|
||||
-- LLVMType (..), LLVMValue (..),
|
||||
-- Visibility (..), llvmIrToString)
|
||||
-- import Control.Monad.State (StateT, execStateT, foldM_, gets,
|
||||
-- modify)
|
||||
-- import qualified Data.Bifunctor as BI
|
||||
-- import Data.List.Extra (trim)
|
||||
-- import Data.Map (Map)
|
||||
-- import qualified Data.Map as Map
|
||||
-- import Data.Tuple.Extra (dupe, first, second)
|
||||
-- import qualified Grammar.Abs as GA
|
||||
-- import Grammar.ErrM (Err)
|
||||
-- import System.Process.Extra (readCreateProcess, shell)
|
||||
-- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
|
||||
-- Ident (..), Program (..), Type (..))
|
||||
-- -- | The record used as the code generator state
|
||||
-- data CodeGenerator = CodeGenerator
|
||||
-- { instructions :: [LLVMIr]
|
||||
-- , functions :: Map Id FunctionInfo
|
||||
-- , constructors :: Map Id ConstructorInfo
|
||||
-- , variableCount :: Integer
|
||||
-- , labelCount :: Integer
|
||||
-- }
|
||||
--
|
||||
-- -- | A state type synonym
|
||||
-- type CompilerState a = StateT CodeGenerator Err a
|
||||
--
|
||||
-- data FunctionInfo = FunctionInfo
|
||||
-- { numArgs :: Int
|
||||
-- , arguments :: [Id]
|
||||
-- }
|
||||
-- data ConstructorInfo = ConstructorInfo
|
||||
-- { numArgsCI :: Int
|
||||
-- , argumentsCI :: [Id]
|
||||
-- , numCI :: Integer
|
||||
-- }
|
||||
--
|
||||
--
|
||||
-- -- | Adds a instruction to the CodeGenerator state
|
||||
-- emit :: LLVMIr -> CompilerState ()
|
||||
-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
|
||||
--
|
||||
-- -- | Increases the variable counter in the CodeGenerator state
|
||||
-- increaseVarCount :: CompilerState ()
|
||||
-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||
--
|
||||
-- -- | Returns the variable count from the CodeGenerator state
|
||||
-- getVarCount :: CompilerState Integer
|
||||
-- getVarCount = gets variableCount
|
||||
--
|
||||
-- -- | Increases the variable count and returns it from the CodeGenerator state
|
||||
-- getNewVar :: CompilerState Integer
|
||||
-- getNewVar = increaseVarCount >> getVarCount
|
||||
--
|
||||
-- -- | Increses the label count and returns a label from the CodeGenerator state
|
||||
-- getNewLabel :: CompilerState Integer
|
||||
-- getNewLabel = do
|
||||
-- modify (\t -> t{labelCount = labelCount t + 1})
|
||||
-- gets labelCount
|
||||
--
|
||||
-- -- | Produces a map of functions infos from a list of binds,
|
||||
-- -- which contains useful data for code generation.
|
||||
-- getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||
-- getFunctions bs = Map.fromList $ go bs
|
||||
-- where
|
||||
-- go [] = []
|
||||
-- go (Bind id args _ : xs) =
|
||||
-- (id, FunctionInfo { numArgs=length args, arguments=args })
|
||||
-- : go xs
|
||||
-- go (DataStructure n cons : xs) = do
|
||||
-- map (\(id, xs) -> ((id, TPol n), FunctionInfo {
|
||||
-- numArgs=length xs, arguments=createArgs xs
|
||||
-- })) cons
|
||||
-- <> go xs
|
||||
--
|
||||
-- createArgs :: [Type] -> [Id]
|
||||
-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
|
||||
--
|
||||
-- -- | Produces a map of functions infos from a list of binds,
|
||||
-- -- which contains useful data for code generation.
|
||||
-- getConstructors :: [Bind] -> Map Id ConstructorInfo
|
||||
-- getConstructors bs = Map.fromList $ go bs
|
||||
-- where
|
||||
-- go [] = []
|
||||
-- go (DataStructure (Ident n) cons : xs) = do
|
||||
-- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
|
||||
-- numArgsCI=length xs,
|
||||
-- argumentsCI=createArgs xs,
|
||||
-- numCI=i
|
||||
-- }) : acc, i+1)) ([],0) cons)
|
||||
-- <> go xs
|
||||
-- go (_: xs) = go xs
|
||||
--
|
||||
-- initCodeGenerator :: [Bind] -> CodeGenerator
|
||||
-- initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||
-- , functions = getFunctions scs
|
||||
-- , constructors = getConstructors scs
|
||||
-- , variableCount = 0
|
||||
-- , labelCount = 0
|
||||
-- }
|
||||
--
|
||||
-- run :: Err String -> IO ()
|
||||
-- run s = do
|
||||
-- let s' = case s of
|
||||
-- Right s -> s
|
||||
-- Left _ -> error "yo"
|
||||
-- writeFile "output/llvm.ll" s'
|
||||
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||
--
|
||||
-- test :: Integer -> Program
|
||||
-- test v = Program [
|
||||
-- DataStructure (Ident "Craig") [
|
||||
-- (Ident "Bob", [TInt])--,
|
||||
-- --(Ident "Alice", [TInt, TInt])
|
||||
-- ],
|
||||
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
|
||||
-- Bind (Ident "main", TInt) [] (
|
||||
-- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
|
||||
-- )
|
||||
-- ]
|
||||
--
|
||||
-- {- | Compiles an AST and produces a LLVM Ir string.
|
||||
-- An easy way to actually "compile" this output is to
|
||||
-- Simply pipe it to LLI
|
||||
-- -}
|
||||
-- generateCode :: Program -> Err String
|
||||
-- generateCode (Program scs) = do
|
||||
-- let codegen = initCodeGenerator scs
|
||||
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||
--
|
||||
-- compileScs :: [Bind] -> CompilerState ()
|
||||
-- compileScs [] = do
|
||||
-- -- as a last step create all the constructors
|
||||
-- c <- gets (Map.toList . constructors)
|
||||
-- mapM_ (\((id, t), ci) -> do
|
||||
-- let t' = type2LlvmType t
|
||||
-- let x = BI.second type2LlvmType <$> argumentsCI ci
|
||||
-- emit $ Define FastCC t' id x
|
||||
-- top <- Ident . show <$> getNewVar
|
||||
-- ptr <- Ident . show <$> getNewVar
|
||||
-- -- allocated the primary type
|
||||
-- emit $ SetVariable top (Alloca t')
|
||||
--
|
||||
-- -- set the first byte to the index of the constructor
|
||||
-- emit $ SetVariable ptr $
|
||||
-- GetElementPtrInbounds t' (Ref t')
|
||||
-- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
|
||||
-- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
|
||||
--
|
||||
-- -- get a pointer of the correct type
|
||||
-- ptr' <- Ident . show <$> getNewVar
|
||||
-- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
|
||||
--
|
||||
-- --emit $ UnsafeRaw "\n"
|
||||
--
|
||||
-- foldM_ (\i (Ident arg_n, arg_t)-> do
|
||||
-- let arg_t' = type2LlvmType arg_t
|
||||
-- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
|
||||
-- elemPtr <- Ident . show <$> getNewVar
|
||||
-- emit $ SetVariable elemPtr (
|
||||
-- GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
|
||||
-- (VIdent ptr' Ptr) I32
|
||||
-- (VInteger 0) I32 (VInteger i))
|
||||
-- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||
-- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
|
||||
-- -- store i32 42, i32* %2
|
||||
-- pure $ i + 1-- + typeByteSize arg_t'
|
||||
-- ) 1 (argumentsCI ci)
|
||||
--
|
||||
-- --emit $ UnsafeRaw "\n"
|
||||
--
|
||||
-- -- load and return the constructed value
|
||||
-- load <- Ident . show <$> getNewVar
|
||||
-- emit $ SetVariable load (Load t' Ptr top)
|
||||
-- emit $ Ret t' (VIdent load t')
|
||||
-- emit DefineEnd
|
||||
--
|
||||
-- modify $ \s -> s { variableCount = 0 }
|
||||
-- ) c
|
||||
-- compileScs (Bind (name, _t) args exp : xs) = do
|
||||
-- emit $ UnsafeRaw "\n"
|
||||
-- emit . Comment $ show name <> ": " <> show exp
|
||||
-- let args' = map (second type2LlvmType) args
|
||||
-- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
|
||||
-- functionBody <- exprToValue exp
|
||||
-- if name == "main"
|
||||
-- then mapM_ emit $ mainContent functionBody
|
||||
-- else emit $ Ret I64 functionBody
|
||||
-- emit DefineEnd
|
||||
-- modify $ \s -> s { variableCount = 0 }
|
||||
-- compileScs xs
|
||||
-- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
|
||||
-- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
|
||||
-- emit $ Type id [I8, Array biggest_variant I8]
|
||||
-- mapM_ (\(Ident inner_id, fi) -> do
|
||||
-- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
|
||||
-- ) ts
|
||||
-- compileScs xs
|
||||
--
|
||||
-- -- where
|
||||
-- -- _t_return = snd $ partitionType (length args) t
|
||||
--
|
||||
-- mainContent :: LLVMValue -> [LLVMIr]
|
||||
-- mainContent var =
|
||||
-- [ UnsafeRaw $
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||
-- -- , Label (Ident "b_1")
|
||||
-- -- , UnsafeRaw
|
||||
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||
-- -- , Br (Ident "end")
|
||||
-- -- , Label (Ident "b_2")
|
||||
-- -- , UnsafeRaw
|
||||
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||
-- -- , Br (Ident "end")
|
||||
-- -- , Label (Ident "end")
|
||||
-- Ret I64 (VInteger 0)
|
||||
-- ]
|
||||
--
|
||||
-- defaultStart :: [LLVMIr]
|
||||
-- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||
-- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||
-- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
-- ]
|
||||
--
|
||||
-- compileExp :: Exp -> CompilerState ()
|
||||
-- compileExp (EInt int) = emitInt int
|
||||
-- compileExp (EAdd t e1 e2) = emitAdd t e1 e2
|
||||
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
|
||||
-- compileExp (EId (name, _)) = emitIdent name
|
||||
-- compileExp (EApp t e1 e2) = emitApp t e1 e2
|
||||
-- compileExp (EAbs t ti e) = emitAbs t ti e
|
||||
-- compileExp (ELet binds e) = emitLet binds e
|
||||
-- compileExp (ECase t e cs) = emitECased t e cs
|
||||
-- -- go (EMul e1 e2) = emitMul e1 e2
|
||||
-- -- go (EDiv e1 e2) = emitDiv e1 e2
|
||||
-- -- go (EMod e1 e2) = emitMod e1 e2
|
||||
--
|
||||
-- --- aux functions ---
|
||||
-- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
|
||||
-- emitECased t e cases = do
|
||||
-- let cs = snd <$> cases
|
||||
-- let ty = type2LlvmType t
|
||||
-- vs <- exprToValue e
|
||||
-- lbl <- getNewLabel
|
||||
-- let label = Ident $ "escape_" <> show lbl
|
||||
-- stackPtr <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
|
||||
-- mapM_ (emitCases ty label stackPtr vs) cs
|
||||
-- emit $ Label label
|
||||
-- res <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
|
||||
-- where
|
||||
-- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||
-- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
|
||||
-- ns <- getNewVar
|
||||
-- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
-- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
-- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
|
||||
-- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
|
||||
-- emit $ Label lbl_succPos
|
||||
-- val <- exprToValue exp
|
||||
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||
-- emit $ Br label
|
||||
-- emit $ Label lbl_failPos
|
||||
-- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
|
||||
-- val <- exprToValue exp
|
||||
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||
-- emit $ Br label
|
||||
--
|
||||
--
|
||||
-- emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
-- emitAbs _t tid e = do
|
||||
-- emit . Comment $
|
||||
-- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||
-- emitLet :: Bind -> Exp -> CompilerState ()
|
||||
-- emitLet xs e = do
|
||||
-- emit $
|
||||
-- Comment $
|
||||
-- concat
|
||||
-- [ "ELet ("
|
||||
-- , show xs
|
||||
-- , " = "
|
||||
-- , show e
|
||||
-- , ") is not implemented!"
|
||||
-- ]
|
||||
--
|
||||
-- emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||
-- emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||
-- where
|
||||
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||
-- appEmitter t e1 e2 stack = do
|
||||
-- let newStack = e2 : stack
|
||||
-- case e1 of
|
||||
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||
-- EId id@(name, _) -> do
|
||||
-- args <- traverse exprToValue newStack
|
||||
-- vs <- getNewVar
|
||||
-- funcs <- gets functions
|
||||
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||
-- args' = map (first valueGetType . dupe) args
|
||||
-- call = Call FastCC (type2LlvmType t) visibility name args'
|
||||
-- emit $ SetVariable (Ident $ show vs) call
|
||||
-- x -> do
|
||||
-- emit . Comment $ "The unspeakable happened: "
|
||||
-- emit . Comment $ show x
|
||||
--
|
||||
-- emitIdent :: Ident -> CompilerState ()
|
||||
-- emitIdent id = do
|
||||
-- -- !!this should never happen!!
|
||||
-- emit $ Comment "This should not have happened!"
|
||||
-- emit $ Variable id
|
||||
-- emit $ UnsafeRaw "\n"
|
||||
--
|
||||
-- emitInt :: Integer -> CompilerState ()
|
||||
-- emitInt i = do
|
||||
-- -- !!this should never happen!!
|
||||
-- varCount <- getNewVar
|
||||
-- emit $ Comment "This should not have happened!"
|
||||
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||
--
|
||||
-- emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||
-- emitAdd t e1 e2 = do
|
||||
-- v1 <- exprToValue e1
|
||||
-- v2 <- exprToValue e2
|
||||
-- v <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
--
|
||||
-- emitSub :: Type -> Exp -> Exp -> CompilerState ()
|
||||
-- emitSub t e1 e2 = do
|
||||
-- v1 <- exprToValue e1
|
||||
-- v2 <- exprToValue e2
|
||||
-- v <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
|
||||
--
|
||||
-- -- emitMul :: Exp -> Exp -> CompilerState ()
|
||||
-- -- emitMul e1 e2 = do
|
||||
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||
-- -- increaseVarCount
|
||||
-- -- v <- gets variableCount
|
||||
-- -- emit $ SetVariable $ Ident $ show v
|
||||
-- -- emit $ Mul I64 v1 v2
|
||||
--
|
||||
-- -- emitMod :: Exp -> Exp -> CompilerState ()
|
||||
-- -- emitMod e1 e2 = do
|
||||
-- -- -- `let m a b = rem (abs $ b + a) b`
|
||||
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||
-- -- increaseVarCount
|
||||
-- -- vadd <- gets variableCount
|
||||
-- -- emit $ SetVariable $ Ident $ show vadd
|
||||
-- -- emit $ Add I64 v1 v2
|
||||
-- --
|
||||
-- -- increaseVarCount
|
||||
-- -- vabs <- gets variableCount
|
||||
-- -- emit $ SetVariable $ Ident $ show vabs
|
||||
-- -- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||
-- -- [ (I64, VIdent (Ident $ show vadd))
|
||||
-- -- , (I1, VInteger 1)
|
||||
-- -- ]
|
||||
-- -- increaseVarCount
|
||||
-- -- v <- gets variableCount
|
||||
-- -- emit $ SetVariable $ Ident $ show v
|
||||
-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||
--
|
||||
-- -- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||
-- -- emitDiv e1 e2 = do
|
||||
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||
-- -- increaseVarCount
|
||||
-- -- v <- gets variableCount
|
||||
-- -- emit $ SetVariable $ Ident $ show v
|
||||
-- -- emit $ Div I64 v1 v2
|
||||
--
|
||||
-- exprToValue :: Exp -> CompilerState LLVMValue
|
||||
-- exprToValue = \case
|
||||
-- EInt i -> pure $ VInteger i
|
||||
--
|
||||
-- EId id@(name, t) -> do
|
||||
-- funcs <- gets functions
|
||||
-- case Map.lookup id funcs of
|
||||
-- Just fi -> do
|
||||
-- if numArgs fi == 0
|
||||
-- then do
|
||||
-- vc <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show vc)
|
||||
-- (Call FastCC (type2LlvmType t) Global name [])
|
||||
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||
-- else pure $ VFunction name Global (type2LlvmType t)
|
||||
-- Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||
--
|
||||
-- e -> do
|
||||
-- compileExp e
|
||||
-- v <- getVarCount
|
||||
-- pure $ VIdent (Ident $ show v) (getType e)
|
||||
--
|
||||
-- type2LlvmType :: Type -> LLVMType
|
||||
-- type2LlvmType = \case
|
||||
-- TInt -> I64
|
||||
-- TFun t xs -> do
|
||||
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||
-- Function t' xs'
|
||||
-- TPol t -> CustomType t
|
||||
-- where
|
||||
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
-- function2LLVMType x s = (type2LlvmType x, s)
|
||||
--
|
||||
-- getType :: Exp -> LLVMType
|
||||
-- getType (EInt _) = I64
|
||||
-- getType (EAdd t _ _) = type2LlvmType t
|
||||
-- getType (ESub t _ _) = type2LlvmType t
|
||||
-- getType (EId (_, t)) = type2LlvmType t
|
||||
-- getType (EApp t _ _) = type2LlvmType t
|
||||
-- getType (EAbs t _ _) = type2LlvmType t
|
||||
-- getType (ELet _ e) = getType e
|
||||
-- getType (ECase t _ _) = type2LlvmType t
|
||||
--
|
||||
-- valueGetType :: LLVMValue -> LLVMType
|
||||
-- valueGetType (VInteger _) = I64
|
||||
-- valueGetType (VIdent _ t) = t
|
||||
-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
||||
-- valueGetType (VFunction _ _ t) = t
|
||||
--
|
||||
-- typeByteSize :: LLVMType -> Integer
|
||||
-- typeByteSize I1 = 1
|
||||
-- typeByteSize I8 = 1
|
||||
-- typeByteSize I32 = 4
|
||||
-- typeByteSize I64 = 8
|
||||
-- typeByteSize Ptr = 8
|
||||
-- typeByteSize (Ref _) = 8
|
||||
-- typeByteSize (Function _ _) = 8
|
||||
-- typeByteSize (Array n t) = n * typeByteSize t
|
||||
-- typeByteSize (CustomType _) = 8
|
||||
--
|
||||
-- increaseVarCount
|
||||
-- vabs <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show vabs
|
||||
-- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||
-- [ (I64, VIdent (Ident $ show vadd))
|
||||
-- , (I1, VInteger 1)
|
||||
-- ]
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||
|
||||
-- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||
-- emitDiv e1 e2 = do
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Div I64 v1 v2
|
||||
|
||||
exprToValue :: Exp -> CompilerState LLVMValue
|
||||
exprToValue = \case
|
||||
EInt i -> pure $ VInteger i
|
||||
|
||||
EId id@(name, t) -> do
|
||||
funcs <- gets functions
|
||||
case Map.lookup id funcs of
|
||||
Just fi -> do
|
||||
if numArgs fi == 0
|
||||
then do
|
||||
vc <- getNewVar
|
||||
emit $ SetVariable (Ident $ show vc)
|
||||
(Call FastCC (type2LlvmType t) Global name [])
|
||||
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||
else pure $ VFunction name Global (type2LlvmType t)
|
||||
Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||
|
||||
e -> do
|
||||
compileExp e
|
||||
v <- getVarCount
|
||||
pure $ VIdent (Ident $ show v) (getType e)
|
||||
|
||||
type2LlvmType :: Type -> LLVMType
|
||||
type2LlvmType = \case
|
||||
TInt -> I64
|
||||
TFun t xs -> do
|
||||
let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||
Function t' xs'
|
||||
TPol t -> CustomType t
|
||||
where
|
||||
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
|
||||
getType :: Exp -> LLVMType
|
||||
getType (EInt _) = I64
|
||||
getType (EAdd t _ _) = type2LlvmType t
|
||||
getType (ESub t _ _) = type2LlvmType t
|
||||
getType (EId (_, t)) = type2LlvmType t
|
||||
getType (EApp t _ _) = type2LlvmType t
|
||||
getType (EAbs t _ _) = type2LlvmType t
|
||||
getType (ELet _ e) = getType e
|
||||
getType (ECase t _ _) = type2LlvmType t
|
||||
|
||||
valueGetType :: LLVMValue -> LLVMType
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
||||
valueGetType (VFunction _ _ t) = t
|
||||
|
||||
typeByteSize :: LLVMType -> Integer
|
||||
typeByteSize I1 = 1
|
||||
typeByteSize I8 = 1
|
||||
typeByteSize I32 = 4
|
||||
typeByteSize I64 = 8
|
||||
typeByteSize Ptr = 8
|
||||
typeByteSize (Ref _) = 8
|
||||
typeByteSize (Function _ _) = 8
|
||||
typeByteSize (Array n t) = n * typeByteSize t
|
||||
typeByteSize (CustomType _) = 8
|
||||
|
|
|
|||
|
|
@ -1,239 +1,241 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Codegen.LlvmIr (
|
||||
LLVMType (..),
|
||||
LLVMIr (..),
|
||||
llvmIrToString,
|
||||
LLVMValue (..),
|
||||
LLVMComp (..),
|
||||
Visibility (..),
|
||||
CallingConvention (..)
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
data CallingConvention = TailCC | FastCC | CCC | ColdCC
|
||||
instance Show CallingConvention where
|
||||
show :: CallingConvention -> String
|
||||
show TailCC = "tailcc"
|
||||
show FastCC = "fastcc"
|
||||
show CCC = "ccc"
|
||||
show ColdCC = "coldcc"
|
||||
|
||||
-- | A datatype which represents some basic LLVM types
|
||||
data LLVMType
|
||||
= I1
|
||||
| I8
|
||||
| I32
|
||||
| I64
|
||||
| Ptr
|
||||
| Ref LLVMType
|
||||
| Function LLVMType [LLVMType]
|
||||
| Array Integer LLVMType
|
||||
| CustomType Ident
|
||||
|
||||
instance Show LLVMType where
|
||||
show :: LLVMType -> String
|
||||
show = \case
|
||||
I1 -> "i1"
|
||||
I8 -> "i8"
|
||||
I32 -> "i32"
|
||||
I64 -> "i64"
|
||||
Ptr -> "ptr"
|
||||
Ref ty -> show ty <> "*"
|
||||
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
|
||||
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
|
||||
CustomType (Ident ty) -> "%" <> ty
|
||||
|
||||
data LLVMComp
|
||||
= LLEq
|
||||
| LLNe
|
||||
| LLUgt
|
||||
| LLUge
|
||||
| LLUlt
|
||||
| LLUle
|
||||
| LLSgt
|
||||
| LLSge
|
||||
| LLSlt
|
||||
| LLSle
|
||||
instance Show LLVMComp where
|
||||
show :: LLVMComp -> String
|
||||
show = \case
|
||||
LLEq -> "eq"
|
||||
LLNe -> "ne"
|
||||
LLUgt -> "ugt"
|
||||
LLUge -> "uge"
|
||||
LLUlt -> "ult"
|
||||
LLUle -> "ule"
|
||||
LLSgt -> "sgt"
|
||||
LLSge -> "sge"
|
||||
LLSlt -> "slt"
|
||||
LLSle -> "sle"
|
||||
|
||||
data Visibility = Local | Global
|
||||
instance Show Visibility where
|
||||
show :: Visibility -> String
|
||||
show Local = "%"
|
||||
show Global = "@"
|
||||
|
||||
-- | Represents a LLVM "value", as in an integer, a register variable,
|
||||
-- or a string contstant
|
||||
data LLVMValue
|
||||
= VInteger Integer
|
||||
| VIdent Ident LLVMType
|
||||
| VConstant String
|
||||
| VFunction Ident Visibility LLVMType
|
||||
|
||||
instance Show LLVMValue where
|
||||
show :: LLVMValue -> String
|
||||
show v = case v of
|
||||
VInteger i -> show i
|
||||
VIdent (Ident n) _ -> "%" <> n
|
||||
VFunction (Ident n) vis _ -> show vis <> n
|
||||
VConstant s -> "c" <> show s
|
||||
|
||||
type Params = [(Ident, LLVMType)]
|
||||
type Args = [(LLVMType, LLVMValue)]
|
||||
|
||||
-- | A datatype which represents different instructions in LLVM
|
||||
data LLVMIr
|
||||
= Type Ident [LLVMType]
|
||||
| Define CallingConvention LLVMType Ident Params
|
||||
| DefineEnd
|
||||
| Declare LLVMType Ident Params
|
||||
| SetVariable Ident LLVMIr
|
||||
| Variable Ident
|
||||
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
|
||||
| Add LLVMType LLVMValue LLVMValue
|
||||
| Sub LLVMType LLVMValue LLVMValue
|
||||
| Div LLVMType LLVMValue LLVMValue
|
||||
| Mul LLVMType LLVMValue LLVMValue
|
||||
| Srem LLVMType LLVMValue LLVMValue
|
||||
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
|
||||
| Br Ident
|
||||
| BrCond LLVMValue Ident Ident
|
||||
| Label Ident
|
||||
| Call CallingConvention LLVMType Visibility Ident Args
|
||||
| Alloca LLVMType
|
||||
| Store LLVMType LLVMValue LLVMType Ident
|
||||
| Load LLVMType LLVMType Ident
|
||||
| Bitcast LLVMType Ident LLVMType
|
||||
| Ret LLVMType LLVMValue
|
||||
| Comment String
|
||||
| UnsafeRaw String -- This should generally be avoided, and proper
|
||||
-- instructions should be used in its place
|
||||
deriving (Show)
|
||||
|
||||
-- | Converts a list of LLVMIr instructions to a string
|
||||
llvmIrToString :: [LLVMIr] -> String
|
||||
llvmIrToString = go 0
|
||||
where
|
||||
go :: Int -> [LLVMIr] -> String
|
||||
go _ [] = mempty
|
||||
go i (x : xs) = do
|
||||
let (i', n) = case x of
|
||||
Define{} -> (i + 1, 0)
|
||||
DefineEnd -> (i - 1, 0)
|
||||
_ -> (i, i)
|
||||
insToString n x <> go i' xs
|
||||
|
||||
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
|
||||
The integer represents the indentation
|
||||
-}
|
||||
{- FOURMOLU_DISABLE -}
|
||||
insToString :: Int -> LLVMIr -> String
|
||||
insToString i l =
|
||||
replicate i '\t' <> case l of
|
||||
(GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
|
||||
-- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
|
||||
concat
|
||||
[ "getelementptr inbounds ", show t1, ", " , show t2
|
||||
, " ", show p, ", ", show t3, " ", show v1,
|
||||
", ", show t4, " ", show v2, "\n" ]
|
||||
(Type (Ident n) types) ->
|
||||
concat
|
||||
[ "%", n, " = type { "
|
||||
, intercalate ", " (map show types)
|
||||
, " }\n"
|
||||
]
|
||||
(Define c t (Ident i) params) ->
|
||||
concat
|
||||
[ "define ", show c, " ", show t, " @", i
|
||||
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
|
||||
, ") {\n"
|
||||
]
|
||||
DefineEnd -> "}\n"
|
||||
(Declare _t (Ident _i) _params) -> undefined
|
||||
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
|
||||
(Add t v1 v2) ->
|
||||
concat
|
||||
[ "add ", show t, " ", show v1
|
||||
, ", ", show v2, "\n"
|
||||
]
|
||||
(Sub t v1 v2) ->
|
||||
concat
|
||||
[ "sub ", show t, " ", show v1, ", "
|
||||
, show v2, "\n"
|
||||
]
|
||||
(Div t v1 v2) ->
|
||||
concat
|
||||
[ "sdiv ", show t, " ", show v1, ", "
|
||||
, show v2, "\n"
|
||||
]
|
||||
(Mul t v1 v2) ->
|
||||
concat
|
||||
[ "mul ", show t, " ", show v1
|
||||
, ", ", show v2, "\n"
|
||||
]
|
||||
(Srem t v1 v2) ->
|
||||
concat
|
||||
[ "srem ", show t, " ", show v1, ", "
|
||||
, show v2, "\n"
|
||||
]
|
||||
(Call c t vis (Ident i) arg) ->
|
||||
concat
|
||||
[ "call ", show c, " ", show t, " ", show vis, i, "("
|
||||
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
|
||||
, ")\n"
|
||||
]
|
||||
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||
(Store t1 val t2 (Ident id2)) ->
|
||||
concat
|
||||
[ "store ", show t1, " ", show val
|
||||
, ", ", show t2 , " %", id2, "\n"
|
||||
]
|
||||
(Load t1 t2 (Ident addr)) ->
|
||||
concat
|
||||
[ "load ", show t1, ", "
|
||||
, show t2, " %", addr, "\n"
|
||||
]
|
||||
(Bitcast t1 (Ident i) t2) ->
|
||||
concat
|
||||
[ "bitcast ", show t1, " %"
|
||||
, i, " to ", show t2, "\n"
|
||||
]
|
||||
(Icmp comp t v1 v2) ->
|
||||
concat
|
||||
[ "icmp ", show comp, " ", show t
|
||||
, " ", show v1, ", ", show v2, "\n"
|
||||
]
|
||||
(Ret t v) ->
|
||||
concat
|
||||
[ "ret ", show t, " "
|
||||
, show v, "\n"
|
||||
]
|
||||
(UnsafeRaw s) -> s
|
||||
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||
(BrCond val (Ident s1) (Ident s2)) ->
|
||||
concat
|
||||
[ "br i1 ", show val, ", ", "label %"
|
||||
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||
]
|
||||
(Comment s) -> "; " <> s <> "\n"
|
||||
(Variable (Ident id)) -> "%" <> id
|
||||
{- FOURMOLU_ENABLE -}
|
||||
|
||||
lblPfx :: String
|
||||
lblPfx = "lbl_"
|
||||
module Codegen.LlvmIr where
|
||||
-- {-# LANGUAGE LambdaCase #-}
|
||||
--
|
||||
-- module Codegen.LlvmIr (
|
||||
-- LLVMType (..),
|
||||
-- LLVMIr (..),
|
||||
-- llvmIrToString,
|
||||
-- LLVMValue (..),
|
||||
-- LLVMComp (..),
|
||||
-- Visibility (..),
|
||||
-- CallingConvention (..)
|
||||
-- ) where
|
||||
--
|
||||
-- import Data.List (intercalate)
|
||||
-- import TypeChecker.TypeCheckerIr
|
||||
--
|
||||
-- data CallingConvention = TailCC | FastCC | CCC | ColdCC
|
||||
-- instance Show CallingConvention where
|
||||
-- show :: CallingConvention -> String
|
||||
-- show TailCC = "tailcc"
|
||||
-- show FastCC = "fastcc"
|
||||
-- show CCC = "ccc"
|
||||
-- show ColdCC = "coldcc"
|
||||
--
|
||||
-- -- | A datatype which represents some basic LLVM types
|
||||
-- data LLVMType
|
||||
-- = I1
|
||||
-- | I8
|
||||
-- | I32
|
||||
-- | I64
|
||||
-- | Ptr
|
||||
-- | Ref LLVMType
|
||||
-- | Function LLVMType [LLVMType]
|
||||
-- | Array Integer LLVMType
|
||||
-- | CustomType Ident
|
||||
--
|
||||
-- instance Show LLVMType where
|
||||
-- show :: LLVMType -> String
|
||||
-- show = \case
|
||||
-- I1 -> "i1"
|
||||
-- I8 -> "i8"
|
||||
-- I32 -> "i32"
|
||||
-- I64 -> "i64"
|
||||
-- Ptr -> "ptr"
|
||||
-- Ref ty -> show ty <> "*"
|
||||
-- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
|
||||
-- Array n ty -> concat ["[", show n, " x ", show ty, "]"]
|
||||
-- CustomType (Ident ty) -> "%" <> ty
|
||||
--
|
||||
-- data LLVMComp
|
||||
-- = LLEq
|
||||
-- | LLNe
|
||||
-- | LLUgt
|
||||
-- | LLUge
|
||||
-- | LLUlt
|
||||
-- | LLUle
|
||||
-- | LLSgt
|
||||
-- | LLSge
|
||||
-- | LLSlt
|
||||
-- | LLSle
|
||||
-- instance Show LLVMComp where
|
||||
-- show :: LLVMComp -> String
|
||||
-- show = \case
|
||||
-- LLEq -> "eq"
|
||||
-- LLNe -> "ne"
|
||||
-- LLUgt -> "ugt"
|
||||
-- LLUge -> "uge"
|
||||
-- LLUlt -> "ult"
|
||||
-- LLUle -> "ule"
|
||||
-- LLSgt -> "sgt"
|
||||
-- LLSge -> "sge"
|
||||
-- LLSlt -> "slt"
|
||||
-- LLSle -> "sle"
|
||||
--
|
||||
-- data Visibility = Local | Global
|
||||
-- instance Show Visibility where
|
||||
-- show :: Visibility -> String
|
||||
-- show Local = "%"
|
||||
-- show Global = "@"
|
||||
--
|
||||
-- -- | Represents a LLVM "value", as in an integer, a register variable,
|
||||
-- -- or a string contstant
|
||||
-- data LLVMValue
|
||||
-- = VInteger Integer
|
||||
-- | VIdent Ident LLVMType
|
||||
-- | VConstant String
|
||||
-- | VFunction Ident Visibility LLVMType
|
||||
--
|
||||
-- instance Show LLVMValue where
|
||||
-- show :: LLVMValue -> String
|
||||
-- show v = case v of
|
||||
-- VInteger i -> show i
|
||||
-- VIdent (Ident n) _ -> "%" <> n
|
||||
-- VFunction (Ident n) vis _ -> show vis <> n
|
||||
-- VConstant s -> "c" <> show s
|
||||
--
|
||||
-- type Params = [(Ident, LLVMType)]
|
||||
-- type Args = [(LLVMType, LLVMValue)]
|
||||
--
|
||||
-- -- | A datatype which represents different instructions in LLVM
|
||||
-- data LLVMIr
|
||||
-- = Type Ident [LLVMType]
|
||||
-- | Define CallingConvention LLVMType Ident Params
|
||||
-- | DefineEnd
|
||||
-- | Declare LLVMType Ident Params
|
||||
-- | SetVariable Ident LLVMIr
|
||||
-- | Variable Ident
|
||||
-- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
|
||||
-- | Add LLVMType LLVMValue LLVMValue
|
||||
-- | Sub LLVMType LLVMValue LLVMValue
|
||||
-- | Div LLVMType LLVMValue LLVMValue
|
||||
-- | Mul LLVMType LLVMValue LLVMValue
|
||||
-- | Srem LLVMType LLVMValue LLVMValue
|
||||
-- | Icmp LLVMComp LLVMType LLVMValue LLVMValue
|
||||
-- | Br Ident
|
||||
-- | BrCond LLVMValue Ident Ident
|
||||
-- | Label Ident
|
||||
-- | Call CallingConvention LLVMType Visibility Ident Args
|
||||
-- | Alloca LLVMType
|
||||
-- | Store LLVMType LLVMValue LLVMType Ident
|
||||
-- | Load LLVMType LLVMType Ident
|
||||
-- | Bitcast LLVMType Ident LLVMType
|
||||
-- | Ret LLVMType LLVMValue
|
||||
-- | Comment String
|
||||
-- | UnsafeRaw String -- This should generally be avoided, and proper
|
||||
-- -- instructions should be used in its place
|
||||
-- deriving (Show)
|
||||
--
|
||||
-- -- | Converts a list of LLVMIr instructions to a string
|
||||
-- llvmIrToString :: [LLVMIr] -> String
|
||||
-- llvmIrToString = go 0
|
||||
-- where
|
||||
-- go :: Int -> [LLVMIr] -> String
|
||||
-- go _ [] = mempty
|
||||
-- go i (x : xs) = do
|
||||
-- let (i', n) = case x of
|
||||
-- Define{} -> (i + 1, 0)
|
||||
-- DefineEnd -> (i - 1, 0)
|
||||
-- _ -> (i, i)
|
||||
-- insToString n x <> go i' xs
|
||||
--
|
||||
-- {- | Converts a LLVM inststruction to a String, allowing for printing etc.
|
||||
-- The integer represents the indentation
|
||||
-- -}
|
||||
-- {- FOURMOLU_DISABLE -}
|
||||
-- insToString :: Int -> LLVMIr -> String
|
||||
-- insToString i l =
|
||||
-- replicate i '\t' <> case l of
|
||||
-- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
|
||||
-- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
|
||||
-- concat
|
||||
-- [ "getelementptr inbounds ", show t1, ", " , show t2
|
||||
-- , " ", show p, ", ", show t3, " ", show v1,
|
||||
-- ", ", show t4, " ", show v2, "\n" ]
|
||||
-- (Type (Ident n) types) ->
|
||||
-- concat
|
||||
-- [ "%", n, " = type { "
|
||||
-- , intercalate ", " (map show types)
|
||||
-- , " }\n"
|
||||
-- ]
|
||||
-- (Define c t (Ident i) params) ->
|
||||
-- concat
|
||||
-- [ "define ", show c, " ", show t, " @", i
|
||||
-- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
|
||||
-- , ") {\n"
|
||||
-- ]
|
||||
-- DefineEnd -> "}\n"
|
||||
-- (Declare _t (Ident _i) _params) -> undefined
|
||||
-- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
|
||||
-- (Add t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "add ", show t, " ", show v1
|
||||
-- , ", ", show v2, "\n"
|
||||
-- ]
|
||||
-- (Sub t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "sub ", show t, " ", show v1, ", "
|
||||
-- , show v2, "\n"
|
||||
-- ]
|
||||
-- (Div t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "sdiv ", show t, " ", show v1, ", "
|
||||
-- , show v2, "\n"
|
||||
-- ]
|
||||
-- (Mul t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "mul ", show t, " ", show v1
|
||||
-- , ", ", show v2, "\n"
|
||||
-- ]
|
||||
-- (Srem t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "srem ", show t, " ", show v1, ", "
|
||||
-- , show v2, "\n"
|
||||
-- ]
|
||||
-- (Call c t vis (Ident i) arg) ->
|
||||
-- concat
|
||||
-- [ "call ", show c, " ", show t, " ", show vis, i, "("
|
||||
-- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
|
||||
-- , ")\n"
|
||||
-- ]
|
||||
-- (Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||
-- (Store t1 val t2 (Ident id2)) ->
|
||||
-- concat
|
||||
-- [ "store ", show t1, " ", show val
|
||||
-- , ", ", show t2 , " %", id2, "\n"
|
||||
-- ]
|
||||
-- (Load t1 t2 (Ident addr)) ->
|
||||
-- concat
|
||||
-- [ "load ", show t1, ", "
|
||||
-- , show t2, " %", addr, "\n"
|
||||
-- ]
|
||||
-- (Bitcast t1 (Ident i) t2) ->
|
||||
-- concat
|
||||
-- [ "bitcast ", show t1, " %"
|
||||
-- , i, " to ", show t2, "\n"
|
||||
-- ]
|
||||
-- (Icmp comp t v1 v2) ->
|
||||
-- concat
|
||||
-- [ "icmp ", show comp, " ", show t
|
||||
-- , " ", show v1, ", ", show v2, "\n"
|
||||
-- ]
|
||||
-- (Ret t v) ->
|
||||
-- concat
|
||||
-- [ "ret ", show t, " "
|
||||
-- , show v, "\n"
|
||||
-- ]
|
||||
-- (UnsafeRaw s) -> s
|
||||
-- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||
-- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||
-- (BrCond val (Ident s1) (Ident s2)) ->
|
||||
-- concat
|
||||
-- [ "br i1 ", show val, ", ", "label %"
|
||||
-- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||
-- ]
|
||||
-- (Comment s) -> "; " <> s <> "\n"
|
||||
-- (Variable (Ident id)) -> "%" <> id
|
||||
-- {- FOURMOLU_ENABLE -}
|
||||
--
|
||||
-- lblPfx :: String
|
||||
-- lblPfx = "lbl_"
|
||||
--
|
||||
|
|
|
|||
|
|
@ -1,235 +1,192 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
--{-# LANGUAGE LambdaCase #-}
|
||||
--{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
|
||||
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
||||
module LambdaLifter.LambdaLifter where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad.State (MonadState (get, put), State,
|
||||
evalState)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Debug.Trace (trace)
|
||||
import qualified Grammar.Abs as GA
|
||||
import Prelude hiding (exp)
|
||||
import Renamer.Renamer
|
||||
import TypeChecker.TypeCheckerIr
|
||||
--import Auxiliary (snoc)
|
||||
--import Control.Applicative (Applicative (liftA2))
|
||||
--import Control.Monad.State (MonadState (get, put), State,
|
||||
-- evalState)
|
||||
--import Data.Set (Set)
|
||||
--import qualified Data.Set as Set
|
||||
--import Prelude hiding (exp)
|
||||
--import Renamer.Renamer
|
||||
--import TypeChecker.TypeCheckerIr
|
||||
|
||||
|
||||
-- | Lift lambdas and let expression into supercombinators.
|
||||
-- Three phases:
|
||||
-- @freeVars@ annotatss all the free variables.
|
||||
-- @abstract@ converts lambdas into let expressions.
|
||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
lambdaLift :: Program -> Program
|
||||
lambdaLift = collectScs . abstract . freeVars
|
||||
|
||||
-- | Annotate free variables
|
||||
freeVars :: Program -> AnnProgram
|
||||
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||
| Bind n xs e <- ds
|
||||
]
|
||||
|
||||
freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||
freeVarsExp localVars = \case
|
||||
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||
| otherwise -> (mempty, AId n)
|
||||
|
||||
EInt i -> (mempty, AInt i)
|
||||
|
||||
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
|
||||
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
|
||||
ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2')
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
---- | Lift lambdas and let expression into supercombinators.
|
||||
---- Three phases:
|
||||
---- @freeVars@ annotatss all the free variables.
|
||||
---- @abstract@ converts lambdas into let expressions.
|
||||
---- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
--lambdaLift :: Program -> Program
|
||||
--lambdaLift = collectScs . abstract . freeVars
|
||||
|
||||
|
||||
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||
where
|
||||
e' = freeVarsExp (Set.insert par localVars) e
|
||||
---- | Annotate free variables
|
||||
--freeVars :: Program -> AnnProgram
|
||||
--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||
-- | Bind n xs e <- ds
|
||||
-- ]
|
||||
|
||||
-- Sum free variables present in bind and the expression
|
||||
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||
where
|
||||
binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||
e_free = Set.delete name $ freeVarsOf e'
|
||||
--freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||
--freeVarsExp localVars = \case
|
||||
-- EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||
-- | otherwise -> (mempty, AId n)
|
||||
|
||||
rhs' = freeVarsExp e_localVars rhs
|
||||
new_bind = ABind name parms rhs'
|
||||
-- ELit _ (LInt i) -> (mempty, AInt i)
|
||||
|
||||
e' = freeVarsExp e_localVars e
|
||||
e_localVars = Set.insert name localVars
|
||||
-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||
-- where
|
||||
-- e1' = freeVarsExp localVars e1
|
||||
-- e2' = freeVarsExp localVars e2
|
||||
|
||||
(ECase t e cs) -> do
|
||||
let e' = freeVarsExp localVars e
|
||||
let vars = freeVarsOf e'
|
||||
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do
|
||||
let e' = freeVarsExp vars e
|
||||
let vars' = freeVarsOf e'
|
||||
(Set.union vars vars', AnnCase c e' : acc)
|
||||
) (vars, []) cs
|
||||
(vars', ACase t e' (reverse cs'))
|
||||
-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||
-- where
|
||||
-- e1' = freeVarsExp localVars e1
|
||||
-- e2' = freeVarsExp localVars e2
|
||||
|
||||
-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||
-- where
|
||||
-- e' = freeVarsExp (Set.insert par localVars) e
|
||||
|
||||
-- -- Sum free variables present in bind and the expression
|
||||
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||
-- where
|
||||
-- binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||
-- e_free = Set.delete name $ freeVarsOf e'
|
||||
|
||||
-- rhs' = freeVarsExp e_localVars rhs
|
||||
-- new_bind = ABind name parms rhs'
|
||||
|
||||
-- e' = freeVarsExp e_localVars e
|
||||
-- e_localVars = Set.insert name localVars
|
||||
|
||||
|
||||
freeVarsOf :: AnnExp -> Set Id
|
||||
freeVarsOf = fst
|
||||
--freeVarsOf :: AnnExp -> Set Id
|
||||
--freeVarsOf = fst
|
||||
|
||||
-- AST annotated with free variables
|
||||
type AnnProgram = [(Id, [Id], AnnExp)]
|
||||
---- AST annotated with free variables
|
||||
--type AnnProgram = [(Id, [Id], AnnExp)]
|
||||
|
||||
type AnnExp = (Set Id, AnnExp')
|
||||
--type AnnExp = (Set Id, AnnExp')
|
||||
|
||||
data ABind = ABind Id [Id] AnnExp deriving Show
|
||||
--data ABind = ABind Id [Id] AnnExp deriving Show
|
||||
|
||||
data AnnExp' = AId Id
|
||||
| AInt Integer
|
||||
| ALet ABind AnnExp
|
||||
| AApp Type AnnExp AnnExp
|
||||
| AAdd Type AnnExp AnnExp
|
||||
| ASub Type AnnExp AnnExp
|
||||
| AAbs Type Id AnnExp
|
||||
| ACase Type AnnExp [AnnCase]
|
||||
deriving Show
|
||||
data AnnCase = AnnCase GA.Case AnnExp
|
||||
deriving Show
|
||||
--data AnnExp' = AId Id
|
||||
-- | AInt Integer
|
||||
-- | ALet ABind AnnExp
|
||||
-- | AApp Type AnnExp AnnExp
|
||||
-- | AAdd Type AnnExp AnnExp
|
||||
-- | AAbs Type Id AnnExp
|
||||
-- deriving Show
|
||||
---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
---- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||
--abstract :: AnnProgram -> Program
|
||||
--abstract prog = Program $ evalState (mapM go prog) 0
|
||||
-- where
|
||||
-- go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||
-- where
|
||||
-- (rhs', parms1) = flattenLambdasAnn rhs
|
||||
|
||||
|
||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||
abstract :: AnnProgram -> Program
|
||||
abstract prog = Program $ evalState (mapM go prog) 0
|
||||
where
|
||||
go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||
where
|
||||
(rhs', parms1) = flattenLambdasAnn rhs
|
||||
---- | Flatten nested lambdas and collect the parameters
|
||||
---- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||
--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
||||
--flattenLambdasAnn ae = go (ae, [])
|
||||
-- where
|
||||
-- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||
-- go ((free, e), acc) =
|
||||
-- case e of
|
||||
-- AAbs _ par (free1, e1) ->
|
||||
-- go ((Set.delete par free1, e1), snoc par acc)
|
||||
-- _ -> ((free, e), acc)
|
||||
|
||||
--abstractExp :: AnnExp -> State Int Exp
|
||||
--abstractExp (free, exp) = case exp of
|
||||
-- AId n -> pure $ EId n
|
||||
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
|
||||
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||
-- where
|
||||
-- go (ABind name parms rhs) = do
|
||||
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||
-- pure $ Bind name (parms ++ parms1) rhs'
|
||||
|
||||
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||
-- skipLambdas f (free, ae) = case ae of
|
||||
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||
-- _ -> f (free, ae)
|
||||
|
||||
-- -- Lift lambda into let and bind free variables
|
||||
-- AAbs t parm e -> do
|
||||
-- i <- nextNumber
|
||||
-- rhs <- abstractExp e
|
||||
|
||||
-- let sc_name = Ident ("sc_" ++ show i)
|
||||
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||
|
||||
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
|
||||
-- where
|
||||
-- freeList = Set.toList free
|
||||
-- parms = snoc parm freeList
|
||||
|
||||
|
||||
-- | Flatten nested lambdas and collect the parameters
|
||||
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
||||
flattenLambdasAnn ae = go (ae, [])
|
||||
where
|
||||
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||
go ((free, e), acc) =
|
||||
case e of
|
||||
AAbs _ par (free1, e1) ->
|
||||
go ((Set.delete par free1, e1), snoc par acc)
|
||||
_ -> ((free, e), acc)
|
||||
--nextNumber :: State Int Int
|
||||
--nextNumber = do
|
||||
-- i <- get
|
||||
-- put $ succ i
|
||||
-- pure i
|
||||
|
||||
abstractExp :: AnnExp -> State Int Exp
|
||||
abstractExp (free, exp) = case exp of
|
||||
AId n -> pure $ EId n
|
||||
AInt i -> pure $ EInt i
|
||||
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||
ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2)
|
||||
ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||
where
|
||||
go (ABind name parms rhs) = do
|
||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||
pure $ Bind name (parms ++ parms1) rhs'
|
||||
|
||||
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||
skipLambdas f (free, ae) = case ae of
|
||||
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||
_ -> f (free, ae)
|
||||
---- | Collects supercombinators by lifting non-constant let expressions
|
||||
--collectScs :: Program -> Program
|
||||
--collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||
-- where
|
||||
-- collectFromRhs (Bind name parms rhs) =
|
||||
-- let (rhs_scs, rhs') = collectScsExp rhs
|
||||
-- in Bind name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
ACase t e cs -> do
|
||||
e' <- abstractExp e
|
||||
cs' <- mapM (\(AnnCase c e) -> do
|
||||
e' <- abstractExp e
|
||||
pure (t,Case c e')) cs
|
||||
pure $ ECase t e' cs'
|
||||
--collectScsExp :: Exp -> ([Bind], Exp)
|
||||
--collectScsExp = \case
|
||||
-- EId n -> ([], EId n)
|
||||
-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
|
||||
|
||||
-- Lift lambda into let and bind free variables
|
||||
AAbs t parm e -> do
|
||||
i <- nextNumber
|
||||
rhs <- abstractExp e
|
||||
-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||
-- where
|
||||
-- (scs1, e1') = collectScsExp e1
|
||||
-- (scs2, e2') = collectScsExp e2
|
||||
|
||||
let sc_name = Ident ("sc_" ++ show i)
|
||||
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||
-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||
-- where
|
||||
-- (scs1, e1') = collectScsExp e1
|
||||
-- (scs2, e2') = collectScsExp e2
|
||||
|
||||
pure $ foldl (EApp TInt) sc $ map EId freeList
|
||||
where
|
||||
freeList = Set.toList free
|
||||
parms = snoc parm freeList
|
||||
-- EAbs t par e -> (scs, EAbs t par e')
|
||||
-- where
|
||||
-- (scs, e') = collectScsExp e
|
||||
|
||||
-- -- Collect supercombinators from bind, the rhss, and the expression.
|
||||
-- --
|
||||
-- -- > f = let sc x y = rhs in e
|
||||
-- --
|
||||
-- ELet (Bind name parms rhs) e -> if null parms
|
||||
-- then ( rhs_scs ++ e_scs, ELet bind e')
|
||||
-- else (bind : rhs_scs ++ e_scs, e')
|
||||
-- where
|
||||
-- bind = Bind name parms rhs'
|
||||
-- (rhs_scs, rhs') = collectScsExp rhs
|
||||
-- (e_scs, e') = collectScsExp e
|
||||
|
||||
|
||||
nextNumber :: State Int Int
|
||||
nextNumber = do
|
||||
i <- get
|
||||
put $ succ i
|
||||
pure i
|
||||
---- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
--flattenLambdas :: Exp -> (Exp, [Id])
|
||||
--flattenLambdas = go . (, [])
|
||||
-- where
|
||||
-- go (e, acc) = case e of
|
||||
-- EAbs _ par e1 -> go (e1, snoc par acc)
|
||||
-- _ -> (e, acc)
|
||||
|
||||
-- | Collects supercombinators by lifting non-constant let expressions
|
||||
collectScs :: Program -> Program
|
||||
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||
where
|
||||
collectFromRhs (Bind name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in Bind name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
collectScsExp :: Exp -> ([Bind], Exp)
|
||||
collectScsExp = \case
|
||||
EId n -> ([], EId n)
|
||||
EInt i -> ([], EInt i)
|
||||
|
||||
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2')
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAbs t par e -> (scs, EAbs t par e')
|
||||
where
|
||||
(scs, e') = collectScsExp e
|
||||
|
||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
--
|
||||
ELet (Bind name parms rhs) e -> if null parms
|
||||
then ( rhs_scs ++ e_scs, ELet bind e')
|
||||
else (bind : rhs_scs ++ e_scs, e')
|
||||
where
|
||||
bind = Bind name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(e_scs, e') = collectScsExp e
|
||||
ECase t e cs -> do
|
||||
let (scs, e') = collectScsExp e
|
||||
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
|
||||
let (scs', e') = collectScsExp e
|
||||
(scs ++ scs', (t,Case c e') : acc)
|
||||
) (scs,[]) cs
|
||||
(scs', ECase t e' cs')
|
||||
|
||||
|
||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
flattenLambdas :: Exp -> (Exp, [Id])
|
||||
flattenLambdas = go . (, [])
|
||||
where
|
||||
go (e, acc) = case e of
|
||||
EAbs _ par e1 -> go (e1, snoc par acc)
|
||||
_ -> (e, acc)
|
||||
|
|
|
|||
58
src/Main.hs
58
src/Main.hs
|
|
@ -2,26 +2,26 @@
|
|||
|
||||
module Main where
|
||||
|
||||
import Codegen.Codegen (generateCode)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
--import Codegen.Codegen (generateCode)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
|
||||
-- import Interpreter (interpret)
|
||||
import Control.Monad (when)
|
||||
import Data.List.Extra (isSuffixOf)
|
||||
import LambdaLifter.LambdaLifter (lambdaLift)
|
||||
import Renamer.Renamer (rename)
|
||||
import System.Directory (createDirectory, doesPathExist,
|
||||
getDirectoryContents,
|
||||
removeDirectoryRecursive,
|
||||
setCurrentDirectory)
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (exitFailure, exitSuccess)
|
||||
import System.IO (stderr)
|
||||
import System.Process.Extra (spawnCommand, waitForProcess)
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
import Control.Monad (when)
|
||||
import Data.List.Extra (isSuffixOf)
|
||||
--import LambdaLifter.LambdaLifter (lambdaLift)
|
||||
import Renamer.Renamer (rename)
|
||||
import System.Directory (createDirectory, doesPathExist,
|
||||
getDirectoryContents,
|
||||
removeDirectoryRecursive,
|
||||
setCurrentDirectory)
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (exitFailure, exitSuccess)
|
||||
import System.IO (stderr)
|
||||
import System.Process.Extra (spawnCommand, waitForProcess)
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
|
||||
main :: IO ()
|
||||
main =
|
||||
|
|
@ -46,19 +46,19 @@ main' debug s = do
|
|||
typechecked <- fromTypeCheckerErr $ typecheck renamed
|
||||
printToErr $ printTree typechecked
|
||||
|
||||
printToErr "\n-- Lambda Lifter --"
|
||||
let lifted = lambdaLift typechecked
|
||||
printToErr $ printTree lifted
|
||||
|
||||
printToErr "\n -- Printing compiler output to stdout --"
|
||||
compiled <- fromCompilerErr $ generateCode lifted
|
||||
-- printToErr "\n-- Lambda Lifter --"
|
||||
-- let lifted = lambdaLift typechecked
|
||||
-- printToErr $ printTree lifted
|
||||
--
|
||||
-- printToErr "\n -- Printing compiler output to stdout --"
|
||||
-- compiled <- fromCompilerErr $ generateCode lifted
|
||||
--putStrLn compiled
|
||||
|
||||
check <- doesPathExist "output"
|
||||
when check (removeDirectoryRecursive "output")
|
||||
createDirectory "output"
|
||||
writeFile "output/llvm.ll" compiled
|
||||
if debug then debugDotViz else putStrLn compiled
|
||||
-- check <- doesPathExist "output"
|
||||
-- when check (removeDirectoryRecursive "output")
|
||||
-- createDirectory "output"
|
||||
-- writeFile "output/llvm.ll" compiled
|
||||
-- if debug then debugDotViz else putStrLn compiled
|
||||
|
||||
|
||||
-- interpred <- fromInterpreterErr $ interpret lifted
|
||||
|
|
|
|||
|
|
@ -1,86 +1,87 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Renamer.Renamer (module Renamer.Renamer) where
|
||||
module Renamer.Renamer where
|
||||
|
||||
import Auxiliary (mapAccumM)
|
||||
import Control.Monad (foldM)
|
||||
import Control.Monad.State (MonadState, State, evalState, gets,
|
||||
modify)
|
||||
import Data.List (foldl')
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Tuple.Extra (dupe)
|
||||
import Grammar.Abs
|
||||
|
||||
|
||||
-- | Rename all variables and local binds
|
||||
rename :: Program -> Program
|
||||
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
|
||||
where
|
||||
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
|
||||
renameSc :: Names -> Bind -> Rn Bind
|
||||
renameSc old_names (Bind name t _ parms rhs) = do
|
||||
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
|
||||
initNames = Map.fromList $ foldl' saveIfBind [] bs
|
||||
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
|
||||
saveIfBind acc _ = acc
|
||||
renameSc :: Names -> Def -> Rn Def
|
||||
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
|
||||
(new_names, parms') <- newNames old_names parms
|
||||
rhs' <- snd <$> renameExp new_names rhs
|
||||
pure $ Bind name t name parms' rhs'
|
||||
|
||||
rhs' <- snd <$> renameExp new_names rhs
|
||||
pure . DBind $ Bind name t name parms' rhs'
|
||||
renameSc _ def = pure def
|
||||
|
||||
-- | Rename monad. State holds the number of renamed names.
|
||||
newtype Rn a = Rn { runRn :: State Int a }
|
||||
deriving (Functor, Applicative, Monad, MonadState Int)
|
||||
newtype Rn a = Rn {runRn :: State Int a}
|
||||
deriving (Functor, Applicative, Monad, MonadState Int)
|
||||
|
||||
-- | Maps old to new name
|
||||
type Names = Map Ident Ident
|
||||
|
||||
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
|
||||
renameLocalBind old_names (Bind name t _ parms rhs) = do
|
||||
(new_names, name') <- newName old_names name
|
||||
(new_names, name') <- newName old_names name
|
||||
(new_names', parms') <- newNames new_names parms
|
||||
(new_names'', rhs') <- renameExp new_names' rhs
|
||||
(new_names'', rhs') <- renameExp new_names' rhs
|
||||
pure (new_names'', Bind name' t name' parms' rhs')
|
||||
|
||||
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
||||
renameExp old_names = \case
|
||||
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
|
||||
|
||||
EInt i1 -> pure (old_names, EInt i1)
|
||||
|
||||
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
|
||||
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
|
||||
EApp e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EApp e1' e2')
|
||||
|
||||
EAdd e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EAdd e1' e2')
|
||||
|
||||
ESub e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, ESub e1' e2')
|
||||
|
||||
ELet b e -> do
|
||||
(new_names, b) <- renameLocalBind old_names b
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', ELet b e')
|
||||
|
||||
EAbs par t e -> do
|
||||
ELet i e1 e2 -> do
|
||||
(new_names, e1') <- renameExp old_names e1
|
||||
(new_names', e2') <- renameExp new_names e2
|
||||
pure (new_names', ELet i e1' e2')
|
||||
EAbs par e -> do
|
||||
(new_names, par') <- newName old_names par
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', EAbs par' t e')
|
||||
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', EAbs par' e')
|
||||
EAnn e t -> do
|
||||
(new_names, e') <- renameExp old_names e
|
||||
pure (new_names, EAnn e' t)
|
||||
ECase e injs -> do
|
||||
(_, e') <- renameExp old_names e
|
||||
(new_names, injs') <- renameInjs old_names injs
|
||||
pure (new_names, ECase e' injs')
|
||||
|
||||
ECase e cs t -> do
|
||||
(new_names, e') <- renameExp old_names e
|
||||
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do
|
||||
(nm,exp') <- renameExp names exp
|
||||
pure (nm,CaseMatch c exp' : stack)
|
||||
) (new_names, []) cs
|
||||
pure (new_names', ECase e' cs' t)
|
||||
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
|
||||
renameInjs ns xs = do
|
||||
(new_names, xs') <- unzip <$> mapM (renameInj ns) xs
|
||||
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||
|
||||
renameInj :: Names -> Inj -> Rn (Names, Inj)
|
||||
renameInj ns (Inj init e) = do
|
||||
(new_names, e') <- renameExp ns e
|
||||
return (new_names, Inj init e')
|
||||
|
||||
-- | Create a new name and add it to name environment.
|
||||
newName :: Names -> Ident -> Rn (Names, Ident)
|
||||
|
|
@ -95,4 +96,3 @@ newNames = mapAccumM newName
|
|||
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
||||
makeName :: Ident -> Rn Ident
|
||||
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
|
||||
|
||||
|
|
|
|||
|
|
@ -1,215 +1,517 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module TypeChecker.TypeChecker (typecheck, partitionType) where
|
||||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
import Auxiliary (maybeToRightM, snoc)
|
||||
import Control.Monad.Except (throwError, unless)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.List (foldl')
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import qualified Data.Map as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (Print (prt), concatD, doc,
|
||||
printTree, render)
|
||||
import Prelude hiding (exp, id)
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
|
||||
Poly (..), Subst)
|
||||
|
||||
-- NOTE: this type checker is poorly tested
|
||||
initCtx = Ctx mempty
|
||||
|
||||
-- TODO
|
||||
-- Coercion
|
||||
-- Type inference
|
||||
initEnv = Env 0 mempty mempty
|
||||
|
||||
data Cxt = Cxt
|
||||
{ env :: Map Ident Type -- ^ Local scope signature
|
||||
, sig :: Map Ident Type -- ^ Top-level signatures
|
||||
}
|
||||
runPretty :: Exp -> Either Error String
|
||||
runPretty = fmap (printTree . fst) . run . inferExp
|
||||
|
||||
initCxt :: [Bind] -> Cxt
|
||||
initCxt sc = Cxt { env = mempty
|
||||
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
|
||||
}
|
||||
run :: Infer a -> Either Error a
|
||||
run = runC initEnv initCtx
|
||||
|
||||
typecheck :: Program -> Err T.Program
|
||||
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
|
||||
runC :: Env -> Ctx -> Infer a -> Either Error a
|
||||
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
|
||||
|
||||
-- | Check if infered rhs type matches type signature.
|
||||
checkBind :: Cxt -> Bind -> Err T.Bind
|
||||
checkBind cxt b =
|
||||
case expandLambdas b of
|
||||
Bind name t _ parms rhs -> do
|
||||
(rhs', t_rhs) <- infer cxt rhs
|
||||
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
|
||||
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
|
||||
where
|
||||
ts_parms = fst $ partitionType (length parms) t
|
||||
typecheck :: Program -> Either Error T.Program
|
||||
typecheck = run . checkPrg
|
||||
|
||||
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
|
||||
expandLambdas :: Bind -> Bind
|
||||
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
|
||||
{- | Start by freshening the type variable of data types to avoid clash with
|
||||
other user defined polymorphic types
|
||||
This might be wrong for type constructors that work over several variables
|
||||
-}
|
||||
freshenData :: Data -> Infer Data
|
||||
freshenData (Data (Constr name ts) constrs) = do
|
||||
fr <- fresh
|
||||
let fr' = case fr of
|
||||
TPol a -> a
|
||||
-- Meh, this part assumes fresh generates a polymorphic type
|
||||
_ ->
|
||||
error
|
||||
"Bug: implementation of \
|
||||
\ fresh and freshenData are not compatible"
|
||||
let new_ts = map (freshenType fr') ts
|
||||
let new_constrs = map (freshenConstr fr') constrs
|
||||
return $ Data (Constr name new_ts) new_constrs
|
||||
|
||||
{- | Freshen all polymorphic variables, regardless of name
|
||||
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
|
||||
-}
|
||||
freshenType :: Ident -> Type -> Type
|
||||
freshenType iden = \case
|
||||
(TPol _) -> TPol iden
|
||||
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
|
||||
(TConstr (Constr a ts)) ->
|
||||
TConstr (Constr a (map (freshenType iden) ts))
|
||||
rest -> rest
|
||||
|
||||
freshenConstr :: Ident -> Constructor -> Constructor
|
||||
freshenConstr iden (Constructor name t) =
|
||||
Constructor name (freshenType iden t)
|
||||
|
||||
checkData :: Data -> Infer ()
|
||||
checkData d = do
|
||||
d' <- freshenData d
|
||||
case d' of
|
||||
(Data typ@(Constr name ts) constrs) -> do
|
||||
unless
|
||||
(all isPoly ts)
|
||||
(throwError $ unwords ["Data type incorrectly declared"])
|
||||
traverse_
|
||||
( \(Constructor name' t') ->
|
||||
if TConstr typ == retType t'
|
||||
then insertConstr name' t'
|
||||
else
|
||||
throwError $
|
||||
unwords
|
||||
[ "return type of constructor:"
|
||||
, printTree name
|
||||
, "with type:"
|
||||
, printTree (retType t')
|
||||
, "does not match data: "
|
||||
, printTree typ
|
||||
]
|
||||
)
|
||||
constrs
|
||||
|
||||
retType :: Type -> Type
|
||||
retType (TArr _ t2) = retType t2
|
||||
retType a = a
|
||||
|
||||
checkPrg :: Program -> Infer T.Program
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
T.Program <$> checkDef bs
|
||||
where
|
||||
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
|
||||
ts_parms = fst $ partitionType (length parms) t
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x : xs) = case x of
|
||||
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
|
||||
DData d@(Data _ _) -> checkData d >> preRun xs
|
||||
|
||||
-- | Infer type of expression.
|
||||
infer :: Cxt -> Exp -> Err (T.Exp, Type)
|
||||
infer cxt = \case
|
||||
EId x ->
|
||||
case lookupEnv x cxt of
|
||||
Nothing ->
|
||||
case lookupSig x cxt of
|
||||
Nothing -> throwError ("Unbound variable:" ++ printTree x)
|
||||
Just t -> pure (T.EId (x, t), t)
|
||||
Just t -> pure (T.EId (x, t), t)
|
||||
checkDef :: [Def] -> Infer [T.Def]
|
||||
checkDef [] = return []
|
||||
checkDef (x : xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
fmap (T.DBind b' :) (checkDef xs)
|
||||
(DData d) -> fmap (T.DData d :) (checkDef xs)
|
||||
|
||||
EInt i -> pure (T.EInt i, T.TInt)
|
||||
checkBind :: Bind -> Infer T.Bind
|
||||
checkBind (Bind n t _ args e) = do
|
||||
(t', e') <- inferExp $ makeLambda e (reverse args)
|
||||
s <- unify t t'
|
||||
let t'' = apply s t
|
||||
unless
|
||||
(t `typeEq` t'')
|
||||
( throwError $
|
||||
unwords
|
||||
[ "Top level signature"
|
||||
, printTree t
|
||||
, "does not match body with inferred type:"
|
||||
, printTree t''
|
||||
]
|
||||
)
|
||||
return $ T.Bind (n, t) e'
|
||||
where
|
||||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
makeLambda = foldl (flip EAbs)
|
||||
|
||||
EApp e e1 -> do
|
||||
(e', t) <- infer cxt e
|
||||
case t of
|
||||
TFun t1 t2 -> do
|
||||
e1' <- check cxt e1 t1
|
||||
pure (T.EApp t2 e' e1', t2)
|
||||
_ -> do
|
||||
throwError ("Not a function: " ++ show e)
|
||||
|
||||
EAdd e e1 -> do
|
||||
e' <- check cxt e T.TInt
|
||||
e1' <- check cxt e1 T.TInt
|
||||
pure (T.EAdd T.TInt e' e1', T.TInt)
|
||||
|
||||
ESub e e1 -> do
|
||||
e' <- check cxt e T.TInt
|
||||
e1' <- check cxt e1 T.TInt
|
||||
pure (T.ESub T.TInt e' e1', T.TInt)
|
||||
|
||||
EAbs x t e -> do
|
||||
(e', t1) <- infer (insertEnv x t cxt) e
|
||||
let t_abs = TFun t t1
|
||||
pure (T.EAbs t_abs (x, t) e', t_abs)
|
||||
|
||||
ELet b e -> do
|
||||
let cxt' = insertBind b cxt
|
||||
b' <- checkBind cxt' b
|
||||
(e', t) <- infer cxt' e
|
||||
pure (T.ELet b' e', t)
|
||||
|
||||
EAnn e t -> do
|
||||
(e', t1) <- infer cxt e
|
||||
unless (typeEq t t1) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
pure (e', t1)
|
||||
|
||||
ECase e cs t -> do
|
||||
(e',t1) <- infer cxt e
|
||||
unless (typeEq t t1) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
case traverse (\(CaseMatch c e) -> do
|
||||
-- //TODO check c as well
|
||||
e' <- check cxt e t
|
||||
unless (typeEq t t1) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
pure (t1, T.Case c e')
|
||||
) cs of
|
||||
Right cs -> pure (T.ECase t1 e' cs,t1)
|
||||
Left e -> throwError e
|
||||
|
||||
-- | Check infered type matches the supplied type.
|
||||
check :: Cxt -> Exp -> Type -> Err T.Exp
|
||||
check cxt exp typ = case exp of
|
||||
EId x -> do
|
||||
t <- case lookupEnv x cxt of
|
||||
Nothing -> maybeToRightM
|
||||
("Unbound variable:" ++ printTree x)
|
||||
(lookupSig x cxt)
|
||||
Just t -> pure t
|
||||
unless (typeEq t typ) . throwError $ typeErr x typ t
|
||||
pure $ T.EId (x, t)
|
||||
|
||||
EInt i -> do
|
||||
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
|
||||
pure $ T.EInt i
|
||||
|
||||
EApp e e1 -> do
|
||||
(e', t) <- infer cxt e
|
||||
case t of
|
||||
TFun t1 t2 -> do
|
||||
e1' <- check cxt e1 t1
|
||||
pure $ T.EApp t2 e' e1'
|
||||
_ -> throwError ("Not a function 2: " ++ printTree e)
|
||||
|
||||
EAdd e e1 -> do
|
||||
e' <- check cxt e T.TInt
|
||||
e1' <- check cxt e1 T.TInt
|
||||
pure $ T.EAdd T.TInt e' e1'
|
||||
|
||||
ESub e e1 -> do
|
||||
e' <- check cxt e T.TInt
|
||||
e1' <- check cxt e1 T.TInt
|
||||
pure $ T.ESub T.TInt e' e1'
|
||||
|
||||
EAbs x t e -> do
|
||||
(e', t_e) <- infer (insertEnv x t cxt) e
|
||||
let t1 = TFun t t_e
|
||||
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
|
||||
pure $ T.EAbs t1 (x, t) e'
|
||||
|
||||
ECase e cs t -> do
|
||||
(e',t1) <- infer cxt e
|
||||
unless (typeEq t t1) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
case traverse (\(CaseMatch c e) -> do
|
||||
-- //TODO check c as well
|
||||
e' <- check cxt e t
|
||||
unless (typeEq t t1) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
pure (t1, T.Case c e')
|
||||
) cs of
|
||||
Right cs -> pure $ T.ECase t1 e' cs
|
||||
Left e -> throwError e
|
||||
|
||||
ELet b e -> do
|
||||
let cxt' = insertBind b cxt
|
||||
b' <- checkBind cxt' b
|
||||
e' <- check cxt' e typ
|
||||
pure $ T.ELet b' e'
|
||||
|
||||
EAnn e t -> do
|
||||
unless (typeEq t typ) $
|
||||
throwError "Inferred type and type annotation doesn't match"
|
||||
check cxt e t
|
||||
|
||||
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
|
||||
{- | Check if two types are considered equal
|
||||
For the purpose of the algorithm two polymorphic types are always considered
|
||||
equal
|
||||
-}
|
||||
typeEq :: Type -> Type -> Bool
|
||||
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
|
||||
typeEq t t1 = t == t1
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
|
||||
length a == length b
|
||||
&& name == name'
|
||||
&& and (zipWith typeEq a b)
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
-- | Partion type into types of parameters and return type.
|
||||
partitionType :: Int -- Number of parameters to apply
|
||||
-> Type
|
||||
-> ([Type], Type)
|
||||
partitionType = go []
|
||||
where
|
||||
go acc 0 t = (acc, t)
|
||||
go acc i t = case t of
|
||||
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
||||
_ -> error "Number of parameters and type doesn't match"
|
||||
isMoreSpecificOrEq :: Type -> Type -> Bool
|
||||
isMoreSpecificOrEq _ (TPol _) = True
|
||||
isMoreSpecificOrEq (TArr a b) (TArr c d) =
|
||||
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
|
||||
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
|
||||
n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
|
||||
isMoreSpecificOrEq a b = a == b
|
||||
|
||||
insertBind :: Bind -> Cxt -> Cxt
|
||||
insertBind (Bind n t _ _ _) = insertEnv n t
|
||||
isPoly :: Type -> Bool
|
||||
isPoly (TPol _) = True
|
||||
isPoly _ = False
|
||||
|
||||
lookupEnv :: Ident -> Cxt -> Maybe Type
|
||||
lookupEnv x = Map.lookup x . env
|
||||
inferExp :: Exp -> Infer (Type, T.Exp)
|
||||
inferExp e = do
|
||||
(s, t, e') <- algoW e
|
||||
let subbed = apply s t
|
||||
return (subbed, replace subbed e')
|
||||
|
||||
insertEnv :: Ident -> Type -> Cxt -> Cxt
|
||||
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
|
||||
replace :: Type -> T.Exp -> T.Exp
|
||||
replace t = \case
|
||||
T.ELit _ e -> T.ELit t e
|
||||
T.EId (n, _) -> T.EId (n, t)
|
||||
T.EAbs _ name e -> T.EAbs t name e
|
||||
T.EApp _ e1 e2 -> T.EApp t e1 e2
|
||||
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||
T.ESub _ e1 e2 -> T.ESub t e1 e2
|
||||
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
|
||||
T.ECase _ expr injs -> T.ECase t expr injs
|
||||
|
||||
lookupSig :: Ident -> Cxt -> Maybe Type
|
||||
lookupSig x = Map.lookup x . sig
|
||||
algoW :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
algoW = \case
|
||||
-- \| TODO: More testing need to be done. Unsure of the correctness of this
|
||||
EAnn e t -> do
|
||||
(s1, t', e') <- algoW e
|
||||
unless
|
||||
(t `isMoreSpecificOrEq` t')
|
||||
( throwError $
|
||||
unwords
|
||||
[ "Annotated type:"
|
||||
, printTree t
|
||||
, "does not match inferred type:"
|
||||
, printTree t'
|
||||
]
|
||||
)
|
||||
applySt s1 $ do
|
||||
s2 <- unify t t'
|
||||
return (s2 `compose` s1, t, e')
|
||||
|
||||
typeErr :: Print a => a -> Type -> Type -> String
|
||||
typeErr p expected actual = render $ concatD
|
||||
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
|
||||
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
|
||||
, doc $ showString "Actual: " , prt 0 actual
|
||||
]
|
||||
-- \| ------------------
|
||||
-- \| Γ ⊢ i : Int, ∅
|
||||
|
||||
ELit (LInt n) ->
|
||||
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
|
||||
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||
-- \| ----------------------
|
||||
-- \| Γ ⊢ x : τ, ∅
|
||||
|
||||
EId i -> do
|
||||
var <- asks vars
|
||||
case M.lookup i var of
|
||||
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
|
||||
Nothing -> do
|
||||
sig <- gets sigs
|
||||
case M.lookup i sig of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> do
|
||||
constr <- gets constructors
|
||||
case M.lookup i constr of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing ->
|
||||
throwError $
|
||||
"Unbound variable: " ++ show i
|
||||
|
||||
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
|
||||
-- \| ---------------------------------
|
||||
-- \| Γ ⊢ w λx. e : Sτ → τ', S
|
||||
|
||||
EAbs name e -> do
|
||||
fr <- fresh
|
||||
withBinding name (Forall [] fr) $ do
|
||||
(s1, t', e') <- algoW e
|
||||
let varType = apply s1 fr
|
||||
let newArr = TArr varType t'
|
||||
return (s1, newArr, T.EAbs newArr (name, varType) e')
|
||||
|
||||
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||
-- \| ------------------------------------------
|
||||
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
|
||||
-- This might be wrong
|
||||
|
||||
EAdd e0 e1 -> do
|
||||
(s1, t0, e0') <- algoW e0
|
||||
applySt s1 $ do
|
||||
(s2, t1, e1') <- algoW e1
|
||||
-- applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return
|
||||
( s4 `compose` s3 `compose` s2 `compose` s1
|
||||
, TMono "Int"
|
||||
, T.EAdd (TMono "Int") e0' e1'
|
||||
)
|
||||
|
||||
ESub e0 e1 -> do
|
||||
(s1, t0, e0') <- algoW e0
|
||||
applySt s1 $ do
|
||||
(s2, t1, e1') <- algoW e1
|
||||
-- applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return
|
||||
( s4 `compose` s3 `compose` s2 `compose` s1
|
||||
, TMono "Int"
|
||||
, T.ESub (TMono "Int") e0' e1'
|
||||
)
|
||||
|
||||
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
|
||||
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
|
||||
-- \| --------------------------------------
|
||||
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
|
||||
|
||||
EApp e0 e1 -> do
|
||||
fr <- fresh
|
||||
(s0, t0, e0') <- algoW e0
|
||||
applySt s0 $ do
|
||||
(s1, t1, e1') <- algoW e1
|
||||
-- applySt s1 $ do
|
||||
s2 <- unify (apply s1 t0) (TArr t1 fr)
|
||||
let t = apply s2 fr
|
||||
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
|
||||
|
||||
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
|
||||
-- \| ----------------------------------------------
|
||||
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
|
||||
|
||||
-- The bar over S₀ and Γ means "generalize"
|
||||
|
||||
ELet name e0 e1 -> do
|
||||
(s1, t1, e0') <- algoW e0
|
||||
env <- asks vars
|
||||
let t' = generalize (apply s1 env) t1
|
||||
withBinding name t' $ do
|
||||
(s2, t2, e1') <- algoW e1
|
||||
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
|
||||
ECase caseExpr injs -> do
|
||||
(_, t0, e0') <- algoW caseExpr
|
||||
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
|
||||
case ts of
|
||||
[] -> throwError "Case expression missing any matches"
|
||||
ts -> do
|
||||
unified <- zipWithM unify ts (tail ts)
|
||||
let unified' = foldl' compose mempty unified
|
||||
let typ = apply unified' (head ts)
|
||||
return (unified', typ, T.ECase typ e0' injs')
|
||||
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: Type -> Type -> Infer Subst
|
||||
unify t0 t1 = do
|
||||
trace ("t0: " ++ show t0) return ()
|
||||
trace ("t1: " ++ show t1) return ()
|
||||
case (t0, t1) of
|
||||
(TArr a b, TArr c d) -> do
|
||||
s1 <- unify a c
|
||||
s2 <- unify (apply s1 b) (apply s1 d)
|
||||
return $ s1 `compose` s2
|
||||
(TPol a, b) -> occurs a b
|
||||
(a, TPol b) -> occurs b a
|
||||
(TMono a, TMono b) ->
|
||||
if a == b then return M.empty else throwError "Types do not unify"
|
||||
-- \| TODO: Figure out a cleaner way to express the same thing
|
||||
(TConstr (Constr name t), TConstr (Constr name' t')) ->
|
||||
if name == name' && length t == length t'
|
||||
then do
|
||||
xs <- zipWithM unify t t'
|
||||
return $ foldr compose nullSubst xs
|
||||
else
|
||||
throwError $
|
||||
unwords
|
||||
[ "Type constructor:"
|
||||
, printTree name
|
||||
, "(" ++ printTree t ++ ")"
|
||||
, "does not match with:"
|
||||
, printTree name'
|
||||
, "(" ++ printTree t' ++ ")"
|
||||
]
|
||||
(a, b) ->
|
||||
throwError . unwords $
|
||||
[ "Type:"
|
||||
, printTree a
|
||||
, "can't be unified with:"
|
||||
, printTree b
|
||||
]
|
||||
|
||||
{- | Check if a type is contained in another type.
|
||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||
such that these are equal
|
||||
-}
|
||||
occurs :: Ident -> Type -> Infer Subst
|
||||
occurs _ (TPol _) = return nullSubst
|
||||
occurs i t =
|
||||
if S.member i (free t)
|
||||
then
|
||||
throwError $
|
||||
unwords
|
||||
[ "Occurs check failed, can't unify"
|
||||
, printTree (TPol i)
|
||||
, "with"
|
||||
, printTree t
|
||||
]
|
||||
else return $ M.singleton i t
|
||||
|
||||
-- | Generalize a type over all free variables in the substitution set
|
||||
generalize :: Map Ident Poly -> Type -> Poly
|
||||
generalize env t = Forall (S.toList $ free t S.\\ free env) t
|
||||
|
||||
{- | Instantiate a polymorphic type. The free type variables are substituted
|
||||
with fresh ones.
|
||||
-}
|
||||
inst :: Poly -> Infer Type
|
||||
inst (Forall xs t) = do
|
||||
xs' <- mapM (const fresh) xs
|
||||
let s = M.fromList $ zip xs xs'
|
||||
return $ apply s t
|
||||
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
|
||||
-- | A class representing free variables functions
|
||||
class FreeVars t where
|
||||
-- | Get all free variables from t
|
||||
free :: t -> Set Ident
|
||||
|
||||
-- | Apply a substitution to t
|
||||
apply :: Subst -> t -> t
|
||||
|
||||
instance FreeVars Type where
|
||||
free :: Type -> Set Ident
|
||||
free (TPol a) = S.singleton a
|
||||
free (TMono _) = mempty
|
||||
free (TArr a b) = free a `S.union` free b
|
||||
-- \| Not guaranteed to be correct
|
||||
free (TConstr (Constr _ a)) =
|
||||
foldl' (\acc x -> free x `S.union` acc) S.empty a
|
||||
|
||||
apply :: Subst -> Type -> Type
|
||||
apply sub t = do
|
||||
case t of
|
||||
TMono a -> TMono a
|
||||
TPol a -> case M.lookup a sub of
|
||||
Nothing -> TPol a
|
||||
Just t -> t
|
||||
TArr a b -> TArr (apply sub a) (apply sub b)
|
||||
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
|
||||
|
||||
instance FreeVars Poly where
|
||||
free :: Poly -> Set Ident
|
||||
free (Forall xs t) = free t S.\\ S.fromList xs
|
||||
apply :: Subst -> Poly -> Poly
|
||||
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
|
||||
|
||||
instance FreeVars (Map Ident Poly) where
|
||||
free :: Map Ident Poly -> Set Ident
|
||||
free m = foldl' S.union S.empty (map free $ M.elems m)
|
||||
apply :: Subst -> Map Ident Poly -> Map Ident Poly
|
||||
apply s = M.map (apply s)
|
||||
|
||||
-- | Apply substitutions to the environment.
|
||||
applySt :: Subst -> Infer a -> Infer a
|
||||
applySt s = local (\st -> st{vars = apply s (vars st)})
|
||||
|
||||
-- | Represents the empty substition set
|
||||
nullSubst :: Subst
|
||||
nullSubst = M.empty
|
||||
|
||||
-- | Generate a new fresh variable and increment the state counter
|
||||
fresh :: Infer Type
|
||||
fresh = do
|
||||
n <- gets count
|
||||
modify (\st -> st{count = n + 1})
|
||||
return . TPol . Ident $ show n
|
||||
|
||||
-- | Run the monadic action with an additional binding
|
||||
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
|
||||
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
|
||||
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: Ident -> Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||
|
||||
-- | Insert a constructor with its data type
|
||||
insertConstr :: Ident -> Type -> Infer ()
|
||||
insertConstr i t =
|
||||
modify (\st -> st{constructors = M.insert i t (constructors st)})
|
||||
|
||||
-------- PATTERN MATCHING ---------
|
||||
|
||||
-- "case expr of", the type of 'expr' is caseType
|
||||
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
|
||||
checkInj caseType (Inj it expr) = do
|
||||
(args, t') <- initType caseType it
|
||||
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
|
||||
return (T.Inj (it, t') e', t)
|
||||
|
||||
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
|
||||
initType expected = \case
|
||||
InitLit lit ->
|
||||
let returnType = litType lit
|
||||
in if expected == returnType
|
||||
then return (mempty, expected)
|
||||
else
|
||||
throwError $
|
||||
unwords
|
||||
[ "Inferred type"
|
||||
, printTree returnType
|
||||
, "does not match expected type:"
|
||||
, printTree expected
|
||||
]
|
||||
InitConstr c args -> do
|
||||
st <- gets constructors
|
||||
case M.lookup c st of
|
||||
Nothing ->
|
||||
throwError $
|
||||
unwords
|
||||
[ "Constructor:"
|
||||
, printTree c
|
||||
, "does not exist"
|
||||
]
|
||||
Just t -> do
|
||||
let flat = flattenType t
|
||||
let returnType = last flat
|
||||
case ( length (init flat) == length args
|
||||
, returnType `isMoreSpecificOrEq` expected
|
||||
) of
|
||||
(True, True) ->
|
||||
return
|
||||
( M.fromList $ zip args (map (Forall []) flat)
|
||||
, expected
|
||||
)
|
||||
(False, _) ->
|
||||
throwError $
|
||||
"Can't partially match on the constructor: "
|
||||
++ printTree c
|
||||
(_, False) ->
|
||||
throwError $
|
||||
unwords
|
||||
[ "Inferred type"
|
||||
, printTree returnType
|
||||
, "does not match expected type:"
|
||||
, printTree expected
|
||||
]
|
||||
InitCatch -> return (mempty, expected)
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TArr a b) = flattenType a ++ flattenType b
|
||||
flattenType a = [a]
|
||||
|
||||
litType :: Literal -> Type
|
||||
litType (LInt _) = TMono "Int"
|
||||
|
|
|
|||
|
|
@ -1,139 +1,184 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.TypeCheckerIr
|
||||
( module Grammar.Abs
|
||||
, module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
module TypeChecker.TypeCheckerIr where
|
||||
|
||||
import Grammar.Abs (Ident (..), Type (..))
|
||||
import qualified Grammar.Abs as GA
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map (Map)
|
||||
import Grammar.Abs (Data (..), Ident (..), Init (..),
|
||||
Literal (..), Type (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
-- | A data type representing type variables
|
||||
data Poly = Forall [Ident] Type
|
||||
deriving (Show)
|
||||
|
||||
newtype Ctx = Ctx {vars :: Map Ident Poly}
|
||||
|
||||
data Env = Env
|
||||
{ count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
, constructors :: Map Ident Type
|
||||
}
|
||||
|
||||
type Error = String
|
||||
type Subst = Map Ident Type
|
||||
|
||||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Exp
|
||||
= EId Id
|
||||
| EInt Integer
|
||||
= EId Id
|
||||
| ELit Type Literal
|
||||
| ELet Bind Exp
|
||||
| EApp Type Exp Exp
|
||||
| EAdd Type Exp Exp
|
||||
| ESub Type Exp Exp
|
||||
| EAbs Type Id Exp
|
||||
| ECase Type Exp [(Type, Case)]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
| EAbs Type Id Exp
|
||||
| ECase Type Exp [Inj]
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Case = Case GA.Case Exp
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
data Inj = Inj (Init, Type) Exp
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Def = DBind Bind | DData Data
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
type Id = (Ident, Type)
|
||||
|
||||
data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])]
|
||||
data Bind = Bind Id Exp
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i (DBind bind) = prt i bind
|
||||
prt i (DData d) = prt i d
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
instance Print Bind where
|
||||
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
|
||||
[ prtId 0 name
|
||||
, doc $ showString ";"
|
||||
, prt 0 n
|
||||
, prtIdPs 0 parms
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD
|
||||
[ prt 0 n
|
||||
, doc $ showString "{"
|
||||
, doc . showString . show $ xs
|
||||
, doc $ showString "}"
|
||||
]
|
||||
prt i (Bind (t, name) rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString "\n"
|
||||
, prt 0 name
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
|
||||
|
||||
prtIdPs :: Int -> [Id] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
||||
|
||||
prtId :: Int -> Id -> Doc
|
||||
prtId i (name, t) = prPrec i 0 $ concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
]
|
||||
prtId i (name, t) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
]
|
||||
|
||||
prtIdP :: Int -> Id -> Doc
|
||||
prtIdP i (name, t) = prPrec i 0 $ concatD
|
||||
[ doc $ showString "("
|
||||
, prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
prtIdP i (name, t) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
instance Print Exp where
|
||||
prt i = \case
|
||||
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
|
||||
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
|
||||
ELet bs e -> prPrec i 3 $ concatD
|
||||
[ doc $ showString "let"
|
||||
, prt 0 bs
|
||||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
]
|
||||
EApp t e1 e2 -> prPrec i 2 $ concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
EAdd t e1 e2 -> prPrec i 1 $ concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, prt 1 e1
|
||||
, doc $ showString "+"
|
||||
, prt 2 e2
|
||||
]
|
||||
ESub t e1 e2 -> prPrec i 1 $ concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, prt 1 e1
|
||||
, doc $ showString "-"
|
||||
, prt 2 e2
|
||||
]
|
||||
EAbs t n e -> prPrec i 0 $ concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, doc $ showString "\\"
|
||||
, prtIdP 0 n
|
||||
, doc $ showString "."
|
||||
, prt 0 e
|
||||
]
|
||||
ECase t e cs -> prPrec i 0 $ concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, doc $ showString "("
|
||||
, prt 0 e
|
||||
, doc $ showString ")"
|
||||
, prPrec i 0 $ concatD . printCases $ cs
|
||||
]
|
||||
prt i = \case
|
||||
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
|
||||
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
|
||||
ELet bs e ->
|
||||
prPrec i 3 $
|
||||
concatD
|
||||
[ doc $ showString "let"
|
||||
, prt 0 bs
|
||||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
, doc $ showString "\n"
|
||||
]
|
||||
EApp _ e1 e2 ->
|
||||
prPrec i 2 $
|
||||
concatD
|
||||
[ prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
EAdd t e1 e2 ->
|
||||
prPrec i 1 $
|
||||
concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, prt 1 e1
|
||||
, doc $ showString "+"
|
||||
, prt 2 e2
|
||||
, doc $ showString "\n"
|
||||
]
|
||||
ESub t e1 e2 ->
|
||||
prPrec i 1 $
|
||||
concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, prt 1 e1
|
||||
, doc $ showString "-"
|
||||
, prt 2 e2
|
||||
, doc $ showString "\n"
|
||||
]
|
||||
EAbs t n e ->
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, doc $ showString "\\"
|
||||
, prtId 0 n
|
||||
, doc $ showString "."
|
||||
, prt 0 e
|
||||
, doc $ showString "\n"
|
||||
]
|
||||
ECase t exp injs ->
|
||||
prPrec
|
||||
i
|
||||
0
|
||||
( concatD
|
||||
[ doc (showString "case")
|
||||
, prt 0 exp
|
||||
, doc (showString "of")
|
||||
, doc (showString "{")
|
||||
, prt 0 injs
|
||||
, doc (showString "}")
|
||||
, doc (showString ":")
|
||||
, prt 0 t
|
||||
, doc $ showString "\n"
|
||||
]
|
||||
)
|
||||
|
||||
where
|
||||
printCases :: [(Type, Case)] -> [Doc]
|
||||
printCases [] = []
|
||||
printCases ((t, Case c e):xs) = concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 0 t
|
||||
, doc $ showString "("
|
||||
, doc . showString . show $ c
|
||||
, doc $ showString ")"
|
||||
, doc $ showString "=>"
|
||||
, prt 0 e
|
||||
, doc $ showString "\n"
|
||||
] : printCases xs
|
||||
instance Print Inj where
|
||||
prt i = \case
|
||||
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print [Inj] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue