Type inference/checking on ADTs mostly complete(?). Still have to test
This commit is contained in:
parent
2f45f39435
commit
bbf6e159c7
8 changed files with 563 additions and 467 deletions
18
Grammar.cf
18
Grammar.cf
|
|
@ -3,7 +3,7 @@ Program. Program ::= [Def] ;
|
|||
|
||||
DBind. Def ::= Bind ;
|
||||
DData. Def ::= Data ;
|
||||
terminator Def ";" ;
|
||||
separator Def ";" ;
|
||||
|
||||
Bind. Bind ::= Ident ":" Type ";"
|
||||
Ident [Ident] "=" Exp ;
|
||||
|
|
@ -31,16 +31,19 @@ IMatch. Match ::= Ident ;
|
|||
InitMatch. Match ::= Ident Match ;
|
||||
separator Match " " ;
|
||||
|
||||
TMono. Type1 ::= "_" Ident ;
|
||||
TPol. Type1 ::= "'" Ident ;
|
||||
TArr. Type ::= Type1 "->" Type ;
|
||||
TMono. Type1 ::= "_" Ident ;
|
||||
TPol. Type1 ::= "'" Ident ;
|
||||
TConstr. Type1 ::= Ident "(" [Type] ")" ;
|
||||
TArr. Type ::= Type1 "->" Type ;
|
||||
|
||||
separator Type " " ;
|
||||
coercions Type 2 ;
|
||||
|
||||
-- shift/reduce problem here
|
||||
Data. Data ::= "data" Ident [Type] "where" ";"
|
||||
Data. Data ::= "data" Type "where" ";"
|
||||
[Constructor];
|
||||
|
||||
terminator Constructor ";" ;
|
||||
separator Constructor "," ;
|
||||
|
||||
Constructor. Constructor ::= Ident ":" Type ;
|
||||
|
||||
|
|
@ -48,10 +51,9 @@ Constructor. Constructor ::= Ident ":" Type ;
|
|||
-- token Poly upper (letter | digit | '_')* ;
|
||||
-- token Mono lower (letter | digit | '_')* ;
|
||||
|
||||
terminator Bind ";" ;
|
||||
separator Bind ";" ;
|
||||
separator Ident " ";
|
||||
|
||||
coercions Type 1 ;
|
||||
coercions Exp 5 ;
|
||||
|
||||
comment "--" ;
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ executable language
|
|||
TypeChecker.TypeChecker
|
||||
TypeChecker.TypeCheckerIr
|
||||
Renamer.Renamer
|
||||
LambdaLifter.LambdaLifter
|
||||
Codegen.Codegen
|
||||
Codegen.LlvmIr
|
||||
-- LambdaLifter.LambdaLifter
|
||||
-- Codegen.Codegen
|
||||
-- Codegen.LlvmIr
|
||||
|
||||
hs-source-dirs: src
|
||||
|
||||
|
|
|
|||
|
|
@ -1,277 +1,277 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
--{-# LANGUAGE LambdaCase #-}
|
||||
--{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Codegen.Codegen (compile) where
|
||||
module Codegen.Codegen where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
|
||||
LLVMValue (..), Visibility (..),
|
||||
llvmIrToString)
|
||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Tuple.Extra (dupe, first, second)
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.TypeChecker
|
||||
import TypeChecker.TypeCheckerIr
|
||||
--import Auxiliary (snoc)
|
||||
--import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
|
||||
-- LLVMValue (..), Visibility (..),
|
||||
-- llvmIrToString)
|
||||
--import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||
--import Data.Map (Map)
|
||||
--import qualified Data.Map as Map
|
||||
--import Data.Tuple.Extra (dupe, first, second)
|
||||
--import Grammar.ErrM (Err)
|
||||
--import TypeChecker.TypeCheckerIr
|
||||
|
||||
-- | The record used as the code generator state
|
||||
data CodeGenerator = CodeGenerator
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map Id FunctionInfo
|
||||
, variableCount :: Integer
|
||||
}
|
||||
---- | The record used as the code generator state
|
||||
--data CodeGenerator = CodeGenerator
|
||||
-- { instructions :: [LLVMIr]
|
||||
-- , functions :: Map Id FunctionInfo
|
||||
-- , variableCount :: Integer
|
||||
-- }
|
||||
|
||||
-- | A state type synonym
|
||||
type CompilerState a = StateT CodeGenerator Err a
|
||||
---- | A state type synonym
|
||||
--type CompilerState a = StateT CodeGenerator Err a
|
||||
|
||||
data FunctionInfo = FunctionInfo
|
||||
{ numArgs :: Int
|
||||
, arguments :: [Id]
|
||||
}
|
||||
--data FunctionInfo = FunctionInfo
|
||||
-- { numArgs :: Int
|
||||
-- , arguments :: [Id]
|
||||
-- }
|
||||
|
||||
-- | Adds a instruction to the CodeGenerator state
|
||||
emit :: LLVMIr -> CompilerState ()
|
||||
emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
|
||||
---- | Adds a instruction to the CodeGenerator state
|
||||
--emit :: LLVMIr -> CompilerState ()
|
||||
--emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
|
||||
|
||||
-- | Increases the variable counter in the CodeGenerator state
|
||||
increaseVarCount :: CompilerState ()
|
||||
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||
---- | Increases the variable counter in the CodeGenerator state
|
||||
--increaseVarCount :: CompilerState ()
|
||||
--increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||
|
||||
-- | Returns the variable count from the CodeGenerator state
|
||||
getVarCount :: CompilerState Integer
|
||||
getVarCount = gets variableCount
|
||||
---- | Returns the variable count from the CodeGenerator state
|
||||
--getVarCount :: CompilerState Integer
|
||||
--getVarCount = gets variableCount
|
||||
|
||||
-- | Increases the variable count and returns it from the CodeGenerator state
|
||||
getNewVar :: CompilerState Integer
|
||||
getNewVar = increaseVarCount >> getVarCount
|
||||
---- | Increases the variable count and returns it from the CodeGenerator state
|
||||
--getNewVar :: CompilerState Integer
|
||||
--getNewVar = increaseVarCount >> getVarCount
|
||||
|
||||
-- | Produces a map of functions infos from a list of binds,
|
||||
-- which contains useful data for code generation.
|
||||
getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||
getFunctions bs = Map.fromList $ map go bs
|
||||
where
|
||||
go (Bind id args _) =
|
||||
(id, FunctionInfo { numArgs=length args, arguments=args })
|
||||
---- | Produces a map of functions infos from a list of binds,
|
||||
---- which contains useful data for code generation.
|
||||
--getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||
--getFunctions bs = Map.fromList $ map go bs
|
||||
-- where
|
||||
-- go (Bind id args _) =
|
||||
-- (id, FunctionInfo { numArgs=length args, arguments=args })
|
||||
|
||||
|
||||
|
||||
initCodeGenerator :: [Bind] -> CodeGenerator
|
||||
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||
, functions = getFunctions scs
|
||||
, variableCount = 0
|
||||
}
|
||||
--initCodeGenerator :: [Bind] -> CodeGenerator
|
||||
--initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||
-- , functions = getFunctions scs
|
||||
-- , variableCount = 0
|
||||
-- }
|
||||
|
||||
-- | Compiles an AST and produces a LLVM Ir string.
|
||||
-- An easy way to actually "compile" this output is to
|
||||
-- Simply pipe it to lli
|
||||
compile :: Program -> Err String
|
||||
compile (Program scs) = do
|
||||
let codegen = initCodeGenerator scs
|
||||
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||
---- | Compiles an AST and produces a LLVM Ir string.
|
||||
---- An easy way to actually "compile" this output is to
|
||||
---- Simply pipe it to lli
|
||||
--compile :: Program -> Err String
|
||||
--compile (Program scs) = do
|
||||
-- let codegen = initCodeGenerator scs
|
||||
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||
|
||||
compileScs :: [Bind] -> CompilerState ()
|
||||
compileScs [] = pure ()
|
||||
compileScs (Bind (name, t) args exp : xs) = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
emit . Comment $ show name <> ": " <> show exp
|
||||
let args' = map (second type2LlvmType) args
|
||||
emit $ Define (type2LlvmType t_return) name args'
|
||||
functionBody <- exprToValue exp
|
||||
if name == "main"
|
||||
then mapM_ emit $ mainContent functionBody
|
||||
else emit $ Ret I64 functionBody
|
||||
emit DefineEnd
|
||||
modify $ \s -> s { variableCount = 0 }
|
||||
compileScs xs
|
||||
where
|
||||
t_return = snd $ partitionType (length args) t
|
||||
--compileScs :: [Bind] -> CompilerState ()
|
||||
--compileScs [] = pure ()
|
||||
--compileScs (Bind (name, t) args exp : xs) = do
|
||||
-- emit $ UnsafeRaw "\n"
|
||||
-- emit . Comment $ show name <> ": " <> show exp
|
||||
-- let args' = map (second type2LlvmType) args
|
||||
-- emit $ Define (type2LlvmType t_return) name args'
|
||||
-- functionBody <- exprToValue exp
|
||||
-- if name == "main"
|
||||
-- then mapM_ emit $ mainContent functionBody
|
||||
-- else emit $ Ret I64 functionBody
|
||||
-- emit DefineEnd
|
||||
-- modify $ \s -> s { variableCount = 0 }
|
||||
-- compileScs xs
|
||||
-- where
|
||||
-- t_return = snd $ partitionType (length args) t
|
||||
|
||||
mainContent :: LLVMValue -> [LLVMIr]
|
||||
mainContent var =
|
||||
[ UnsafeRaw $
|
||||
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||
-- , Label (Ident "b_1")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "b_2")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "end")
|
||||
Ret I64 (VInteger 0)
|
||||
]
|
||||
--mainContent :: LLVMValue -> [LLVMIr]
|
||||
--mainContent var =
|
||||
-- [ UnsafeRaw $
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||
-- -- , Label (Ident "b_1")
|
||||
-- -- , UnsafeRaw
|
||||
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||
-- -- , Br (Ident "end")
|
||||
-- -- , Label (Ident "b_2")
|
||||
-- -- , UnsafeRaw
|
||||
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||
-- -- , Br (Ident "end")
|
||||
-- -- , Label (Ident "end")
|
||||
-- Ret I64 (VInteger 0)
|
||||
-- ]
|
||||
|
||||
defaultStart :: [LLVMIr]
|
||||
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
]
|
||||
--defaultStart :: [LLVMIr]
|
||||
--defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
-- ]
|
||||
|
||||
compileExp :: Exp -> CompilerState ()
|
||||
compileExp = \case
|
||||
ELit _ (LInt i) -> emitInt i
|
||||
EAdd t e1 e2 -> emitAdd t e1 e2
|
||||
EId (name, _) -> emitIdent name
|
||||
EApp t e1 e2 -> emitApp t e1 e2
|
||||
EAbs t ti e -> emitAbs t ti e
|
||||
ELet bind e -> emitLet bind e
|
||||
--compileExp :: Exp -> CompilerState ()
|
||||
--compileExp = \case
|
||||
-- ELit _ (LInt i) -> emitInt i
|
||||
-- EAdd t e1 e2 -> emitAdd t e1 e2
|
||||
-- EId (name, _) -> emitIdent name
|
||||
-- EApp t e1 e2 -> emitApp t e1 e2
|
||||
-- EAbs t ti e -> emitAbs t ti e
|
||||
-- ELet bind e -> emitLet bind e
|
||||
|
||||
--- aux functions ---
|
||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||
----- aux functions ---
|
||||
--emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
--emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||
|
||||
emitLet :: Bind -> Exp -> CompilerState ()
|
||||
emitLet b e = emit . Comment $ concat [ "ELet ("
|
||||
, show b
|
||||
, " = "
|
||||
, show e
|
||||
, ") is not implemented!"
|
||||
]
|
||||
--emitLet :: Bind -> Exp -> CompilerState ()
|
||||
--emitLet b e = emit . Comment $ concat [ "ELet ("
|
||||
-- , show b
|
||||
-- , " = "
|
||||
-- , show e
|
||||
-- , ") is not implemented!"
|
||||
-- ]
|
||||
|
||||
emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||
emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||
where
|
||||
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||
appEmitter t e1 e2 stack = do
|
||||
let newStack = e2 : stack
|
||||
case e1 of
|
||||
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||
EId id@(name, _) -> do
|
||||
args <- traverse exprToValue newStack
|
||||
vs <- getNewVar
|
||||
funcs <- gets functions
|
||||
let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||
args' = map (first valueGetType . dupe) args
|
||||
call = Call (type2LlvmType t) visibility name args'
|
||||
emit $ SetVariable (Ident $ show vs) call
|
||||
x -> do
|
||||
emit . Comment $ "The unspeakable happened: "
|
||||
emit . Comment $ show x
|
||||
--emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||
--emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||
-- where
|
||||
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||
-- appEmitter t e1 e2 stack = do
|
||||
-- let newStack = e2 : stack
|
||||
-- case e1 of
|
||||
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||
-- EId id@(name, _) -> do
|
||||
-- args <- traverse exprToValue newStack
|
||||
-- vs <- getNewVar
|
||||
-- funcs <- gets functions
|
||||
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||
-- args' = map (first valueGetType . dupe) args
|
||||
-- call = Call (type2LlvmType t) visibility name args'
|
||||
-- emit $ SetVariable (Ident $ show vs) call
|
||||
-- x -> do
|
||||
-- emit . Comment $ "The unspeakable happened: "
|
||||
-- emit . Comment $ show x
|
||||
|
||||
emitIdent :: Ident -> CompilerState ()
|
||||
emitIdent id = do
|
||||
-- !!this should never happen!!
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ Variable id
|
||||
emit $ UnsafeRaw "\n"
|
||||
--emitIdent :: Ident -> CompilerState ()
|
||||
--emitIdent id = do
|
||||
-- -- !!this should never happen!!
|
||||
-- emit $ Comment "This should not have happened!"
|
||||
-- emit $ Variable id
|
||||
-- emit $ UnsafeRaw "\n"
|
||||
|
||||
emitInt :: Integer -> CompilerState ()
|
||||
emitInt i = do
|
||||
-- !!this should never happen!!
|
||||
varCount <- getNewVar
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||
--emitInt :: Integer -> CompilerState ()
|
||||
--emitInt i = do
|
||||
-- -- !!this should never happen!!
|
||||
-- varCount <- getNewVar
|
||||
-- emit $ Comment "This should not have happened!"
|
||||
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||
|
||||
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||
emitAdd t e1 e2 = do
|
||||
v1 <- exprToValue e1
|
||||
v2 <- exprToValue e2
|
||||
v <- getNewVar
|
||||
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
--emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||
--emitAdd t e1 e2 = do
|
||||
-- v1 <- exprToValue e1
|
||||
-- v2 <- exprToValue e2
|
||||
-- v <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
|
||||
-- emitMul :: Exp -> Exp -> CompilerState ()
|
||||
-- emitMul e1 e2 = do
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Mul I64 v1 v2
|
||||
---- emitMul :: Exp -> Exp -> CompilerState ()
|
||||
---- emitMul e1 e2 = do
|
||||
---- (v1,v2) <- binExprToValues e1 e2
|
||||
---- increaseVarCount
|
||||
---- v <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show v
|
||||
---- emit $ Mul I64 v1 v2
|
||||
|
||||
-- emitMod :: Exp -> Exp -> CompilerState ()
|
||||
-- emitMod e1 e2 = do
|
||||
-- -- `let m a b = rem (abs $ b + a) b`
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- vadd <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show vadd
|
||||
-- emit $ Add I64 v1 v2
|
||||
--
|
||||
-- increaseVarCount
|
||||
-- vabs <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show vabs
|
||||
-- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||
-- [ (I64, VIdent (Ident $ show vadd))
|
||||
-- , (I1, VInteger 1)
|
||||
-- ]
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||
---- emitMod :: Exp -> Exp -> CompilerState ()
|
||||
---- emitMod e1 e2 = do
|
||||
---- -- `let m a b = rem (abs $ b + a) b`
|
||||
---- (v1,v2) <- binExprToValues e1 e2
|
||||
---- increaseVarCount
|
||||
---- vadd <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show vadd
|
||||
---- emit $ Add I64 v1 v2
|
||||
----
|
||||
---- increaseVarCount
|
||||
---- vabs <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show vabs
|
||||
---- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||
---- [ (I64, VIdent (Ident $ show vadd))
|
||||
---- , (I1, VInteger 1)
|
||||
---- ]
|
||||
---- increaseVarCount
|
||||
---- v <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show v
|
||||
---- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||
|
||||
-- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||
-- emitDiv e1 e2 = do
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Div I64 v1 v2
|
||||
---- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||
---- emitDiv e1 e2 = do
|
||||
---- (v1,v2) <- binExprToValues e1 e2
|
||||
---- increaseVarCount
|
||||
---- v <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show v
|
||||
---- emit $ Div I64 v1 v2
|
||||
|
||||
-- emitSub :: Exp -> Exp -> CompilerState ()
|
||||
-- emitSub e1 e2 = do
|
||||
-- (v1,v2) <- binExprToValues e1 e2
|
||||
-- increaseVarCount
|
||||
-- v <- gets variableCount
|
||||
-- emit $ SetVariable $ Ident $ show v
|
||||
-- emit $ Sub I64 v1 v2
|
||||
---- emitSub :: Exp -> Exp -> CompilerState ()
|
||||
---- emitSub e1 e2 = do
|
||||
---- (v1,v2) <- binExprToValues e1 e2
|
||||
---- increaseVarCount
|
||||
---- v <- gets variableCount
|
||||
---- emit $ SetVariable $ Ident $ show v
|
||||
---- emit $ Sub I64 v1 v2
|
||||
|
||||
exprToValue :: Exp -> CompilerState LLVMValue
|
||||
exprToValue = \case
|
||||
ELit _ (LInt i) -> pure $ VInteger i
|
||||
--exprToValue :: Exp -> CompilerState LLVMValue
|
||||
--exprToValue = \case
|
||||
-- ELit _ (LInt i) -> pure $ VInteger i
|
||||
|
||||
EId id@(name, t) -> do
|
||||
funcs <- gets functions
|
||||
case Map.lookup id funcs of
|
||||
Just fi -> do
|
||||
if numArgs fi == 0
|
||||
then do
|
||||
vc <- getNewVar
|
||||
emit $ SetVariable (Ident $ show vc)
|
||||
(Call (type2LlvmType t) Global name [])
|
||||
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||
else pure $ VFunction name Global (type2LlvmType t)
|
||||
Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||
-- EId id@(name, t) -> do
|
||||
-- funcs <- gets functions
|
||||
-- case Map.lookup id funcs of
|
||||
-- Just fi -> do
|
||||
-- if numArgs fi == 0
|
||||
-- then do
|
||||
-- vc <- getNewVar
|
||||
-- emit $ SetVariable (Ident $ show vc)
|
||||
-- (Call (type2LlvmType t) Global name [])
|
||||
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||
-- else pure $ VFunction name Global (type2LlvmType t)
|
||||
-- Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||
|
||||
e -> do
|
||||
compileExp e
|
||||
v <- getVarCount
|
||||
pure $ VIdent (Ident $ show v) (getType e)
|
||||
-- e -> do
|
||||
-- compileExp e
|
||||
-- v <- getVarCount
|
||||
-- pure $ VIdent (Ident $ show v) (getType e)
|
||||
|
||||
type2LlvmType :: Type -> LLVMType
|
||||
type2LlvmType = \case
|
||||
(TMono "Int") -> I64
|
||||
TArr t xs -> do
|
||||
let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||
Function t' xs'
|
||||
t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"")
|
||||
where
|
||||
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
--type2LlvmType :: Type -> LLVMType
|
||||
--type2LlvmType = \case
|
||||
-- (TMono "Int") -> I64
|
||||
-- TArr t xs -> do
|
||||
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||
-- Function t' xs'
|
||||
-- -- This part will not work as we don't have a monomorphization step yet
|
||||
-- t -> CustomType $ Ident ("\"" ++ show t ++ "\"")
|
||||
-- where
|
||||
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
-- function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
-- function2LLVMType x s = (type2LlvmType x, s)
|
||||
|
||||
getType :: Exp -> LLVMType
|
||||
getType (ELit _ (LInt _)) = I64
|
||||
getType (EAdd t _ _) = type2LlvmType t
|
||||
getType (EId (_, t)) = type2LlvmType t
|
||||
getType (EApp t _ _) = type2LlvmType t
|
||||
getType (EAbs t _ _) = type2LlvmType t
|
||||
getType (ELet _ e) = getType e
|
||||
--getType :: Exp -> LLVMType
|
||||
--getType (ELit _ (LInt _)) = I64
|
||||
--getType (EAdd t _ _) = type2LlvmType t
|
||||
--getType (EId (_, t)) = type2LlvmType t
|
||||
--getType (EApp t _ _) = type2LlvmType t
|
||||
--getType (EAbs t _ _) = type2LlvmType t
|
||||
--getType (ELet _ e) = getType e
|
||||
|
||||
valueGetType :: LLVMValue -> LLVMType
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (length s) I8
|
||||
valueGetType (VFunction _ _ t) = t
|
||||
--valueGetType :: LLVMValue -> LLVMType
|
||||
--valueGetType (VInteger _) = I64
|
||||
--valueGetType (VIdent _ t) = t
|
||||
--valueGetType (VConstant s) = Array (length s) I8
|
||||
--valueGetType (VFunction _ _ t) = t
|
||||
|
||||
-- | Partion type into types of parameters and return type.
|
||||
partitionType :: Int -- Number of parameters to apply
|
||||
-> Type
|
||||
-> ([Type], Type)
|
||||
partitionType = go []
|
||||
where
|
||||
go acc 0 t = (acc, t)
|
||||
go acc i t = case t of
|
||||
TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
||||
_ -> error "Number of parameters and type doesn't match"
|
||||
---- | Partion type into types of parameters and return type.
|
||||
--partitionType :: Int -- Number of parameters to apply
|
||||
-- -> Type
|
||||
-- -> ([Type], Type)
|
||||
--partitionType = go []
|
||||
-- where
|
||||
-- go acc 0 t = (acc, t)
|
||||
-- go acc i t = case t of
|
||||
-- TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
||||
-- _ -> error "Number of parameters and type doesn't match"
|
||||
|
|
|
|||
|
|
@ -1,192 +1,192 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
--{-# LANGUAGE LambdaCase #-}
|
||||
--{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
|
||||
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
||||
module LambdaLifter.LambdaLifter where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad.State (MonadState (get, put), State,
|
||||
evalState)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Prelude hiding (exp)
|
||||
import Renamer.Renamer
|
||||
import TypeChecker.TypeCheckerIr
|
||||
--import Auxiliary (snoc)
|
||||
--import Control.Applicative (Applicative (liftA2))
|
||||
--import Control.Monad.State (MonadState (get, put), State,
|
||||
-- evalState)
|
||||
--import Data.Set (Set)
|
||||
--import qualified Data.Set as Set
|
||||
--import Prelude hiding (exp)
|
||||
--import Renamer.Renamer
|
||||
--import TypeChecker.TypeCheckerIr
|
||||
|
||||
|
||||
-- | Lift lambdas and let expression into supercombinators.
|
||||
-- Three phases:
|
||||
-- @freeVars@ annotatss all the free variables.
|
||||
-- @abstract@ converts lambdas into let expressions.
|
||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
lambdaLift :: Program -> Program
|
||||
lambdaLift = collectScs . abstract . freeVars
|
||||
---- | Lift lambdas and let expression into supercombinators.
|
||||
---- Three phases:
|
||||
---- @freeVars@ annotatss all the free variables.
|
||||
---- @abstract@ converts lambdas into let expressions.
|
||||
---- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
--lambdaLift :: Program -> Program
|
||||
--lambdaLift = collectScs . abstract . freeVars
|
||||
|
||||
|
||||
-- | Annotate free variables
|
||||
freeVars :: Program -> AnnProgram
|
||||
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||
| Bind n xs e <- ds
|
||||
]
|
||||
---- | Annotate free variables
|
||||
--freeVars :: Program -> AnnProgram
|
||||
--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||
-- | Bind n xs e <- ds
|
||||
-- ]
|
||||
|
||||
freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||
freeVarsExp localVars = \case
|
||||
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||
| otherwise -> (mempty, AId n)
|
||||
--freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||
--freeVarsExp localVars = \case
|
||||
-- EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||
-- | otherwise -> (mempty, AId n)
|
||||
|
||||
ELit _ (LInt i) -> (mempty, AInt i)
|
||||
-- ELit _ (LInt i) -> (mempty, AInt i)
|
||||
|
||||
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||
-- where
|
||||
-- e1' = freeVarsExp localVars e1
|
||||
-- e2' = freeVarsExp localVars e2
|
||||
|
||||
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||
-- where
|
||||
-- e1' = freeVarsExp localVars e1
|
||||
-- e2' = freeVarsExp localVars e2
|
||||
|
||||
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||
where
|
||||
e' = freeVarsExp (Set.insert par localVars) e
|
||||
-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||
-- where
|
||||
-- e' = freeVarsExp (Set.insert par localVars) e
|
||||
|
||||
-- Sum free variables present in bind and the expression
|
||||
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||
where
|
||||
binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||
e_free = Set.delete name $ freeVarsOf e'
|
||||
-- -- Sum free variables present in bind and the expression
|
||||
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||
-- where
|
||||
-- binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||
-- e_free = Set.delete name $ freeVarsOf e'
|
||||
|
||||
rhs' = freeVarsExp e_localVars rhs
|
||||
new_bind = ABind name parms rhs'
|
||||
-- rhs' = freeVarsExp e_localVars rhs
|
||||
-- new_bind = ABind name parms rhs'
|
||||
|
||||
e' = freeVarsExp e_localVars e
|
||||
e_localVars = Set.insert name localVars
|
||||
-- e' = freeVarsExp e_localVars e
|
||||
-- e_localVars = Set.insert name localVars
|
||||
|
||||
|
||||
freeVarsOf :: AnnExp -> Set Id
|
||||
freeVarsOf = fst
|
||||
--freeVarsOf :: AnnExp -> Set Id
|
||||
--freeVarsOf = fst
|
||||
|
||||
-- AST annotated with free variables
|
||||
type AnnProgram = [(Id, [Id], AnnExp)]
|
||||
---- AST annotated with free variables
|
||||
--type AnnProgram = [(Id, [Id], AnnExp)]
|
||||
|
||||
type AnnExp = (Set Id, AnnExp')
|
||||
--type AnnExp = (Set Id, AnnExp')
|
||||
|
||||
data ABind = ABind Id [Id] AnnExp deriving Show
|
||||
--data ABind = ABind Id [Id] AnnExp deriving Show
|
||||
|
||||
data AnnExp' = AId Id
|
||||
| AInt Integer
|
||||
| ALet ABind AnnExp
|
||||
| AApp Type AnnExp AnnExp
|
||||
| AAdd Type AnnExp AnnExp
|
||||
| AAbs Type Id AnnExp
|
||||
deriving Show
|
||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||
abstract :: AnnProgram -> Program
|
||||
abstract prog = Program $ evalState (mapM go prog) 0
|
||||
where
|
||||
go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||
where
|
||||
(rhs', parms1) = flattenLambdasAnn rhs
|
||||
--data AnnExp' = AId Id
|
||||
-- | AInt Integer
|
||||
-- | ALet ABind AnnExp
|
||||
-- | AApp Type AnnExp AnnExp
|
||||
-- | AAdd Type AnnExp AnnExp
|
||||
-- | AAbs Type Id AnnExp
|
||||
-- deriving Show
|
||||
---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
---- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||
--abstract :: AnnProgram -> Program
|
||||
--abstract prog = Program $ evalState (mapM go prog) 0
|
||||
-- where
|
||||
-- go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||
-- where
|
||||
-- (rhs', parms1) = flattenLambdasAnn rhs
|
||||
|
||||
|
||||
-- | Flatten nested lambdas and collect the parameters
|
||||
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
||||
flattenLambdasAnn ae = go (ae, [])
|
||||
where
|
||||
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||
go ((free, e), acc) =
|
||||
case e of
|
||||
AAbs _ par (free1, e1) ->
|
||||
go ((Set.delete par free1, e1), snoc par acc)
|
||||
_ -> ((free, e), acc)
|
||||
---- | Flatten nested lambdas and collect the parameters
|
||||
---- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||
--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
||||
--flattenLambdasAnn ae = go (ae, [])
|
||||
-- where
|
||||
-- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||
-- go ((free, e), acc) =
|
||||
-- case e of
|
||||
-- AAbs _ par (free1, e1) ->
|
||||
-- go ((Set.delete par free1, e1), snoc par acc)
|
||||
-- _ -> ((free, e), acc)
|
||||
|
||||
abstractExp :: AnnExp -> State Int Exp
|
||||
abstractExp (free, exp) = case exp of
|
||||
AId n -> pure $ EId n
|
||||
AInt i -> pure $ ELit (TMono "Int") (LInt i)
|
||||
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||
ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||
where
|
||||
go (ABind name parms rhs) = do
|
||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||
pure $ Bind name (parms ++ parms1) rhs'
|
||||
--abstractExp :: AnnExp -> State Int Exp
|
||||
--abstractExp (free, exp) = case exp of
|
||||
-- AId n -> pure $ EId n
|
||||
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
|
||||
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||
-- where
|
||||
-- go (ABind name parms rhs) = do
|
||||
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||
-- pure $ Bind name (parms ++ parms1) rhs'
|
||||
|
||||
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||
skipLambdas f (free, ae) = case ae of
|
||||
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||
_ -> f (free, ae)
|
||||
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||
-- skipLambdas f (free, ae) = case ae of
|
||||
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||
-- _ -> f (free, ae)
|
||||
|
||||
-- Lift lambda into let and bind free variables
|
||||
AAbs t parm e -> do
|
||||
i <- nextNumber
|
||||
rhs <- abstractExp e
|
||||
-- -- Lift lambda into let and bind free variables
|
||||
-- AAbs t parm e -> do
|
||||
-- i <- nextNumber
|
||||
-- rhs <- abstractExp e
|
||||
|
||||
let sc_name = Ident ("sc_" ++ show i)
|
||||
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||
-- let sc_name = Ident ("sc_" ++ show i)
|
||||
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||
|
||||
pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
|
||||
where
|
||||
freeList = Set.toList free
|
||||
parms = snoc parm freeList
|
||||
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
|
||||
-- where
|
||||
-- freeList = Set.toList free
|
||||
-- parms = snoc parm freeList
|
||||
|
||||
|
||||
nextNumber :: State Int Int
|
||||
nextNumber = do
|
||||
i <- get
|
||||
put $ succ i
|
||||
pure i
|
||||
--nextNumber :: State Int Int
|
||||
--nextNumber = do
|
||||
-- i <- get
|
||||
-- put $ succ i
|
||||
-- pure i
|
||||
|
||||
-- | Collects supercombinators by lifting non-constant let expressions
|
||||
collectScs :: Program -> Program
|
||||
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||
where
|
||||
collectFromRhs (Bind name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in Bind name parms rhs' : rhs_scs
|
||||
---- | Collects supercombinators by lifting non-constant let expressions
|
||||
--collectScs :: Program -> Program
|
||||
--collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||
-- where
|
||||
-- collectFromRhs (Bind name parms rhs) =
|
||||
-- let (rhs_scs, rhs') = collectScsExp rhs
|
||||
-- in Bind name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
collectScsExp :: Exp -> ([Bind], Exp)
|
||||
collectScsExp = \case
|
||||
EId n -> ([], EId n)
|
||||
ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
|
||||
--collectScsExp :: Exp -> ([Bind], Exp)
|
||||
--collectScsExp = \case
|
||||
-- EId n -> ([], EId n)
|
||||
-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
|
||||
|
||||
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||
-- where
|
||||
-- (scs1, e1') = collectScsExp e1
|
||||
-- (scs2, e2') = collectScsExp e2
|
||||
|
||||
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||
-- where
|
||||
-- (scs1, e1') = collectScsExp e1
|
||||
-- (scs2, e2') = collectScsExp e2
|
||||
|
||||
EAbs t par e -> (scs, EAbs t par e')
|
||||
where
|
||||
(scs, e') = collectScsExp e
|
||||
-- EAbs t par e -> (scs, EAbs t par e')
|
||||
-- where
|
||||
-- (scs, e') = collectScsExp e
|
||||
|
||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
--
|
||||
ELet (Bind name parms rhs) e -> if null parms
|
||||
then ( rhs_scs ++ e_scs, ELet bind e')
|
||||
else (bind : rhs_scs ++ e_scs, e')
|
||||
where
|
||||
bind = Bind name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(e_scs, e') = collectScsExp e
|
||||
-- -- Collect supercombinators from bind, the rhss, and the expression.
|
||||
-- --
|
||||
-- -- > f = let sc x y = rhs in e
|
||||
-- --
|
||||
-- ELet (Bind name parms rhs) e -> if null parms
|
||||
-- then ( rhs_scs ++ e_scs, ELet bind e')
|
||||
-- else (bind : rhs_scs ++ e_scs, e')
|
||||
-- where
|
||||
-- bind = Bind name parms rhs'
|
||||
-- (rhs_scs, rhs') = collectScsExp rhs
|
||||
-- (e_scs, e') = collectScsExp e
|
||||
|
||||
|
||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
flattenLambdas :: Exp -> (Exp, [Id])
|
||||
flattenLambdas = go . (, [])
|
||||
where
|
||||
go (e, acc) = case e of
|
||||
EAbs _ par e1 -> go (e1, snoc par acc)
|
||||
_ -> (e, acc)
|
||||
---- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
--flattenLambdas :: Exp -> (Exp, [Id])
|
||||
--flattenLambdas = go . (, [])
|
||||
-- where
|
||||
-- go (e, acc) = case e of
|
||||
-- EAbs _ par e1 -> go (e1, snoc par acc)
|
||||
-- _ -> (e, acc)
|
||||
|
||||
|
|
|
|||
36
src/Main.hs
36
src/Main.hs
|
|
@ -2,18 +2,18 @@
|
|||
|
||||
module Main where
|
||||
|
||||
import Codegen.Codegen (compile)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
-- import Codegen.Codegen (compile)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
|
||||
import LambdaLifter.LambdaLifter (lambdaLift)
|
||||
import Renamer.Renamer (rename)
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (exitFailure, exitSuccess)
|
||||
import System.IO (stderr)
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
-- import LambdaLifter.LambdaLifter (lambdaLift)
|
||||
import Renamer.Renamer (rename)
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (exitFailure, exitSuccess)
|
||||
import System.IO (stderr)
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
|
||||
main :: IO ()
|
||||
main =
|
||||
|
|
@ -37,14 +37,14 @@ main' s = do
|
|||
typechecked <- fromTypeCheckerErr $ typecheck renamed
|
||||
printToErr $ printTree typechecked
|
||||
|
||||
printToErr "\n-- Lambda Lifter --"
|
||||
let lifted = lambdaLift typechecked
|
||||
printToErr $ printTree lifted
|
||||
-- printToErr "\n-- Lambda Lifter --"
|
||||
-- let lifted = lambdaLift typechecked
|
||||
-- printToErr $ printTree lifted
|
||||
|
||||
printToErr "\n -- Printing compiler output to stdout --"
|
||||
compiled <- fromCompilerErr $ compile lifted
|
||||
putStrLn compiled
|
||||
writeFile "llvm.ll" compiled
|
||||
-- printToErr "\n -- Printing compiler output to stdout --"
|
||||
-- compiled <- fromCompilerErr $ compile lifted
|
||||
-- putStrLn compiled
|
||||
-- writeFile "llvm.ll" compiled
|
||||
|
||||
exitSuccess
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# HLINT ignore "Use traverse_" #-}
|
||||
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
|
||||
{-# HLINT ignore "Use zipWithM" #-}
|
||||
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ import qualified Data.Map as M
|
|||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
|
||||
import Data.Foldable (traverse_)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
|
@ -24,10 +26,12 @@ import qualified TypeChecker.TypeCheckerIr as T
|
|||
data Poly = Forall [Ident] Type
|
||||
deriving Show
|
||||
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly }
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly
|
||||
}
|
||||
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
, dtypes :: Map Ident Type
|
||||
}
|
||||
|
||||
type Error = String
|
||||
|
|
@ -36,7 +40,7 @@ type Subst = Map Ident Type
|
|||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 mempty
|
||||
initEnv = Env 0 mempty mempty
|
||||
|
||||
runPretty :: Exp -> Either Error String
|
||||
runPretty = fmap (printTree . fst). run . inferExp
|
||||
|
|
@ -50,21 +54,44 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
|
|||
typecheck :: Program -> Either Error T.Program
|
||||
typecheck = run . checkPrg
|
||||
|
||||
checkData :: Data -> Infer ()
|
||||
checkData d = case d of
|
||||
(Data typ@(TConstr name _) constrs) -> do
|
||||
traverse_ (\(Constructor name' t')
|
||||
-> if typ == retType t'
|
||||
then insertConstr name' t' else
|
||||
throwError $
|
||||
unwords
|
||||
[ "return type of constructor:"
|
||||
, printTree name
|
||||
, "with type:"
|
||||
, printTree (retType t')
|
||||
, "does not match data: "
|
||||
, printTree typ]) constrs
|
||||
_ -> throwError "Data type incorrectly declared"
|
||||
where
|
||||
retType :: Type -> Type
|
||||
retType (TArr _ t2) = retType t2
|
||||
retType a = a
|
||||
|
||||
checkPrg :: Program -> Infer T.Program
|
||||
checkPrg (Program bs) = do
|
||||
let bs' = getBinds bs
|
||||
traverse (\(Bind n t _ _ _) -> insertSig n t) bs'
|
||||
bs' <- mapM checkBind bs'
|
||||
return $ T.Program bs'
|
||||
preRun bs
|
||||
T.Program <$> checkDef bs
|
||||
where
|
||||
getBinds :: [Def] -> [Bind]
|
||||
getBinds = map toBind . filter isBind
|
||||
isBind :: Def -> Bool
|
||||
isBind (DBind _) = True
|
||||
isBind _ = True
|
||||
toBind :: Def -> Bind
|
||||
toBind (DBind bind) = bind
|
||||
toBind _ = error "Can't convert DData to Bind"
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x:xs) = case x of
|
||||
DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs
|
||||
DData d@(Data _ _) -> checkData d >> preRun xs
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def]
|
||||
checkDef [] = return []
|
||||
checkDef (x:xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
fmap (T.DBind b' :) (checkDef xs)
|
||||
(DData d) -> fmap (T.DData d :) (checkDef xs)
|
||||
|
||||
checkBind :: Bind -> Infer T.Bind
|
||||
checkBind (Bind n t _ args e) = do
|
||||
|
|
@ -77,15 +104,18 @@ checkBind (Bind n t _ args e) = do
|
|||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
makeLambda = foldl (flip EAbs)
|
||||
|
||||
-- | Check if two types are considered equal
|
||||
-- For the purpose of the algorithm two polymorphic types are always considered equal
|
||||
typeEq :: Type -> Type -> Bool
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TConstr name a) (TConstr name' b) = name == name' && and (zipWith typeEq a b)
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
inferExp :: Exp -> Infer (Type, T.Exp)
|
||||
inferExp e = do
|
||||
(s, t, e') <- w e
|
||||
(s, t, e') <- algoW e
|
||||
let subbed = apply s t
|
||||
return (subbed, replace subbed e')
|
||||
|
||||
|
|
@ -98,19 +128,26 @@ replace t = \case
|
|||
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
|
||||
|
||||
w :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
w = \case
|
||||
algoW :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
algoW = \case
|
||||
|
||||
EAnn e t -> do
|
||||
(s1, t', e') <- w e
|
||||
(s1, t', e') <- algoW e
|
||||
applySt s1 $ do
|
||||
s2 <- unify (apply s1 t) t'
|
||||
return (s2 `compose` s1, t, e')
|
||||
|
||||
-- | ------------------
|
||||
-- | Γ ⊢ e₀ : Int, ∅
|
||||
|
||||
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||
|
||||
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
|
||||
|
||||
-- | x : σ ∈ Γ τ = inst(σ)
|
||||
-- | ----------------------
|
||||
-- | Γ ⊢ x : τ, ∅
|
||||
|
||||
EId i -> do
|
||||
var <- asks vars
|
||||
case M.lookup i var of
|
||||
|
|
@ -118,42 +155,67 @@ w = \case
|
|||
Nothing -> do
|
||||
sig <- gets sigs
|
||||
case M.lookup i sig of
|
||||
Nothing -> throwError $ "Unbound variable: " ++ show i
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> do
|
||||
constr <- gets dtypes
|
||||
case M.lookup i constr of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> throwError $ "Unbound variable: " ++ show i
|
||||
|
||||
-- | τ = newvar Γ, x : τ ⊢ e : τ', S
|
||||
-- | ---------------------------------
|
||||
-- | Γ ⊢ w λx. e : Sτ → τ', S
|
||||
|
||||
EAbs name e -> do
|
||||
fr <- fresh
|
||||
withBinding name (Forall [] fr) $ do
|
||||
(s1, t', e') <- w e
|
||||
(s1, t', e') <- algoW e
|
||||
let varType = apply s1 fr
|
||||
let newArr = TArr varType t'
|
||||
return (s1, newArr, T.EAbs newArr (name, varType) e')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||
-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||
-- | ------------------------------------------
|
||||
-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
|
||||
-- This might be wrong
|
||||
|
||||
EAdd e0 e1 -> do
|
||||
(s1, t0, e0') <- w e0
|
||||
(s1, t0, e0') <- algoW e0
|
||||
applySt s1 $ do
|
||||
(s2, t1, e1') <- w e1
|
||||
applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
|
||||
(s2, t1, e1') <- algoW e1
|
||||
-- applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
|
||||
-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
|
||||
-- | --------------------------------------
|
||||
-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
|
||||
|
||||
EApp e0 e1 -> do
|
||||
fr <- fresh
|
||||
(s0, t0, e0') <- w e0
|
||||
(s0, t0, e0') <- algoW e0
|
||||
applySt s0 $ do
|
||||
(s1, t1, e1') <- w e1
|
||||
(s1, t1, e1') <- algoW e1
|
||||
-- applySt s1 $ do
|
||||
s2 <- unify (apply s1 t0) (TArr t1 fr)
|
||||
let t = apply s2 fr
|
||||
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
|
||||
-- | ----------------------------------------------
|
||||
-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
|
||||
|
||||
-- The bar over S₀ and Γ means "generalize"
|
||||
|
||||
ELet name e0 e1 -> do
|
||||
(s1, t1, e0') <- w e0
|
||||
(s1, t1, e0') <- algoW e0
|
||||
env <- asks vars
|
||||
let t' = generalize (apply s1 env) t1
|
||||
withBinding name t' $ do
|
||||
(s2, t2, e1') <- w e1
|
||||
(s2, t2, e1') <- algoW e1
|
||||
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
|
||||
|
||||
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
|
||||
|
|
@ -168,6 +230,12 @@ unify t0 t1 = case (t0, t1) of
|
|||
(TPol a, b) -> occurs a b
|
||||
(a, TPol b) -> occurs b a
|
||||
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
|
||||
-- | TODO: Figure out a cleaner way to express the same thing
|
||||
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
|
||||
then do
|
||||
xs <- sequence $ zipWith unify t t'
|
||||
return $ foldr compose nullSubst xs
|
||||
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
|
||||
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
|
||||
|
||||
-- | Check if a type is contained in another type.
|
||||
|
|
@ -202,9 +270,11 @@ class FreeVars t where
|
|||
|
||||
instance FreeVars Type where
|
||||
free :: Type -> Set Ident
|
||||
free (TPol a) = S.singleton a
|
||||
free (TMono _) = mempty
|
||||
free (TArr a b) = free a `S.union` free b
|
||||
free (TPol a) = S.singleton a
|
||||
free (TMono _) = mempty
|
||||
free (TArr a b) = free a `S.union` free b
|
||||
-- | Not guaranteed to be correct
|
||||
free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a
|
||||
apply :: Subst -> Type -> Type
|
||||
apply sub t = do
|
||||
case t of
|
||||
|
|
@ -213,6 +283,7 @@ instance FreeVars Type where
|
|||
Nothing -> TPol a
|
||||
Just t -> t
|
||||
TArr a b -> TArr (apply sub a) (apply sub b)
|
||||
TConstr name a -> TConstr name (map (apply sub) a)
|
||||
|
||||
instance FreeVars Poly where
|
||||
free :: Poly -> Set Ident
|
||||
|
|
@ -248,3 +319,7 @@ withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
|
|||
-- | Insert a function signature into the environment
|
||||
insertSig :: Ident -> Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
|
||||
|
||||
-- | Insert a constructor with its data type
|
||||
insertConstr :: Ident -> Type -> Infer ()
|
||||
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ module TypeChecker.TypeCheckerIr
|
|||
, module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
|
||||
import Grammar.Abs (Ident (..), Literal (..), Type (..))
|
||||
import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Exp
|
||||
|
|
@ -22,11 +22,18 @@ data Exp
|
|||
| EAbs Type Id Exp
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Def = DBind Bind | DData Data
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
type Id = (Ident, Type)
|
||||
|
||||
data Bind = Bind Id [Id] Exp
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
instance Print Def where
|
||||
prt i (DBind bind) = prt i bind
|
||||
prt i (DData d) = prt i d
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
|
|
@ -75,7 +82,7 @@ instance Print Exp where
|
|||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
]
|
||||
EApp t e1 e2 -> prPrec i 2 $ concatD
|
||||
EApp _ e1 e2 -> prPrec i 2 $ concatD
|
||||
[ prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
|
|
|
|||
16
test_program
16
test_program
|
|
@ -1,2 +1,14 @@
|
|||
main : _Int ;
|
||||
main = 3 + 3 ;
|
||||
data List ('a) where;
|
||||
Nil : List ('a),
|
||||
Cons : 'a -> List ('a) -> List ('a) ;
|
||||
|
||||
main : List (_Int) ;
|
||||
main = Cons 1 (Cons 0 Nil) ;
|
||||
|
||||
data Bool () where;
|
||||
True : Bool (),
|
||||
False : Bool ();
|
||||
|
||||
boolean : Bool (_Int);
|
||||
boolean = True ;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue