Added support for pattern matching on ints. Might need a lookover.

This commit is contained in:
Samuel Hammersberg 2023-02-20 14:39:43 +01:00
parent 18e0a92fe0
commit 6749650223
7 changed files with 157 additions and 64 deletions

1
.gitignore vendored
View file

@ -5,3 +5,4 @@ dist-newstyle
src/Grammar src/Grammar
language language
llvm.ll llvm.ll
output

View file

@ -11,11 +11,25 @@
-- main = apply (\x : Int . x + 5) 5 -- main = apply (\x : Int . x + 5) 5
-- answer: 10 -- answer: 10
apply : (Int -> Int -> Int) -> Int -> Int -> Int; -- apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x y = f x y; -- apply f x y = f x y;
krimp: Int -> Int -> Int; -- krimp: Int -> Int -> Int;
krimp x y = x + y; -- krimp x y = x + y;
main : Int; -- main : Int;
main = apply (krimp) 2 3; -- main = apply (krimp) 2 3;
-- answer: 5 -- answer: 5
fibbonaci : Int -> Int;
fibbonaci x = case x of {
0 => 0,
1 => 1,
-- abusing overflows to represent negatives like a boss
_ => (fibbonaci (x + 9223372036854775807 + 9223372036854775807))
+ (fibbonaci (x + 9223372036854775807 + 9223372036854775807 + 1))
} : Int;
faccer : Int -> Int;
main : Int;
main = fibbonaci 10;
-- answer: 55

View file

@ -5,19 +5,18 @@ module Compiler (compile) where
import Auxiliary (snoc) import Auxiliary (snoc)
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
--import Data.List.Extra (trim)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second) import Data.Tuple.Extra (dupe, first, second)
import qualified Grammar.Abs as GA
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import LlvmIr (LLVMComp (..), LLVMIr (..), LLVMType (..), import LlvmIr (LLVMComp (..), LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..), LLVMValue (..), Visibility (..),
llvmIrToString) llvmIrToString)
--import System.Process.Extra (readCreateProcess, shell)
import TypeChecker (partitionType) import TypeChecker (partitionType)
import TypeCheckerIr (Bind (..), CLit (CInt, CatchAll), import TypeCheckerIr (Bind (..), Case (..), Exp (..), Id,
Case (..), Exp (..), Id, Ident (..), Ident (..), Program (..),
Program (..), Type (TFun, TInt)) Type (TFun, TInt))
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
@ -73,38 +72,38 @@ initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, variableCount = 0 , variableCount = 0
, labelCount = 0 , labelCount = 0
} }
{-
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
--run :: Err String -> IO () test :: Integer -> Program
--run s = do test v = Program [
-- let s' = case s of Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
-- Right s -> s ECase TInt (EId ("x", TInt)) [
-- Left _ -> error "yo" (TInt,Case (CInt 0) (EInt 0)),
-- writeFile "llvm.ll" s' Case (CInt 1) (EInt 1),
-- putStrLn . trim =<< readCreateProcess (shell "lli") s' Case CatchAll (EAdd TInt
-- (EApp TInt (EId (Ident "fibonacci", TInt)) (
--test :: Integer -> Program EAdd TInt (EId (Ident "x", TInt))
--test v = Program [ (EInt (fromIntegral ((maxBound :: Int) * 2)))
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] ( ))
-- ECased (EId ("x", TInt)) [ (EApp TInt (EId (Ident "fibonacci", TInt)) (
-- Case (CInt 0) (EInt 0), EAdd TInt (EId (Ident "x", TInt))
-- Case (CInt 1) (EInt 1), (EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
-- Case CatchAll (EAdd TInt ))
-- (EApp TInt (EId (Ident "fibonacci", TInt)) ( )
-- EAdd TInt (EId (Ident "x", TInt)) ]
-- (EInt (fromIntegral ((maxBound :: Int) * 2))) ),
-- )) Bind (Ident "main", TInt) [] (
-- (EApp TInt (EId (Ident "fibonacci", TInt)) ( EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
-- EAdd TInt (EId (Ident "x", TInt)) )
-- (EInt (fromIntegral ((maxBound :: Int) * 2 + 1))) ]
-- )) -}
-- )
-- ]
-- ),
-- Bind (Ident "main",TInt) [] (
-- EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
-- )
-- ]
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
Simply pipe it to LLI Simply pipe it to LLI
@ -120,7 +119,7 @@ compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args' emit $ Define I64 {-(type2LlvmType t_return)-} name args'
functionBody <- exprToValue exp functionBody <- exprToValue exp
if name == "main" if name == "main"
then mapM_ emit $ mainContent functionBody then mapM_ emit $ mainContent functionBody
@ -161,42 +160,44 @@ compileExp (EId (name, _)) = emitIdent name
compileExp (EApp t e1 e2) = emitApp t e1 e2 compileExp (EApp t e1 e2) = emitApp t e1 e2
compileExp (EAbs t ti e) = emitAbs t ti e compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ELet binds e) = emitLet binds e compileExp (ELet binds e) = emitLet binds e
compileExp (ECased e c) = emitECased e c compileExp (ECase t e cs) = emitECased t e cs
-- go (ESub e1 e2) = emitSub e1 e2 -- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: Exp -> [Case] -> CompilerState () emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState ()
emitECased e cs = do emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
vs <- exprToValue e vs <- exprToValue e
lbl <- getNewLabel lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64) emit $ SetVariable (Ident $ show stackPtr) (Alloca ty)
mapM_ (emitCases label stackPtr vs) cs mapM_ (emitCases ty label stackPtr vs) cs
emit $ Label label emit $ Label label
res <- getNewVar res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr)) emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr))
where where
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState () emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases label stackPtr vs (Case (CInt i) exp) = do emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do
ns <- getNewVar ns <- getNewVar
lbl_fail <- getNewLabel lbl_fail <- getNewLabel
lbl_succ <- getNewLabel lbl_succ <- getNewLabel
let failed = Ident $ "failed_" <> show lbl_fail let failed = Ident $ "failed_" <> show lbl_fail
let success = Ident $ "success_" <> show lbl_succ let success = Ident $ "success_" <> show lbl_succ
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i)) emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed emit $ BrCond (VIdent (Ident $ show ns) ty) success failed
emit $ Label success emit $ Label success
val <- exprToValue exp val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr) emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label emit $ Br label
emit $ Label failed emit $ Label failed
emitCases label stackPtr _ (Case CatchAll exp) = do emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do
val <- exprToValue exp val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr) emit $ Store ty val Ptr (Ident . show $ stackPtr)
emit $ Br label emit $ Br label
@ -343,7 +344,7 @@ getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e getType (ELet _ e) = getType e
getType (ECased e cs) = undefined getType (ECase t _ _) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64

View file

@ -9,6 +9,8 @@ import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState) import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as Set import qualified Data.Set as Set
import Debug.Trace (trace)
import qualified Grammar.Abs as GA
import Prelude hiding (exp) import Prelude hiding (exp)
import Renamer import Renamer
import TypeCheckerIr import TypeCheckerIr
@ -22,7 +24,6 @@ import TypeCheckerIr
lambdaLift :: Program -> Program lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables -- | Annotate free variables
freeVars :: Program -> AnnProgram freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
@ -62,6 +63,16 @@ freeVarsExp localVars = \case
e' = freeVarsExp e_localVars e e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars e_localVars = Set.insert name localVars
(ECase t e cs) -> do
let e' = freeVarsExp localVars e
let vars = freeVarsOf e'
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do
let e' = freeVarsExp vars e
let vars' = freeVarsOf e'
(Set.union vars vars', AnnCase c e' : acc)
) (vars, []) cs
(vars', ACase t e' (reverse cs'))
freeVarsOf :: AnnExp -> Set Id freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst freeVarsOf = fst
@ -79,7 +90,12 @@ data AnnExp' = AId Id
| AApp Type AnnExp AnnExp | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp | AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp | AAbs Type Id AnnExp
| ACase Type AnnExp [AnnCase]
deriving Show deriving Show
data AnnCase = AnnCase GA.Case AnnExp
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound. -- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program abstract :: AnnProgram -> Program
@ -120,6 +136,14 @@ abstractExp (free, exp) = case exp of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae) _ -> f (free, ae)
ACase t e cs -> do
e' <- abstractExp e
cs' <- mapM (\(AnnCase c e) -> do
e' <- abstractExp e
pure (t,Case c e')) cs
pure $ ECase t e' cs'
-- Lift lambda into let and bind free variables -- Lift lambda into let and bind free variables
AAbs t parm e -> do AAbs t parm e -> do
i <- nextNumber i <- nextNumber
@ -179,6 +203,13 @@ collectScsExp = \case
bind = Bind name parms rhs' bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e (e_scs, e') = collectScsExp e
ECase t e cs -> do
let (scs, e') = collectScsExp e
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
let (scs', e') = collectScsExp e
(scs ++ scs', (t,Case c e') : acc)
) (scs,[]) cs
(scs', ECase t e' cs')
-- @\x.\y.\z. e → (e, [x,y,z])@ -- @\x.\y.\z. e → (e, [x,y,z])@

View file

@ -3,6 +3,7 @@
module Renamer (module Renamer) where module Renamer (module Renamer) where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad (foldM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (MonadState, State, evalState, gets,
modify) modify)
import Data.Map (Map) import Data.Map (Map)
@ -68,6 +69,14 @@ renameExp old_names = \case
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase e cs t -> do
(new_names, e') <- renameExp old_names e
(new_names', cs') <- foldM (\(names, stack) (CaseMatch c exp) -> do
(nm,exp') <- renameExp names exp
pure (nm,CaseMatch c exp' : stack)
) (new_names, []) cs
pure (new_names', ECase e' cs' t)
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do newName env old_name = do

View file

@ -95,10 +95,23 @@ infer cxt = \case
throwError "Inferred type and type annotation doesn't match" throwError "Inferred type and type annotation doesn't match"
pure (e', t1) pure (e', t1)
ECase e cs t -> do
(e',t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
case traverse (\(CaseMatch c e) -> do
-- //TODO check c as well
e' <- check cxt e t
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (t1, T.Case c e')
) cs of
Right cs -> pure (T.ECase t1 e' cs,t1)
Left e -> throwError e
-- | Check infered type matches the supplied type. -- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of check cxt exp typ = case exp of
EId x -> do EId x -> do
t <- case lookupEnv x cxt of t <- case lookupEnv x cxt of
Nothing -> maybeToRightM Nothing -> maybeToRightM
@ -142,6 +155,11 @@ check cxt exp typ = case exp of
throwError "Inferred type and type annotation doesn't match" throwError "Inferred type and type annotation doesn't match"
check cxt e t check cxt e t
ECase e _ t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism. -- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1

View file

@ -6,6 +6,7 @@ module TypeCheckerIr
) where ) where
import Grammar.Abs (Ident (..), Type (..)) import Grammar.Abs (Ident (..), Type (..))
import qualified Grammar.Abs as GA
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
@ -20,14 +21,12 @@ data Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| ECased Exp [Case] | ECase Type Exp [(Type, Case)]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Case = Case CLit Exp data Case = Case GA.Case Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data CLit = CInt Integer | CatchAll
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp data Bind = Bind Id [Id] Exp
@ -102,5 +101,25 @@ instance Print Exp where
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
] ]
ECase t e cs -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, prt 0 e
, doc $ showString ")"
, prPrec i 0 $ concatD . printCases $ cs
]
where
printCases :: [(Type, Case)] -> [Doc]
printCases [] = []
printCases ((t, Case c e):xs) = concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "("
, doc . showString . show $ c
, doc $ showString ")"
, doc $ showString "=>"
, prt 0 e
, doc $ showString "\n"
] : printCases xs