Merge branch 'prep-tc-martin' of github.com:bachelor-group-66-systemf/language into prep-tc-martin

This commit is contained in:
Samuel Hammersberg 2023-02-18 15:03:11 +01:00
commit a4c12ede79
8 changed files with 590 additions and 640 deletions

View file

@ -1,24 +1,20 @@
Program. Program ::= [Bind]; Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
EId. Exp3 ::= Ident; EInt. Exp3 ::= Integer;
EInt. Exp3 ::= Integer; EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" [Bind] "in" Exp; ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3; EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2; EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp; EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}"; ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}";
--
CaseMatch. CaseMatch ::= Case "=>" Exp ; CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ","; separator CaseMatch ",";
--terminator CaseMatch ".";
CInt. Case ::= Integer ; CInt. Case ::= Integer ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ; Ident [Ident] "=" Exp;
separator Bind ";"; separator Bind ";";
separator Ident ""; separator Ident "";

View file

@ -1,20 +1,21 @@
--tripplemagic : Int -> Int -> Int -> Int; -- tripplemagic : Int -> Int -> Int -> Int;
--tripplemagic x y z = ((\x:Int. x+x) x) + y + z; -- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
--main : Int; -- main : Int;
--main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 -- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
--apply : (Int -> Int) -> Int -> Int;
--apply f x = f x;
--
--main : Int;
--main = (\x : Int . x + 5) 5
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
apply : (Int -> Int -> Int) -> Int -> Int -> Int; apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x y = f x y; apply f x y = f x y;
krimp: Int -> Int -> Int; krimp: Int -> Int -> Int;
krimp x y = x + y; krimp x y = x + y;
main : Int; main : Int;
main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2; main = apply (krimp) 2 3;
-- answer: 5

View file

@ -3,22 +3,21 @@
module Compiler (compile) where module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify) import Auxiliary (snoc)
import Data.List.Extra (trim) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) --import Data.List.Extra (trim)
import qualified Data.Map as Map import Data.Map (Map)
import Data.Tuple.Extra (second) import qualified Data.Map as Map
import Grammar.ErrM (Err) import Data.Tuple.Extra (dupe, first, second)
import Grammar.Print (printTree) import Grammar.ErrM (Err)
import LlvmIr (LLVMComp (..), LLVMIr (..), import LlvmIr (LLVMComp (..), LLVMIr (..), LLVMType (..),
LLVMType (..), LLVMValue (..), LLVMValue (..), Visibility (..),
Visibility (..), llvmIrToString) llvmIrToString)
import System.IO (stdin) --import System.Process.Extra (readCreateProcess, shell)
import System.Process.Extra (CreateProcess (std_in), import TypeChecker (partitionType)
StdStream (CreatePipe), createProcess, import TypeCheckerIr (Bind (..), CLit (CInt, CatchAll),
readCreateProcess, shell) Case (..), Exp (..), Id, Ident (..),
import TypeChecker (partitionType) Program (..), Type (TFun, TInt))
import TypeCheckerIr
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
@ -38,11 +37,11 @@ data FunctionInfo = FunctionInfo
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify (\t -> t{instructions = instructions t ++ [l]}) emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1}) increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state -- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer getVarCount :: CompilerState Integer
@ -58,280 +57,272 @@ getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1}) modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount gets labelCount
{- | Produces a map of functions infos from a list of binds, -- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. -- which contains useful data for code generation.
-}
getFunctions :: [Bind] -> Map Id FunctionInfo getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions xs = getFunctions bs = Map.fromList $ map go bs
Map.fromList $ where
map go (Bind id args _) =
( \(Bind id args _) -> (id, FunctionInfo { numArgs=length args, arguments=args })
( id
, FunctionInfo
{ numArgs = length args
, arguments = args
}
)
)
xs
run :: Err String -> IO ()
run s = do
let s' = case s of initCodeGenerator :: [Bind] -> CodeGenerator
Right s -> s initCodeGenerator scs = CodeGenerator { instructions = defaultStart
Left _ -> error "yo" , functions = getFunctions scs
writeFile "llvm.ll" s' , variableCount = 0
putStrLn . trim =<< readCreateProcess (shell "lli") s' , labelCount = 0
test :: Integer -> Program }
test v = Program [
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] ( --run :: Err String -> IO ()
ECased (EId ("x", TInt)) [ --run s = do
Case (CInt 0) (EInt 0), -- let s' = case s of
Case (CInt 1) (EInt 1), -- Right s -> s
Case CatchAll (EAdd TInt -- Left _ -> error "yo"
(EApp TInt (EId (Ident "fibonacci", TInt)) ( -- writeFile "llvm.ll" s'
EAdd TInt (EId (Ident "x", TInt)) -- putStrLn . trim =<< readCreateProcess (shell "lli") s'
(EInt (fromIntegral ((maxBound :: Int) * 2))) --
)) --test :: Integer -> Program
(EApp TInt (EId (Ident "fibonacci", TInt)) ( --test v = Program [
EAdd TInt (EId (Ident "x", TInt)) -- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1))) -- ECased (EId ("x", TInt)) [
)) -- Case (CInt 0) (EInt 0),
) -- Case (CInt 1) (EInt 1),
] -- Case CatchAll (EAdd TInt
), -- (EApp TInt (EId (Ident "fibonacci", TInt)) (
Bind (Ident "main",TInt) [] ( -- EAdd TInt (EId (Ident "x", TInt))
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92) -- (EInt (fromIntegral ((maxBound :: Int) * 2)))
) -- ))
] -- (EApp TInt (EId (Ident "fibonacci", TInt)) (
-- EAdd TInt (EId (Ident "x", TInt))
-- (EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
-- ))
-- )
-- ]
-- ),
-- Bind (Ident "main",TInt) [] (
-- EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
-- )
-- ]
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
Simply pipe it to LLI Simply pipe it to LLI
-} -}
compile :: Program -> Err String compile :: Program -> Err String
compile (Program prg) = do compile (Program scs) = do
let s = let codegen = initCodeGenerator scs
CodeGenerator llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
{ instructions = defaultStart
, functions = getFunctions prg compileScs :: [Bind] -> CompilerState ()
, variableCount = 0 compileScs [] = pure ()
, labelCount = 0 compileScs (Bind (name, t) args exp : xs) = do
} emit $ UnsafeRaw "\n"
ins <- instructions <$> execStateT (goDef prg) s emit . Comment $ show name <> ": " <> show exp
pure $ llvmIrToString ins let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where where
mainContent :: LLVMValue -> [LLVMIr] t_return = snd $ partitionType (length args) t
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr] mainContent :: LLVMValue -> [LLVMIr]
defaultStart = mainContent var =
[ Comment (show $ printTree (Program prg)) [ UnsafeRaw $
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
] -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
goDef :: [Bind] -> CompilerState () defaultStart :: [LLVMIr]
goDef [] = return () defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
goDef (Bind (name, t) args exp : xs) = do , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
emit $ UnsafeRaw "\n" ]
emit $ Comment $ show name <> ": " <> show exp
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit (mainContent functionBody)
else emit $ Ret I64 functionBody
emit DefineEnd
modify (\s -> s{variableCount = 0})
goDef xs
where
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState () compileExp :: Exp -> CompilerState ()
go (EInt int) = emitInt int compileExp (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2 compileExp (EAdd t e1 e2) = emitAdd t e1 e2
go (EId (name, _)) = emitIdent name compileExp (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2 compileExp (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e compileExp (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e compileExp (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn compileExp (ECased e c) = emitECased e c
go (ECased e c) = emitECased e c
-- go (ESub e1 e2) = emitSub e1 e2 -- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: Exp -> [Case] -> CompilerState () emitECased :: Exp -> [Case] -> CompilerState ()
emitECased e cs = do emitECased e cs = do
vs <- exprToValue e vs <- exprToValue e
lbl <- getNewLabel lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64) emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
mapM_ (emitCases label stackPtr vs) cs mapM_ (emitCases label stackPtr vs) cs
emit $ Label label emit $ Label label
res <- getNewVar res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr)) emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
where where
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState () emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases label stackPtr vs (Case (CInt i) exp) = do emitCases label stackPtr vs (Case (CInt i) exp) = do
ns <- getNewVar ns <- getNewVar
lbl_fail <- getNewLabel lbl_fail <- getNewLabel
lbl_succ <- getNewLabel lbl_succ <- getNewLabel
let failed = Ident $ "failed_" <> show lbl_fail let failed = Ident $ "failed_" <> show lbl_fail
let success = Ident $ "success_" <> show lbl_succ let success = Ident $ "success_" <> show lbl_succ
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i)) emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
emit $ Label success emit $ Label success
val <- exprToValue exp val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr) emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label emit $ Br label
emit $ Label failed emit $ Label failed
emitCases label stackPtr _ (Case CatchAll exp) = do emitCases label stackPtr _ (Case CatchAll exp) = do
val <- exprToValue exp val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr) emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label emit $ Br label
emitEAnn :: CompilerState () emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages" emitAbs _t tid e = do
emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
concat
[ "ELet ("
, show xs
, " = "
, show e
, ") is not implemented!"
]
emitAbs :: Type -> Id -> Exp -> CompilerState () emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitAbs _t tid e = do emitApp t e1 e2 = appEmitter t e1 e2 []
emit . Comment $ where
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
emitLet :: [Bind] -> Exp -> CompilerState () appEmitter t e1 e2 stack = do
emitLet xs e = do let newStack = e2 : stack
emit $ case e1 of
Comment $ EApp _ e1' e2' -> appEmitter t e1' e2' newStack
concat EId id@(name, _) -> do
[ "ELet (" args <- traverse exprToValue newStack
, show xs vs <- getNewVar
, " = " funcs <- gets functions
, show e let visibility = maybe Local (const Global) $ Map.lookup id funcs
, ") is not implemented!" args' = map (first valueGetType . dupe) args
] call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitApp :: Type -> Exp -> Exp -> CompilerState () emitIdent :: Ident -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 [] emitIdent id = do
where -- !!this should never happen!!
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () emit $ Comment "This should not have happened!"
appEmitter t e1 e2 stack = do emit $ Variable id
let newStack = e2 : stack emit $ UnsafeRaw "\n"
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let vis = case Map.lookup id funcs of
Nothing -> Local
Just _ -> Global
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState () emitInt :: Integer -> CompilerState ()
emitIdent id = do emitInt i = do
-- !!this should never happen!! -- !!this should never happen!!
emit $ Comment "This should not have happened!" varCount <- getNewVar
emit $ Variable id emit $ Comment "This should not have happened!"
emit $ UnsafeRaw "\n" emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitInt :: Integer -> CompilerState () emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitInt i = do emitAdd t e1 e2 = do
-- !!this should never happen!! v1 <- exprToValue e1
varCount <- getNewVar v2 <- exprToValue e2
emit $ Comment "This should not have happened!" v <- getNewVar
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitAdd :: Type -> Exp -> Exp -> CompilerState () -- emitMul :: Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do -- emitMul e1 e2 = do
v1 <- exprToValue e1 -- (v1,v2) <- binExprToValues e1 e2
v2 <- exprToValue e2 -- increaseVarCount
v <- getNewVar -- v <- gets variableCount
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) -- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMul :: Exp -> Exp -> CompilerState () -- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do -- emitMod e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- -- `let m a b = rem (abs $ b + a) b`
-- increaseVarCount -- (v1,v2) <- binExprToValues e1 e2
-- v <- gets variableCount -- increaseVarCount
-- emit $ SetVariable $ Ident $ show v -- vadd <- gets variableCount
-- emit $ Mul I64 v1 v2 -- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitMod :: Exp -> Exp -> CompilerState () -- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do -- emitDiv e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b` -- (v1,v2) <- binExprToValues e1 e2
-- (v1,v2) <- binExprToValues e1 e2 -- increaseVarCount
-- increaseVarCount -- v <- gets variableCount
-- vadd <- gets variableCount -- emit $ SetVariable $ Ident $ show v
-- emit $ SetVariable $ Ident $ show vadd -- emit $ Div I64 v1 v2
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState () -- emitSub :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do -- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount -- increaseVarCount
-- v <- gets variableCount -- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v -- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2 -- emit $ Sub I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState () exprToValue :: Exp -> CompilerState LLVMValue
-- emitSub e1 e2 = do exprToValue = \case
-- (v1,v2) <- binExprToValues e1 e2 EInt i -> pure $ VInteger i
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue EId id@(name, t) -> do
exprToValue (EInt i) = return $ VInteger i
exprToValue (EId id@(name, t)) = do
funcs <- gets functions funcs <- gets functions
case Map.lookup id funcs of case Map.lookup id funcs of
Just fi -> do Just fi -> do
if numArgs fi == 0 if numArgs fi == 0
then do then do
vc <- getNewVar vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) emit $ SetVariable (Ident $ show vc)
return $ VIdent (Ident $ show vc) (type2LlvmType t) (Call (type2LlvmType t) Global name [])
else return $ VFunction name Global (type2LlvmType t) pure $ VIdent (Ident $ show vc) (type2LlvmType t)
Nothing -> return $ VIdent name (type2LlvmType t) else pure $ VFunction name Global (type2LlvmType t)
exprToValue e = do Nothing -> pure $ VIdent name (type2LlvmType t)
go e
e -> do
compileExp e
v <- getVarCount v <- getVarCount
return $ VIdent (Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType = \case type2LlvmType = \case
@ -346,13 +337,13 @@ type2LlvmType = \case
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType getType :: Exp -> LLVMType
getType (EInt _) = I64 getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t getType (ECased e cs) = undefined
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64

View file

@ -7,16 +7,18 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc) import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState) import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Foldable.Extra (notNull) import Data.Set (Set)
import Data.List (mapAccumL, partition)
import Data.Set (Set, (\\))
import qualified Data.Set as Set import qualified Data.Set as Set
import Prelude hiding (exp) import Prelude hiding (exp)
import Renamer hiding (fromBinders) import Renamer
import TypeCheckerIr import TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators. -- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars lambdaLift = collectScs . abstract . freeVars
@ -29,55 +31,41 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
freeVarsExp :: Set Id -> Exp -> AnnExp freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EId n | Set.member n localVars -> (Set.singleton n, AId n) EInt i -> (mempty, AInt i)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i) EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where where
e1' = freeVarsExp localVars e1 e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2 e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where where
e1' = freeVarsExp localVars e1 e' = freeVarsExp (Set.insert par localVars) e
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') -- Sum free variables present in bind and the expression
where ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
e' = freeVarsExp (Set.insert par localVars) e where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
-- Sum free variables present in binders and the expression rhs' = freeVarsExp e_localVars rhs
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') new_bind = ABind name parms rhs'
where
binders_frees = rhss_frees \\ names_set
e_free = freeVarsOf e' \\ names_set
rhss_frees = foldr1 Set.union (map freeVarsOf rhss') e' = freeVarsExp e_localVars e
names_set = Set.fromList names e_localVars = Set.insert name localVars
(names, parms, rhss) = fromBinders binders
rhss' = map (freeVarsExp e_localVars) rhss
e_localVars = Set.union localVars names_set
binders' = zipWith3 ABind names parms rhss'
e' = freeVarsExp e_localVars e
EAnn e t -> (freeVarsOf e', AAnn e' t)
where
e' = freeVarsExp localVars e
freeVarsOf :: AnnExp -> Set Id freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst freeVarsOf = fst
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
-- AST annotated with free variables -- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)] type AnnProgram = [(Id, [Id], AnnExp)]
@ -87,23 +75,20 @@ data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id data AnnExp' = AId Id
| AInt Integer | AInt Integer
| ALet [ABind] AnnExp | ALet ABind AnnExp
| AApp Type AnnExp AnnExp | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp | AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp | AAbs Type Id AnnExp
| AAnn AnnExp Type
deriving Show deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound. -- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0 abstract prog = Program $ evalState (mapM go prog) 0
where where
go :: (Id, [Id], AnnExp) -> State Int Bind go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where where
(rhs', parms1) = flattenLambdasAnn rhs (rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters -- | Flatten nested lambdas and collect the parameters
@ -113,112 +98,93 @@ flattenLambdasAnn ae = go (ae, [])
where where
go :: (AnnExp, [Id]) -> (AnnExp, [Id]) go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) = go ((free, e), acc) =
case e of case e of
AAbs _ par (free1, e1) -> AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc) go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc) _ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of abstractExp (free, exp) = case exp of
AId n -> pure $ EId n AId n -> pure $ EId n
AInt i -> pure $ EInt i AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e) ALet b e -> liftA2 ELet (go b) (abstractExp e)
where where
go (ABind name parms rhs) = do go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs' pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae) _ -> f (free, ae)
-- Lift lambda into let and bind free variables -- Lift lambda into let and bind free variables
AAbs t parm e -> do AAbs t parm e -> do
i <- nextNumber i <- nextNumber
rhs <- abstractExp e rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t) sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList pure $ foldl (EApp TInt) sc $ map EId freeList
where where
freeList = Set.toList free freeList = Set.toList free
parms = snoc parm freeList parms = snoc parm freeList
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
nextNumber :: State Int Int nextNumber :: State Int Int
nextNumber = do nextNumber = do
i <- get i <- get
put $ succ i put $ succ i
pure i pure i
-- | Collects supercombinators by lifting appropriate let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where where
collectFromRhs (Bind name parms rhs) = collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp) collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case collectScsExp = \case
EId n -> ([], EId n) EId n -> ([], EId n)
EInt i -> ([], EInt i) EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e') EAbs t par e -> (scs, EAbs t par e')
where where
(scs, e') = collectScsExp e (scs, e') = collectScsExp e
-- Collect supercombinators from binds, the rhss, and the expression. -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- > f = let -- > f = let sc x y = rhs in e
-- > sc = rhs --
-- > sc1 = rhs1 ELet (Bind name parms rhs) e -> if null parms
-- > ... then ( rhs_scs ++ e_scs, ELet bind e')
-- > in e else (bind : rhs_scs ++ e_scs, e')
-- where
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') bind = Bind name parms rhs'
where (rhs_scs, rhs') = collectScsExp rhs
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in (e_scs, e') = collectScsExp e
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
EAnn e t -> (scs, EAnn e' t)
where
(scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@ -- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id]) flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, []) flattenLambdas = go . (, [])
where where
go (e, acc) = case e of go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc) EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc) _ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e
mkEAbs bs e = ELet bs e

View file

@ -51,8 +51,8 @@ data LLVMComp
instance Show LLVMComp where instance Show LLVMComp where
show :: LLVMComp -> String show :: LLVMComp -> String
show = \case show = \case
LLEq -> "eq" LLEq -> "eq"
LLNe -> "ne" LLNe -> "ne"
LLUgt -> "ugt" LLUgt -> "ugt"
LLUge -> "uge" LLUge -> "uge"
LLUlt -> "ult" LLUlt -> "ult"
@ -68,9 +68,8 @@ instance Show Visibility where
show Local = "%" show Local = "%"
show Global = "@" show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, -- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant -- or a string contstant
-}
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VIdent Ident LLVMType | VIdent Ident LLVMType

View file

@ -2,82 +2,83 @@
module Renamer (module Renamer) where module Renamer (module Renamer) where
import Data.List (mapAccumL, unzip4, zipWith4) import Auxiliary (mapAccumM)
import Data.Map (Map) import Control.Monad.State (MonadState, State, evalState, gets,
import qualified Data.Map as Map modify)
import Data.Maybe (fromMaybe) import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all supercombinators and variables -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
rename (Program sc) = Program $ map (renameSc 0) sc rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where where
renameSc i (Bind n t _ xs e) = Bind n t n xs' e' initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
where renameSc :: Names -> Bind -> Rn Bind
(i1, xs', env) = newNames i xs renameSc old_names (Bind name t _ parms rhs) = do
e' = snd $ renameExp env i1 e (new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) pure $ Bind name t name parms' rhs'
renameExp env i = \case
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
EInt i1 -> (i, EInt i1)
EApp e1 e2 -> (i2, EApp e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
EAdd e1 e2 -> (i2, EAdd e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e')
where
mkBind name t = Bind name t name
(i1, e') = renameExp e_env i e
(names, types, pars, rhss) = fromBinders bs
(i2, names', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
EAbs par t e -> (i2, EAbs par' t e') -- | Rename monad. State holds the number of renamed names.
where newtype Rn a = Rn { runRn :: State Int a }
(i1, par', env') = newName par deriving (Functor, Applicative, Monad, MonadState Int)
(i2, e') = renameExp (Map.union env' env ) i1 e
EAnn e t -> (i1, EAnn e' t) -- | Maps old to new name
where type Names = Map Ident Ident
(i1, e') = renameExp env i e
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
newName :: Ident -> (Int, Ident, Map Ident Ident) renameExp :: Names -> Exp -> Rn (Names, Exp)
newName old_name = (i, head names, env) renameExp old_names = \case
where (i, names, env) = newNames 1 [old_name] EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) EInt i1 -> pure (old_names, EInt i1)
newNames i old_names = (i', new_names, env)
where
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
getNames :: Int -> [Ident] -> (Int, [Ident]) EApp e1 e2 -> do
getNames i ns = (i + length ss, zipWith makeName ss [i..]) (env1, e1') <- renameExp old_names e1
where (env2, e2') <- renameExp old_names e2
ss = map (\(Ident s) -> s) ns pure (Map.union env1 env2, EApp e1' e2')
makeName :: String -> Int -> Ident EAdd e1 e2 -> do
makeName prefix i = Ident (prefix ++ "_" ++ show i) (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ELet b e -> do
(new_names, b) <- renameLocalBind old_names b
(new_names', e') <- renameExp new_names e
pure (new_names', ELet b e')
EAbs par t e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp])
fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ]

View file

@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree,
import Prelude hiding (exp, id) import Prelude hiding (exp, id)
import qualified TypeCheckerIr as T import qualified TypeCheckerIr as T
-- NOTE: this type checker is poorly tested -- NOTE: this type checker is poorly tested
-- TODO -- TODO
@ -22,9 +21,9 @@ import qualified TypeCheckerIr as T
-- Type inference -- Type inference
data Cxt = Cxt data Cxt = Cxt
{ env :: Map Ident Type { env :: Map Ident Type -- ^ Local scope signature
, sig :: Map Ident Type , sig :: Map Ident Type -- ^ Top-level signatures
} }
initCxt :: [Bind] -> Cxt initCxt :: [Bind] -> Cxt
initCxt sc = Cxt { env = mempty initCxt sc = Cxt { env = mempty
@ -34,134 +33,133 @@ initCxt sc = Cxt { env = mempty
typecheck :: Program -> Err T.Program typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
-- | Check if infered rhs type matches type signature.
checkBind :: Cxt -> Bind -> Err T.Bind checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b = checkBind cxt b =
case expandLambdas b of case expandLambdas b of
Bind name t _ parms rhs -> do Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs (rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
pure $ T.Bind (name, t) (zip parms ts_parms) rhs' ts_parms = fst $ partitionType (length parms) t
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
expandLambdas :: Bind -> Bind expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
where where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t ts_parms = fst $ partitionType (length parms) t
-- | Infer type of expression.
infer :: Cxt -> Exp -> Err (T.Exp, Type) infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EId x -> EInt i -> pure (T.EInt i, T.TInt)
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EInt i -> pure (T.EInt i, T.TInt) EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EApp e e1 -> do EAdd e e1 -> do
(e', t) <- infer cxt e e' <- check cxt e T.TInt
case t of e1' <- check cxt e1 T.TInt
TFun t1 t2 -> do pure (T.EAdd T.TInt e' e1', T.TInt)
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do EAbs x t e -> do
e' <- check cxt e T.TInt (e', t1) <- infer (insertEnv x t cxt) e
e1' <- check cxt e1 T.TInt let t_abs = TFun t t1
pure (T.EAdd T.TInt e' e1', T.TInt) pure (T.EAbs t_abs (x, t) e', t_abs)
EAbs x t e -> do ELet b e -> do
(e', t1) <- infer (insertEnv x t cxt) e let cxt' = insertBind b cxt
let t_abs = TFun t t1 b' <- checkBind cxt' b
pure (T.EAbs t_abs (x, t) e', t_abs) (e', t) <- infer cxt' e
pure (T.ELet b' e', t)
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
(e', t) <- infer cxt' e
pure (T.ELet bs'' e', t)
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do
e' <- check cxt e t
pure (T.EAnn e' t, t)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of check cxt exp typ = case exp of
EId x -> do EId x -> do
t <- case lookupEnv x cxt of t <- case lookupEnv x cxt of
Nothing -> maybeToRightM Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x) ("Unbound variable:" ++ printTree x)
(lookupSig x cxt) (lookupSig x cxt)
Just t -> pure t Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
unless (typeEq t typ) . throwError $ typeErr x typ t EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
pure $ T.EId (x, t) EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EInt i -> do EAdd e e1 -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ e' <- check cxt e T.TInt
pure $ T.EInt i e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
EApp e e1 -> do EAbs x t e -> do
(e', t) <- infer cxt e (e', t_e) <- infer (insertEnv x t cxt) e
case t of let t1 = TFun t t_e
TFun t1 t2 -> do unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
e1' <- check cxt e1 t1 pure $ T.EAbs t1 (x, t) e'
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do ELet b e -> do
e' <- check cxt e T.TInt let cxt' = insertBind b cxt
e1' <- check cxt e1 T.TInt b' <- checkBind cxt' b
pure $ T.EAdd T.TInt e' e1' e' <- check cxt' e typ
pure $ T.ELet b' e'
EAbs x t e -> do EAnn e t -> do
(e', t_e) <- infer (insertEnv x t cxt) e unless (typeEq t typ) $
let t1 = TFun t t_e throwError "Inferred type and type annotation doesn't match"
unless (typeEq t1 typ) $ throwError "Wrong lamda type!" check cxt e t
pure $ T.EAbs t1 (x, t) e'
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
e' <- check cxt' e typ
pure $ T.ELet bs'' e'
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
e' <- check cxt e t
pure $ T.EAnn e' typ
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1 typeEq t t1 = t == t1
partitionType :: Int -> Type -> ([Type], Type) -- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go [] partitionType = go []
where where
go acc 0 t = (acc, t) go acc 0 t = (acc, t)
go acc i t = case t of go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match" _ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
lookupEnv :: Ident -> Cxt -> Maybe Type lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env lookupEnv x = Map.lookup x . env
@ -174,7 +172,7 @@ lookupSig x = Map.lookup x . sig
typeErr :: Print a => a -> Type -> Type -> String typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual , doc $ showString "Actual: " , prt 0 actual
] ]

View file

@ -16,30 +16,35 @@ newtype Program = Program [Bind]
data Exp data Exp
= EId Id = EId Id
| EInt Integer | EInt Integer
| ELet [Bind] Exp | ELet Bind Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| EAnn Exp Type | ECased Exp [Case]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Case = Case CLit Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
data CLit = CInt Integer | CatchAll
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name [ prtId 0 name
, doc $ showString ";" , doc $ showString ";"
, prt 0 n , prt 0 n
, prtIdPs 0 parms , prtIdPs 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
@ -51,58 +56,51 @@ prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name [ prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
] ]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "(" [ doc $ showString "("
, prt 0 name , prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
, doc $ showString ")" , doc $ showString ")"
] ]
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n] EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet bs e -> prPrec i 3 $ concatD ELet bs e -> prPrec i 3 $ concatD
[ doc $ showString "let" [ doc $ showString "let"
, prt 0 bs , prt 0 bs
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
] ]
EApp t e1 e2 -> prPrec i 2 $ concatD EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, prt 2 e1 , prt 2 e1
, prt 3 e2 , prt 3 e2
] ]
EAdd t e1 e2 -> prPrec i 1 $ concatD EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, prt 1 e1 , prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
] ]
EAbs t n e -> prPrec i 0 $ concatD EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t , prt 0 t
, doc $ showString "\\" , doc $ showString "\\"
, prtIdP 0 n , prtIdP 0 n
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
] ]
EAnn e t -> prPrec i 3 $ concatD
[ doc $ showString "("
, prt 0 e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]