Started importing Sebastian's new typechecker.
This commit is contained in:
parent
d5dd7896d8
commit
350cd3b0e9
9 changed files with 1611 additions and 1346 deletions
68
Grammar.cf
68
Grammar.cf
|
|
@ -1,33 +1,51 @@
|
||||||
Program. Program ::= [Bind];
|
|
||||||
|
|
||||||
EId. Exp3 ::= Ident;
|
Program. Program ::= [Def] ;
|
||||||
EInt. Exp3 ::= Integer;
|
|
||||||
EAnn. Exp3 ::= "(" Exp ":" Type ")";
|
|
||||||
ELet. Exp3 ::= "let" Bind "in" Exp;
|
|
||||||
EApp. Exp2 ::= Exp2 Exp3;
|
|
||||||
EAdd. Exp1 ::= Exp1 "+" Exp2;
|
|
||||||
ESub. Exp1 ::= Exp1 "-" Exp2;
|
|
||||||
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
|
|
||||||
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}" ":" Type;
|
|
||||||
CaseMatch. CaseMatch ::= Case "=>" Exp ;
|
|
||||||
separator CaseMatch ",";
|
|
||||||
|
|
||||||
|
DBind. Def ::= Bind ;
|
||||||
CInt. Case ::= Integer ;
|
DData. Def ::= Data ;
|
||||||
CatchAll. Case ::= "_" ;
|
separator Def ";" ;
|
||||||
|
|
||||||
Bind. Bind ::= Ident ":" Type ";"
|
Bind. Bind ::= Ident ":" Type ";"
|
||||||
Ident [Ident] "=" Exp;
|
Ident [Ident] "=" Exp ;
|
||||||
|
|
||||||
separator Bind ";";
|
Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
|
||||||
separator Ident "";
|
|
||||||
|
|
||||||
coercions Exp 3;
|
Constructor. Constructor ::= Ident ":" Type ;
|
||||||
|
separator nonempty Constructor "" ;
|
||||||
|
|
||||||
TInt. Type1 ::= "Int" ;
|
TMono. Type1 ::= "_" Ident ;
|
||||||
TPol. Type1 ::= Ident ;
|
TPol. Type1 ::= "'" Ident ;
|
||||||
TFun. Type ::= Type1 "->" Type ;
|
TConstr. Type1 ::= Constr ;
|
||||||
coercions Type 1 ;
|
TArr. Type ::= Type1 "->" Type ;
|
||||||
|
|
||||||
comment "--";
|
Constr. Constr ::= Ident "(" [Type] ")" ;
|
||||||
comment "{-" "-}";
|
|
||||||
|
-- TODO: Move literal to its own thing since it's reused in Init as well.
|
||||||
|
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
|
||||||
|
EId. Exp4 ::= Ident ;
|
||||||
|
ELit. Exp4 ::= Literal ;
|
||||||
|
EApp. Exp3 ::= Exp3 Exp4 ;
|
||||||
|
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
|
||||||
|
ESub. Exp1 ::= Exp1 "-" Exp2 ;
|
||||||
|
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
|
||||||
|
EAbs. Exp ::= "\\" Ident "." Exp ;
|
||||||
|
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
|
||||||
|
|
||||||
|
LInt. Literal ::= Integer ;
|
||||||
|
|
||||||
|
Inj. Inj ::= Init "=>" Exp ;
|
||||||
|
separator nonempty Inj ";" ;
|
||||||
|
|
||||||
|
InitLit. Init ::= Literal ;
|
||||||
|
InitConstr. Init ::= Ident [Ident] ;
|
||||||
|
InitCatch. Init ::= "_" ;
|
||||||
|
|
||||||
|
separator Type " " ;
|
||||||
|
coercions Type 2 ;
|
||||||
|
|
||||||
|
separator Ident " ";
|
||||||
|
|
||||||
|
coercions Exp 5 ;
|
||||||
|
|
||||||
|
comment "--" ;
|
||||||
|
comment "{-" "-}" ;
|
||||||
|
|
|
||||||
|
|
@ -1,87 +1,26 @@
|
||||||
|
posMul : _Int -> _Int -> _Int;
|
||||||
-- tripplemagic : Int -> Int -> Int -> Int;
|
|
||||||
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
|
|
||||||
-- main : Int;
|
|
||||||
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
|
|
||||||
-- answer: 22
|
|
||||||
|
|
||||||
-- apply : (Int -> Int) -> Int -> Int;
|
|
||||||
-- apply f x = f x;
|
|
||||||
-- main : Int;
|
|
||||||
-- main = apply (\x : Int . x + 5) 5
|
|
||||||
-- answer: 10
|
|
||||||
|
|
||||||
-- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
|
|
||||||
-- apply f x y = f x y;
|
|
||||||
-- krimp: Int -> Int -> Int;
|
|
||||||
-- krimp x y = x + y;
|
|
||||||
-- main : Int;
|
|
||||||
-- main = apply (krimp) 2 3;
|
|
||||||
-- answer: 5
|
|
||||||
|
|
||||||
-- fibbonaci : Int -> Int;
|
|
||||||
-- fibbonaci x = case x of {
|
|
||||||
-- 0 => 0,
|
|
||||||
-- 1 => 1,
|
|
||||||
-- -- abusing overflows to represent negatives like a boss
|
|
||||||
-- _ => (fibbonaci (x - 2))
|
|
||||||
-- + (fibbonaci (x - 1))
|
|
||||||
-- } : Int;
|
|
||||||
-- main : Int;
|
|
||||||
-- main = fibbonaci 10;
|
|
||||||
-- answer: 55
|
|
||||||
|
|
||||||
-- succ : Int -> Int;
|
|
||||||
-- succ x = x - 1;
|
|
||||||
--
|
|
||||||
-- isZero : Int -> Int;
|
|
||||||
-- isZero x = case x of {
|
|
||||||
-- 0 => 1,
|
|
||||||
-- _ => 0
|
|
||||||
-- } : Int;
|
|
||||||
--
|
|
||||||
-- minimization : (Int -> Int) -> Int -> Int;
|
|
||||||
-- minimization p x = case p x of {
|
|
||||||
-- 1 => 0,
|
|
||||||
-- _ => minimization p (succ x)
|
|
||||||
-- } : Int;
|
|
||||||
--
|
|
||||||
-- main : Int;
|
|
||||||
-- main = minimization isZero 10;
|
|
||||||
-- answer: 0
|
|
||||||
|
|
||||||
posMul : Int -> Int -> Int;
|
|
||||||
posMul a b = case b of {
|
posMul a b = case b of {
|
||||||
0 => 0,
|
0 => 0;
|
||||||
_ => a + posMul a (b - 1)
|
_ => a + posMul a (b - 1)
|
||||||
} : Int;
|
};
|
||||||
|
|
||||||
facc : Int -> Int;
|
facc : _Int -> _Int;
|
||||||
facc a = case a of {
|
facc a = case a of {
|
||||||
1 => 1,
|
1 => 1;
|
||||||
_ => posMul a (facc (a - 1))
|
_ => posMul a (facc (a - 1))
|
||||||
} : Int;
|
};
|
||||||
-- main : Int;
|
|
||||||
-- main = facc 5
|
|
||||||
-- answer: 120
|
|
||||||
|
|
||||||
-- pow : Int -> Int -> Int;
|
minimization : (_Int -> _Int) -> _Int -> _Int;
|
||||||
-- pow a b = case b of {
|
|
||||||
-- 0 => 1,
|
|
||||||
-- _ => posMul a (pow a (b-1))
|
|
||||||
-- } : Int;
|
|
||||||
|
|
||||||
minimization : (Int -> Int) -> Int -> Int;
|
|
||||||
minimization p x = case p x of {
|
minimization p x = case p x of {
|
||||||
1 => x,
|
1 => x;
|
||||||
_ => minimization p (x + 1)
|
_ => minimization p (x + 1)
|
||||||
} : Int;
|
};
|
||||||
|
|
||||||
checkFac : Int -> Int;
|
checkFac : _Int -> _Int;
|
||||||
checkFac x = case facc x of {
|
checkFac x = case facc x of {
|
||||||
0 => 1,
|
0 => 1;
|
||||||
_ => 0
|
_ => 0
|
||||||
} : Int;
|
};
|
||||||
|
|
||||||
main : Int;
|
main : _Int;
|
||||||
main = minimization checkFac 1
|
main = minimization checkFac 1
|
||||||
|
|
@ -1,441 +1,443 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
module Codegen.Codegen where
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
-- {-# LANGUAGE LambdaCase #-}
|
||||||
|
-- {-# LANGUAGE OverloadedStrings #-}
|
||||||
module Codegen.Codegen (generateCode) where
|
--
|
||||||
|
-- module Codegen.Codegen (generateCode) where
|
||||||
import Auxiliary (snoc)
|
--
|
||||||
import Codegen.LlvmIr (CallingConvention (..),
|
-- import Auxiliary (snoc)
|
||||||
LLVMComp (..), LLVMIr (..),
|
-- import Codegen.LlvmIr (CallingConvention (..),
|
||||||
LLVMType (..), LLVMValue (..),
|
-- LLVMComp (..), LLVMIr (..),
|
||||||
Visibility (..), llvmIrToString)
|
-- LLVMType (..), LLVMValue (..),
|
||||||
import Control.Monad.State (StateT, execStateT, foldM_, gets,
|
-- Visibility (..), llvmIrToString)
|
||||||
modify)
|
-- import Control.Monad.State (StateT, execStateT, foldM_, gets,
|
||||||
import qualified Data.Bifunctor as BI
|
-- modify)
|
||||||
import Data.List.Extra (trim)
|
-- import qualified Data.Bifunctor as BI
|
||||||
import Data.Map (Map)
|
-- import Data.List.Extra (trim)
|
||||||
import qualified Data.Map as Map
|
-- import Data.Map (Map)
|
||||||
import Data.Tuple.Extra (dupe, first, second)
|
-- import qualified Data.Map as Map
|
||||||
import qualified Grammar.Abs as GA
|
-- import Data.Tuple.Extra (dupe, first, second)
|
||||||
import Grammar.ErrM (Err)
|
-- import qualified Grammar.Abs as GA
|
||||||
import System.Process.Extra (readCreateProcess, shell)
|
-- import Grammar.ErrM (Err)
|
||||||
import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
|
-- import System.Process.Extra (readCreateProcess, shell)
|
||||||
Ident (..), Program (..), Type (..))
|
-- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
|
||||||
-- | The record used as the code generator state
|
-- Ident (..), Program (..), Type (..))
|
||||||
data CodeGenerator = CodeGenerator
|
-- -- | The record used as the code generator state
|
||||||
{ instructions :: [LLVMIr]
|
-- data CodeGenerator = CodeGenerator
|
||||||
, functions :: Map Id FunctionInfo
|
-- { instructions :: [LLVMIr]
|
||||||
, constructors :: Map Id ConstructorInfo
|
-- , functions :: Map Id FunctionInfo
|
||||||
, variableCount :: Integer
|
-- , constructors :: Map Id ConstructorInfo
|
||||||
, labelCount :: Integer
|
-- , variableCount :: Integer
|
||||||
}
|
-- , labelCount :: Integer
|
||||||
|
-- }
|
||||||
-- | A state type synonym
|
--
|
||||||
type CompilerState a = StateT CodeGenerator Err a
|
-- -- | A state type synonym
|
||||||
|
-- type CompilerState a = StateT CodeGenerator Err a
|
||||||
data FunctionInfo = FunctionInfo
|
--
|
||||||
{ numArgs :: Int
|
-- data FunctionInfo = FunctionInfo
|
||||||
, arguments :: [Id]
|
-- { numArgs :: Int
|
||||||
}
|
-- , arguments :: [Id]
|
||||||
data ConstructorInfo = ConstructorInfo
|
-- }
|
||||||
{ numArgsCI :: Int
|
-- data ConstructorInfo = ConstructorInfo
|
||||||
, argumentsCI :: [Id]
|
-- { numArgsCI :: Int
|
||||||
, numCI :: Integer
|
-- , argumentsCI :: [Id]
|
||||||
}
|
-- , numCI :: Integer
|
||||||
|
-- }
|
||||||
|
--
|
||||||
-- | Adds a instruction to the CodeGenerator state
|
--
|
||||||
emit :: LLVMIr -> CompilerState ()
|
-- -- | Adds a instruction to the CodeGenerator state
|
||||||
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
|
-- emit :: LLVMIr -> CompilerState ()
|
||||||
|
-- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
|
||||||
-- | Increases the variable counter in the CodeGenerator state
|
--
|
||||||
increaseVarCount :: CompilerState ()
|
-- -- | Increases the variable counter in the CodeGenerator state
|
||||||
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
-- increaseVarCount :: CompilerState ()
|
||||||
|
-- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
|
||||||
-- | Returns the variable count from the CodeGenerator state
|
--
|
||||||
getVarCount :: CompilerState Integer
|
-- -- | Returns the variable count from the CodeGenerator state
|
||||||
getVarCount = gets variableCount
|
-- getVarCount :: CompilerState Integer
|
||||||
|
-- getVarCount = gets variableCount
|
||||||
-- | Increases the variable count and returns it from the CodeGenerator state
|
--
|
||||||
getNewVar :: CompilerState Integer
|
-- -- | Increases the variable count and returns it from the CodeGenerator state
|
||||||
getNewVar = increaseVarCount >> getVarCount
|
-- getNewVar :: CompilerState Integer
|
||||||
|
-- getNewVar = increaseVarCount >> getVarCount
|
||||||
-- | Increses the label count and returns a label from the CodeGenerator state
|
--
|
||||||
getNewLabel :: CompilerState Integer
|
-- -- | Increses the label count and returns a label from the CodeGenerator state
|
||||||
getNewLabel = do
|
-- getNewLabel :: CompilerState Integer
|
||||||
modify (\t -> t{labelCount = labelCount t + 1})
|
-- getNewLabel = do
|
||||||
gets labelCount
|
-- modify (\t -> t{labelCount = labelCount t + 1})
|
||||||
|
-- gets labelCount
|
||||||
-- | Produces a map of functions infos from a list of binds,
|
--
|
||||||
-- which contains useful data for code generation.
|
-- -- | Produces a map of functions infos from a list of binds,
|
||||||
getFunctions :: [Bind] -> Map Id FunctionInfo
|
-- -- which contains useful data for code generation.
|
||||||
getFunctions bs = Map.fromList $ go bs
|
-- getFunctions :: [Bind] -> Map Id FunctionInfo
|
||||||
where
|
-- getFunctions bs = Map.fromList $ go bs
|
||||||
go [] = []
|
-- where
|
||||||
go (Bind id args _ : xs) =
|
-- go [] = []
|
||||||
(id, FunctionInfo { numArgs=length args, arguments=args })
|
-- go (Bind id args _ : xs) =
|
||||||
: go xs
|
-- (id, FunctionInfo { numArgs=length args, arguments=args })
|
||||||
go (DataStructure n cons : xs) = do
|
-- : go xs
|
||||||
map (\(id, xs) -> ((id, TPol n), FunctionInfo {
|
-- go (DataStructure n cons : xs) = do
|
||||||
numArgs=length xs, arguments=createArgs xs
|
-- map (\(id, xs) -> ((id, TPol n), FunctionInfo {
|
||||||
})) cons
|
-- numArgs=length xs, arguments=createArgs xs
|
||||||
<> go xs
|
-- })) cons
|
||||||
|
-- <> go xs
|
||||||
createArgs :: [Type] -> [Id]
|
--
|
||||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
|
-- createArgs :: [Type] -> [Id]
|
||||||
|
-- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs
|
||||||
-- | Produces a map of functions infos from a list of binds,
|
--
|
||||||
-- which contains useful data for code generation.
|
-- -- | Produces a map of functions infos from a list of binds,
|
||||||
getConstructors :: [Bind] -> Map Id ConstructorInfo
|
-- -- which contains useful data for code generation.
|
||||||
getConstructors bs = Map.fromList $ go bs
|
-- getConstructors :: [Bind] -> Map Id ConstructorInfo
|
||||||
where
|
-- getConstructors bs = Map.fromList $ go bs
|
||||||
go [] = []
|
-- where
|
||||||
go (DataStructure (Ident n) cons : xs) = do
|
-- go [] = []
|
||||||
fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
|
-- go (DataStructure (Ident n) cons : xs) = do
|
||||||
numArgsCI=length xs,
|
-- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo {
|
||||||
argumentsCI=createArgs xs,
|
-- numArgsCI=length xs,
|
||||||
numCI=i
|
-- argumentsCI=createArgs xs,
|
||||||
}) : acc, i+1)) ([],0) cons)
|
-- numCI=i
|
||||||
<> go xs
|
-- }) : acc, i+1)) ([],0) cons)
|
||||||
go (_: xs) = go xs
|
-- <> go xs
|
||||||
|
-- go (_: xs) = go xs
|
||||||
initCodeGenerator :: [Bind] -> CodeGenerator
|
--
|
||||||
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
-- initCodeGenerator :: [Bind] -> CodeGenerator
|
||||||
, functions = getFunctions scs
|
-- initCodeGenerator scs = CodeGenerator { instructions = defaultStart
|
||||||
, constructors = getConstructors scs
|
-- , functions = getFunctions scs
|
||||||
, variableCount = 0
|
-- , constructors = getConstructors scs
|
||||||
, labelCount = 0
|
-- , variableCount = 0
|
||||||
}
|
-- , labelCount = 0
|
||||||
|
-- }
|
||||||
run :: Err String -> IO ()
|
--
|
||||||
run s = do
|
-- run :: Err String -> IO ()
|
||||||
let s' = case s of
|
-- run s = do
|
||||||
Right s -> s
|
-- let s' = case s of
|
||||||
Left _ -> error "yo"
|
-- Right s -> s
|
||||||
writeFile "output/llvm.ll" s'
|
-- Left _ -> error "yo"
|
||||||
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
-- writeFile "output/llvm.ll" s'
|
||||||
|
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||||
test :: Integer -> Program
|
--
|
||||||
test v = Program [
|
-- test :: Integer -> Program
|
||||||
DataStructure (Ident "Craig") [
|
-- test v = Program [
|
||||||
(Ident "Bob", [TInt])--,
|
-- DataStructure (Ident "Craig") [
|
||||||
--(Ident "Alice", [TInt, TInt])
|
-- (Ident "Bob", [TInt])--,
|
||||||
],
|
-- --(Ident "Alice", [TInt, TInt])
|
||||||
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
|
-- ],
|
||||||
Bind (Ident "main", TInt) [] (
|
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)),
|
||||||
EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
|
-- Bind (Ident "main", TInt) [] (
|
||||||
)
|
-- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92)
|
||||||
]
|
-- )
|
||||||
|
-- ]
|
||||||
{- | Compiles an AST and produces a LLVM Ir string.
|
--
|
||||||
An easy way to actually "compile" this output is to
|
-- {- | Compiles an AST and produces a LLVM Ir string.
|
||||||
Simply pipe it to LLI
|
-- An easy way to actually "compile" this output is to
|
||||||
-}
|
-- Simply pipe it to LLI
|
||||||
generateCode :: Program -> Err String
|
-- -}
|
||||||
generateCode (Program scs) = do
|
-- generateCode :: Program -> Err String
|
||||||
let codegen = initCodeGenerator scs
|
-- generateCode (Program scs) = do
|
||||||
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
-- let codegen = initCodeGenerator scs
|
||||||
|
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
|
||||||
compileScs :: [Bind] -> CompilerState ()
|
--
|
||||||
compileScs [] = do
|
-- compileScs :: [Bind] -> CompilerState ()
|
||||||
-- as a last step create all the constructors
|
-- compileScs [] = do
|
||||||
c <- gets (Map.toList . constructors)
|
-- -- as a last step create all the constructors
|
||||||
mapM_ (\((id, t), ci) -> do
|
-- c <- gets (Map.toList . constructors)
|
||||||
let t' = type2LlvmType t
|
-- mapM_ (\((id, t), ci) -> do
|
||||||
let x = BI.second type2LlvmType <$> argumentsCI ci
|
-- let t' = type2LlvmType t
|
||||||
emit $ Define FastCC t' id x
|
-- let x = BI.second type2LlvmType <$> argumentsCI ci
|
||||||
top <- Ident . show <$> getNewVar
|
-- emit $ Define FastCC t' id x
|
||||||
ptr <- Ident . show <$> getNewVar
|
-- top <- Ident . show <$> getNewVar
|
||||||
-- allocated the primary type
|
-- ptr <- Ident . show <$> getNewVar
|
||||||
emit $ SetVariable top (Alloca t')
|
-- -- allocated the primary type
|
||||||
|
-- emit $ SetVariable top (Alloca t')
|
||||||
-- set the first byte to the index of the constructor
|
--
|
||||||
emit $ SetVariable ptr $
|
-- -- set the first byte to the index of the constructor
|
||||||
GetElementPtrInbounds t' (Ref t')
|
-- emit $ SetVariable ptr $
|
||||||
(VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
|
-- GetElementPtrInbounds t' (Ref t')
|
||||||
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
|
-- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0)
|
||||||
|
-- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr
|
||||||
-- get a pointer of the correct type
|
--
|
||||||
ptr' <- Ident . show <$> getNewVar
|
-- -- get a pointer of the correct type
|
||||||
emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
|
-- ptr' <- Ident . show <$> getNewVar
|
||||||
|
-- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id))
|
||||||
--emit $ UnsafeRaw "\n"
|
--
|
||||||
|
-- --emit $ UnsafeRaw "\n"
|
||||||
foldM_ (\i (Ident arg_n, arg_t)-> do
|
--
|
||||||
let arg_t' = type2LlvmType arg_t
|
-- foldM_ (\i (Ident arg_n, arg_t)-> do
|
||||||
emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
|
-- let arg_t' = type2LlvmType arg_t
|
||||||
elemPtr <- Ident . show <$> getNewVar
|
-- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i )
|
||||||
emit $ SetVariable elemPtr (
|
-- elemPtr <- Ident . show <$> getNewVar
|
||||||
GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
|
-- emit $ SetVariable elemPtr (
|
||||||
(VIdent ptr' Ptr) I32
|
-- GetElementPtrInbounds (CustomType id) (Ref (CustomType id))
|
||||||
(VInteger 0) I32 (VInteger i))
|
-- (VIdent ptr' Ptr) I32
|
||||||
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
-- (VInteger 0) I32 (VInteger i))
|
||||||
-- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
|
-- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||||
-- store i32 42, i32* %2
|
-- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1
|
||||||
pure $ i + 1-- + typeByteSize arg_t'
|
-- -- store i32 42, i32* %2
|
||||||
) 1 (argumentsCI ci)
|
-- pure $ i + 1-- + typeByteSize arg_t'
|
||||||
|
-- ) 1 (argumentsCI ci)
|
||||||
--emit $ UnsafeRaw "\n"
|
--
|
||||||
|
-- --emit $ UnsafeRaw "\n"
|
||||||
-- load and return the constructed value
|
--
|
||||||
load <- Ident . show <$> getNewVar
|
-- -- load and return the constructed value
|
||||||
emit $ SetVariable load (Load t' Ptr top)
|
-- load <- Ident . show <$> getNewVar
|
||||||
emit $ Ret t' (VIdent load t')
|
-- emit $ SetVariable load (Load t' Ptr top)
|
||||||
emit DefineEnd
|
-- emit $ Ret t' (VIdent load t')
|
||||||
|
-- emit DefineEnd
|
||||||
modify $ \s -> s { variableCount = 0 }
|
--
|
||||||
) c
|
-- modify $ \s -> s { variableCount = 0 }
|
||||||
compileScs (Bind (name, _t) args exp : xs) = do
|
-- ) c
|
||||||
emit $ UnsafeRaw "\n"
|
-- compileScs (Bind (name, _t) args exp : xs) = do
|
||||||
emit . Comment $ show name <> ": " <> show exp
|
-- emit $ UnsafeRaw "\n"
|
||||||
let args' = map (second type2LlvmType) args
|
-- emit . Comment $ show name <> ": " <> show exp
|
||||||
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
|
-- let args' = map (second type2LlvmType) args
|
||||||
functionBody <- exprToValue exp
|
-- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
|
||||||
if name == "main"
|
-- functionBody <- exprToValue exp
|
||||||
then mapM_ emit $ mainContent functionBody
|
-- if name == "main"
|
||||||
else emit $ Ret I64 functionBody
|
-- then mapM_ emit $ mainContent functionBody
|
||||||
emit DefineEnd
|
-- else emit $ Ret I64 functionBody
|
||||||
modify $ \s -> s { variableCount = 0 }
|
-- emit DefineEnd
|
||||||
compileScs xs
|
-- modify $ \s -> s { variableCount = 0 }
|
||||||
compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
|
-- compileScs xs
|
||||||
let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
|
-- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do
|
||||||
emit $ Type id [I8, Array biggest_variant I8]
|
-- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
|
||||||
mapM_ (\(Ident inner_id, fi) -> do
|
-- emit $ Type id [I8, Array biggest_variant I8]
|
||||||
emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
|
-- mapM_ (\(Ident inner_id, fi) -> do
|
||||||
) ts
|
-- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
|
||||||
compileScs xs
|
-- ) ts
|
||||||
|
-- compileScs xs
|
||||||
-- where
|
--
|
||||||
-- _t_return = snd $ partitionType (length args) t
|
-- -- where
|
||||||
|
-- -- _t_return = snd $ partitionType (length args) t
|
||||||
mainContent :: LLVMValue -> [LLVMIr]
|
--
|
||||||
mainContent var =
|
-- mainContent :: LLVMValue -> [LLVMIr]
|
||||||
[ UnsafeRaw $
|
-- mainContent var =
|
||||||
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
-- [ UnsafeRaw $
|
||||||
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
|
||||||
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||||
-- , Label (Ident "b_1")
|
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||||
-- , UnsafeRaw
|
-- -- , Label (Ident "b_1")
|
||||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
-- -- , UnsafeRaw
|
||||||
-- , Br (Ident "end")
|
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||||
-- , Label (Ident "b_2")
|
-- -- , Br (Ident "end")
|
||||||
-- , UnsafeRaw
|
-- -- , Label (Ident "b_2")
|
||||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
-- -- , UnsafeRaw
|
||||||
-- , Br (Ident "end")
|
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||||
-- , Label (Ident "end")
|
-- -- , Br (Ident "end")
|
||||||
Ret I64 (VInteger 0)
|
-- -- , Label (Ident "end")
|
||||||
]
|
-- Ret I64 (VInteger 0)
|
||||||
|
-- ]
|
||||||
defaultStart :: [LLVMIr]
|
--
|
||||||
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
-- defaultStart :: [LLVMIr]
|
||||||
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
-- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
-- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
-- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||||
]
|
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||||
|
-- ]
|
||||||
compileExp :: Exp -> CompilerState ()
|
--
|
||||||
compileExp (EInt int) = emitInt int
|
-- compileExp :: Exp -> CompilerState ()
|
||||||
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
|
-- compileExp (EInt int) = emitInt int
|
||||||
compileExp (ESub t e1 e2) = emitSub t e1 e2
|
-- compileExp (EAdd t e1 e2) = emitAdd t e1 e2
|
||||||
compileExp (EId (name, _)) = emitIdent name
|
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
|
||||||
compileExp (EApp t e1 e2) = emitApp t e1 e2
|
-- compileExp (EId (name, _)) = emitIdent name
|
||||||
compileExp (EAbs t ti e) = emitAbs t ti e
|
-- compileExp (EApp t e1 e2) = emitApp t e1 e2
|
||||||
compileExp (ELet binds e) = emitLet binds e
|
-- compileExp (EAbs t ti e) = emitAbs t ti e
|
||||||
compileExp (ECase t e cs) = emitECased t e cs
|
-- compileExp (ELet binds e) = emitLet binds e
|
||||||
-- go (EMul e1 e2) = emitMul e1 e2
|
-- compileExp (ECase t e cs) = emitECased t e cs
|
||||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
-- -- go (EMul e1 e2) = emitMul e1 e2
|
||||||
-- go (EMod e1 e2) = emitMod e1 e2
|
-- -- go (EDiv e1 e2) = emitDiv e1 e2
|
||||||
|
-- -- go (EMod e1 e2) = emitMod e1 e2
|
||||||
--- aux functions ---
|
--
|
||||||
emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
|
-- --- aux functions ---
|
||||||
emitECased t e cases = do
|
-- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
|
||||||
let cs = snd <$> cases
|
-- emitECased t e cases = do
|
||||||
let ty = type2LlvmType t
|
-- let cs = snd <$> cases
|
||||||
vs <- exprToValue e
|
-- let ty = type2LlvmType t
|
||||||
lbl <- getNewLabel
|
-- vs <- exprToValue e
|
||||||
let label = Ident $ "escape_" <> show lbl
|
-- lbl <- getNewLabel
|
||||||
stackPtr <- getNewVar
|
-- let label = Ident $ "escape_" <> show lbl
|
||||||
emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
|
-- stackPtr <- getNewVar
|
||||||
mapM_ (emitCases ty label stackPtr vs) cs
|
-- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
|
||||||
emit $ Label label
|
-- mapM_ (emitCases ty label stackPtr vs) cs
|
||||||
res <- getNewVar
|
-- emit $ Label label
|
||||||
emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
|
-- res <- getNewVar
|
||||||
where
|
-- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
|
||||||
emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
-- where
|
||||||
emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
|
-- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||||
ns <- getNewVar
|
-- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
|
||||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
-- ns <- getNewVar
|
||||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
-- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
|
-- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||||
emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
|
-- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
|
||||||
emit $ Label lbl_succPos
|
-- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos
|
||||||
val <- exprToValue exp
|
-- emit $ Label lbl_succPos
|
||||||
emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
-- val <- exprToValue exp
|
||||||
emit $ Br label
|
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||||
emit $ Label lbl_failPos
|
-- emit $ Br label
|
||||||
emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
|
-- emit $ Label lbl_failPos
|
||||||
val <- exprToValue exp
|
-- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
|
||||||
emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
-- val <- exprToValue exp
|
||||||
emit $ Br label
|
-- emit $ Store ty val Ptr (Ident . show $ stackPtr)
|
||||||
|
-- emit $ Br label
|
||||||
|
--
|
||||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
--
|
||||||
emitAbs _t tid e = do
|
-- emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||||
emit . Comment $
|
-- emitAbs _t tid e = do
|
||||||
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
-- emit . Comment $
|
||||||
emitLet :: Bind -> Exp -> CompilerState ()
|
-- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
|
||||||
emitLet xs e = do
|
-- emitLet :: Bind -> Exp -> CompilerState ()
|
||||||
emit $
|
-- emitLet xs e = do
|
||||||
Comment $
|
-- emit $
|
||||||
concat
|
-- Comment $
|
||||||
[ "ELet ("
|
-- concat
|
||||||
, show xs
|
-- [ "ELet ("
|
||||||
, " = "
|
-- , show xs
|
||||||
, show e
|
-- , " = "
|
||||||
, ") is not implemented!"
|
-- , show e
|
||||||
]
|
-- , ") is not implemented!"
|
||||||
|
-- ]
|
||||||
emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
--
|
||||||
emitApp t e1 e2 = appEmitter t e1 e2 []
|
-- emitApp :: Type -> Exp -> Exp -> CompilerState ()
|
||||||
where
|
-- emitApp t e1 e2 = appEmitter t e1 e2 []
|
||||||
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
-- where
|
||||||
appEmitter t e1 e2 stack = do
|
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
|
||||||
let newStack = e2 : stack
|
-- appEmitter t e1 e2 stack = do
|
||||||
case e1 of
|
-- let newStack = e2 : stack
|
||||||
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
-- case e1 of
|
||||||
EId id@(name, _) -> do
|
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
|
||||||
args <- traverse exprToValue newStack
|
-- EId id@(name, _) -> do
|
||||||
vs <- getNewVar
|
-- args <- traverse exprToValue newStack
|
||||||
funcs <- gets functions
|
-- vs <- getNewVar
|
||||||
let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
-- funcs <- gets functions
|
||||||
args' = map (first valueGetType . dupe) args
|
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
|
||||||
call = Call FastCC (type2LlvmType t) visibility name args'
|
-- args' = map (first valueGetType . dupe) args
|
||||||
emit $ SetVariable (Ident $ show vs) call
|
-- call = Call FastCC (type2LlvmType t) visibility name args'
|
||||||
x -> do
|
-- emit $ SetVariable (Ident $ show vs) call
|
||||||
emit . Comment $ "The unspeakable happened: "
|
-- x -> do
|
||||||
emit . Comment $ show x
|
-- emit . Comment $ "The unspeakable happened: "
|
||||||
|
-- emit . Comment $ show x
|
||||||
emitIdent :: Ident -> CompilerState ()
|
--
|
||||||
emitIdent id = do
|
-- emitIdent :: Ident -> CompilerState ()
|
||||||
-- !!this should never happen!!
|
-- emitIdent id = do
|
||||||
emit $ Comment "This should not have happened!"
|
-- -- !!this should never happen!!
|
||||||
emit $ Variable id
|
-- emit $ Comment "This should not have happened!"
|
||||||
emit $ UnsafeRaw "\n"
|
-- emit $ Variable id
|
||||||
|
-- emit $ UnsafeRaw "\n"
|
||||||
emitInt :: Integer -> CompilerState ()
|
--
|
||||||
emitInt i = do
|
-- emitInt :: Integer -> CompilerState ()
|
||||||
-- !!this should never happen!!
|
-- emitInt i = do
|
||||||
varCount <- getNewVar
|
-- -- !!this should never happen!!
|
||||||
emit $ Comment "This should not have happened!"
|
-- varCount <- getNewVar
|
||||||
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
-- emit $ Comment "This should not have happened!"
|
||||||
|
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
|
||||||
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
--
|
||||||
emitAdd t e1 e2 = do
|
-- emitAdd :: Type -> Exp -> Exp -> CompilerState ()
|
||||||
v1 <- exprToValue e1
|
-- emitAdd t e1 e2 = do
|
||||||
v2 <- exprToValue e2
|
-- v1 <- exprToValue e1
|
||||||
v <- getNewVar
|
-- v2 <- exprToValue e2
|
||||||
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
-- v <- getNewVar
|
||||||
|
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||||
emitSub :: Type -> Exp -> Exp -> CompilerState ()
|
--
|
||||||
emitSub t e1 e2 = do
|
-- emitSub :: Type -> Exp -> Exp -> CompilerState ()
|
||||||
v1 <- exprToValue e1
|
-- emitSub t e1 e2 = do
|
||||||
v2 <- exprToValue e2
|
-- v1 <- exprToValue e1
|
||||||
v <- getNewVar
|
-- v2 <- exprToValue e2
|
||||||
emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
|
-- v <- getNewVar
|
||||||
|
-- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2)
|
||||||
-- emitMul :: Exp -> Exp -> CompilerState ()
|
--
|
||||||
-- emitMul e1 e2 = do
|
-- -- emitMul :: Exp -> Exp -> CompilerState ()
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
-- -- emitMul e1 e2 = do
|
||||||
-- increaseVarCount
|
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||||
-- v <- gets variableCount
|
-- -- increaseVarCount
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
-- -- v <- gets variableCount
|
||||||
-- emit $ Mul I64 v1 v2
|
-- -- emit $ SetVariable $ Ident $ show v
|
||||||
|
-- -- emit $ Mul I64 v1 v2
|
||||||
-- emitMod :: Exp -> Exp -> CompilerState ()
|
--
|
||||||
-- emitMod e1 e2 = do
|
-- -- emitMod :: Exp -> Exp -> CompilerState ()
|
||||||
-- -- `let m a b = rem (abs $ b + a) b`
|
-- -- emitMod e1 e2 = do
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
-- -- -- `let m a b = rem (abs $ b + a) b`
|
||||||
-- increaseVarCount
|
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||||
-- vadd <- gets variableCount
|
-- -- increaseVarCount
|
||||||
-- emit $ SetVariable $ Ident $ show vadd
|
-- -- vadd <- gets variableCount
|
||||||
-- emit $ Add I64 v1 v2
|
-- -- emit $ SetVariable $ Ident $ show vadd
|
||||||
|
-- -- emit $ Add I64 v1 v2
|
||||||
|
-- --
|
||||||
|
-- -- increaseVarCount
|
||||||
|
-- -- vabs <- gets variableCount
|
||||||
|
-- -- emit $ SetVariable $ Ident $ show vabs
|
||||||
|
-- -- emit $ Call I64 (Ident "llvm.abs.i64")
|
||||||
|
-- -- [ (I64, VIdent (Ident $ show vadd))
|
||||||
|
-- -- , (I1, VInteger 1)
|
||||||
|
-- -- ]
|
||||||
|
-- -- increaseVarCount
|
||||||
|
-- -- v <- gets variableCount
|
||||||
|
-- -- emit $ SetVariable $ Ident $ show v
|
||||||
|
-- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
||||||
|
--
|
||||||
|
-- -- emitDiv :: Exp -> Exp -> CompilerState ()
|
||||||
|
-- -- emitDiv e1 e2 = do
|
||||||
|
-- -- (v1,v2) <- binExprToValues e1 e2
|
||||||
|
-- -- increaseVarCount
|
||||||
|
-- -- v <- gets variableCount
|
||||||
|
-- -- emit $ SetVariable $ Ident $ show v
|
||||||
|
-- -- emit $ Div I64 v1 v2
|
||||||
|
--
|
||||||
|
-- exprToValue :: Exp -> CompilerState LLVMValue
|
||||||
|
-- exprToValue = \case
|
||||||
|
-- EInt i -> pure $ VInteger i
|
||||||
|
--
|
||||||
|
-- EId id@(name, t) -> do
|
||||||
|
-- funcs <- gets functions
|
||||||
|
-- case Map.lookup id funcs of
|
||||||
|
-- Just fi -> do
|
||||||
|
-- if numArgs fi == 0
|
||||||
|
-- then do
|
||||||
|
-- vc <- getNewVar
|
||||||
|
-- emit $ SetVariable (Ident $ show vc)
|
||||||
|
-- (Call FastCC (type2LlvmType t) Global name [])
|
||||||
|
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
||||||
|
-- else pure $ VFunction name Global (type2LlvmType t)
|
||||||
|
-- Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||||
|
--
|
||||||
|
-- e -> do
|
||||||
|
-- compileExp e
|
||||||
|
-- v <- getVarCount
|
||||||
|
-- pure $ VIdent (Ident $ show v) (getType e)
|
||||||
|
--
|
||||||
|
-- type2LlvmType :: Type -> LLVMType
|
||||||
|
-- type2LlvmType = \case
|
||||||
|
-- TInt -> I64
|
||||||
|
-- TFun t xs -> do
|
||||||
|
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||||
|
-- Function t' xs'
|
||||||
|
-- TPol t -> CustomType t
|
||||||
|
-- where
|
||||||
|
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||||
|
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||||
|
-- function2LLVMType x s = (type2LlvmType x, s)
|
||||||
|
--
|
||||||
|
-- getType :: Exp -> LLVMType
|
||||||
|
-- getType (EInt _) = I64
|
||||||
|
-- getType (EAdd t _ _) = type2LlvmType t
|
||||||
|
-- getType (ESub t _ _) = type2LlvmType t
|
||||||
|
-- getType (EId (_, t)) = type2LlvmType t
|
||||||
|
-- getType (EApp t _ _) = type2LlvmType t
|
||||||
|
-- getType (EAbs t _ _) = type2LlvmType t
|
||||||
|
-- getType (ELet _ e) = getType e
|
||||||
|
-- getType (ECase t _ _) = type2LlvmType t
|
||||||
|
--
|
||||||
|
-- valueGetType :: LLVMValue -> LLVMType
|
||||||
|
-- valueGetType (VInteger _) = I64
|
||||||
|
-- valueGetType (VIdent _ t) = t
|
||||||
|
-- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
||||||
|
-- valueGetType (VFunction _ _ t) = t
|
||||||
|
--
|
||||||
|
-- typeByteSize :: LLVMType -> Integer
|
||||||
|
-- typeByteSize I1 = 1
|
||||||
|
-- typeByteSize I8 = 1
|
||||||
|
-- typeByteSize I32 = 4
|
||||||
|
-- typeByteSize I64 = 8
|
||||||
|
-- typeByteSize Ptr = 8
|
||||||
|
-- typeByteSize (Ref _) = 8
|
||||||
|
-- typeByteSize (Function _ _) = 8
|
||||||
|
-- typeByteSize (Array n t) = n * typeByteSize t
|
||||||
|
-- typeByteSize (CustomType _) = 8
|
||||||
--
|
--
|
||||||
-- increaseVarCount
|
|
||||||
-- vabs <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show vabs
|
|
||||||
-- emit $ Call I64 (Ident "llvm.abs.i64")
|
|
||||||
-- [ (I64, VIdent (Ident $ show vadd))
|
|
||||||
-- , (I1, VInteger 1)
|
|
||||||
-- ]
|
|
||||||
-- increaseVarCount
|
|
||||||
-- v <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
|
||||||
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
|
|
||||||
|
|
||||||
-- emitDiv :: Exp -> Exp -> CompilerState ()
|
|
||||||
-- emitDiv e1 e2 = do
|
|
||||||
-- (v1,v2) <- binExprToValues e1 e2
|
|
||||||
-- increaseVarCount
|
|
||||||
-- v <- gets variableCount
|
|
||||||
-- emit $ SetVariable $ Ident $ show v
|
|
||||||
-- emit $ Div I64 v1 v2
|
|
||||||
|
|
||||||
exprToValue :: Exp -> CompilerState LLVMValue
|
|
||||||
exprToValue = \case
|
|
||||||
EInt i -> pure $ VInteger i
|
|
||||||
|
|
||||||
EId id@(name, t) -> do
|
|
||||||
funcs <- gets functions
|
|
||||||
case Map.lookup id funcs of
|
|
||||||
Just fi -> do
|
|
||||||
if numArgs fi == 0
|
|
||||||
then do
|
|
||||||
vc <- getNewVar
|
|
||||||
emit $ SetVariable (Ident $ show vc)
|
|
||||||
(Call FastCC (type2LlvmType t) Global name [])
|
|
||||||
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
|
|
||||||
else pure $ VFunction name Global (type2LlvmType t)
|
|
||||||
Nothing -> pure $ VIdent name (type2LlvmType t)
|
|
||||||
|
|
||||||
e -> do
|
|
||||||
compileExp e
|
|
||||||
v <- getVarCount
|
|
||||||
pure $ VIdent (Ident $ show v) (getType e)
|
|
||||||
|
|
||||||
type2LlvmType :: Type -> LLVMType
|
|
||||||
type2LlvmType = \case
|
|
||||||
TInt -> I64
|
|
||||||
TFun t xs -> do
|
|
||||||
let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
|
||||||
Function t' xs'
|
|
||||||
TPol t -> CustomType t
|
|
||||||
where
|
|
||||||
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
|
||||||
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
|
||||||
function2LLVMType x s = (type2LlvmType x, s)
|
|
||||||
|
|
||||||
getType :: Exp -> LLVMType
|
|
||||||
getType (EInt _) = I64
|
|
||||||
getType (EAdd t _ _) = type2LlvmType t
|
|
||||||
getType (ESub t _ _) = type2LlvmType t
|
|
||||||
getType (EId (_, t)) = type2LlvmType t
|
|
||||||
getType (EApp t _ _) = type2LlvmType t
|
|
||||||
getType (EAbs t _ _) = type2LlvmType t
|
|
||||||
getType (ELet _ e) = getType e
|
|
||||||
getType (ECase t _ _) = type2LlvmType t
|
|
||||||
|
|
||||||
valueGetType :: LLVMValue -> LLVMType
|
|
||||||
valueGetType (VInteger _) = I64
|
|
||||||
valueGetType (VIdent _ t) = t
|
|
||||||
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
|
||||||
valueGetType (VFunction _ _ t) = t
|
|
||||||
|
|
||||||
typeByteSize :: LLVMType -> Integer
|
|
||||||
typeByteSize I1 = 1
|
|
||||||
typeByteSize I8 = 1
|
|
||||||
typeByteSize I32 = 4
|
|
||||||
typeByteSize I64 = 8
|
|
||||||
typeByteSize Ptr = 8
|
|
||||||
typeByteSize (Ref _) = 8
|
|
||||||
typeByteSize (Function _ _) = 8
|
|
||||||
typeByteSize (Array n t) = n * typeByteSize t
|
|
||||||
typeByteSize (CustomType _) = 8
|
|
||||||
|
|
|
||||||
|
|
@ -1,239 +1,241 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
module Codegen.LlvmIr where
|
||||||
|
-- {-# LANGUAGE LambdaCase #-}
|
||||||
module Codegen.LlvmIr (
|
--
|
||||||
LLVMType (..),
|
-- module Codegen.LlvmIr (
|
||||||
LLVMIr (..),
|
-- LLVMType (..),
|
||||||
llvmIrToString,
|
-- LLVMIr (..),
|
||||||
LLVMValue (..),
|
-- llvmIrToString,
|
||||||
LLVMComp (..),
|
-- LLVMValue (..),
|
||||||
Visibility (..),
|
-- LLVMComp (..),
|
||||||
CallingConvention (..)
|
-- Visibility (..),
|
||||||
) where
|
-- CallingConvention (..)
|
||||||
|
-- ) where
|
||||||
import Data.List (intercalate)
|
--
|
||||||
import TypeChecker.TypeCheckerIr
|
-- import Data.List (intercalate)
|
||||||
|
-- import TypeChecker.TypeCheckerIr
|
||||||
data CallingConvention = TailCC | FastCC | CCC | ColdCC
|
--
|
||||||
instance Show CallingConvention where
|
-- data CallingConvention = TailCC | FastCC | CCC | ColdCC
|
||||||
show :: CallingConvention -> String
|
-- instance Show CallingConvention where
|
||||||
show TailCC = "tailcc"
|
-- show :: CallingConvention -> String
|
||||||
show FastCC = "fastcc"
|
-- show TailCC = "tailcc"
|
||||||
show CCC = "ccc"
|
-- show FastCC = "fastcc"
|
||||||
show ColdCC = "coldcc"
|
-- show CCC = "ccc"
|
||||||
|
-- show ColdCC = "coldcc"
|
||||||
-- | A datatype which represents some basic LLVM types
|
--
|
||||||
data LLVMType
|
-- -- | A datatype which represents some basic LLVM types
|
||||||
= I1
|
-- data LLVMType
|
||||||
| I8
|
-- = I1
|
||||||
| I32
|
-- | I8
|
||||||
| I64
|
-- | I32
|
||||||
| Ptr
|
-- | I64
|
||||||
| Ref LLVMType
|
-- | Ptr
|
||||||
| Function LLVMType [LLVMType]
|
-- | Ref LLVMType
|
||||||
| Array Integer LLVMType
|
-- | Function LLVMType [LLVMType]
|
||||||
| CustomType Ident
|
-- | Array Integer LLVMType
|
||||||
|
-- | CustomType Ident
|
||||||
instance Show LLVMType where
|
--
|
||||||
show :: LLVMType -> String
|
-- instance Show LLVMType where
|
||||||
show = \case
|
-- show :: LLVMType -> String
|
||||||
I1 -> "i1"
|
-- show = \case
|
||||||
I8 -> "i8"
|
-- I1 -> "i1"
|
||||||
I32 -> "i32"
|
-- I8 -> "i8"
|
||||||
I64 -> "i64"
|
-- I32 -> "i32"
|
||||||
Ptr -> "ptr"
|
-- I64 -> "i64"
|
||||||
Ref ty -> show ty <> "*"
|
-- Ptr -> "ptr"
|
||||||
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
|
-- Ref ty -> show ty <> "*"
|
||||||
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
|
-- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
|
||||||
CustomType (Ident ty) -> "%" <> ty
|
-- Array n ty -> concat ["[", show n, " x ", show ty, "]"]
|
||||||
|
-- CustomType (Ident ty) -> "%" <> ty
|
||||||
data LLVMComp
|
--
|
||||||
= LLEq
|
-- data LLVMComp
|
||||||
| LLNe
|
-- = LLEq
|
||||||
| LLUgt
|
-- | LLNe
|
||||||
| LLUge
|
-- | LLUgt
|
||||||
| LLUlt
|
-- | LLUge
|
||||||
| LLUle
|
-- | LLUlt
|
||||||
| LLSgt
|
-- | LLUle
|
||||||
| LLSge
|
-- | LLSgt
|
||||||
| LLSlt
|
-- | LLSge
|
||||||
| LLSle
|
-- | LLSlt
|
||||||
instance Show LLVMComp where
|
-- | LLSle
|
||||||
show :: LLVMComp -> String
|
-- instance Show LLVMComp where
|
||||||
show = \case
|
-- show :: LLVMComp -> String
|
||||||
LLEq -> "eq"
|
-- show = \case
|
||||||
LLNe -> "ne"
|
-- LLEq -> "eq"
|
||||||
LLUgt -> "ugt"
|
-- LLNe -> "ne"
|
||||||
LLUge -> "uge"
|
-- LLUgt -> "ugt"
|
||||||
LLUlt -> "ult"
|
-- LLUge -> "uge"
|
||||||
LLUle -> "ule"
|
-- LLUlt -> "ult"
|
||||||
LLSgt -> "sgt"
|
-- LLUle -> "ule"
|
||||||
LLSge -> "sge"
|
-- LLSgt -> "sgt"
|
||||||
LLSlt -> "slt"
|
-- LLSge -> "sge"
|
||||||
LLSle -> "sle"
|
-- LLSlt -> "slt"
|
||||||
|
-- LLSle -> "sle"
|
||||||
data Visibility = Local | Global
|
--
|
||||||
instance Show Visibility where
|
-- data Visibility = Local | Global
|
||||||
show :: Visibility -> String
|
-- instance Show Visibility where
|
||||||
show Local = "%"
|
-- show :: Visibility -> String
|
||||||
show Global = "@"
|
-- show Local = "%"
|
||||||
|
-- show Global = "@"
|
||||||
-- | Represents a LLVM "value", as in an integer, a register variable,
|
--
|
||||||
-- or a string contstant
|
-- -- | Represents a LLVM "value", as in an integer, a register variable,
|
||||||
data LLVMValue
|
-- -- or a string contstant
|
||||||
= VInteger Integer
|
-- data LLVMValue
|
||||||
| VIdent Ident LLVMType
|
-- = VInteger Integer
|
||||||
| VConstant String
|
-- | VIdent Ident LLVMType
|
||||||
| VFunction Ident Visibility LLVMType
|
-- | VConstant String
|
||||||
|
-- | VFunction Ident Visibility LLVMType
|
||||||
instance Show LLVMValue where
|
--
|
||||||
show :: LLVMValue -> String
|
-- instance Show LLVMValue where
|
||||||
show v = case v of
|
-- show :: LLVMValue -> String
|
||||||
VInteger i -> show i
|
-- show v = case v of
|
||||||
VIdent (Ident n) _ -> "%" <> n
|
-- VInteger i -> show i
|
||||||
VFunction (Ident n) vis _ -> show vis <> n
|
-- VIdent (Ident n) _ -> "%" <> n
|
||||||
VConstant s -> "c" <> show s
|
-- VFunction (Ident n) vis _ -> show vis <> n
|
||||||
|
-- VConstant s -> "c" <> show s
|
||||||
type Params = [(Ident, LLVMType)]
|
--
|
||||||
type Args = [(LLVMType, LLVMValue)]
|
-- type Params = [(Ident, LLVMType)]
|
||||||
|
-- type Args = [(LLVMType, LLVMValue)]
|
||||||
-- | A datatype which represents different instructions in LLVM
|
--
|
||||||
data LLVMIr
|
-- -- | A datatype which represents different instructions in LLVM
|
||||||
= Type Ident [LLVMType]
|
-- data LLVMIr
|
||||||
| Define CallingConvention LLVMType Ident Params
|
-- = Type Ident [LLVMType]
|
||||||
| DefineEnd
|
-- | Define CallingConvention LLVMType Ident Params
|
||||||
| Declare LLVMType Ident Params
|
-- | DefineEnd
|
||||||
| SetVariable Ident LLVMIr
|
-- | Declare LLVMType Ident Params
|
||||||
| Variable Ident
|
-- | SetVariable Ident LLVMIr
|
||||||
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
|
-- | Variable Ident
|
||||||
| Add LLVMType LLVMValue LLVMValue
|
-- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
|
||||||
| Sub LLVMType LLVMValue LLVMValue
|
-- | Add LLVMType LLVMValue LLVMValue
|
||||||
| Div LLVMType LLVMValue LLVMValue
|
-- | Sub LLVMType LLVMValue LLVMValue
|
||||||
| Mul LLVMType LLVMValue LLVMValue
|
-- | Div LLVMType LLVMValue LLVMValue
|
||||||
| Srem LLVMType LLVMValue LLVMValue
|
-- | Mul LLVMType LLVMValue LLVMValue
|
||||||
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
|
-- | Srem LLVMType LLVMValue LLVMValue
|
||||||
| Br Ident
|
-- | Icmp LLVMComp LLVMType LLVMValue LLVMValue
|
||||||
| BrCond LLVMValue Ident Ident
|
-- | Br Ident
|
||||||
| Label Ident
|
-- | BrCond LLVMValue Ident Ident
|
||||||
| Call CallingConvention LLVMType Visibility Ident Args
|
-- | Label Ident
|
||||||
| Alloca LLVMType
|
-- | Call CallingConvention LLVMType Visibility Ident Args
|
||||||
| Store LLVMType LLVMValue LLVMType Ident
|
-- | Alloca LLVMType
|
||||||
| Load LLVMType LLVMType Ident
|
-- | Store LLVMType LLVMValue LLVMType Ident
|
||||||
| Bitcast LLVMType Ident LLVMType
|
-- | Load LLVMType LLVMType Ident
|
||||||
| Ret LLVMType LLVMValue
|
-- | Bitcast LLVMType Ident LLVMType
|
||||||
| Comment String
|
-- | Ret LLVMType LLVMValue
|
||||||
| UnsafeRaw String -- This should generally be avoided, and proper
|
-- | Comment String
|
||||||
-- instructions should be used in its place
|
-- | UnsafeRaw String -- This should generally be avoided, and proper
|
||||||
deriving (Show)
|
-- -- instructions should be used in its place
|
||||||
|
-- deriving (Show)
|
||||||
-- | Converts a list of LLVMIr instructions to a string
|
--
|
||||||
llvmIrToString :: [LLVMIr] -> String
|
-- -- | Converts a list of LLVMIr instructions to a string
|
||||||
llvmIrToString = go 0
|
-- llvmIrToString :: [LLVMIr] -> String
|
||||||
where
|
-- llvmIrToString = go 0
|
||||||
go :: Int -> [LLVMIr] -> String
|
-- where
|
||||||
go _ [] = mempty
|
-- go :: Int -> [LLVMIr] -> String
|
||||||
go i (x : xs) = do
|
-- go _ [] = mempty
|
||||||
let (i', n) = case x of
|
-- go i (x : xs) = do
|
||||||
Define{} -> (i + 1, 0)
|
-- let (i', n) = case x of
|
||||||
DefineEnd -> (i - 1, 0)
|
-- Define{} -> (i + 1, 0)
|
||||||
_ -> (i, i)
|
-- DefineEnd -> (i - 1, 0)
|
||||||
insToString n x <> go i' xs
|
-- _ -> (i, i)
|
||||||
|
-- insToString n x <> go i' xs
|
||||||
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
|
--
|
||||||
The integer represents the indentation
|
-- {- | Converts a LLVM inststruction to a String, allowing for printing etc.
|
||||||
-}
|
-- The integer represents the indentation
|
||||||
{- FOURMOLU_DISABLE -}
|
-- -}
|
||||||
insToString :: Int -> LLVMIr -> String
|
-- {- FOURMOLU_DISABLE -}
|
||||||
insToString i l =
|
-- insToString :: Int -> LLVMIr -> String
|
||||||
replicate i '\t' <> case l of
|
-- insToString i l =
|
||||||
(GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
|
-- replicate i '\t' <> case l of
|
||||||
-- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
|
-- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do
|
||||||
concat
|
-- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0
|
||||||
[ "getelementptr inbounds ", show t1, ", " , show t2
|
-- concat
|
||||||
, " ", show p, ", ", show t3, " ", show v1,
|
-- [ "getelementptr inbounds ", show t1, ", " , show t2
|
||||||
", ", show t4, " ", show v2, "\n" ]
|
-- , " ", show p, ", ", show t3, " ", show v1,
|
||||||
(Type (Ident n) types) ->
|
-- ", ", show t4, " ", show v2, "\n" ]
|
||||||
concat
|
-- (Type (Ident n) types) ->
|
||||||
[ "%", n, " = type { "
|
-- concat
|
||||||
, intercalate ", " (map show types)
|
-- [ "%", n, " = type { "
|
||||||
, " }\n"
|
-- , intercalate ", " (map show types)
|
||||||
]
|
-- , " }\n"
|
||||||
(Define c t (Ident i) params) ->
|
-- ]
|
||||||
concat
|
-- (Define c t (Ident i) params) ->
|
||||||
[ "define ", show c, " ", show t, " @", i
|
-- concat
|
||||||
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
|
-- [ "define ", show c, " ", show t, " @", i
|
||||||
, ") {\n"
|
-- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
|
||||||
]
|
-- , ") {\n"
|
||||||
DefineEnd -> "}\n"
|
-- ]
|
||||||
(Declare _t (Ident _i) _params) -> undefined
|
-- DefineEnd -> "}\n"
|
||||||
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
|
-- (Declare _t (Ident _i) _params) -> undefined
|
||||||
(Add t v1 v2) ->
|
-- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
|
||||||
concat
|
-- (Add t v1 v2) ->
|
||||||
[ "add ", show t, " ", show v1
|
-- concat
|
||||||
, ", ", show v2, "\n"
|
-- [ "add ", show t, " ", show v1
|
||||||
]
|
-- , ", ", show v2, "\n"
|
||||||
(Sub t v1 v2) ->
|
-- ]
|
||||||
concat
|
-- (Sub t v1 v2) ->
|
||||||
[ "sub ", show t, " ", show v1, ", "
|
-- concat
|
||||||
, show v2, "\n"
|
-- [ "sub ", show t, " ", show v1, ", "
|
||||||
]
|
-- , show v2, "\n"
|
||||||
(Div t v1 v2) ->
|
-- ]
|
||||||
concat
|
-- (Div t v1 v2) ->
|
||||||
[ "sdiv ", show t, " ", show v1, ", "
|
-- concat
|
||||||
, show v2, "\n"
|
-- [ "sdiv ", show t, " ", show v1, ", "
|
||||||
]
|
-- , show v2, "\n"
|
||||||
(Mul t v1 v2) ->
|
-- ]
|
||||||
concat
|
-- (Mul t v1 v2) ->
|
||||||
[ "mul ", show t, " ", show v1
|
-- concat
|
||||||
, ", ", show v2, "\n"
|
-- [ "mul ", show t, " ", show v1
|
||||||
]
|
-- , ", ", show v2, "\n"
|
||||||
(Srem t v1 v2) ->
|
-- ]
|
||||||
concat
|
-- (Srem t v1 v2) ->
|
||||||
[ "srem ", show t, " ", show v1, ", "
|
-- concat
|
||||||
, show v2, "\n"
|
-- [ "srem ", show t, " ", show v1, ", "
|
||||||
]
|
-- , show v2, "\n"
|
||||||
(Call c t vis (Ident i) arg) ->
|
-- ]
|
||||||
concat
|
-- (Call c t vis (Ident i) arg) ->
|
||||||
[ "call ", show c, " ", show t, " ", show vis, i, "("
|
-- concat
|
||||||
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
|
-- [ "call ", show c, " ", show t, " ", show vis, i, "("
|
||||||
, ")\n"
|
-- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
|
||||||
]
|
-- , ")\n"
|
||||||
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
-- ]
|
||||||
(Store t1 val t2 (Ident id2)) ->
|
-- (Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||||
concat
|
-- (Store t1 val t2 (Ident id2)) ->
|
||||||
[ "store ", show t1, " ", show val
|
-- concat
|
||||||
, ", ", show t2 , " %", id2, "\n"
|
-- [ "store ", show t1, " ", show val
|
||||||
]
|
-- , ", ", show t2 , " %", id2, "\n"
|
||||||
(Load t1 t2 (Ident addr)) ->
|
-- ]
|
||||||
concat
|
-- (Load t1 t2 (Ident addr)) ->
|
||||||
[ "load ", show t1, ", "
|
-- concat
|
||||||
, show t2, " %", addr, "\n"
|
-- [ "load ", show t1, ", "
|
||||||
]
|
-- , show t2, " %", addr, "\n"
|
||||||
(Bitcast t1 (Ident i) t2) ->
|
-- ]
|
||||||
concat
|
-- (Bitcast t1 (Ident i) t2) ->
|
||||||
[ "bitcast ", show t1, " %"
|
-- concat
|
||||||
, i, " to ", show t2, "\n"
|
-- [ "bitcast ", show t1, " %"
|
||||||
]
|
-- , i, " to ", show t2, "\n"
|
||||||
(Icmp comp t v1 v2) ->
|
-- ]
|
||||||
concat
|
-- (Icmp comp t v1 v2) ->
|
||||||
[ "icmp ", show comp, " ", show t
|
-- concat
|
||||||
, " ", show v1, ", ", show v2, "\n"
|
-- [ "icmp ", show comp, " ", show t
|
||||||
]
|
-- , " ", show v1, ", ", show v2, "\n"
|
||||||
(Ret t v) ->
|
-- ]
|
||||||
concat
|
-- (Ret t v) ->
|
||||||
[ "ret ", show t, " "
|
-- concat
|
||||||
, show v, "\n"
|
-- [ "ret ", show t, " "
|
||||||
]
|
-- , show v, "\n"
|
||||||
(UnsafeRaw s) -> s
|
-- ]
|
||||||
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
-- (UnsafeRaw s) -> s
|
||||||
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
-- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||||
(BrCond val (Ident s1) (Ident s2)) ->
|
-- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||||
concat
|
-- (BrCond val (Ident s1) (Ident s2)) ->
|
||||||
[ "br i1 ", show val, ", ", "label %"
|
-- concat
|
||||||
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
-- [ "br i1 ", show val, ", ", "label %"
|
||||||
]
|
-- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||||
(Comment s) -> "; " <> s <> "\n"
|
-- ]
|
||||||
(Variable (Ident id)) -> "%" <> id
|
-- (Comment s) -> "; " <> s <> "\n"
|
||||||
{- FOURMOLU_ENABLE -}
|
-- (Variable (Ident id)) -> "%" <> id
|
||||||
|
-- {- FOURMOLU_ENABLE -}
|
||||||
lblPfx :: String
|
--
|
||||||
lblPfx = "lbl_"
|
-- lblPfx :: String
|
||||||
|
-- lblPfx = "lbl_"
|
||||||
|
--
|
||||||
|
|
|
||||||
|
|
@ -1,235 +1,192 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
--{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
--{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
|
||||||
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
module LambdaLifter.LambdaLifter where
|
||||||
|
|
||||||
import Auxiliary (snoc)
|
--import Auxiliary (snoc)
|
||||||
import Control.Applicative (Applicative (liftA2))
|
--import Control.Applicative (Applicative (liftA2))
|
||||||
import Control.Monad.State (MonadState (get, put), State,
|
--import Control.Monad.State (MonadState (get, put), State,
|
||||||
evalState)
|
-- evalState)
|
||||||
import Data.Set (Set)
|
--import Data.Set (Set)
|
||||||
import qualified Data.Set as Set
|
--import qualified Data.Set as Set
|
||||||
import Debug.Trace (trace)
|
--import Prelude hiding (exp)
|
||||||
import qualified Grammar.Abs as GA
|
--import Renamer.Renamer
|
||||||
import Prelude hiding (exp)
|
--import TypeChecker.TypeCheckerIr
|
||||||
import Renamer.Renamer
|
|
||||||
import TypeChecker.TypeCheckerIr
|
|
||||||
|
|
||||||
|
|
||||||
-- | Lift lambdas and let expression into supercombinators.
|
---- | Lift lambdas and let expression into supercombinators.
|
||||||
-- Three phases:
|
---- Three phases:
|
||||||
-- @freeVars@ annotatss all the free variables.
|
---- @freeVars@ annotatss all the free variables.
|
||||||
-- @abstract@ converts lambdas into let expressions.
|
---- @abstract@ converts lambdas into let expressions.
|
||||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
---- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||||
lambdaLift :: Program -> Program
|
--lambdaLift :: Program -> Program
|
||||||
lambdaLift = collectScs . abstract . freeVars
|
--lambdaLift = collectScs . abstract . freeVars
|
||||||
|
|
||||||
-- | Annotate free variables
|
|
||||||
freeVars :: Program -> AnnProgram
|
|
||||||
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
|
||||||
| Bind n xs e <- ds
|
|
||||||
]
|
|
||||||
|
|
||||||
freeVarsExp :: Set Id -> Exp -> AnnExp
|
|
||||||
freeVarsExp localVars = \case
|
|
||||||
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
|
||||||
| otherwise -> (mempty, AId n)
|
|
||||||
|
|
||||||
EInt i -> (mempty, AInt i)
|
|
||||||
|
|
||||||
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
|
||||||
where
|
|
||||||
e1' = freeVarsExp localVars e1
|
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
|
||||||
where
|
|
||||||
e1' = freeVarsExp localVars e1
|
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
ESub t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), ASub t e1' e2')
|
|
||||||
where
|
|
||||||
e1' = freeVarsExp localVars e1
|
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
|
|
||||||
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
---- | Annotate free variables
|
||||||
where
|
--freeVars :: Program -> AnnProgram
|
||||||
e' = freeVarsExp (Set.insert par localVars) e
|
--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
||||||
|
-- | Bind n xs e <- ds
|
||||||
|
-- ]
|
||||||
|
|
||||||
-- Sum free variables present in bind and the expression
|
--freeVarsExp :: Set Id -> Exp -> AnnExp
|
||||||
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
--freeVarsExp localVars = \case
|
||||||
where
|
-- EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
||||||
binders_frees = Set.delete name $ freeVarsOf rhs'
|
-- | otherwise -> (mempty, AId n)
|
||||||
e_free = Set.delete name $ freeVarsOf e'
|
|
||||||
|
|
||||||
rhs' = freeVarsExp e_localVars rhs
|
-- ELit _ (LInt i) -> (mempty, AInt i)
|
||||||
new_bind = ABind name parms rhs'
|
|
||||||
|
|
||||||
e' = freeVarsExp e_localVars e
|
-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
||||||
e_localVars = Set.insert name localVars
|
-- where
|
||||||
|
-- e1' = freeVarsExp localVars e1
|
||||||
|
-- e2' = freeVarsExp localVars e2
|
||||||
|
|
||||||
(ECase t e cs) -> do
|
-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
||||||
let e' = freeVarsExp localVars e
|
-- where
|
||||||
let vars = freeVarsOf e'
|
-- e1' = freeVarsExp localVars e1
|
||||||
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do
|
-- e2' = freeVarsExp localVars e2
|
||||||
let e' = freeVarsExp vars e
|
|
||||||
let vars' = freeVarsOf e'
|
-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
||||||
(Set.union vars vars', AnnCase c e' : acc)
|
-- where
|
||||||
) (vars, []) cs
|
-- e' = freeVarsExp (Set.insert par localVars) e
|
||||||
(vars', ACase t e' (reverse cs'))
|
|
||||||
|
-- -- Sum free variables present in bind and the expression
|
||||||
|
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
||||||
|
-- where
|
||||||
|
-- binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||||
|
-- e_free = Set.delete name $ freeVarsOf e'
|
||||||
|
|
||||||
|
-- rhs' = freeVarsExp e_localVars rhs
|
||||||
|
-- new_bind = ABind name parms rhs'
|
||||||
|
|
||||||
|
-- e' = freeVarsExp e_localVars e
|
||||||
|
-- e_localVars = Set.insert name localVars
|
||||||
|
|
||||||
|
|
||||||
freeVarsOf :: AnnExp -> Set Id
|
--freeVarsOf :: AnnExp -> Set Id
|
||||||
freeVarsOf = fst
|
--freeVarsOf = fst
|
||||||
|
|
||||||
-- AST annotated with free variables
|
---- AST annotated with free variables
|
||||||
type AnnProgram = [(Id, [Id], AnnExp)]
|
--type AnnProgram = [(Id, [Id], AnnExp)]
|
||||||
|
|
||||||
type AnnExp = (Set Id, AnnExp')
|
--type AnnExp = (Set Id, AnnExp')
|
||||||
|
|
||||||
data ABind = ABind Id [Id] AnnExp deriving Show
|
--data ABind = ABind Id [Id] AnnExp deriving Show
|
||||||
|
|
||||||
data AnnExp' = AId Id
|
--data AnnExp' = AId Id
|
||||||
| AInt Integer
|
-- | AInt Integer
|
||||||
| ALet ABind AnnExp
|
-- | ALet ABind AnnExp
|
||||||
| AApp Type AnnExp AnnExp
|
-- | AApp Type AnnExp AnnExp
|
||||||
| AAdd Type AnnExp AnnExp
|
-- | AAdd Type AnnExp AnnExp
|
||||||
| ASub Type AnnExp AnnExp
|
-- | AAbs Type Id AnnExp
|
||||||
| AAbs Type Id AnnExp
|
-- deriving Show
|
||||||
| ACase Type AnnExp [AnnCase]
|
---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||||
deriving Show
|
---- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||||
data AnnCase = AnnCase GA.Case AnnExp
|
--abstract :: AnnProgram -> Program
|
||||||
deriving Show
|
--abstract prog = Program $ evalState (mapM go prog) 0
|
||||||
|
-- where
|
||||||
|
-- go :: (Id, [Id], AnnExp) -> State Int Bind
|
||||||
|
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||||
|
-- where
|
||||||
|
-- (rhs', parms1) = flattenLambdasAnn rhs
|
||||||
|
|
||||||
|
|
||||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
---- | Flatten nested lambdas and collect the parameters
|
||||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
---- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||||
abstract :: AnnProgram -> Program
|
--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
||||||
abstract prog = Program $ evalState (mapM go prog) 0
|
--flattenLambdasAnn ae = go (ae, [])
|
||||||
where
|
-- where
|
||||||
go :: (Id, [Id], AnnExp) -> State Int Bind
|
-- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
||||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
-- go ((free, e), acc) =
|
||||||
where
|
-- case e of
|
||||||
(rhs', parms1) = flattenLambdasAnn rhs
|
-- AAbs _ par (free1, e1) ->
|
||||||
|
-- go ((Set.delete par free1, e1), snoc par acc)
|
||||||
|
-- _ -> ((free, e), acc)
|
||||||
|
|
||||||
|
--abstractExp :: AnnExp -> State Int Exp
|
||||||
|
--abstractExp (free, exp) = case exp of
|
||||||
|
-- AId n -> pure $ EId n
|
||||||
|
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
|
||||||
|
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
||||||
|
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
||||||
|
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
||||||
|
-- where
|
||||||
|
-- go (ABind name parms rhs) = do
|
||||||
|
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||||
|
-- pure $ Bind name (parms ++ parms1) rhs'
|
||||||
|
|
||||||
|
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
||||||
|
-- skipLambdas f (free, ae) = case ae of
|
||||||
|
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
||||||
|
-- _ -> f (free, ae)
|
||||||
|
|
||||||
|
-- -- Lift lambda into let and bind free variables
|
||||||
|
-- AAbs t parm e -> do
|
||||||
|
-- i <- nextNumber
|
||||||
|
-- rhs <- abstractExp e
|
||||||
|
|
||||||
|
-- let sc_name = Ident ("sc_" ++ show i)
|
||||||
|
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
||||||
|
|
||||||
|
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
|
||||||
|
-- where
|
||||||
|
-- freeList = Set.toList free
|
||||||
|
-- parms = snoc parm freeList
|
||||||
|
|
||||||
|
|
||||||
-- | Flatten nested lambdas and collect the parameters
|
--nextNumber :: State Int Int
|
||||||
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
--nextNumber = do
|
||||||
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
-- i <- get
|
||||||
flattenLambdasAnn ae = go (ae, [])
|
-- put $ succ i
|
||||||
where
|
-- pure i
|
||||||
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
|
||||||
go ((free, e), acc) =
|
|
||||||
case e of
|
|
||||||
AAbs _ par (free1, e1) ->
|
|
||||||
go ((Set.delete par free1, e1), snoc par acc)
|
|
||||||
_ -> ((free, e), acc)
|
|
||||||
|
|
||||||
abstractExp :: AnnExp -> State Int Exp
|
---- | Collects supercombinators by lifting non-constant let expressions
|
||||||
abstractExp (free, exp) = case exp of
|
--collectScs :: Program -> Program
|
||||||
AId n -> pure $ EId n
|
--collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
||||||
AInt i -> pure $ EInt i
|
-- where
|
||||||
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
-- collectFromRhs (Bind name parms rhs) =
|
||||||
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
-- let (rhs_scs, rhs') = collectScsExp rhs
|
||||||
ASub t e1 e2 -> liftA2 (ESub t) (abstractExp e1) (abstractExp e2)
|
-- in Bind name parms rhs' : rhs_scs
|
||||||
ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
|
||||||
where
|
|
||||||
go (ABind name parms rhs) = do
|
|
||||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
|
||||||
pure $ Bind name (parms ++ parms1) rhs'
|
|
||||||
|
|
||||||
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
|
||||||
skipLambdas f (free, ae) = case ae of
|
|
||||||
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
|
||||||
_ -> f (free, ae)
|
|
||||||
|
|
||||||
|
|
||||||
ACase t e cs -> do
|
--collectScsExp :: Exp -> ([Bind], Exp)
|
||||||
e' <- abstractExp e
|
--collectScsExp = \case
|
||||||
cs' <- mapM (\(AnnCase c e) -> do
|
-- EId n -> ([], EId n)
|
||||||
e' <- abstractExp e
|
-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
|
||||||
pure (t,Case c e')) cs
|
|
||||||
pure $ ECase t e' cs'
|
|
||||||
|
|
||||||
-- Lift lambda into let and bind free variables
|
-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
||||||
AAbs t parm e -> do
|
-- where
|
||||||
i <- nextNumber
|
-- (scs1, e1') = collectScsExp e1
|
||||||
rhs <- abstractExp e
|
-- (scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
let sc_name = Ident ("sc_" ++ show i)
|
-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
||||||
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
-- where
|
||||||
|
-- (scs1, e1') = collectScsExp e1
|
||||||
|
-- (scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
pure $ foldl (EApp TInt) sc $ map EId freeList
|
-- EAbs t par e -> (scs, EAbs t par e')
|
||||||
where
|
-- where
|
||||||
freeList = Set.toList free
|
-- (scs, e') = collectScsExp e
|
||||||
parms = snoc parm freeList
|
|
||||||
|
-- -- Collect supercombinators from bind, the rhss, and the expression.
|
||||||
|
-- --
|
||||||
|
-- -- > f = let sc x y = rhs in e
|
||||||
|
-- --
|
||||||
|
-- ELet (Bind name parms rhs) e -> if null parms
|
||||||
|
-- then ( rhs_scs ++ e_scs, ELet bind e')
|
||||||
|
-- else (bind : rhs_scs ++ e_scs, e')
|
||||||
|
-- where
|
||||||
|
-- bind = Bind name parms rhs'
|
||||||
|
-- (rhs_scs, rhs') = collectScsExp rhs
|
||||||
|
-- (e_scs, e') = collectScsExp e
|
||||||
|
|
||||||
|
|
||||||
nextNumber :: State Int Int
|
---- @\x.\y.\z. e → (e, [x,y,z])@
|
||||||
nextNumber = do
|
--flattenLambdas :: Exp -> (Exp, [Id])
|
||||||
i <- get
|
--flattenLambdas = go . (, [])
|
||||||
put $ succ i
|
-- where
|
||||||
pure i
|
-- go (e, acc) = case e of
|
||||||
|
-- EAbs _ par e1 -> go (e1, snoc par acc)
|
||||||
|
-- _ -> (e, acc)
|
||||||
|
|
||||||
-- | Collects supercombinators by lifting non-constant let expressions
|
|
||||||
collectScs :: Program -> Program
|
|
||||||
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
|
||||||
where
|
|
||||||
collectFromRhs (Bind name parms rhs) =
|
|
||||||
let (rhs_scs, rhs') = collectScsExp rhs
|
|
||||||
in Bind name parms rhs' : rhs_scs
|
|
||||||
|
|
||||||
|
|
||||||
collectScsExp :: Exp -> ([Bind], Exp)
|
|
||||||
collectScsExp = \case
|
|
||||||
EId n -> ([], EId n)
|
|
||||||
EInt i -> ([], EInt i)
|
|
||||||
|
|
||||||
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
|
||||||
where
|
|
||||||
(scs1, e1') = collectScsExp e1
|
|
||||||
(scs2, e2') = collectScsExp e2
|
|
||||||
|
|
||||||
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
|
||||||
where
|
|
||||||
(scs1, e1') = collectScsExp e1
|
|
||||||
(scs2, e2') = collectScsExp e2
|
|
||||||
|
|
||||||
ESub t e1 e2 -> (scs1 ++ scs2, ESub t e1' e2')
|
|
||||||
where
|
|
||||||
(scs1, e1') = collectScsExp e1
|
|
||||||
(scs2, e2') = collectScsExp e2
|
|
||||||
|
|
||||||
EAbs t par e -> (scs, EAbs t par e')
|
|
||||||
where
|
|
||||||
(scs, e') = collectScsExp e
|
|
||||||
|
|
||||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
|
||||||
--
|
|
||||||
-- > f = let sc x y = rhs in e
|
|
||||||
--
|
|
||||||
ELet (Bind name parms rhs) e -> if null parms
|
|
||||||
then ( rhs_scs ++ e_scs, ELet bind e')
|
|
||||||
else (bind : rhs_scs ++ e_scs, e')
|
|
||||||
where
|
|
||||||
bind = Bind name parms rhs'
|
|
||||||
(rhs_scs, rhs') = collectScsExp rhs
|
|
||||||
(e_scs, e') = collectScsExp e
|
|
||||||
ECase t e cs -> do
|
|
||||||
let (scs, e') = collectScsExp e
|
|
||||||
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
|
|
||||||
let (scs', e') = collectScsExp e
|
|
||||||
(scs ++ scs', (t,Case c e') : acc)
|
|
||||||
) (scs,[]) cs
|
|
||||||
(scs', ECase t e' cs')
|
|
||||||
|
|
||||||
|
|
||||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
|
||||||
flattenLambdas :: Exp -> (Exp, [Id])
|
|
||||||
flattenLambdas = go . (, [])
|
|
||||||
where
|
|
||||||
go (e, acc) = case e of
|
|
||||||
EAbs _ par e1 -> go (e1, snoc par acc)
|
|
||||||
_ -> (e, acc)
|
|
||||||
|
|
|
||||||
58
src/Main.hs
58
src/Main.hs
|
|
@ -2,26 +2,26 @@
|
||||||
|
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Codegen.Codegen (generateCode)
|
--import Codegen.Codegen (generateCode)
|
||||||
import GHC.IO.Handle.Text (hPutStrLn)
|
import GHC.IO.Handle.Text (hPutStrLn)
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Grammar.Par (myLexer, pProgram)
|
import Grammar.Par (myLexer, pProgram)
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
|
|
||||||
-- import Interpreter (interpret)
|
-- import Interpreter (interpret)
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
import Data.List.Extra (isSuffixOf)
|
import Data.List.Extra (isSuffixOf)
|
||||||
import LambdaLifter.LambdaLifter (lambdaLift)
|
--import LambdaLifter.LambdaLifter (lambdaLift)
|
||||||
import Renamer.Renamer (rename)
|
import Renamer.Renamer (rename)
|
||||||
import System.Directory (createDirectory, doesPathExist,
|
import System.Directory (createDirectory, doesPathExist,
|
||||||
getDirectoryContents,
|
getDirectoryContents,
|
||||||
removeDirectoryRecursive,
|
removeDirectoryRecursive,
|
||||||
setCurrentDirectory)
|
setCurrentDirectory)
|
||||||
import System.Environment (getArgs)
|
import System.Environment (getArgs)
|
||||||
import System.Exit (exitFailure, exitSuccess)
|
import System.Exit (exitFailure, exitSuccess)
|
||||||
import System.IO (stderr)
|
import System.IO (stderr)
|
||||||
import System.Process.Extra (spawnCommand, waitForProcess)
|
import System.Process.Extra (spawnCommand, waitForProcess)
|
||||||
import TypeChecker.TypeChecker (typecheck)
|
import TypeChecker.TypeChecker (typecheck)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main =
|
main =
|
||||||
|
|
@ -46,19 +46,19 @@ main' debug s = do
|
||||||
typechecked <- fromTypeCheckerErr $ typecheck renamed
|
typechecked <- fromTypeCheckerErr $ typecheck renamed
|
||||||
printToErr $ printTree typechecked
|
printToErr $ printTree typechecked
|
||||||
|
|
||||||
printToErr "\n-- Lambda Lifter --"
|
-- printToErr "\n-- Lambda Lifter --"
|
||||||
let lifted = lambdaLift typechecked
|
-- let lifted = lambdaLift typechecked
|
||||||
printToErr $ printTree lifted
|
-- printToErr $ printTree lifted
|
||||||
|
--
|
||||||
printToErr "\n -- Printing compiler output to stdout --"
|
-- printToErr "\n -- Printing compiler output to stdout --"
|
||||||
compiled <- fromCompilerErr $ generateCode lifted
|
-- compiled <- fromCompilerErr $ generateCode lifted
|
||||||
--putStrLn compiled
|
--putStrLn compiled
|
||||||
|
|
||||||
check <- doesPathExist "output"
|
-- check <- doesPathExist "output"
|
||||||
when check (removeDirectoryRecursive "output")
|
-- when check (removeDirectoryRecursive "output")
|
||||||
createDirectory "output"
|
-- createDirectory "output"
|
||||||
writeFile "output/llvm.ll" compiled
|
-- writeFile "output/llvm.ll" compiled
|
||||||
if debug then debugDotViz else putStrLn compiled
|
-- if debug then debugDotViz else putStrLn compiled
|
||||||
|
|
||||||
|
|
||||||
-- interpred <- fromInterpreterErr $ interpret lifted
|
-- interpred <- fromInterpreterErr $ interpret lifted
|
||||||
|
|
|
||||||
|
|
@ -1,86 +1,87 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
module Renamer.Renamer (module Renamer.Renamer) where
|
module Renamer.Renamer where
|
||||||
|
|
||||||
import Auxiliary (mapAccumM)
|
import Auxiliary (mapAccumM)
|
||||||
import Control.Monad (foldM)
|
|
||||||
import Control.Monad.State (MonadState, State, evalState, gets,
|
import Control.Monad.State (MonadState, State, evalState, gets,
|
||||||
modify)
|
modify)
|
||||||
|
import Data.List (foldl')
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import qualified Data.Map as Map
|
import qualified Data.Map as Map
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.Tuple.Extra (dupe)
|
import Data.Tuple.Extra (dupe)
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
|
|
||||||
|
|
||||||
-- | Rename all variables and local binds
|
-- | Rename all variables and local binds
|
||||||
rename :: Program -> Program
|
rename :: Program -> Program
|
||||||
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
|
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
|
||||||
where
|
where
|
||||||
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
|
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
|
||||||
renameSc :: Names -> Bind -> Rn Bind
|
initNames = Map.fromList $ foldl' saveIfBind [] bs
|
||||||
renameSc old_names (Bind name t _ parms rhs) = do
|
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
|
||||||
|
saveIfBind acc _ = acc
|
||||||
|
renameSc :: Names -> Def -> Rn Def
|
||||||
|
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
|
||||||
(new_names, parms') <- newNames old_names parms
|
(new_names, parms') <- newNames old_names parms
|
||||||
rhs' <- snd <$> renameExp new_names rhs
|
rhs' <- snd <$> renameExp new_names rhs
|
||||||
pure $ Bind name t name parms' rhs'
|
pure . DBind $ Bind name t name parms' rhs'
|
||||||
|
renameSc _ def = pure def
|
||||||
|
|
||||||
-- | Rename monad. State holds the number of renamed names.
|
-- | Rename monad. State holds the number of renamed names.
|
||||||
newtype Rn a = Rn { runRn :: State Int a }
|
newtype Rn a = Rn {runRn :: State Int a}
|
||||||
deriving (Functor, Applicative, Monad, MonadState Int)
|
deriving (Functor, Applicative, Monad, MonadState Int)
|
||||||
|
|
||||||
-- | Maps old to new name
|
-- | Maps old to new name
|
||||||
type Names = Map Ident Ident
|
type Names = Map Ident Ident
|
||||||
|
|
||||||
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
|
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
|
||||||
renameLocalBind old_names (Bind name t _ parms rhs) = do
|
renameLocalBind old_names (Bind name t _ parms rhs) = do
|
||||||
(new_names, name') <- newName old_names name
|
(new_names, name') <- newName old_names name
|
||||||
(new_names', parms') <- newNames new_names parms
|
(new_names', parms') <- newNames new_names parms
|
||||||
(new_names'', rhs') <- renameExp new_names' rhs
|
(new_names'', rhs') <- renameExp new_names' rhs
|
||||||
pure (new_names'', Bind name' t name' parms' rhs')
|
pure (new_names'', Bind name' t name' parms' rhs')
|
||||||
|
|
||||||
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
||||||
renameExp old_names = \case
|
renameExp old_names = \case
|
||||||
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
|
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
|
||||||
|
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
|
||||||
EInt i1 -> pure (old_names, EInt i1)
|
|
||||||
|
|
||||||
EApp e1 e2 -> do
|
EApp e1 e2 -> do
|
||||||
(env1, e1') <- renameExp old_names e1
|
(env1, e1') <- renameExp old_names e1
|
||||||
(env2, e2') <- renameExp old_names e2
|
(env2, e2') <- renameExp old_names e2
|
||||||
pure (Map.union env1 env2, EApp e1' e2')
|
pure (Map.union env1 env2, EApp e1' e2')
|
||||||
|
|
||||||
EAdd e1 e2 -> do
|
EAdd e1 e2 -> do
|
||||||
(env1, e1') <- renameExp old_names e1
|
(env1, e1') <- renameExp old_names e1
|
||||||
(env2, e2') <- renameExp old_names e2
|
(env2, e2') <- renameExp old_names e2
|
||||||
pure (Map.union env1 env2, EAdd e1' e2')
|
pure (Map.union env1 env2, EAdd e1' e2')
|
||||||
|
|
||||||
ESub e1 e2 -> do
|
ESub e1 e2 -> do
|
||||||
(env1, e1') <- renameExp old_names e1
|
(env1, e1') <- renameExp old_names e1
|
||||||
(env2, e2') <- renameExp old_names e2
|
(env2, e2') <- renameExp old_names e2
|
||||||
pure (Map.union env1 env2, ESub e1' e2')
|
pure (Map.union env1 env2, ESub e1' e2')
|
||||||
|
ELet i e1 e2 -> do
|
||||||
ELet b e -> do
|
(new_names, e1') <- renameExp old_names e1
|
||||||
(new_names, b) <- renameLocalBind old_names b
|
(new_names', e2') <- renameExp new_names e2
|
||||||
(new_names', e') <- renameExp new_names e
|
pure (new_names', ELet i e1' e2')
|
||||||
pure (new_names', ELet b e')
|
EAbs par e -> do
|
||||||
|
|
||||||
EAbs par t e -> do
|
|
||||||
(new_names, par') <- newName old_names par
|
(new_names, par') <- newName old_names par
|
||||||
(new_names', e') <- renameExp new_names e
|
(new_names', e') <- renameExp new_names e
|
||||||
pure (new_names', EAbs par' t e')
|
pure (new_names', EAbs par' e')
|
||||||
|
|
||||||
EAnn e t -> do
|
EAnn e t -> do
|
||||||
(new_names, e') <- renameExp old_names e
|
(new_names, e') <- renameExp old_names e
|
||||||
pure (new_names, EAnn e' t)
|
pure (new_names, EAnn e' t)
|
||||||
|
ECase e injs -> do
|
||||||
|
(_, e') <- renameExp old_names e
|
||||||
|
(new_names, injs') <- renameInjs old_names injs
|
||||||
|
pure (new_names, ECase e' injs')
|
||||||
|
|
||||||
ECase e cs t -> do
|
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
|
||||||
(new_names, e') <- renameExp old_names e
|
renameInjs ns xs = do
|
||||||
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do
|
(new_names, xs') <- unzip <$> mapM (renameInj ns) xs
|
||||||
(nm,exp') <- renameExp names exp
|
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||||
pure (nm,CaseMatch c exp' : stack)
|
|
||||||
) (new_names, []) cs
|
renameInj :: Names -> Inj -> Rn (Names, Inj)
|
||||||
pure (new_names', ECase e' cs' t)
|
renameInj ns (Inj init e) = do
|
||||||
|
(new_names, e') <- renameExp ns e
|
||||||
|
return (new_names, Inj init e')
|
||||||
|
|
||||||
-- | Create a new name and add it to name environment.
|
-- | Create a new name and add it to name environment.
|
||||||
newName :: Names -> Ident -> Rn (Names, Ident)
|
newName :: Names -> Ident -> Rn (Names, Ident)
|
||||||
|
|
@ -95,4 +96,3 @@ newNames = mapAccumM newName
|
||||||
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
||||||
makeName :: Ident -> Rn Ident
|
makeName :: Ident -> Rn Ident
|
||||||
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
|
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,215 +1,517 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedRecordDot #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module TypeChecker.TypeChecker (typecheck, partitionType) where
|
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||||
|
module TypeChecker.TypeChecker where
|
||||||
|
|
||||||
import Auxiliary (maybeToRightM, snoc)
|
import Control.Monad.Except
|
||||||
import Control.Monad.Except (throwError, unless)
|
import Control.Monad.Reader
|
||||||
|
import Control.Monad.State
|
||||||
|
import Data.Foldable (traverse_)
|
||||||
|
import Data.Functor.Identity (runIdentity)
|
||||||
|
import Data.List (foldl')
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import qualified Data.Map as Map
|
import qualified Data.Map as M
|
||||||
|
import Data.Set (Set)
|
||||||
|
import qualified Data.Set as S
|
||||||
|
import Debug.Trace (trace)
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.Print (printTree)
|
||||||
import Grammar.Print (Print (prt), concatD, doc,
|
|
||||||
printTree, render)
|
|
||||||
import Prelude hiding (exp, id)
|
|
||||||
import qualified TypeChecker.TypeCheckerIr as T
|
import qualified TypeChecker.TypeCheckerIr as T
|
||||||
|
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
|
||||||
|
Poly (..), Subst)
|
||||||
|
|
||||||
-- NOTE: this type checker is poorly tested
|
initCtx = Ctx mempty
|
||||||
|
|
||||||
-- TODO
|
initEnv = Env 0 mempty mempty
|
||||||
-- Coercion
|
|
||||||
-- Type inference
|
|
||||||
|
|
||||||
data Cxt = Cxt
|
runPretty :: Exp -> Either Error String
|
||||||
{ env :: Map Ident Type -- ^ Local scope signature
|
runPretty = fmap (printTree . fst) . run . inferExp
|
||||||
, sig :: Map Ident Type -- ^ Top-level signatures
|
|
||||||
}
|
|
||||||
|
|
||||||
initCxt :: [Bind] -> Cxt
|
run :: Infer a -> Either Error a
|
||||||
initCxt sc = Cxt { env = mempty
|
run = runC initEnv initCtx
|
||||||
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
|
|
||||||
}
|
|
||||||
|
|
||||||
typecheck :: Program -> Err T.Program
|
runC :: Env -> Ctx -> Infer a -> Either Error a
|
||||||
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
|
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
|
||||||
|
|
||||||
-- | Check if infered rhs type matches type signature.
|
typecheck :: Program -> Either Error T.Program
|
||||||
checkBind :: Cxt -> Bind -> Err T.Bind
|
typecheck = run . checkPrg
|
||||||
checkBind cxt b =
|
|
||||||
case expandLambdas b of
|
|
||||||
Bind name t _ parms rhs -> do
|
|
||||||
(rhs', t_rhs) <- infer cxt rhs
|
|
||||||
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
|
|
||||||
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
|
|
||||||
where
|
|
||||||
ts_parms = fst $ partitionType (length parms) t
|
|
||||||
|
|
||||||
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
|
{- | Start by freshening the type variable of data types to avoid clash with
|
||||||
expandLambdas :: Bind -> Bind
|
other user defined polymorphic types
|
||||||
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
|
This might be wrong for type constructors that work over several variables
|
||||||
|
-}
|
||||||
|
freshenData :: Data -> Infer Data
|
||||||
|
freshenData (Data (Constr name ts) constrs) = do
|
||||||
|
fr <- fresh
|
||||||
|
let fr' = case fr of
|
||||||
|
TPol a -> a
|
||||||
|
-- Meh, this part assumes fresh generates a polymorphic type
|
||||||
|
_ ->
|
||||||
|
error
|
||||||
|
"Bug: implementation of \
|
||||||
|
\ fresh and freshenData are not compatible"
|
||||||
|
let new_ts = map (freshenType fr') ts
|
||||||
|
let new_constrs = map (freshenConstr fr') constrs
|
||||||
|
return $ Data (Constr name new_ts) new_constrs
|
||||||
|
|
||||||
|
{- | Freshen all polymorphic variables, regardless of name
|
||||||
|
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
|
||||||
|
-}
|
||||||
|
freshenType :: Ident -> Type -> Type
|
||||||
|
freshenType iden = \case
|
||||||
|
(TPol _) -> TPol iden
|
||||||
|
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
|
||||||
|
(TConstr (Constr a ts)) ->
|
||||||
|
TConstr (Constr a (map (freshenType iden) ts))
|
||||||
|
rest -> rest
|
||||||
|
|
||||||
|
freshenConstr :: Ident -> Constructor -> Constructor
|
||||||
|
freshenConstr iden (Constructor name t) =
|
||||||
|
Constructor name (freshenType iden t)
|
||||||
|
|
||||||
|
checkData :: Data -> Infer ()
|
||||||
|
checkData d = do
|
||||||
|
d' <- freshenData d
|
||||||
|
case d' of
|
||||||
|
(Data typ@(Constr name ts) constrs) -> do
|
||||||
|
unless
|
||||||
|
(all isPoly ts)
|
||||||
|
(throwError $ unwords ["Data type incorrectly declared"])
|
||||||
|
traverse_
|
||||||
|
( \(Constructor name' t') ->
|
||||||
|
if TConstr typ == retType t'
|
||||||
|
then insertConstr name' t'
|
||||||
|
else
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "return type of constructor:"
|
||||||
|
, printTree name
|
||||||
|
, "with type:"
|
||||||
|
, printTree (retType t')
|
||||||
|
, "does not match data: "
|
||||||
|
, printTree typ
|
||||||
|
]
|
||||||
|
)
|
||||||
|
constrs
|
||||||
|
|
||||||
|
retType :: Type -> Type
|
||||||
|
retType (TArr _ t2) = retType t2
|
||||||
|
retType a = a
|
||||||
|
|
||||||
|
checkPrg :: Program -> Infer T.Program
|
||||||
|
checkPrg (Program bs) = do
|
||||||
|
preRun bs
|
||||||
|
T.Program <$> checkDef bs
|
||||||
where
|
where
|
||||||
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
|
preRun :: [Def] -> Infer ()
|
||||||
ts_parms = fst $ partitionType (length parms) t
|
preRun [] = return ()
|
||||||
|
preRun (x : xs) = case x of
|
||||||
|
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
|
||||||
|
DData d@(Data _ _) -> checkData d >> preRun xs
|
||||||
|
|
||||||
-- | Infer type of expression.
|
checkDef :: [Def] -> Infer [T.Def]
|
||||||
infer :: Cxt -> Exp -> Err (T.Exp, Type)
|
checkDef [] = return []
|
||||||
infer cxt = \case
|
checkDef (x : xs) = case x of
|
||||||
EId x ->
|
(DBind b) -> do
|
||||||
case lookupEnv x cxt of
|
b' <- checkBind b
|
||||||
Nothing ->
|
fmap (T.DBind b' :) (checkDef xs)
|
||||||
case lookupSig x cxt of
|
(DData d) -> fmap (T.DData d :) (checkDef xs)
|
||||||
Nothing -> throwError ("Unbound variable:" ++ printTree x)
|
|
||||||
Just t -> pure (T.EId (x, t), t)
|
|
||||||
Just t -> pure (T.EId (x, t), t)
|
|
||||||
|
|
||||||
EInt i -> pure (T.EInt i, T.TInt)
|
checkBind :: Bind -> Infer T.Bind
|
||||||
|
checkBind (Bind n t _ args e) = do
|
||||||
|
(t', e') <- inferExp $ makeLambda e (reverse args)
|
||||||
|
s <- unify t t'
|
||||||
|
let t'' = apply s t
|
||||||
|
unless
|
||||||
|
(t `typeEq` t'')
|
||||||
|
( throwError $
|
||||||
|
unwords
|
||||||
|
[ "Top level signature"
|
||||||
|
, printTree t
|
||||||
|
, "does not match body with inferred type:"
|
||||||
|
, printTree t''
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return $ T.Bind (n, t) e'
|
||||||
|
where
|
||||||
|
makeLambda :: Exp -> [Ident] -> Exp
|
||||||
|
makeLambda = foldl (flip EAbs)
|
||||||
|
|
||||||
EApp e e1 -> do
|
{- | Check if two types are considered equal
|
||||||
(e', t) <- infer cxt e
|
For the purpose of the algorithm two polymorphic types are always considered
|
||||||
case t of
|
equal
|
||||||
TFun t1 t2 -> do
|
-}
|
||||||
e1' <- check cxt e1 t1
|
|
||||||
pure (T.EApp t2 e' e1', t2)
|
|
||||||
_ -> do
|
|
||||||
throwError ("Not a function: " ++ show e)
|
|
||||||
|
|
||||||
EAdd e e1 -> do
|
|
||||||
e' <- check cxt e T.TInt
|
|
||||||
e1' <- check cxt e1 T.TInt
|
|
||||||
pure (T.EAdd T.TInt e' e1', T.TInt)
|
|
||||||
|
|
||||||
ESub e e1 -> do
|
|
||||||
e' <- check cxt e T.TInt
|
|
||||||
e1' <- check cxt e1 T.TInt
|
|
||||||
pure (T.ESub T.TInt e' e1', T.TInt)
|
|
||||||
|
|
||||||
EAbs x t e -> do
|
|
||||||
(e', t1) <- infer (insertEnv x t cxt) e
|
|
||||||
let t_abs = TFun t t1
|
|
||||||
pure (T.EAbs t_abs (x, t) e', t_abs)
|
|
||||||
|
|
||||||
ELet b e -> do
|
|
||||||
let cxt' = insertBind b cxt
|
|
||||||
b' <- checkBind cxt' b
|
|
||||||
(e', t) <- infer cxt' e
|
|
||||||
pure (T.ELet b' e', t)
|
|
||||||
|
|
||||||
EAnn e t -> do
|
|
||||||
(e', t1) <- infer cxt e
|
|
||||||
unless (typeEq t t1) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
pure (e', t1)
|
|
||||||
|
|
||||||
ECase e cs t -> do
|
|
||||||
(e',t1) <- infer cxt e
|
|
||||||
unless (typeEq t t1) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
case traverse (\(CaseMatch c e) -> do
|
|
||||||
-- //TODO check c as well
|
|
||||||
e' <- check cxt e t
|
|
||||||
unless (typeEq t t1) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
pure (t1, T.Case c e')
|
|
||||||
) cs of
|
|
||||||
Right cs -> pure (T.ECase t1 e' cs,t1)
|
|
||||||
Left e -> throwError e
|
|
||||||
|
|
||||||
-- | Check infered type matches the supplied type.
|
|
||||||
check :: Cxt -> Exp -> Type -> Err T.Exp
|
|
||||||
check cxt exp typ = case exp of
|
|
||||||
EId x -> do
|
|
||||||
t <- case lookupEnv x cxt of
|
|
||||||
Nothing -> maybeToRightM
|
|
||||||
("Unbound variable:" ++ printTree x)
|
|
||||||
(lookupSig x cxt)
|
|
||||||
Just t -> pure t
|
|
||||||
unless (typeEq t typ) . throwError $ typeErr x typ t
|
|
||||||
pure $ T.EId (x, t)
|
|
||||||
|
|
||||||
EInt i -> do
|
|
||||||
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
|
|
||||||
pure $ T.EInt i
|
|
||||||
|
|
||||||
EApp e e1 -> do
|
|
||||||
(e', t) <- infer cxt e
|
|
||||||
case t of
|
|
||||||
TFun t1 t2 -> do
|
|
||||||
e1' <- check cxt e1 t1
|
|
||||||
pure $ T.EApp t2 e' e1'
|
|
||||||
_ -> throwError ("Not a function 2: " ++ printTree e)
|
|
||||||
|
|
||||||
EAdd e e1 -> do
|
|
||||||
e' <- check cxt e T.TInt
|
|
||||||
e1' <- check cxt e1 T.TInt
|
|
||||||
pure $ T.EAdd T.TInt e' e1'
|
|
||||||
|
|
||||||
ESub e e1 -> do
|
|
||||||
e' <- check cxt e T.TInt
|
|
||||||
e1' <- check cxt e1 T.TInt
|
|
||||||
pure $ T.ESub T.TInt e' e1'
|
|
||||||
|
|
||||||
EAbs x t e -> do
|
|
||||||
(e', t_e) <- infer (insertEnv x t cxt) e
|
|
||||||
let t1 = TFun t t_e
|
|
||||||
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
|
|
||||||
pure $ T.EAbs t1 (x, t) e'
|
|
||||||
|
|
||||||
ECase e cs t -> do
|
|
||||||
(e',t1) <- infer cxt e
|
|
||||||
unless (typeEq t t1) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
case traverse (\(CaseMatch c e) -> do
|
|
||||||
-- //TODO check c as well
|
|
||||||
e' <- check cxt e t
|
|
||||||
unless (typeEq t t1) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
pure (t1, T.Case c e')
|
|
||||||
) cs of
|
|
||||||
Right cs -> pure $ T.ECase t1 e' cs
|
|
||||||
Left e -> throwError e
|
|
||||||
|
|
||||||
ELet b e -> do
|
|
||||||
let cxt' = insertBind b cxt
|
|
||||||
b' <- checkBind cxt' b
|
|
||||||
e' <- check cxt' e typ
|
|
||||||
pure $ T.ELet b' e'
|
|
||||||
|
|
||||||
EAnn e t -> do
|
|
||||||
unless (typeEq t typ) $
|
|
||||||
throwError "Inferred type and type annotation doesn't match"
|
|
||||||
check cxt e t
|
|
||||||
|
|
||||||
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
|
|
||||||
typeEq :: Type -> Type -> Bool
|
typeEq :: Type -> Type -> Bool
|
||||||
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
|
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||||
typeEq t t1 = t == t1
|
typeEq (TMono a) (TMono b) = a == b
|
||||||
|
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
|
||||||
|
length a == length b
|
||||||
|
&& name == name'
|
||||||
|
&& and (zipWith typeEq a b)
|
||||||
|
typeEq (TPol _) (TPol _) = True
|
||||||
|
typeEq _ _ = False
|
||||||
|
|
||||||
-- | Partion type into types of parameters and return type.
|
isMoreSpecificOrEq :: Type -> Type -> Bool
|
||||||
partitionType :: Int -- Number of parameters to apply
|
isMoreSpecificOrEq _ (TPol _) = True
|
||||||
-> Type
|
isMoreSpecificOrEq (TArr a b) (TArr c d) =
|
||||||
-> ([Type], Type)
|
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
|
||||||
partitionType = go []
|
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
|
||||||
where
|
n1 == n2
|
||||||
go acc 0 t = (acc, t)
|
&& length ts1 == length ts2
|
||||||
go acc i t = case t of
|
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
|
||||||
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
|
isMoreSpecificOrEq a b = a == b
|
||||||
_ -> error "Number of parameters and type doesn't match"
|
|
||||||
|
|
||||||
insertBind :: Bind -> Cxt -> Cxt
|
isPoly :: Type -> Bool
|
||||||
insertBind (Bind n t _ _ _) = insertEnv n t
|
isPoly (TPol _) = True
|
||||||
|
isPoly _ = False
|
||||||
|
|
||||||
lookupEnv :: Ident -> Cxt -> Maybe Type
|
inferExp :: Exp -> Infer (Type, T.Exp)
|
||||||
lookupEnv x = Map.lookup x . env
|
inferExp e = do
|
||||||
|
(s, t, e') <- algoW e
|
||||||
|
let subbed = apply s t
|
||||||
|
return (subbed, replace subbed e')
|
||||||
|
|
||||||
insertEnv :: Ident -> Type -> Cxt -> Cxt
|
replace :: Type -> T.Exp -> T.Exp
|
||||||
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
|
replace t = \case
|
||||||
|
T.ELit _ e -> T.ELit t e
|
||||||
|
T.EId (n, _) -> T.EId (n, t)
|
||||||
|
T.EAbs _ name e -> T.EAbs t name e
|
||||||
|
T.EApp _ e1 e2 -> T.EApp t e1 e2
|
||||||
|
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||||
|
T.ESub _ e1 e2 -> T.ESub t e1 e2
|
||||||
|
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
|
||||||
|
T.ECase _ expr injs -> T.ECase t expr injs
|
||||||
|
|
||||||
lookupSig :: Ident -> Cxt -> Maybe Type
|
algoW :: Exp -> Infer (Subst, Type, T.Exp)
|
||||||
lookupSig x = Map.lookup x . sig
|
algoW = \case
|
||||||
|
-- \| TODO: More testing need to be done. Unsure of the correctness of this
|
||||||
|
EAnn e t -> do
|
||||||
|
(s1, t', e') <- algoW e
|
||||||
|
unless
|
||||||
|
(t `isMoreSpecificOrEq` t')
|
||||||
|
( throwError $
|
||||||
|
unwords
|
||||||
|
[ "Annotated type:"
|
||||||
|
, printTree t
|
||||||
|
, "does not match inferred type:"
|
||||||
|
, printTree t'
|
||||||
|
]
|
||||||
|
)
|
||||||
|
applySt s1 $ do
|
||||||
|
s2 <- unify t t'
|
||||||
|
return (s2 `compose` s1, t, e')
|
||||||
|
|
||||||
typeErr :: Print a => a -> Type -> Type -> String
|
-- \| ------------------
|
||||||
typeErr p expected actual = render $ concatD
|
-- \| Γ ⊢ i : Int, ∅
|
||||||
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
|
|
||||||
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
|
ELit (LInt n) ->
|
||||||
, doc $ showString "Actual: " , prt 0 actual
|
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||||
]
|
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
|
||||||
|
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||||
|
-- \| ----------------------
|
||||||
|
-- \| Γ ⊢ x : τ, ∅
|
||||||
|
|
||||||
|
EId i -> do
|
||||||
|
var <- asks vars
|
||||||
|
case M.lookup i var of
|
||||||
|
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
|
||||||
|
Nothing -> do
|
||||||
|
sig <- gets sigs
|
||||||
|
case M.lookup i sig of
|
||||||
|
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||||
|
Nothing -> do
|
||||||
|
constr <- gets constructors
|
||||||
|
case M.lookup i constr of
|
||||||
|
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||||
|
Nothing ->
|
||||||
|
throwError $
|
||||||
|
"Unbound variable: " ++ show i
|
||||||
|
|
||||||
|
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
|
||||||
|
-- \| ---------------------------------
|
||||||
|
-- \| Γ ⊢ w λx. e : Sτ → τ', S
|
||||||
|
|
||||||
|
EAbs name e -> do
|
||||||
|
fr <- fresh
|
||||||
|
withBinding name (Forall [] fr) $ do
|
||||||
|
(s1, t', e') <- algoW e
|
||||||
|
let varType = apply s1 fr
|
||||||
|
let newArr = TArr varType t'
|
||||||
|
return (s1, newArr, T.EAbs newArr (name, varType) e')
|
||||||
|
|
||||||
|
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||||
|
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||||
|
-- \| ------------------------------------------
|
||||||
|
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
|
||||||
|
-- This might be wrong
|
||||||
|
|
||||||
|
EAdd e0 e1 -> do
|
||||||
|
(s1, t0, e0') <- algoW e0
|
||||||
|
applySt s1 $ do
|
||||||
|
(s2, t1, e1') <- algoW e1
|
||||||
|
-- applySt s2 $ do
|
||||||
|
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||||
|
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||||
|
return
|
||||||
|
( s4 `compose` s3 `compose` s2 `compose` s1
|
||||||
|
, TMono "Int"
|
||||||
|
, T.EAdd (TMono "Int") e0' e1'
|
||||||
|
)
|
||||||
|
|
||||||
|
ESub e0 e1 -> do
|
||||||
|
(s1, t0, e0') <- algoW e0
|
||||||
|
applySt s1 $ do
|
||||||
|
(s2, t1, e1') <- algoW e1
|
||||||
|
-- applySt s2 $ do
|
||||||
|
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||||
|
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||||
|
return
|
||||||
|
( s4 `compose` s3 `compose` s2 `compose` s1
|
||||||
|
, TMono "Int"
|
||||||
|
, T.ESub (TMono "Int") e0' e1'
|
||||||
|
)
|
||||||
|
|
||||||
|
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
|
||||||
|
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
|
||||||
|
-- \| --------------------------------------
|
||||||
|
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
|
||||||
|
|
||||||
|
EApp e0 e1 -> do
|
||||||
|
fr <- fresh
|
||||||
|
(s0, t0, e0') <- algoW e0
|
||||||
|
applySt s0 $ do
|
||||||
|
(s1, t1, e1') <- algoW e1
|
||||||
|
-- applySt s1 $ do
|
||||||
|
s2 <- unify (apply s1 t0) (TArr t1 fr)
|
||||||
|
let t = apply s2 fr
|
||||||
|
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
|
||||||
|
|
||||||
|
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
|
||||||
|
-- \| ----------------------------------------------
|
||||||
|
-- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
|
||||||
|
|
||||||
|
-- The bar over S₀ and Γ means "generalize"
|
||||||
|
|
||||||
|
ELet name e0 e1 -> do
|
||||||
|
(s1, t1, e0') <- algoW e0
|
||||||
|
env <- asks vars
|
||||||
|
let t' = generalize (apply s1 env) t1
|
||||||
|
withBinding name t' $ do
|
||||||
|
(s2, t2, e1') <- algoW e1
|
||||||
|
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
|
||||||
|
ECase caseExpr injs -> do
|
||||||
|
(_, t0, e0') <- algoW caseExpr
|
||||||
|
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
|
||||||
|
case ts of
|
||||||
|
[] -> throwError "Case expression missing any matches"
|
||||||
|
ts -> do
|
||||||
|
unified <- zipWithM unify ts (tail ts)
|
||||||
|
let unified' = foldl' compose mempty unified
|
||||||
|
let typ = apply unified' (head ts)
|
||||||
|
return (unified', typ, T.ECase typ e0' injs')
|
||||||
|
|
||||||
|
-- | Unify two types producing a new substitution
|
||||||
|
unify :: Type -> Type -> Infer Subst
|
||||||
|
unify t0 t1 = do
|
||||||
|
trace ("t0: " ++ show t0) return ()
|
||||||
|
trace ("t1: " ++ show t1) return ()
|
||||||
|
case (t0, t1) of
|
||||||
|
(TArr a b, TArr c d) -> do
|
||||||
|
s1 <- unify a c
|
||||||
|
s2 <- unify (apply s1 b) (apply s1 d)
|
||||||
|
return $ s1 `compose` s2
|
||||||
|
(TPol a, b) -> occurs a b
|
||||||
|
(a, TPol b) -> occurs b a
|
||||||
|
(TMono a, TMono b) ->
|
||||||
|
if a == b then return M.empty else throwError "Types do not unify"
|
||||||
|
-- \| TODO: Figure out a cleaner way to express the same thing
|
||||||
|
(TConstr (Constr name t), TConstr (Constr name' t')) ->
|
||||||
|
if name == name' && length t == length t'
|
||||||
|
then do
|
||||||
|
xs <- zipWithM unify t t'
|
||||||
|
return $ foldr compose nullSubst xs
|
||||||
|
else
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "Type constructor:"
|
||||||
|
, printTree name
|
||||||
|
, "(" ++ printTree t ++ ")"
|
||||||
|
, "does not match with:"
|
||||||
|
, printTree name'
|
||||||
|
, "(" ++ printTree t' ++ ")"
|
||||||
|
]
|
||||||
|
(a, b) ->
|
||||||
|
throwError . unwords $
|
||||||
|
[ "Type:"
|
||||||
|
, printTree a
|
||||||
|
, "can't be unified with:"
|
||||||
|
, printTree b
|
||||||
|
]
|
||||||
|
|
||||||
|
{- | Check if a type is contained in another type.
|
||||||
|
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||||
|
such that these are equal
|
||||||
|
-}
|
||||||
|
occurs :: Ident -> Type -> Infer Subst
|
||||||
|
occurs _ (TPol _) = return nullSubst
|
||||||
|
occurs i t =
|
||||||
|
if S.member i (free t)
|
||||||
|
then
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "Occurs check failed, can't unify"
|
||||||
|
, printTree (TPol i)
|
||||||
|
, "with"
|
||||||
|
, printTree t
|
||||||
|
]
|
||||||
|
else return $ M.singleton i t
|
||||||
|
|
||||||
|
-- | Generalize a type over all free variables in the substitution set
|
||||||
|
generalize :: Map Ident Poly -> Type -> Poly
|
||||||
|
generalize env t = Forall (S.toList $ free t S.\\ free env) t
|
||||||
|
|
||||||
|
{- | Instantiate a polymorphic type. The free type variables are substituted
|
||||||
|
with fresh ones.
|
||||||
|
-}
|
||||||
|
inst :: Poly -> Infer Type
|
||||||
|
inst (Forall xs t) = do
|
||||||
|
xs' <- mapM (const fresh) xs
|
||||||
|
let s = M.fromList $ zip xs xs'
|
||||||
|
return $ apply s t
|
||||||
|
|
||||||
|
-- | Compose two substitution sets
|
||||||
|
compose :: Subst -> Subst -> Subst
|
||||||
|
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||||
|
|
||||||
|
-- | A class representing free variables functions
|
||||||
|
class FreeVars t where
|
||||||
|
-- | Get all free variables from t
|
||||||
|
free :: t -> Set Ident
|
||||||
|
|
||||||
|
-- | Apply a substitution to t
|
||||||
|
apply :: Subst -> t -> t
|
||||||
|
|
||||||
|
instance FreeVars Type where
|
||||||
|
free :: Type -> Set Ident
|
||||||
|
free (TPol a) = S.singleton a
|
||||||
|
free (TMono _) = mempty
|
||||||
|
free (TArr a b) = free a `S.union` free b
|
||||||
|
-- \| Not guaranteed to be correct
|
||||||
|
free (TConstr (Constr _ a)) =
|
||||||
|
foldl' (\acc x -> free x `S.union` acc) S.empty a
|
||||||
|
|
||||||
|
apply :: Subst -> Type -> Type
|
||||||
|
apply sub t = do
|
||||||
|
case t of
|
||||||
|
TMono a -> TMono a
|
||||||
|
TPol a -> case M.lookup a sub of
|
||||||
|
Nothing -> TPol a
|
||||||
|
Just t -> t
|
||||||
|
TArr a b -> TArr (apply sub a) (apply sub b)
|
||||||
|
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
|
||||||
|
|
||||||
|
instance FreeVars Poly where
|
||||||
|
free :: Poly -> Set Ident
|
||||||
|
free (Forall xs t) = free t S.\\ S.fromList xs
|
||||||
|
apply :: Subst -> Poly -> Poly
|
||||||
|
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
|
||||||
|
|
||||||
|
instance FreeVars (Map Ident Poly) where
|
||||||
|
free :: Map Ident Poly -> Set Ident
|
||||||
|
free m = foldl' S.union S.empty (map free $ M.elems m)
|
||||||
|
apply :: Subst -> Map Ident Poly -> Map Ident Poly
|
||||||
|
apply s = M.map (apply s)
|
||||||
|
|
||||||
|
-- | Apply substitutions to the environment.
|
||||||
|
applySt :: Subst -> Infer a -> Infer a
|
||||||
|
applySt s = local (\st -> st{vars = apply s (vars st)})
|
||||||
|
|
||||||
|
-- | Represents the empty substition set
|
||||||
|
nullSubst :: Subst
|
||||||
|
nullSubst = M.empty
|
||||||
|
|
||||||
|
-- | Generate a new fresh variable and increment the state counter
|
||||||
|
fresh :: Infer Type
|
||||||
|
fresh = do
|
||||||
|
n <- gets count
|
||||||
|
modify (\st -> st{count = n + 1})
|
||||||
|
return . TPol . Ident $ show n
|
||||||
|
|
||||||
|
-- | Run the monadic action with an additional binding
|
||||||
|
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
|
||||||
|
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
|
||||||
|
|
||||||
|
-- | Insert a function signature into the environment
|
||||||
|
insertSig :: Ident -> Type -> Infer ()
|
||||||
|
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||||
|
|
||||||
|
-- | Insert a constructor with its data type
|
||||||
|
insertConstr :: Ident -> Type -> Infer ()
|
||||||
|
insertConstr i t =
|
||||||
|
modify (\st -> st{constructors = M.insert i t (constructors st)})
|
||||||
|
|
||||||
|
-------- PATTERN MATCHING ---------
|
||||||
|
|
||||||
|
-- "case expr of", the type of 'expr' is caseType
|
||||||
|
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
|
||||||
|
checkInj caseType (Inj it expr) = do
|
||||||
|
(args, t') <- initType caseType it
|
||||||
|
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
|
||||||
|
return (T.Inj (it, t') e', t)
|
||||||
|
|
||||||
|
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
|
||||||
|
initType expected = \case
|
||||||
|
InitLit lit ->
|
||||||
|
let returnType = litType lit
|
||||||
|
in if expected == returnType
|
||||||
|
then return (mempty, expected)
|
||||||
|
else
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "Inferred type"
|
||||||
|
, printTree returnType
|
||||||
|
, "does not match expected type:"
|
||||||
|
, printTree expected
|
||||||
|
]
|
||||||
|
InitConstr c args -> do
|
||||||
|
st <- gets constructors
|
||||||
|
case M.lookup c st of
|
||||||
|
Nothing ->
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "Constructor:"
|
||||||
|
, printTree c
|
||||||
|
, "does not exist"
|
||||||
|
]
|
||||||
|
Just t -> do
|
||||||
|
let flat = flattenType t
|
||||||
|
let returnType = last flat
|
||||||
|
case ( length (init flat) == length args
|
||||||
|
, returnType `isMoreSpecificOrEq` expected
|
||||||
|
) of
|
||||||
|
(True, True) ->
|
||||||
|
return
|
||||||
|
( M.fromList $ zip args (map (Forall []) flat)
|
||||||
|
, expected
|
||||||
|
)
|
||||||
|
(False, _) ->
|
||||||
|
throwError $
|
||||||
|
"Can't partially match on the constructor: "
|
||||||
|
++ printTree c
|
||||||
|
(_, False) ->
|
||||||
|
throwError $
|
||||||
|
unwords
|
||||||
|
[ "Inferred type"
|
||||||
|
, printTree returnType
|
||||||
|
, "does not match expected type:"
|
||||||
|
, printTree expected
|
||||||
|
]
|
||||||
|
InitCatch -> return (mempty, expected)
|
||||||
|
|
||||||
|
flattenType :: Type -> [Type]
|
||||||
|
flattenType (TArr a b) = flattenType a ++ flattenType b
|
||||||
|
flattenType a = [a]
|
||||||
|
|
||||||
|
litType :: Literal -> Type
|
||||||
|
litType (LInt _) = TMono "Int"
|
||||||
|
|
|
||||||
|
|
@ -1,139 +1,184 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
module TypeChecker.TypeCheckerIr
|
module TypeChecker.TypeCheckerIr where
|
||||||
( module Grammar.Abs
|
|
||||||
, module TypeChecker.TypeCheckerIr
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Grammar.Abs (Ident (..), Type (..))
|
import Control.Monad.Except
|
||||||
import qualified Grammar.Abs as GA
|
import Control.Monad.Reader
|
||||||
|
import Control.Monad.State
|
||||||
|
import Data.Functor.Identity (Identity)
|
||||||
|
import Data.Map (Map)
|
||||||
|
import Grammar.Abs (Data (..), Ident (..), Init (..),
|
||||||
|
Literal (..), Type (..))
|
||||||
import Grammar.Print
|
import Grammar.Print
|
||||||
import Prelude
|
import Prelude
|
||||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||||
|
|
||||||
newtype Program = Program [Bind]
|
-- | A data type representing type variables
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
data Poly = Forall [Ident] Type
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
newtype Ctx = Ctx {vars :: Map Ident Poly}
|
||||||
|
|
||||||
|
data Env = Env
|
||||||
|
{ count :: Int
|
||||||
|
, sigs :: Map Ident Type
|
||||||
|
, constructors :: Map Ident Type
|
||||||
|
}
|
||||||
|
|
||||||
|
type Error = String
|
||||||
|
type Subst = Map Ident Type
|
||||||
|
|
||||||
|
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||||
|
|
||||||
|
newtype Program = Program [Def]
|
||||||
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
|
|
||||||
data Exp
|
data Exp
|
||||||
= EId Id
|
= EId Id
|
||||||
| EInt Integer
|
| ELit Type Literal
|
||||||
| ELet Bind Exp
|
| ELet Bind Exp
|
||||||
| EApp Type Exp Exp
|
| EApp Type Exp Exp
|
||||||
| EAdd Type Exp Exp
|
| EAdd Type Exp Exp
|
||||||
| ESub Type Exp Exp
|
| ESub Type Exp Exp
|
||||||
| EAbs Type Id Exp
|
| EAbs Type Id Exp
|
||||||
| ECase Type Exp [(Type, Case)]
|
| ECase Type Exp [Inj]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||||
|
|
||||||
data Case = Case GA.Case Exp
|
data Inj = Inj (Init, Type) Exp
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||||
|
|
||||||
|
data Def = DBind Bind | DData Data
|
||||||
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||||
|
|
||||||
type Id = (Ident, Type)
|
type Id = (Ident, Type)
|
||||||
|
|
||||||
data Bind = Bind Id [Id] Exp | DataStructure Ident [(Ident, [Type])]
|
data Bind = Bind Id Exp
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||||
|
|
||||||
|
instance Print [Def] where
|
||||||
|
prt _ [] = concatD []
|
||||||
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
|
||||||
|
|
||||||
|
instance Print Def where
|
||||||
|
prt i (DBind bind) = prt i bind
|
||||||
|
prt i (DData d) = prt i d
|
||||||
|
|
||||||
instance Print Program where
|
instance Print Program where
|
||||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||||
|
|
||||||
instance Print Bind where
|
instance Print Bind where
|
||||||
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
|
prt i (Bind (t, name) rhs) =
|
||||||
[ prtId 0 name
|
prPrec i 0 $
|
||||||
, doc $ showString ";"
|
concatD
|
||||||
, prt 0 n
|
[ prt 0 name
|
||||||
, prtIdPs 0 parms
|
, doc $ showString ":"
|
||||||
, doc $ showString "="
|
, prt 0 t
|
||||||
, prt 0 rhs
|
, doc $ showString "\n"
|
||||||
]
|
, prt 0 name
|
||||||
prt i (DataStructure (Ident n) xs) = prPrec i 0 $ concatD
|
, doc $ showString "="
|
||||||
[ prt 0 n
|
, prt 0 rhs
|
||||||
, doc $ showString "{"
|
]
|
||||||
, doc . showString . show $ xs
|
|
||||||
, doc $ showString "}"
|
|
||||||
]
|
|
||||||
|
|
||||||
instance Print [Bind] where
|
instance Print [Bind] where
|
||||||
prt _ [] = concatD []
|
prt _ [] = concatD []
|
||||||
prt _ [x] = concatD [prt 0 x]
|
prt _ [x] = concatD [prt 0 x]
|
||||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
|
||||||
|
|
||||||
prtIdPs :: Int -> [Id] -> Doc
|
prtIdPs :: Int -> [Id] -> Doc
|
||||||
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
||||||
|
|
||||||
prtId :: Int -> Id -> Doc
|
prtId :: Int -> Id -> Doc
|
||||||
prtId i (name, t) = prPrec i 0 $ concatD
|
prtId i (name, t) =
|
||||||
[ prt 0 name
|
prPrec i 0 $
|
||||||
, doc $ showString ":"
|
concatD
|
||||||
, prt 0 t
|
[ prt 0 name
|
||||||
]
|
, doc $ showString ":"
|
||||||
|
, prt 0 t
|
||||||
|
]
|
||||||
|
|
||||||
prtIdP :: Int -> Id -> Doc
|
prtIdP :: Int -> Id -> Doc
|
||||||
prtIdP i (name, t) = prPrec i 0 $ concatD
|
prtIdP i (name, t) =
|
||||||
[ doc $ showString "("
|
prPrec i 0 $
|
||||||
, prt 0 name
|
concatD
|
||||||
, doc $ showString ":"
|
[ doc $ showString "("
|
||||||
, prt 0 t
|
, prt 0 name
|
||||||
, doc $ showString ")"
|
, doc $ showString ":"
|
||||||
]
|
, prt 0 t
|
||||||
|
, doc $ showString ")"
|
||||||
|
]
|
||||||
|
|
||||||
instance Print Exp where
|
instance Print Exp where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
|
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
|
||||||
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
|
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
|
||||||
ELet bs e -> prPrec i 3 $ concatD
|
ELet bs e ->
|
||||||
[ doc $ showString "let"
|
prPrec i 3 $
|
||||||
, prt 0 bs
|
concatD
|
||||||
, doc $ showString "in"
|
[ doc $ showString "let"
|
||||||
, prt 0 e
|
, prt 0 bs
|
||||||
]
|
, doc $ showString "in"
|
||||||
EApp t e1 e2 -> prPrec i 2 $ concatD
|
, prt 0 e
|
||||||
[ doc $ showString "@"
|
, doc $ showString "\n"
|
||||||
, prt 0 t
|
]
|
||||||
, prt 2 e1
|
EApp _ e1 e2 ->
|
||||||
, prt 3 e2
|
prPrec i 2 $
|
||||||
]
|
concatD
|
||||||
EAdd t e1 e2 -> prPrec i 1 $ concatD
|
[ prt 2 e1
|
||||||
[ doc $ showString "@"
|
, prt 3 e2
|
||||||
, prt 0 t
|
]
|
||||||
, prt 1 e1
|
EAdd t e1 e2 ->
|
||||||
, doc $ showString "+"
|
prPrec i 1 $
|
||||||
, prt 2 e2
|
concatD
|
||||||
]
|
[ doc $ showString "@"
|
||||||
ESub t e1 e2 -> prPrec i 1 $ concatD
|
, prt 0 t
|
||||||
[ doc $ showString "@"
|
, prt 1 e1
|
||||||
, prt 0 t
|
, doc $ showString "+"
|
||||||
, prt 1 e1
|
, prt 2 e2
|
||||||
, doc $ showString "-"
|
, doc $ showString "\n"
|
||||||
, prt 2 e2
|
]
|
||||||
]
|
ESub t e1 e2 ->
|
||||||
EAbs t n e -> prPrec i 0 $ concatD
|
prPrec i 1 $
|
||||||
[ doc $ showString "@"
|
concatD
|
||||||
, prt 0 t
|
[ doc $ showString "@"
|
||||||
, doc $ showString "\\"
|
, prt 0 t
|
||||||
, prtIdP 0 n
|
, prt 1 e1
|
||||||
, doc $ showString "."
|
, doc $ showString "-"
|
||||||
, prt 0 e
|
, prt 2 e2
|
||||||
]
|
, doc $ showString "\n"
|
||||||
ECase t e cs -> prPrec i 0 $ concatD
|
]
|
||||||
[ doc $ showString "@"
|
EAbs t n e ->
|
||||||
, prt 0 t
|
prPrec i 0 $
|
||||||
, doc $ showString "("
|
concatD
|
||||||
, prt 0 e
|
[ doc $ showString "@"
|
||||||
, doc $ showString ")"
|
, prt 0 t
|
||||||
, prPrec i 0 $ concatD . printCases $ cs
|
, doc $ showString "\\"
|
||||||
]
|
, prtId 0 n
|
||||||
|
, doc $ showString "."
|
||||||
|
, prt 0 e
|
||||||
|
, doc $ showString "\n"
|
||||||
|
]
|
||||||
|
ECase t exp injs ->
|
||||||
|
prPrec
|
||||||
|
i
|
||||||
|
0
|
||||||
|
( concatD
|
||||||
|
[ doc (showString "case")
|
||||||
|
, prt 0 exp
|
||||||
|
, doc (showString "of")
|
||||||
|
, doc (showString "{")
|
||||||
|
, prt 0 injs
|
||||||
|
, doc (showString "}")
|
||||||
|
, doc (showString ":")
|
||||||
|
, prt 0 t
|
||||||
|
, doc $ showString "\n"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
where
|
instance Print Inj where
|
||||||
printCases :: [(Type, Case)] -> [Doc]
|
prt i = \case
|
||||||
printCases [] = []
|
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
|
||||||
printCases ((t, Case c e):xs) = concatD
|
|
||||||
[ doc $ showString "@"
|
instance Print [Inj] where
|
||||||
, prt 0 t
|
prt _ [] = concatD []
|
||||||
, doc $ showString "("
|
prt _ [x] = concatD [prt 0 x]
|
||||||
, doc . showString . show $ c
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||||
, doc $ showString ")"
|
|
||||||
, doc $ showString "=>"
|
|
||||||
, prt 0 e
|
|
||||||
, doc $ showString "\n"
|
|
||||||
] : printCases xs
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue