Merge llvm_testing, and use TypeCheckerIr instead of Abs

This commit is contained in:
Martin Fredin 2023-02-16 02:17:07 +01:00
commit 7ef7090aa5
21 changed files with 499 additions and 101 deletions

3
.gitignore vendored
View file

@ -3,4 +3,5 @@ dist-newstyle
*.x
*.bak
src/Grammar
/language
language
llvm.ll

View file

@ -1,5 +1,3 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
@ -24,5 +22,4 @@ TFun. Type ::= Type1 "->" Type ;
coercions Type 1 ;
comment "--";
comment "{-" "-}";
comment "{-" "-}";

2
cabal.project.local Normal file
View file

@ -0,0 +1,2 @@
ignore-project: False
tests: True

View file

@ -1,4 +1,4 @@
cabal-version: 3.0
cabal-version: 3.4
name: language
@ -12,18 +12,19 @@ build-type: Simple
extra-doc-files: CHANGELOG.md
extra-source-files:
Grammar.cf
common warnings
ghc-options: -Wall
ghc-options: -W
executable language
import: warnings
main-is: Main.hs
other-modules:
Grammar.Abs
Grammar.Lex
@ -33,11 +34,12 @@ executable language
Grammar.ErrM
LambdaLifter
Auxiliary
-- Interpreter
Renamer
TypeChecker
TypeCheckerIr
-- Interpreter
Compiler
LlvmIr
hs-source-dirs: src
build-depends:
@ -47,5 +49,4 @@ executable language
, either
, array
, extra
default-language: GHC2021

View file

@ -1,3 +1,6 @@
f : Int -> Int;
f = \x:Int. x+1;
tripplemagic : Int -> Int -> Int -> Int;
tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
main : Int;
main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3

View file

@ -1,3 +0,0 @@
main : Int -> Int -> Int;
main x y = (x : Int) + y;

View file

@ -1,7 +0,0 @@
add : Int -> Int -> Int;
add x = \y:Int. x+y;
main : Int;
main = (\z:Int. z+z) ((add 4) 6);

View file

@ -1,3 +0,0 @@
main : Int;
main = (\x:Int. x+x+3) ((\x:Int. x) 2);

View file

@ -1,7 +0,0 @@
f : Int -> Int;
f x = let
g : Int -> Int;
g = (\y:Int. y+1);
in
g (g x);

View file

@ -1,14 +0,0 @@
id : Int -> Int;
id x = x;
add : Int -> Int -> Int;
add x y = x + y;
double : Int -> Int;
double n = n + n;
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x = \y:Int. f x y;
main : Int;
main = apply add ((\x:Int. x + 1) 1) (double (id 3));

View file

@ -1,4 +0,0 @@
f : Int -> Int -> Int;
f = \x:Int.\y:Int. x+y;

View file

@ -1,8 +0,0 @@
add : Int -> Int -> Int;
add x y = x + y;
apply : (Int -> Int) -> Int -> Int;
apply f x = f x;
main : Int;
main = apply (add 4) 6;

View file

@ -1,7 +0,0 @@
f : Int -> Int;
f x = let
double : Int -> Int;
double = \y:Int. y+y
in
double (x + 4);

View file

@ -1,5 +0,0 @@
main : Int;
main = (\f:Int -> Int.\x:Int.\y:Int. f x + f y) (\x:Int. x+x) ((\x:Int. x+1) ((\x:Int. x+3) 2)) 4

View file

@ -1,5 +1,5 @@
let
pkgs = import (fetchTarball https://github.com/NixOS/nixpkgs/archive/8c619a1f3cedd16ea172146e30645e703d21bfc1.tar.gz) { }; # pin the channel to ensure reproducibility!
pkgs = import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/747927516efcb5e31ba03b7ff32f61f6d47e7d87.zip") { }; # pin the channel to ensure reproducibility!
in
pkgs.haskellPackages.developPackage {
root = ./.;

View file

@ -0,0 +1,259 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (second)
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), llvmIrToString)
import TypeChecker (partitionType)
import TypeCheckerIr
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, variableCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify (\t -> t{instructions = instructions t ++ [l]})
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1})
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions xs =
Map.fromList $
map
( \(Bind id args _) ->
( id
, FunctionInfo
{ numArgs = length args
, arguments = args
}
)
)
xs
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
compile :: Program -> Err String
compile (Program prg) = do
let s =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions prg
, variableCount = 0
}
ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins
where
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart =
[ Comment (show $ printTree (Program prg))
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
goDef :: [Bind] -> CompilerState ()
goDef [] = return ()
goDef (Bind id@(name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit (mainContent functionBody)
else emit $ Ret I64 functionBody
emit DefineEnd
modify (\s -> s{variableCount = 0})
goDef xs
where
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState ()
go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd e1 e2
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp e1 e2
--- aux functions ---
emitAbs :: Ident -> Exp -> CompilerState ()
emitAbs id e = do
emit $
Comment $
concat
[ "EAbs ("
, show id
, ", "
, show I64
, ", "
, show e
, ") is not implemented!"
]
emitLet :: [Bind] -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
concat
[ "ELet ("
, show xs
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Exp -> Exp -> CompilerState ()
emitApp e1 e2 = appEmitter e1 e2 []
where
appEmitter :: Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp t e1' e2' -> appEmitter e1' e2' newStack
EId (name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
emit $ SetVariable (Ident $ show vs) (Call I64 name (map (I64,) args))
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Exp -> Exp -> CompilerState ()
emitAdd e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add I64 v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue (EInt i) = return $ VInteger i
exprToValue (EId id@(name, t)) = do
funcs <- gets functions
case Map.lookup id funcs of
Just _ -> do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) name [])
return $ VIdent (Ident $ show vc, t)
Nothing -> return $ VIdent id
exprToValue e = do
go e
v <- getVarCount
return $ VIdent (Ident $ show v, TInt)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
TInt -> I64
t -> error $ "missing type case: " ++ show t

View file

@ -35,7 +35,6 @@ initCxt scs =
expandLambdas :: Bind -> Bind
expandLambdas (Bind name parms rhs) = Bind name [] $ foldr EAbs rhs parms
findMain :: [Bind] -> Err Exp
findMain [] = throwError "No main!"
findMain (sc:scs) = case sc of

View file

@ -101,15 +101,9 @@ abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs@(_, e)) =
case e of
AAbs _ parm e1 -> do
e2' <- abstractExp e2
pure $ Bind name (snoc parm parms ++ parms2) e2'
where
(e2, parms2) = flattenLambdasAnn e1
_ -> Bind name parms <$> abstractExp rhs
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
@ -147,12 +141,11 @@ abstractExp (free, exp) = case exp of
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t_bind) parms rhs] $ EId (sc_name, t)
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
t_bind = typeApplyPars (length parm) t
parms = snoc parm freeList
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
@ -163,15 +156,6 @@ nextNumber = do
put $ succ i
pure i
typeApplyPars :: Int -> Type -> Type
typeApplyPars 0 t = t
typeApplyPars i t =
case t of
TFun _ t1 -> typeApplyPars (i-1) t1
_ -> error "Number of applied pars and type not matching"
-- | Collects supercombinators by lifting appropriate let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs

192
src/LlvmIr.hs Normal file
View file

@ -0,0 +1,192 @@
{-# LANGUAGE LambdaCase #-}
module LlvmIr (LLVMType (..), LLVMIr (..), llvmIrToString, LLVMValue (..), LLVMComp (..)) where
import Data.List (intercalate)
import TypeCheckerIr
-- | A datatype which represents some basic LLVM types
data LLVMType
= I1
| I8
| I32
| I64
| Ptr
| Ref LLVMType
| Array Integer LLVMType
| CustomType Ident
instance Show LLVMType where
show :: LLVMType -> String
show = \case
I1 -> "i1"
I8 -> "i8"
I32 -> "i32"
I64 -> "i64"
Ptr -> "ptr"
Ref ty -> show ty <> "*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
CustomType (Ident ty) -> ty
data LLVMComp
= LLEq
| LLNe
| LLUgt
| LLUge
| LLUlt
| LLUle
| LLSgt
| LLSge
| LLSlt
| LLSle
instance Show LLVMComp where
show :: LLVMComp -> String
show = \case
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
LLUle -> "ule"
LLSgt -> "sgt"
LLSge -> "sge"
LLSlt -> "slt"
LLSle -> "sle"
{- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant
-}
data LLVMValue = VInteger Integer | VIdent Id | VConstant String
instance Show LLVMValue where
show :: LLVMValue -> String
show v = case v of
VInteger i -> show i
VIdent (n, _) -> "%" <> fromIdent n
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM
data LLVMIr
= Define LLVMType Ident Params
| DefineEnd
| Declare LLVMType Ident Params
| SetVariable Ident LLVMIr
| Variable Ident
| Add LLVMType LLVMValue LLVMValue
| Sub LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
| Br Ident
| BrCond LLVMValue Ident Ident
| Label Ident
| Call LLVMType Ident Args
| Alloca LLVMType
| Store LLVMType Ident LLVMType Ident
| Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue
| Comment String
| UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place
deriving (Show)
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
where
go :: Int -> [LLVMIr] -> String
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation
-}
{- FOURMOLU_DISABLE -}
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of
(Define t (Ident i) params) ->
concat
[ "define ", show t, " @", i
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
, ") {\n"
]
DefineEnd -> "}\n"
(Declare _t (Ident _i) _params) -> undefined
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
(Add t v1 v2) ->
concat
[ "add ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Sub t v1 v2) ->
concat
[ "sub ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Div t v1 v2) ->
concat
[ "sdiv ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Mul t v1 v2) ->
concat
[ "mul ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Srem t v1 v2) ->
concat
[ "srem ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Call t (Ident i) arg) ->
concat
[ "call ", show t, " @", i, "("
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
, ")\n"
]
(Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) ->
concat
[ "store ", show t1, " %", id1
, ", ", show t2 , " %", id2, "\n"
]
(Bitcast t1 (Ident i) t2) ->
concat
[ "bitcast ", show t1, " %"
, i, " to ", show t2, "\n"
]
(Icmp comp t v1 v2) ->
concat
[ "icmp ", show comp, " ", show t
, " ", show v1, ", ", show v2, "\n"
]
(Ret t v) ->
concat
[ "ret ", show t, " "
, show v, "\n"
]
(UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) ->
concat
[ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n"
]
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -}
fromIdent :: Ident -> String
fromIdent (Ident s) = s

View file

@ -1,14 +1,17 @@
{-# LANGUAGE LambdaCase #-}
module Main where
import Compiler (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
--import Interpreter (interpret)
import LambdaLifter (abstract, freeVars, lambdaLift)
import LambdaLifter (lambdaLift)
import Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker (typecheck)
main :: IO ()
@ -20,21 +23,26 @@ main' :: String -> IO ()
main' s = do
file <- readFile s
putStrLn "\n-- Parser"
printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file
putStrLn $ printTree parsed
printToErr $ printTree parsed
putStrLn "\n-- Renamer"
putStrLn "\n-- Renamer --"
let renamed = rename parsed
putStrLn $ printTree renamed
putStrLn "\n-- TypeChecker"
putStrLn "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed
putStrLn $ printTree typechecked
putStrLn "\n-- Lambda Lifter"
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
putStrLn $ printTree lifted
printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ compile lifted
putStrLn compiled
writeFile "llvm.ll" compiled
-- interpred <- fromInterpreterErr $ interpret lifted
-- putStrLn "\n-- interpret"
@ -42,6 +50,16 @@ main' s = do
exitSuccess
printToErr :: String -> IO ()
printToErr = hPutStrLn stderr
fromCompilerErr :: Err a -> IO a
fromCompilerErr = either
(\err -> do
putStrLn "\nCOMPILER ERROR"
putStrLn err
exitFailure)
pure
fromSyntaxErr :: Err a -> IO a
fromSyntaxErr = either

View file

@ -1,7 +1,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module TypeChecker (typecheck) where
module TypeChecker (typecheck, partitionType) where
import Auxiliary (maybeToRightM, snoc)
import Control.Monad.Except (throwError, unless)