Merge branch 'prep-tc-martin' of github.com:bachelor-group-66-systemf/language into prep-tc-martin
This commit is contained in:
commit
a4c12ede79
8 changed files with 590 additions and 640 deletions
20
Grammar.cf
20
Grammar.cf
|
|
@ -1,24 +1,20 @@
|
||||||
Program. Program ::= [Bind];
|
Program. Program ::= [Bind];
|
||||||
|
|
||||||
|
EId. Exp3 ::= Ident;
|
||||||
EId. Exp3 ::= Ident;
|
EInt. Exp3 ::= Integer;
|
||||||
EInt. Exp3 ::= Integer;
|
EAnn. Exp3 ::= "(" Exp ":" Type ")";
|
||||||
ELet. Exp3 ::= "let" [Bind] "in" Exp;
|
ELet. Exp3 ::= "let" Bind "in" Exp;
|
||||||
EApp. Exp2 ::= Exp2 Exp3;
|
EApp. Exp2 ::= Exp2 Exp3;
|
||||||
EAdd. Exp1 ::= Exp1 "+" Exp2;
|
EAdd. Exp1 ::= Exp1 "+" Exp2;
|
||||||
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
|
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
|
||||||
EAnn. Exp3 ::= "(" Exp ":" Type ")";
|
|
||||||
|
|
||||||
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}";
|
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}";
|
||||||
--
|
|
||||||
CaseMatch. CaseMatch ::= Case "=>" Exp ;
|
CaseMatch. CaseMatch ::= Case "=>" Exp ;
|
||||||
separator CaseMatch ",";
|
separator CaseMatch ",";
|
||||||
--terminator CaseMatch ".";
|
|
||||||
|
|
||||||
CInt. Case ::= Integer ;
|
CInt. Case ::= Integer ;
|
||||||
|
|
||||||
Bind. Bind ::= Ident ":" Type ";"
|
Bind. Bind ::= Ident ":" Type ";"
|
||||||
Ident [Ident] "=" Exp ;
|
Ident [Ident] "=" Exp;
|
||||||
|
|
||||||
separator Bind ";";
|
separator Bind ";";
|
||||||
separator Ident "";
|
separator Ident "";
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
|
|
||||||
--tripplemagic : Int -> Int -> Int -> Int;
|
-- tripplemagic : Int -> Int -> Int -> Int;
|
||||||
--tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
|
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
|
||||||
--main : Int;
|
-- main : Int;
|
||||||
--main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
|
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
|
||||||
|
-- answer: 22
|
||||||
--apply : (Int -> Int) -> Int -> Int;
|
|
||||||
--apply f x = f x;
|
|
||||||
--
|
|
||||||
--main : Int;
|
|
||||||
--main = (\x : Int . x + 5) 5
|
|
||||||
|
|
||||||
|
-- apply : (Int -> Int) -> Int -> Int;
|
||||||
|
-- apply f x = f x;
|
||||||
|
-- main : Int;
|
||||||
|
-- main = apply (\x : Int . x + 5) 5
|
||||||
|
-- answer: 10
|
||||||
|
|
||||||
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
|
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
|
||||||
apply f x y = f x y;
|
apply f x y = f x y;
|
||||||
krimp: Int -> Int -> Int;
|
krimp: Int -> Int -> Int;
|
||||||
krimp x y = x + y;
|
krimp x y = x + y;
|
||||||
main : Int;
|
main : Int;
|
||||||
main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2;
|
main = apply (krimp) 2 3;
|
||||||
|
-- answer: 5
|
||||||
|
|
||||||
|
|
|
||||||
517
src/Compiler.hs
517
src/Compiler.hs
|
|
@ -3,22 +3,21 @@
|
||||||
|
|
||||||
module Compiler (compile) where
|
module Compiler (compile) where
|
||||||
|
|
||||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
import Auxiliary (snoc)
|
||||||
import Data.List.Extra (trim)
|
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||||
import Data.Map (Map)
|
--import Data.List.Extra (trim)
|
||||||
import qualified Data.Map as Map
|
import Data.Map (Map)
|
||||||
import Data.Tuple.Extra (second)
|
import qualified Data.Map as Map
|
||||||
import Grammar.ErrM (Err)
|
import Data.Tuple.Extra (dupe, first, second)
|
||||||
import Grammar.Print (printTree)
|
import Grammar.ErrM (Err)
|
||||||
import LlvmIr (LLVMComp (..), LLVMIr (..),
|
import LlvmIr (LLVMComp (..), LLVMIr (..), LLVMType (..),
|
||||||
LLVMType (..), LLVMValue (..),
|
LLVMValue (..), Visibility (..),
|
||||||
Visibility (..), llvmIrToString)
|
llvmIrToString)
|
||||||
import System.IO (stdin)
|
--import System.Process.Extra (readCreateProcess, shell)
|
||||||
import System.Process.Extra (CreateProcess (std_in),
|
import TypeChecker (partitionType)
|
||||||
StdStream (CreatePipe), createProcess,
|
import TypeCheckerIr (Bind (..), CLit (CInt, CatchAll),
|
||||||
readCreateProcess, shell)
|
Case (..), Exp (..), Id, Ident (..),
|
||||||
import TypeChecker (partitionType)
|
Program (..), Type (TFun, TInt))
|
||||||
import TypeCheckerIr
|
|
||||||
|
|
||||||
-- | The record used as the code generator state
|
-- | The record used as the code generator state
|
||||||
data CodeGenerator = CodeGenerator
|
data CodeGenerator = CodeGenerator
|
||||||
|
|
@ -38,11 +37,11 @@ data FunctionInfo = FunctionInfo
|
||||||
|
|
||||||
-- | Adds a instruction to the CodeGenerator state
|
-- | Adds a instruction to the CodeGenerator state
|
||||||
emit :: LLVMIr -> CompilerState ()
|
emit :: LLVMIr -> CompilerState ()
|
||||||
emit l = modify (\t -> t{instructions = instructions t ++ [l]})
|
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
|
||||||
|
|
||||||
-- | Increases the variable counter in the CodeGenerator state
|
-- | Increases the variable counter in the CodeGenerator state
|
||||||
increaseVarCount :: CompilerState ()
|
increaseVarCount :: CompilerState ()
|
||||||
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1})
|
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||||
|
|
||||||
-- | Returns the variable count from the CodeGenerator state
|
-- | Returns the variable count from the CodeGenerator state
|
||||||
getVarCount :: CompilerState Integer
|
getVarCount :: CompilerState Integer
|
||||||
|
|
@ -58,280 +57,272 @@ getNewLabel = do
|
||||||
modify (\t -> t{labelCount = labelCount t + 1})
|
modify (\t -> t{labelCount = labelCount t + 1})
|
||||||
gets labelCount
|
gets labelCount
|
||||||
|
|
||||||
{- | Produces a map of functions infos from a list of binds,
|
-- | Produces a map of functions infos from a list of binds,
|
||||||
which contains useful data for code generation.
|
-- which contains useful data for code generation.
|
||||||
-}
|
|
||||||
getFunctions :: [Bind] -> Map Id FunctionInfo
|
getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||||
getFunctions xs =
|
getFunctions bs = Map.fromList $ map go bs
|
||||||
Map.fromList $
|
where
|
||||||
map
|
go (Bind id args _) =
|
||||||
( \(Bind id args _) ->
|
(id, FunctionInfo { numArgs=length args, arguments=args })
|
||||||
( id
|
|
||||||
, FunctionInfo
|
|
||||||
{ numArgs = length args
|
|
||||||
, arguments = args
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
xs
|
|
||||||
|
|
||||||
run :: Err String -> IO ()
|
|
||||||
run s = do
|
|
||||||
let s' = case s of
|
initCodeGenerator :: [Bind] -> CodeGenerator
|
||||||
Right s -> s
|
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||||
Left _ -> error "yo"
|
, functions = getFunctions scs
|
||||||
writeFile "llvm.ll" s'
|
, variableCount = 0
|
||||||
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
, labelCount = 0
|
||||||
test :: Integer -> Program
|
}
|
||||||
test v = Program [
|
|
||||||
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
|
--run :: Err String -> IO ()
|
||||||
ECased (EId ("x", TInt)) [
|
--run s = do
|
||||||
Case (CInt 0) (EInt 0),
|
-- let s' = case s of
|
||||||
Case (CInt 1) (EInt 1),
|
-- Right s -> s
|
||||||
Case CatchAll (EAdd TInt
|
-- Left _ -> error "yo"
|
||||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
-- writeFile "llvm.ll" s'
|
||||||
EAdd TInt (EId (Ident "x", TInt))
|
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||||
(EInt (fromIntegral ((maxBound :: Int) * 2)))
|
--
|
||||||
))
|
--test :: Integer -> Program
|
||||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
--test v = Program [
|
||||||
EAdd TInt (EId (Ident "x", TInt))
|
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
|
||||||
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
|
-- ECased (EId ("x", TInt)) [
|
||||||
))
|
-- Case (CInt 0) (EInt 0),
|
||||||
)
|
-- Case (CInt 1) (EInt 1),
|
||||||
]
|
-- Case CatchAll (EAdd TInt
|
||||||
),
|
-- (EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||||
Bind (Ident "main",TInt) [] (
|
-- EAdd TInt (EId (Ident "x", TInt))
|
||||||
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
|
-- (EInt (fromIntegral ((maxBound :: Int) * 2)))
|
||||||
)
|
-- ))
|
||||||
]
|
-- (EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||||
|
-- EAdd TInt (EId (Ident "x", TInt))
|
||||||
|
-- (EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
|
||||||
|
-- ))
|
||||||
|
-- )
|
||||||
|
-- ]
|
||||||
|
-- ),
|
||||||
|
-- Bind (Ident "main",TInt) [] (
|
||||||
|
-- EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
|
||||||
|
-- )
|
||||||
|
-- ]
|
||||||
|
|
||||||
{- | Compiles an AST and produces a LLVM Ir string.
|
{- | Compiles an AST and produces a LLVM Ir string.
|
||||||
An easy way to actually "compile" this output is to
|
An easy way to actually "compile" this output is to
|
||||||
Simply pipe it to LLI
|
Simply pipe it to LLI
|
||||||
-}
|
-}
|
||||||
compile :: Program -> Err String
|
compile :: Program -> Err String
|
||||||
compile (Program prg) = do
|
compile (Program scs) = do
|
||||||
let s =
|
let codegen = initCodeGenerator scs
|
||||||
CodeGenerator
|
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||||
{ instructions = defaultStart
|
|
||||||
, functions = getFunctions prg
|
compileScs :: [Bind] -> CompilerState ()
|
||||||
, variableCount = 0
|
compileScs [] = pure ()
|
||||||
, labelCount = 0
|
compileScs (Bind (name, t) args exp : xs) = do
|
||||||
}
|
emit $ UnsafeRaw "\n"
|
||||||
ins <- instructions <$> execStateT (goDef prg) s
|
emit . Comment $ show name <> ": " <> show exp
|
||||||
pure $ llvmIrToString ins
|
let args' = map (second type2LlvmType) args
|
||||||
|
emit $ Define (type2LlvmType t_return) name args'
|
||||||
|
functionBody <- exprToValue exp
|
||||||
|
if name == "main"
|
||||||
|
then mapM_ emit $ mainContent functionBody
|
||||||
|
else emit $ Ret I64 functionBody
|
||||||
|
emit DefineEnd
|
||||||
|
modify $ \s -> s { variableCount = 0 }
|
||||||
|
compileScs xs
|
||||||
where
|
where
|
||||||
mainContent :: LLVMValue -> [LLVMIr]
|
t_return = snd $ partitionType (length args) t
|
||||||
mainContent var =
|
|
||||||
[ UnsafeRaw $
|
|
||||||
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
|
||||||
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
|
||||||
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
|
||||||
-- , Label (Ident "b_1")
|
|
||||||
-- , UnsafeRaw
|
|
||||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
|
||||||
-- , Br (Ident "end")
|
|
||||||
-- , Label (Ident "b_2")
|
|
||||||
-- , UnsafeRaw
|
|
||||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
|
||||||
-- , Br (Ident "end")
|
|
||||||
-- , Label (Ident "end")
|
|
||||||
Ret I64 (VInteger 0)
|
|
||||||
]
|
|
||||||
|
|
||||||
defaultStart :: [LLVMIr]
|
mainContent :: LLVMValue -> [LLVMIr]
|
||||||
defaultStart =
|
mainContent var =
|
||||||
[ Comment (show $ printTree (Program prg))
|
[ UnsafeRaw $
|
||||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||||
]
|
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||||
|
-- , Label (Ident "b_1")
|
||||||
|
-- , UnsafeRaw
|
||||||
|
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||||
|
-- , Br (Ident "end")
|
||||||
|
-- , Label (Ident "b_2")
|
||||||
|
-- , UnsafeRaw
|
||||||
|
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||||
|
-- , Br (Ident "end")
|
||||||
|
-- , Label (Ident "end")
|
||||||
|
Ret I64 (VInteger 0)
|
||||||
|
]
|
||||||
|
|
||||||
goDef :: [Bind] -> CompilerState ()
|
defaultStart :: [LLVMIr]
|
||||||
goDef [] = return ()
|
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||||
goDef (Bind (name, t) args exp : xs) = do
|
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||||
emit $ UnsafeRaw "\n"
|
]
|
||||||
emit $ Comment $ show name <> ": " <> show exp
|
|
||||||
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
|
|
||||||
functionBody <- exprToValue exp
|
|
||||||
if name == "main"
|
|
||||||
then mapM_ emit (mainContent functionBody)
|
|
||||||
else emit $ Ret I64 functionBody
|
|
||||||
emit DefineEnd
|
|
||||||
modify (\s -> s{variableCount = 0})
|
|
||||||
goDef xs
|
|
||||||
where
|
|
||||||
t_return = snd $ partitionType (length args) t
|
|
||||||
|
|
||||||
go :: Exp -> CompilerState ()
|
compileExp :: Exp -> CompilerState ()
|
||||||
go (EInt int) = emitInt int
|
compileExp (EInt int) = emitInt int
|
||||||
go (EAdd t e1 e2) = emitAdd t e1 e2
|
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
|
||||||
go (EId (name, _)) = emitIdent name
|
compileExp (EId (name, _)) = emitIdent name
|
||||||
go (EApp t e1 e2) = emitApp t e1 e2
|
compileExp (EApp t e1 e2) = emitApp t e1 e2
|
||||||
go (EAbs t ti e) = emitAbs t ti e
|
compileExp (EAbs t ti e) = emitAbs t ti e
|
||||||
go (ELet binds e) = emitLet binds e
|
compileExp (ELet binds e) = emitLet binds e
|
||||||
go (EAnn _ _) = emitEAnn
|
compileExp (ECased e c) = emitECased e c
|
||||||
go (ECased e c) = emitECased e c
|
|
||||||
-- go (ESub e1 e2) = emitSub e1 e2
|
-- go (ESub e1 e2) = emitSub e1 e2
|
||||||
-- go (EMul e1 e2) = emitMul e1 e2
|
-- go (EMul e1 e2) = emitMul e1 e2
|
||||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
-- go (EDiv e1 e2) = emitDiv e1 e2
|
||||||
-- go (EMod e1 e2) = emitMod e1 e2
|
-- go (EMod e1 e2) = emitMod e1 e2
|
||||||
|
|
||||||
--- aux functions ---
|
--- aux functions ---
|
||||||
emitECased :: Exp -> [Case] -> CompilerState ()
|
emitECased :: Exp -> [Case] -> CompilerState ()
|
||||||
emitECased e cs = do
|
emitECased e cs = do
|
||||||
vs <- exprToValue e
|
vs <- exprToValue e
|
||||||
lbl <- getNewLabel
|
lbl <- getNewLabel
|
||||||
let label = Ident $ "escape_" <> show lbl
|
let label = Ident $ "escape_" <> show lbl
|
||||||
stackPtr <- getNewVar
|
stackPtr <- getNewVar
|
||||||
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
|
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
|
||||||
mapM_ (emitCases label stackPtr vs) cs
|
mapM_ (emitCases label stackPtr vs) cs
|
||||||
emit $ Label label
|
emit $ Label label
|
||||||
res <- getNewVar
|
res <- getNewVar
|
||||||
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
|
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
|
||||||
where
|
where
|
||||||
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||||
emitCases label stackPtr vs (Case (CInt i) exp) = do
|
emitCases label stackPtr vs (Case (CInt i) exp) = do
|
||||||
ns <- getNewVar
|
ns <- getNewVar
|
||||||
lbl_fail <- getNewLabel
|
lbl_fail <- getNewLabel
|
||||||
lbl_succ <- getNewLabel
|
lbl_succ <- getNewLabel
|
||||||
let failed = Ident $ "failed_" <> show lbl_fail
|
let failed = Ident $ "failed_" <> show lbl_fail
|
||||||
let success = Ident $ "success_" <> show lbl_succ
|
let success = Ident $ "success_" <> show lbl_succ
|
||||||
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
|
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
|
||||||
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
|
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
|
||||||
emit $ Label success
|
emit $ Label success
|
||||||
val <- exprToValue exp
|
val <- exprToValue exp
|
||||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
emit $ Label failed
|
emit $ Label failed
|
||||||
emitCases label stackPtr _ (Case CatchAll exp) = do
|
emitCases label stackPtr _ (Case CatchAll exp) = do
|
||||||
val <- exprToValue exp
|
val <- exprToValue exp
|
||||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
|
|
||||||
|
|
||||||
emitEAnn :: CompilerState ()
|
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||||
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
|
emitAbs _t tid e = do
|
||||||
|
emit . Comment $
|
||||||
|
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||||
|
emitLet :: Bind -> Exp -> CompilerState ()
|
||||||
|
emitLet xs e = do
|
||||||
|
emit $
|
||||||
|
Comment $
|
||||||
|
concat
|
||||||
|
[ "ELet ("
|
||||||
|
, show xs
|
||||||
|
, " = "
|
||||||
|
, show e
|
||||||
|
, ") is not implemented!"
|
||||||
|
]
|
||||||
|
|
||||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||||
emitAbs _t tid e = do
|
emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||||
emit . Comment $
|
where
|
||||||
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||||
emitLet :: [Bind] -> Exp -> CompilerState ()
|
appEmitter t e1 e2 stack = do
|
||||||
emitLet xs e = do
|
let newStack = e2 : stack
|
||||||
emit $
|
case e1 of
|
||||||
Comment $
|
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||||
concat
|
EId id@(name, _) -> do
|
||||||
[ "ELet ("
|
args <- traverse exprToValue newStack
|
||||||
, show xs
|
vs <- getNewVar
|
||||||
, " = "
|
funcs <- gets functions
|
||||||
, show e
|
let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||||
, ") is not implemented!"
|
args' = map (first valueGetType . dupe) args
|
||||||
]
|
call = Call (type2LlvmType t) visibility name args'
|
||||||
|
emit $ SetVariable (Ident $ show vs) call
|
||||||
|
x -> do
|
||||||
|
emit . Comment $ "The unspeakable happened: "
|
||||||
|
emit . Comment $ show x
|
||||||
|
|
||||||
emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
emitIdent :: Ident -> CompilerState ()
|
||||||
emitApp t e1 e2 = appEmitter t e1 e2 []
|
emitIdent id = do
|
||||||
where
|
-- !!this should never happen!!
|
||||||
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
emit $ Comment "This should not have happened!"
|
||||||
appEmitter t e1 e2 stack = do
|
emit $ Variable id
|
||||||
let newStack = e2 : stack
|
emit $ UnsafeRaw "\n"
|
||||||
case e1 of
|
|
||||||
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
|
||||||
EId id@(name, _) -> do
|
|
||||||
args <- traverse exprToValue newStack
|
|
||||||
vs <- getNewVar
|
|
||||||
funcs <- gets functions
|
|
||||||
let vis = case Map.lookup id funcs of
|
|
||||||
Nothing -> Local
|
|
||||||
Just _ -> Global
|
|
||||||
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
|
|
||||||
emit $ SetVariable (Ident $ show vs) call
|
|
||||||
x -> do
|
|
||||||
emit . Comment $ "The unspeakable happened: "
|
|
||||||
emit . Comment $ show x
|
|
||||||
|
|
||||||
emitIdent :: Ident -> CompilerState ()
|
emitInt :: Integer -> CompilerState ()
|
||||||
emitIdent id = do
|
emitInt i = do
|
||||||
-- !!this should never happen!!
|
-- !!this should never happen!!
|
||||||
emit $ Comment "This should not have happened!"
|
varCount <- getNewVar
|
||||||
emit $ Variable id
|
emit $ Comment "This should not have happened!"
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||||
|
|
||||||
emitInt :: Integer -> CompilerState ()
|
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||||
emitInt i = do
|
emitAdd t e1 e2 = do
|
||||||
-- !!this should never happen!!
|
v1 <- exprToValue e1
|
||||||
varCount <- getNewVar
|
v2 <- exprToValue e2
|
||||||
emit $ Comment "This should not have happened!"
|
v <- getNewVar
|
||||||
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||||
|
|
||||||
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
-- emitMul :: Exp -> Exp -> CompilerState ()
|
||||||
emitAdd t e1 e2 = do
|
-- emitMul e1 e2 = do
|
||||||
v1 <- exprToValue e1
|
-- (v1,v2) <- binExprToValues e1 e2
|
||||||
v2 <- exprToValue e2
|
-- increaseVarCount
|
||||||
v <- getNewVar
|
-- v <- gets variableCount
|
||||||
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
-- emit $ SetVariable $ Ident $ show v
|
||||||
|
-- emit $ Mul I64 v1 v2
|
||||||
|
|
||||||
-- emitMul :: Exp -> Exp -> CompilerState ()
|
-- emitMod :: Exp -> Exp -> CompilerState ()
|
||||||
-- emitMul e1 e2 = do
|
-- emitMod e1 e2 = do
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
-- -- `let m a b = rem (abs $ b + a) b`
|
||||||
-- increaseVarCount
|
-- (v1,v2) <- binExprToValues e1 e2
|
||||||
-- v <- gets variableCount
|
-- increaseVarCount
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
-- vadd <- gets variableCount
|
||||||
-- emit $ Mul I64 v1 v2
|
-- emit $ SetVariable $ Ident $ show vadd
|
||||||
|
-- emit $ Add I64 v1 v2
|
||||||
|
--
|
||||||
|
-- increaseVarCount
|
||||||
|
-- vabs <- gets variableCount
|
||||||
|
-- emit $ SetVariable $ Ident $ show vabs
|
||||||
|
-- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||||
|
-- [ (I64, VIdent (Ident $ show vadd))
|
||||||
|
-- , (I1, VInteger 1)
|
||||||
|
-- ]
|
||||||
|
-- increaseVarCount
|
||||||
|
-- v <- gets variableCount
|
||||||
|
-- emit $ SetVariable $ Ident $ show v
|
||||||
|
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||||
|
|
||||||
-- emitMod :: Exp -> Exp -> CompilerState ()
|
-- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||||
-- emitMod e1 e2 = do
|
-- emitDiv e1 e2 = do
|
||||||
-- -- `let m a b = rem (abs $ b + a) b`
|
-- (v1,v2) <- binExprToValues e1 e2
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
-- increaseVarCount
|
||||||
-- increaseVarCount
|
-- v <- gets variableCount
|
||||||
-- vadd <- gets variableCount
|
-- emit $ SetVariable $ Ident $ show v
|
||||||
-- emit $ SetVariable $ Ident $ show vadd
|
-- emit $ Div I64 v1 v2
|
||||||
-- emit $ Add I64 v1 v2
|
|
||||||
--
|
|
||||||
-- increaseVarCount
|
|
||||||
-- vabs <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show vabs
|
|
||||||
-- emit $ Call I64 (Ident "llvm.abs.i64")
|
|
||||||
-- [ (I64, VIdent (Ident $ show vadd))
|
|
||||||
-- , (I1, VInteger 1)
|
|
||||||
-- ]
|
|
||||||
-- increaseVarCount
|
|
||||||
-- v <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
|
||||||
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
|
||||||
|
|
||||||
-- emitDiv :: Exp -> Exp -> CompilerState ()
|
-- emitSub :: Exp -> Exp -> CompilerState ()
|
||||||
-- emitDiv e1 e2 = do
|
-- emitSub e1 e2 = do
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
-- (v1,v2) <- binExprToValues e1 e2
|
||||||
-- increaseVarCount
|
-- increaseVarCount
|
||||||
-- v <- gets variableCount
|
-- v <- gets variableCount
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
-- emit $ SetVariable $ Ident $ show v
|
||||||
-- emit $ Div I64 v1 v2
|
-- emit $ Sub I64 v1 v2
|
||||||
|
|
||||||
-- emitSub :: Exp -> Exp -> CompilerState ()
|
exprToValue :: Exp -> CompilerState LLVMValue
|
||||||
-- emitSub e1 e2 = do
|
exprToValue = \case
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
EInt i -> pure $ VInteger i
|
||||||
-- increaseVarCount
|
|
||||||
-- v <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
|
||||||
-- emit $ Sub I64 v1 v2
|
|
||||||
|
|
||||||
exprToValue :: Exp -> CompilerState LLVMValue
|
EId id@(name, t) -> do
|
||||||
exprToValue (EInt i) = return $ VInteger i
|
|
||||||
exprToValue (EId id@(name, t)) = do
|
|
||||||
funcs <- gets functions
|
funcs <- gets functions
|
||||||
case Map.lookup id funcs of
|
case Map.lookup id funcs of
|
||||||
Just fi -> do
|
Just fi -> do
|
||||||
if numArgs fi == 0
|
if numArgs fi == 0
|
||||||
then do
|
then do
|
||||||
vc <- getNewVar
|
vc <- getNewVar
|
||||||
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name [])
|
emit $ SetVariable (Ident $ show vc)
|
||||||
return $ VIdent (Ident $ show vc) (type2LlvmType t)
|
(Call (type2LlvmType t) Global name [])
|
||||||
else return $ VFunction name Global (type2LlvmType t)
|
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||||
Nothing -> return $ VIdent name (type2LlvmType t)
|
else pure $ VFunction name Global (type2LlvmType t)
|
||||||
exprToValue e = do
|
Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||||
go e
|
|
||||||
|
e -> do
|
||||||
|
compileExp e
|
||||||
v <- getVarCount
|
v <- getVarCount
|
||||||
return $ VIdent (Ident $ show v) (getType e)
|
pure $ VIdent (Ident $ show v) (getType e)
|
||||||
|
|
||||||
type2LlvmType :: Type -> LLVMType
|
type2LlvmType :: Type -> LLVMType
|
||||||
type2LlvmType = \case
|
type2LlvmType = \case
|
||||||
|
|
@ -346,13 +337,13 @@ type2LlvmType = \case
|
||||||
function2LLVMType x s = (type2LlvmType x, s)
|
function2LLVMType x s = (type2LlvmType x, s)
|
||||||
|
|
||||||
getType :: Exp -> LLVMType
|
getType :: Exp -> LLVMType
|
||||||
getType (EInt _) = I64
|
getType (EInt _) = I64
|
||||||
getType (EAdd t _ _) = type2LlvmType t
|
getType (EAdd t _ _) = type2LlvmType t
|
||||||
getType (EId (_, t)) = type2LlvmType t
|
getType (EId (_, t)) = type2LlvmType t
|
||||||
getType (EApp t _ _) = type2LlvmType t
|
getType (EApp t _ _) = type2LlvmType t
|
||||||
getType (EAbs t _ _) = type2LlvmType t
|
getType (EAbs t _ _) = type2LlvmType t
|
||||||
getType (ELet _ e) = getType e
|
getType (ELet _ e) = getType e
|
||||||
getType (EAnn _ t) = type2LlvmType t
|
getType (ECased e cs) = undefined
|
||||||
|
|
||||||
valueGetType :: LLVMValue -> LLVMType
|
valueGetType :: LLVMValue -> LLVMType
|
||||||
valueGetType (VInteger _) = I64
|
valueGetType (VInteger _) = I64
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,18 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
||||||
import Auxiliary (snoc)
|
import Auxiliary (snoc)
|
||||||
import Control.Applicative (Applicative (liftA2))
|
import Control.Applicative (Applicative (liftA2))
|
||||||
import Control.Monad.State (MonadState (get, put), State, evalState)
|
import Control.Monad.State (MonadState (get, put), State, evalState)
|
||||||
import Data.Foldable.Extra (notNull)
|
import Data.Set (Set)
|
||||||
import Data.List (mapAccumL, partition)
|
|
||||||
import Data.Set (Set, (\\))
|
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import Prelude hiding (exp)
|
import Prelude hiding (exp)
|
||||||
import Renamer hiding (fromBinders)
|
import Renamer
|
||||||
import TypeCheckerIr
|
import TypeCheckerIr
|
||||||
|
|
||||||
|
|
||||||
-- | Lift lambdas and let expression into supercombinators.
|
-- | Lift lambdas and let expression into supercombinators.
|
||||||
|
-- Three phases:
|
||||||
|
-- @freeVars@ annotatss all the free variables.
|
||||||
|
-- @abstract@ converts lambdas into let expressions.
|
||||||
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||||
lambdaLift :: Program -> Program
|
lambdaLift :: Program -> Program
|
||||||
lambdaLift = collectScs . abstract . freeVars
|
lambdaLift = collectScs . abstract . freeVars
|
||||||
|
|
||||||
|
|
@ -29,55 +31,41 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||||
|
|
||||||
freeVarsExp :: Set Id -> Exp -> AnnExp
|
freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||||
freeVarsExp localVars = \case
|
freeVarsExp localVars = \case
|
||||||
|
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||||
|
| otherwise -> (mempty, AId n)
|
||||||
|
|
||||||
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
EInt i -> (mempty, AInt i)
|
||||||
| otherwise -> (mempty, AId n)
|
|
||||||
|
|
||||||
EInt i -> (mempty, AInt i)
|
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||||
|
where
|
||||||
|
e1' = freeVarsExp localVars e1
|
||||||
|
e2' = freeVarsExp localVars e2
|
||||||
|
|
||||||
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||||
where
|
where
|
||||||
e1' = freeVarsExp localVars e1
|
e1' = freeVarsExp localVars e1
|
||||||
e2' = freeVarsExp localVars e2
|
e2' = freeVarsExp localVars e2
|
||||||
|
|
||||||
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||||
where
|
where
|
||||||
e1' = freeVarsExp localVars e1
|
e' = freeVarsExp (Set.insert par localVars) e
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
-- Sum free variables present in bind and the expression
|
||||||
where
|
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||||
e' = freeVarsExp (Set.insert par localVars) e
|
where
|
||||||
|
binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||||
|
e_free = Set.delete name $ freeVarsOf e'
|
||||||
|
|
||||||
-- Sum free variables present in binders and the expression
|
rhs' = freeVarsExp e_localVars rhs
|
||||||
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e')
|
new_bind = ABind name parms rhs'
|
||||||
where
|
|
||||||
binders_frees = rhss_frees \\ names_set
|
|
||||||
e_free = freeVarsOf e' \\ names_set
|
|
||||||
|
|
||||||
rhss_frees = foldr1 Set.union (map freeVarsOf rhss')
|
e' = freeVarsExp e_localVars e
|
||||||
names_set = Set.fromList names
|
e_localVars = Set.insert name localVars
|
||||||
|
|
||||||
(names, parms, rhss) = fromBinders binders
|
|
||||||
rhss' = map (freeVarsExp e_localVars) rhss
|
|
||||||
e_localVars = Set.union localVars names_set
|
|
||||||
|
|
||||||
binders' = zipWith3 ABind names parms rhss'
|
|
||||||
e' = freeVarsExp e_localVars e
|
|
||||||
|
|
||||||
EAnn e t -> (freeVarsOf e', AAnn e' t)
|
|
||||||
where
|
|
||||||
e' = freeVarsExp localVars e
|
|
||||||
|
|
||||||
|
|
||||||
freeVarsOf :: AnnExp -> Set Id
|
freeVarsOf :: AnnExp -> Set Id
|
||||||
freeVarsOf = fst
|
freeVarsOf = fst
|
||||||
|
|
||||||
|
|
||||||
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
|
|
||||||
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
|
|
||||||
|
|
||||||
|
|
||||||
-- AST annotated with free variables
|
-- AST annotated with free variables
|
||||||
type AnnProgram = [(Id, [Id], AnnExp)]
|
type AnnProgram = [(Id, [Id], AnnExp)]
|
||||||
|
|
||||||
|
|
@ -87,23 +75,20 @@ data ABind = ABind Id [Id] AnnExp deriving Show
|
||||||
|
|
||||||
data AnnExp' = AId Id
|
data AnnExp' = AId Id
|
||||||
| AInt Integer
|
| AInt Integer
|
||||||
| ALet [ABind] AnnExp
|
| ALet ABind AnnExp
|
||||||
| AApp Type AnnExp AnnExp
|
| AApp Type AnnExp AnnExp
|
||||||
| AAdd Type AnnExp AnnExp
|
| AAdd Type AnnExp AnnExp
|
||||||
| AAbs Type Id AnnExp
|
| AAbs Type Id AnnExp
|
||||||
| AAnn AnnExp Type
|
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
|
|
||||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||||
abstract :: AnnProgram -> Program
|
abstract :: AnnProgram -> Program
|
||||||
abstract prog = Program $ evalState (mapM go prog) 0
|
abstract prog = Program $ evalState (mapM go prog) 0
|
||||||
where
|
where
|
||||||
go :: (Id, [Id], AnnExp) -> State Int Bind
|
go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||||
where
|
where
|
||||||
(rhs', parms1) = flattenLambdasAnn rhs
|
(rhs', parms1) = flattenLambdasAnn rhs
|
||||||
|
|
||||||
|
|
||||||
-- | Flatten nested lambdas and collect the parameters
|
-- | Flatten nested lambdas and collect the parameters
|
||||||
|
|
@ -113,112 +98,93 @@ flattenLambdasAnn ae = go (ae, [])
|
||||||
where
|
where
|
||||||
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||||
go ((free, e), acc) =
|
go ((free, e), acc) =
|
||||||
case e of
|
case e of
|
||||||
AAbs _ par (free1, e1) ->
|
AAbs _ par (free1, e1) ->
|
||||||
go ((Set.delete par free1, e1), snoc par acc)
|
go ((Set.delete par free1, e1), snoc par acc)
|
||||||
_ -> ((free, e), acc)
|
_ -> ((free, e), acc)
|
||||||
|
|
||||||
abstractExp :: AnnExp -> State Int Exp
|
abstractExp :: AnnExp -> State Int Exp
|
||||||
abstractExp (free, exp) = case exp of
|
abstractExp (free, exp) = case exp of
|
||||||
AId n -> pure $ EId n
|
AId n -> pure $ EId n
|
||||||
AInt i -> pure $ EInt i
|
AInt i -> pure $ EInt i
|
||||||
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||||
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||||
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e)
|
ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||||
where
|
where
|
||||||
go (ABind name parms rhs) = do
|
go (ABind name parms rhs) = do
|
||||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||||
pure $ Bind name (parms ++ parms1) rhs'
|
pure $ Bind name (parms ++ parms1) rhs'
|
||||||
|
|
||||||
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||||
skipLambdas f (free, ae) = case ae of
|
skipLambdas f (free, ae) = case ae of
|
||||||
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||||
_ -> f (free, ae)
|
_ -> f (free, ae)
|
||||||
|
|
||||||
-- Lift lambda into let and bind free variables
|
-- Lift lambda into let and bind free variables
|
||||||
AAbs t parm e -> do
|
AAbs t parm e -> do
|
||||||
i <- nextNumber
|
i <- nextNumber
|
||||||
rhs <- abstractExp e
|
rhs <- abstractExp e
|
||||||
|
|
||||||
let sc_name = Ident ("sc_" ++ show i)
|
let sc_name = Ident ("sc_" ++ show i)
|
||||||
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t)
|
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||||
|
|
||||||
pure $ foldl (EApp TInt) sc $ map EId freeList
|
pure $ foldl (EApp TInt) sc $ map EId freeList
|
||||||
where
|
where
|
||||||
freeList = Set.toList free
|
freeList = Set.toList free
|
||||||
parms = snoc parm freeList
|
parms = snoc parm freeList
|
||||||
|
|
||||||
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
|
|
||||||
|
|
||||||
nextNumber :: State Int Int
|
nextNumber :: State Int Int
|
||||||
nextNumber = do
|
nextNumber = do
|
||||||
i <- get
|
i <- get
|
||||||
put $ succ i
|
put $ succ i
|
||||||
pure i
|
pure i
|
||||||
|
|
||||||
-- | Collects supercombinators by lifting appropriate let expressions
|
-- | Collects supercombinators by lifting non-constant let expressions
|
||||||
collectScs :: Program -> Program
|
collectScs :: Program -> Program
|
||||||
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||||
where
|
where
|
||||||
collectFromRhs (Bind name parms rhs) =
|
collectFromRhs (Bind name parms rhs) =
|
||||||
let (rhs_scs, rhs') = collectScsExp rhs
|
let (rhs_scs, rhs') = collectScsExp rhs
|
||||||
in Bind name parms rhs' : rhs_scs
|
in Bind name parms rhs' : rhs_scs
|
||||||
|
|
||||||
|
|
||||||
collectScsExp :: Exp -> ([Bind], Exp)
|
collectScsExp :: Exp -> ([Bind], Exp)
|
||||||
collectScsExp = \case
|
collectScsExp = \case
|
||||||
EId n -> ([], EId n)
|
EId n -> ([], EId n)
|
||||||
EInt i -> ([], EInt i)
|
EInt i -> ([], EInt i)
|
||||||
|
|
||||||
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||||
where
|
where
|
||||||
(scs1, e1') = collectScsExp e1
|
(scs1, e1') = collectScsExp e1
|
||||||
(scs2, e2') = collectScsExp e2
|
(scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||||
where
|
where
|
||||||
(scs1, e1') = collectScsExp e1
|
(scs1, e1') = collectScsExp e1
|
||||||
(scs2, e2') = collectScsExp e2
|
(scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
EAbs t par e -> (scs, EAbs t par e')
|
EAbs t par e -> (scs, EAbs t par e')
|
||||||
where
|
where
|
||||||
(scs, e') = collectScsExp e
|
(scs, e') = collectScsExp e
|
||||||
|
|
||||||
-- Collect supercombinators from binds, the rhss, and the expression.
|
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||||
--
|
--
|
||||||
-- > f = let
|
-- > f = let sc x y = rhs in e
|
||||||
-- > sc = rhs
|
--
|
||||||
-- > sc1 = rhs1
|
ELet (Bind name parms rhs) e -> if null parms
|
||||||
-- > ...
|
then ( rhs_scs ++ e_scs, ELet bind e')
|
||||||
-- > in e
|
else (bind : rhs_scs ++ e_scs, e')
|
||||||
--
|
where
|
||||||
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
|
bind = Bind name parms rhs'
|
||||||
where
|
(rhs_scs, rhs') = collectScsExp rhs
|
||||||
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
|
(e_scs, e') = collectScsExp e
|
||||||
Bind n (parms ++ parms1) rhs'
|
|
||||||
| Bind n parms rhs <- scs'
|
|
||||||
]
|
|
||||||
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
|
|
||||||
(e_scs, e') = collectScsExp e
|
|
||||||
|
|
||||||
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
|
|
||||||
|
|
||||||
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
|
|
||||||
where
|
|
||||||
(rhs_scs, rhs') = collectScsExp rhs
|
|
||||||
|
|
||||||
EAnn e t -> (scs, EAnn e' t)
|
|
||||||
where
|
|
||||||
(scs, e') = collectScsExp e
|
|
||||||
|
|
||||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||||
flattenLambdas :: Exp -> (Exp, [Id])
|
flattenLambdas :: Exp -> (Exp, [Id])
|
||||||
flattenLambdas = go . (, [])
|
flattenLambdas = go . (, [])
|
||||||
where
|
where
|
||||||
go (e, acc) = case e of
|
go (e, acc) = case e of
|
||||||
EAbs _ par e1 -> go (e1, snoc par acc)
|
EAbs _ par e1 -> go (e1, snoc par acc)
|
||||||
_ -> (e, acc)
|
_ -> (e, acc)
|
||||||
|
|
||||||
mkEAbs :: [Bind] -> Exp -> Exp
|
|
||||||
mkEAbs [] e = e
|
|
||||||
mkEAbs bs e = ELet bs e
|
|
||||||
|
|
|
||||||
|
|
@ -51,8 +51,8 @@ data LLVMComp
|
||||||
instance Show LLVMComp where
|
instance Show LLVMComp where
|
||||||
show :: LLVMComp -> String
|
show :: LLVMComp -> String
|
||||||
show = \case
|
show = \case
|
||||||
LLEq -> "eq"
|
LLEq -> "eq"
|
||||||
LLNe -> "ne"
|
LLNe -> "ne"
|
||||||
LLUgt -> "ugt"
|
LLUgt -> "ugt"
|
||||||
LLUge -> "uge"
|
LLUge -> "uge"
|
||||||
LLUlt -> "ult"
|
LLUlt -> "ult"
|
||||||
|
|
@ -68,9 +68,8 @@ instance Show Visibility where
|
||||||
show Local = "%"
|
show Local = "%"
|
||||||
show Global = "@"
|
show Global = "@"
|
||||||
|
|
||||||
{- | Represents a LLVM "value", as in an integer, a register variable,
|
-- | Represents a LLVM "value", as in an integer, a register variable,
|
||||||
or a string contstant
|
-- or a string contstant
|
||||||
-}
|
|
||||||
data LLVMValue
|
data LLVMValue
|
||||||
= VInteger Integer
|
= VInteger Integer
|
||||||
| VIdent Ident LLVMType
|
| VIdent Ident LLVMType
|
||||||
|
|
|
||||||
129
src/Renamer.hs
129
src/Renamer.hs
|
|
@ -2,82 +2,83 @@
|
||||||
|
|
||||||
module Renamer (module Renamer) where
|
module Renamer (module Renamer) where
|
||||||
|
|
||||||
import Data.List (mapAccumL, unzip4, zipWith4)
|
import Auxiliary (mapAccumM)
|
||||||
import Data.Map (Map)
|
import Control.Monad.State (MonadState, State, evalState, gets,
|
||||||
import qualified Data.Map as Map
|
modify)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Map (Map)
|
||||||
|
import qualified Data.Map as Map
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
|
import Data.Tuple.Extra (dupe)
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
|
|
||||||
|
|
||||||
-- | Rename all supercombinators and variables
|
-- | Rename all variables and local binds
|
||||||
rename :: Program -> Program
|
rename :: Program -> Program
|
||||||
rename (Program sc) = Program $ map (renameSc 0) sc
|
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
|
||||||
where
|
where
|
||||||
renameSc i (Bind n t _ xs e) = Bind n t n xs' e'
|
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
|
||||||
where
|
renameSc :: Names -> Bind -> Rn Bind
|
||||||
(i1, xs', env) = newNames i xs
|
renameSc old_names (Bind name t _ parms rhs) = do
|
||||||
e' = snd $ renameExp env i1 e
|
(new_names, parms') <- newNames old_names parms
|
||||||
|
rhs' <- snd <$> renameExp new_names rhs
|
||||||
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp)
|
pure $ Bind name t name parms' rhs'
|
||||||
renameExp env i = \case
|
|
||||||
|
|
||||||
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
|
|
||||||
|
|
||||||
EInt i1 -> (i, EInt i1)
|
|
||||||
|
|
||||||
EApp e1 e2 -> (i2, EApp e1' e2')
|
|
||||||
where
|
|
||||||
(i1, e1') = renameExp env i e1
|
|
||||||
(i2, e2') = renameExp env i1 e2
|
|
||||||
|
|
||||||
EAdd e1 e2 -> (i2, EAdd e1' e2')
|
|
||||||
where
|
|
||||||
(i1, e1') = renameExp env i e1
|
|
||||||
(i2, e2') = renameExp env i1 e2
|
|
||||||
|
|
||||||
ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e')
|
|
||||||
where
|
|
||||||
mkBind name t = Bind name t name
|
|
||||||
(i1, e') = renameExp e_env i e
|
|
||||||
(names, types, pars, rhss) = fromBinders bs
|
|
||||||
(i2, names', env') = newNames i1 (names ++ concat pars)
|
|
||||||
pars' = (map . map) renamePar pars
|
|
||||||
e_env = Map.union env' env
|
|
||||||
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
|
|
||||||
|
|
||||||
renamePar p = case Map.lookup p env' of
|
|
||||||
Just p' -> p'
|
|
||||||
Nothing -> error ("Can't find name for " ++ show p)
|
|
||||||
|
|
||||||
|
|
||||||
EAbs par t e -> (i2, EAbs par' t e')
|
-- | Rename monad. State holds the number of renamed names.
|
||||||
where
|
newtype Rn a = Rn { runRn :: State Int a }
|
||||||
(i1, par', env') = newName par
|
deriving (Functor, Applicative, Monad, MonadState Int)
|
||||||
(i2, e') = renameExp (Map.union env' env ) i1 e
|
|
||||||
|
|
||||||
EAnn e t -> (i1, EAnn e' t)
|
-- | Maps old to new name
|
||||||
where
|
type Names = Map Ident Ident
|
||||||
(i1, e') = renameExp env i e
|
|
||||||
|
|
||||||
|
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
|
||||||
|
renameLocalBind old_names (Bind name t _ parms rhs) = do
|
||||||
|
(new_names, name') <- newName old_names name
|
||||||
|
(new_names', parms') <- newNames new_names parms
|
||||||
|
(new_names'', rhs') <- renameExp new_names' rhs
|
||||||
|
pure (new_names'', Bind name' t name' parms' rhs')
|
||||||
|
|
||||||
newName :: Ident -> (Int, Ident, Map Ident Ident)
|
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
||||||
newName old_name = (i, head names, env)
|
renameExp old_names = \case
|
||||||
where (i, names, env) = newNames 1 [old_name]
|
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
|
||||||
|
|
||||||
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
|
EInt i1 -> pure (old_names, EInt i1)
|
||||||
newNames i old_names = (i', new_names, env)
|
|
||||||
where
|
|
||||||
(i', new_names) = getNames i old_names
|
|
||||||
env = Map.fromList $ zip old_names new_names
|
|
||||||
|
|
||||||
getNames :: Int -> [Ident] -> (Int, [Ident])
|
EApp e1 e2 -> do
|
||||||
getNames i ns = (i + length ss, zipWith makeName ss [i..])
|
(env1, e1') <- renameExp old_names e1
|
||||||
where
|
(env2, e2') <- renameExp old_names e2
|
||||||
ss = map (\(Ident s) -> s) ns
|
pure (Map.union env1 env2, EApp e1' e2')
|
||||||
|
|
||||||
makeName :: String -> Int -> Ident
|
EAdd e1 e2 -> do
|
||||||
makeName prefix i = Ident (prefix ++ "_" ++ show i)
|
(env1, e1') <- renameExp old_names e1
|
||||||
|
(env2, e2') <- renameExp old_names e2
|
||||||
|
pure (Map.union env1 env2, EAdd e1' e2')
|
||||||
|
|
||||||
|
ELet b e -> do
|
||||||
|
(new_names, b) <- renameLocalBind old_names b
|
||||||
|
(new_names', e') <- renameExp new_names e
|
||||||
|
pure (new_names', ELet b e')
|
||||||
|
|
||||||
|
EAbs par t e -> do
|
||||||
|
(new_names, par') <- newName old_names par
|
||||||
|
(new_names', e') <- renameExp new_names e
|
||||||
|
pure (new_names', EAbs par' t e')
|
||||||
|
|
||||||
|
EAnn e t -> do
|
||||||
|
(new_names, e') <- renameExp old_names e
|
||||||
|
pure (new_names, EAnn e' t)
|
||||||
|
|
||||||
|
-- | Create a new name and add it to name environment.
|
||||||
|
newName :: Names -> Ident -> Rn (Names, Ident)
|
||||||
|
newName env old_name = do
|
||||||
|
new_name <- makeName old_name
|
||||||
|
pure (Map.insert old_name new_name env, new_name)
|
||||||
|
|
||||||
|
-- | Create multiple names and add them to the name environment
|
||||||
|
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
|
||||||
|
newNames = mapAccumM newName
|
||||||
|
|
||||||
|
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
||||||
|
makeName :: Ident -> Rn Ident
|
||||||
|
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
|
||||||
|
|
||||||
fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp])
|
|
||||||
fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ]
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree,
|
||||||
import Prelude hiding (exp, id)
|
import Prelude hiding (exp, id)
|
||||||
import qualified TypeCheckerIr as T
|
import qualified TypeCheckerIr as T
|
||||||
|
|
||||||
|
|
||||||
-- NOTE: this type checker is poorly tested
|
-- NOTE: this type checker is poorly tested
|
||||||
|
|
||||||
-- TODO
|
-- TODO
|
||||||
|
|
@ -22,9 +21,9 @@ import qualified TypeCheckerIr as T
|
||||||
-- Type inference
|
-- Type inference
|
||||||
|
|
||||||
data Cxt = Cxt
|
data Cxt = Cxt
|
||||||
{ env :: Map Ident Type
|
{ env :: Map Ident Type -- ^ Local scope signature
|
||||||
, sig :: Map Ident Type
|
, sig :: Map Ident Type -- ^ Top-level signatures
|
||||||
}
|
}
|
||||||
|
|
||||||
initCxt :: [Bind] -> Cxt
|
initCxt :: [Bind] -> Cxt
|
||||||
initCxt sc = Cxt { env = mempty
|
initCxt sc = Cxt { env = mempty
|
||||||
|
|
@ -34,134 +33,133 @@ initCxt sc = Cxt { env = mempty
|
||||||
typecheck :: Program -> Err T.Program
|
typecheck :: Program -> Err T.Program
|
||||||
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
|
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
|
||||||
|
|
||||||
|
-- | Check if infered rhs type matches type signature.
|
||||||
checkBind :: Cxt -> Bind -> Err T.Bind
|
checkBind :: Cxt -> Bind -> Err T.Bind
|
||||||
checkBind cxt b =
|
checkBind cxt b =
|
||||||
case expandLambdas b of
|
case expandLambdas b of
|
||||||
Bind name t _ parms rhs -> do
|
Bind name t _ parms rhs -> do
|
||||||
(rhs', t_rhs) <- infer cxt rhs
|
(rhs', t_rhs) <- infer cxt rhs
|
||||||
|
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
|
||||||
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
|
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
|
||||||
|
where
|
||||||
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
|
ts_parms = fst $ partitionType (length parms) t
|
||||||
|
|
||||||
where
|
|
||||||
ts_parms = fst $ partitionType (length parms) t
|
|
||||||
|
|
||||||
|
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
|
||||||
expandLambdas :: Bind -> Bind
|
expandLambdas :: Bind -> Bind
|
||||||
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
|
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
|
||||||
where
|
where
|
||||||
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
|
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
|
||||||
ts_parms = fst $ partitionType (length parms) t
|
ts_parms = fst $ partitionType (length parms) t
|
||||||
|
|
||||||
|
-- | Infer type of expression.
|
||||||
infer :: Cxt -> Exp -> Err (T.Exp, Type)
|
infer :: Cxt -> Exp -> Err (T.Exp, Type)
|
||||||
infer cxt = \case
|
infer cxt = \case
|
||||||
|
EId x ->
|
||||||
|
case lookupEnv x cxt of
|
||||||
|
Nothing ->
|
||||||
|
case lookupSig x cxt of
|
||||||
|
Nothing -> throwError ("Unbound variable:" ++ printTree x)
|
||||||
|
Just t -> pure (T.EId (x, t), t)
|
||||||
|
Just t -> pure (T.EId (x, t), t)
|
||||||
|
|
||||||
EId x ->
|
EInt i -> pure (T.EInt i, T.TInt)
|
||||||
case lookupEnv x cxt of
|
|
||||||
Nothing ->
|
|
||||||
case lookupSig x cxt of
|
|
||||||
Nothing -> throwError ("Unbound variable:" ++ printTree x)
|
|
||||||
Just t -> pure (T.EId (x, t), t)
|
|
||||||
Just t -> pure (T.EId (x, t), t)
|
|
||||||
|
|
||||||
EInt i -> pure (T.EInt i, T.TInt)
|
EApp e e1 -> do
|
||||||
|
(e', t) <- infer cxt e
|
||||||
|
case t of
|
||||||
|
TFun t1 t2 -> do
|
||||||
|
e1' <- check cxt e1 t1
|
||||||
|
pure (T.EApp t2 e' e1', t2)
|
||||||
|
_ -> do
|
||||||
|
throwError ("Not a function: " ++ show e)
|
||||||
|
|
||||||
EApp e e1 -> do
|
EAdd e e1 -> do
|
||||||
(e', t) <- infer cxt e
|
e' <- check cxt e T.TInt
|
||||||
case t of
|
e1' <- check cxt e1 T.TInt
|
||||||
TFun t1 t2 -> do
|
pure (T.EAdd T.TInt e' e1', T.TInt)
|
||||||
e1' <- check cxt e1 t1
|
|
||||||
pure (T.EApp t2 e' e1', t2)
|
|
||||||
_ -> do
|
|
||||||
throwError ("Not a function: " ++ show e)
|
|
||||||
|
|
||||||
EAdd e e1 -> do
|
EAbs x t e -> do
|
||||||
e' <- check cxt e T.TInt
|
(e', t1) <- infer (insertEnv x t cxt) e
|
||||||
e1' <- check cxt e1 T.TInt
|
let t_abs = TFun t t1
|
||||||
pure (T.EAdd T.TInt e' e1', T.TInt)
|
pure (T.EAbs t_abs (x, t) e', t_abs)
|
||||||
|
|
||||||
EAbs x t e -> do
|
ELet b e -> do
|
||||||
(e', t1) <- infer (insertEnv x t cxt) e
|
let cxt' = insertBind b cxt
|
||||||
let t_abs = TFun t t1
|
b' <- checkBind cxt' b
|
||||||
pure (T.EAbs t_abs (x, t) e', t_abs)
|
(e', t) <- infer cxt' e
|
||||||
|
pure (T.ELet b' e', t)
|
||||||
ELet bs e -> do
|
|
||||||
bs'' <- mapM (checkBind cxt') bs'
|
|
||||||
(e', t) <- infer cxt' e
|
|
||||||
pure (T.ELet bs'' e', t)
|
|
||||||
where
|
|
||||||
bs' = map expandLambdas bs
|
|
||||||
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
|
|
||||||
|
|
||||||
EAnn e t -> do
|
|
||||||
e' <- check cxt e t
|
|
||||||
pure (T.EAnn e' t, t)
|
|
||||||
|
|
||||||
|
EAnn e t -> do
|
||||||
|
(e', t1) <- infer cxt e
|
||||||
|
unless (typeEq t t1) $
|
||||||
|
throwError "Inferred type and type annotation doesn't match"
|
||||||
|
pure (e', t1)
|
||||||
|
|
||||||
|
-- | Check infered type matches the supplied type.
|
||||||
check :: Cxt -> Exp -> Type -> Err T.Exp
|
check :: Cxt -> Exp -> Type -> Err T.Exp
|
||||||
check cxt exp typ = case exp of
|
check cxt exp typ = case exp of
|
||||||
|
|
||||||
EId x -> do
|
EId x -> do
|
||||||
t <- case lookupEnv x cxt of
|
t <- case lookupEnv x cxt of
|
||||||
Nothing -> maybeToRightM
|
Nothing -> maybeToRightM
|
||||||
("Unbound variable:" ++ printTree x)
|
("Unbound variable:" ++ printTree x)
|
||||||
(lookupSig x cxt)
|
(lookupSig x cxt)
|
||||||
Just t -> pure t
|
Just t -> pure t
|
||||||
|
unless (typeEq t typ) . throwError $ typeErr x typ t
|
||||||
|
pure $ T.EId (x, t)
|
||||||
|
|
||||||
unless (typeEq t typ) . throwError $ typeErr x typ t
|
EInt i -> do
|
||||||
|
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
|
||||||
|
pure $ T.EInt i
|
||||||
|
|
||||||
pure $ T.EId (x, t)
|
EApp e e1 -> do
|
||||||
|
(e', t) <- infer cxt e
|
||||||
|
case t of
|
||||||
|
TFun t1 t2 -> do
|
||||||
|
e1' <- check cxt e1 t1
|
||||||
|
pure $ T.EApp t2 e' e1'
|
||||||
|
_ -> throwError ("Not a function 2: " ++ printTree e)
|
||||||
|
|
||||||
EInt i -> do
|
EAdd e e1 -> do
|
||||||
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
|
e' <- check cxt e T.TInt
|
||||||
pure $ T.EInt i
|
e1' <- check cxt e1 T.TInt
|
||||||
|
pure $ T.EAdd T.TInt e' e1'
|
||||||
|
|
||||||
EApp e e1 -> do
|
EAbs x t e -> do
|
||||||
(e', t) <- infer cxt e
|
(e', t_e) <- infer (insertEnv x t cxt) e
|
||||||
case t of
|
let t1 = TFun t t_e
|
||||||
TFun t1 t2 -> do
|
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
|
||||||
e1' <- check cxt e1 t1
|
pure $ T.EAbs t1 (x, t) e'
|
||||||
pure $ T.EApp t2 e' e1'
|
|
||||||
_ -> throwError ("Not a function 2: " ++ printTree e)
|
|
||||||
|
|
||||||
EAdd e e1 -> do
|
ELet b e -> do
|
||||||
e' <- check cxt e T.TInt
|
let cxt' = insertBind b cxt
|
||||||
e1' <- check cxt e1 T.TInt
|
b' <- checkBind cxt' b
|
||||||
pure $ T.EAdd T.TInt e' e1'
|
e' <- check cxt' e typ
|
||||||
|
pure $ T.ELet b' e'
|
||||||
|
|
||||||
EAbs x t e -> do
|
EAnn e t -> do
|
||||||
(e', t_e) <- infer (insertEnv x t cxt) e
|
unless (typeEq t typ) $
|
||||||
let t1 = TFun t t_e
|
throwError "Inferred type and type annotation doesn't match"
|
||||||
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
|
check cxt e t
|
||||||
pure $ T.EAbs t1 (x, t) e'
|
|
||||||
|
|
||||||
ELet bs e -> do
|
|
||||||
bs'' <- mapM (checkBind cxt') bs'
|
|
||||||
e' <- check cxt' e typ
|
|
||||||
pure $ T.ELet bs'' e'
|
|
||||||
where
|
|
||||||
bs' = map expandLambdas bs
|
|
||||||
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
|
|
||||||
|
|
||||||
EAnn e t -> do
|
|
||||||
unless (typeEq t typ) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
e' <- check cxt e t
|
|
||||||
pure $ T.EAnn e' typ
|
|
||||||
|
|
||||||
|
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
|
||||||
typeEq :: Type -> Type -> Bool
|
typeEq :: Type -> Type -> Bool
|
||||||
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
|
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
|
||||||
typeEq t t1 = t == t1
|
typeEq t t1 = t == t1
|
||||||
|
|
||||||
partitionType :: Int -> Type -> ([Type], Type)
|
-- | Partion type into types of parameters and return type.
|
||||||
|
partitionType :: Int -- Number of parameters to apply
|
||||||
|
-> Type
|
||||||
|
-> ([Type], Type)
|
||||||
partitionType = go []
|
partitionType = go []
|
||||||
where
|
where
|
||||||
go acc 0 t = (acc, t)
|
go acc 0 t = (acc, t)
|
||||||
go acc i t = case t of
|
go acc i t = case t of
|
||||||
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
||||||
_ -> error "Number of parameters and type doesn't match"
|
_ -> error "Number of parameters and type doesn't match"
|
||||||
|
|
||||||
|
insertBind :: Bind -> Cxt -> Cxt
|
||||||
|
insertBind (Bind n t _ _ _) = insertEnv n t
|
||||||
|
|
||||||
lookupEnv :: Ident -> Cxt -> Maybe Type
|
lookupEnv :: Ident -> Cxt -> Maybe Type
|
||||||
lookupEnv x = Map.lookup x . env
|
lookupEnv x = Map.lookup x . env
|
||||||
|
|
@ -174,7 +172,7 @@ lookupSig x = Map.lookup x . sig
|
||||||
|
|
||||||
typeErr :: Print a => a -> Type -> Type -> String
|
typeErr :: Print a => a -> Type -> Type -> String
|
||||||
typeErr p expected actual = render $ concatD
|
typeErr p expected actual = render $ concatD
|
||||||
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
|
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
|
||||||
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
|
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
|
||||||
, doc $ showString "Actual: " , prt 0 actual
|
, doc $ showString "Actual: " , prt 0 actual
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -16,30 +16,35 @@ newtype Program = Program [Bind]
|
||||||
data Exp
|
data Exp
|
||||||
= EId Id
|
= EId Id
|
||||||
| EInt Integer
|
| EInt Integer
|
||||||
| ELet [Bind] Exp
|
| ELet Bind Exp
|
||||||
| EApp Type Exp Exp
|
| EApp Type Exp Exp
|
||||||
| EAdd Type Exp Exp
|
| EAdd Type Exp Exp
|
||||||
| EAbs Type Id Exp
|
| EAbs Type Id Exp
|
||||||
| EAnn Exp Type
|
| ECased Exp [Case]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
|
|
||||||
|
data Case = Case CLit Exp
|
||||||
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
|
|
||||||
|
data CLit = CInt Integer | CatchAll
|
||||||
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
type Id = (Ident, Type)
|
type Id = (Ident, Type)
|
||||||
|
|
||||||
data Bind = Bind Id [Id] Exp
|
data Bind = Bind Id [Id] Exp
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
|
|
||||||
instance Print Program where
|
instance Print Program where
|
||||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||||
|
|
||||||
instance Print Bind where
|
instance Print Bind where
|
||||||
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
|
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
|
||||||
[ prtId 0 name
|
[ prtId 0 name
|
||||||
, doc $ showString ";"
|
, doc $ showString ";"
|
||||||
, prt 0 n
|
, prt 0 n
|
||||||
, prtIdPs 0 parms
|
, prtIdPs 0 parms
|
||||||
, doc $ showString "="
|
, doc $ showString "="
|
||||||
, prt 0 rhs
|
, prt 0 rhs
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print [Bind] where
|
instance Print [Bind] where
|
||||||
prt _ [] = concatD []
|
prt _ [] = concatD []
|
||||||
|
|
@ -51,58 +56,51 @@ prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
||||||
|
|
||||||
prtId :: Int -> Id -> Doc
|
prtId :: Int -> Id -> Doc
|
||||||
prtId i (name, t) = prPrec i 0 $ concatD
|
prtId i (name, t) = prPrec i 0 $ concatD
|
||||||
[ prt 0 name
|
[ prt 0 name
|
||||||
, doc $ showString ":"
|
, doc $ showString ":"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
]
|
]
|
||||||
|
|
||||||
prtIdP :: Int -> Id -> Doc
|
prtIdP :: Int -> Id -> Doc
|
||||||
prtIdP i (name, t) = prPrec i 0 $ concatD
|
prtIdP i (name, t) = prPrec i 0 $ concatD
|
||||||
[ doc $ showString "("
|
[ doc $ showString "("
|
||||||
, prt 0 name
|
, prt 0 name
|
||||||
, doc $ showString ":"
|
, doc $ showString ":"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
, doc $ showString ")"
|
, doc $ showString ")"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
instance Print Exp where
|
instance Print Exp where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
|
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
|
||||||
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
|
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
|
||||||
ELet bs e -> prPrec i 3 $ concatD
|
ELet bs e -> prPrec i 3 $ concatD
|
||||||
[ doc $ showString "let"
|
[ doc $ showString "let"
|
||||||
, prt 0 bs
|
, prt 0 bs
|
||||||
, doc $ showString "in"
|
, doc $ showString "in"
|
||||||
, prt 0 e
|
, prt 0 e
|
||||||
]
|
]
|
||||||
EApp t e1 e2 -> prPrec i 2 $ concatD
|
EApp t e1 e2 -> prPrec i 2 $ concatD
|
||||||
[ doc $ showString "@"
|
[ doc $ showString "@"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
, prt 2 e1
|
, prt 2 e1
|
||||||
, prt 3 e2
|
, prt 3 e2
|
||||||
]
|
]
|
||||||
EAdd t e1 e2 -> prPrec i 1 $ concatD
|
EAdd t e1 e2 -> prPrec i 1 $ concatD
|
||||||
[ doc $ showString "@"
|
[ doc $ showString "@"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
, prt 1 e1
|
, prt 1 e1
|
||||||
, doc $ showString "+"
|
, doc $ showString "+"
|
||||||
, prt 2 e2
|
, prt 2 e2
|
||||||
]
|
]
|
||||||
EAbs t n e -> prPrec i 0 $ concatD
|
EAbs t n e -> prPrec i 0 $ concatD
|
||||||
[ doc $ showString "@"
|
[ doc $ showString "@"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
, doc $ showString "\\"
|
, doc $ showString "\\"
|
||||||
, prtIdP 0 n
|
, prtIdP 0 n
|
||||||
, doc $ showString "."
|
, doc $ showString "."
|
||||||
, prt 0 e
|
, prt 0 e
|
||||||
]
|
]
|
||||||
EAnn e t -> prPrec i 3 $ concatD
|
|
||||||
[ doc $ showString "("
|
|
||||||
, prt 0 e
|
|
||||||
, doc $ showString ":"
|
|
||||||
, prt 0 t
|
|
||||||
, doc $ showString ")"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue