diff --git a/.gitignore b/.gitignore index d0ab5db..1daddd6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ dist-newstyle *.x *.bak src/Grammar + +language +llvm.ll /language .vscode/ diff --git a/Grammar.cf b/Grammar.cf index b258446..0b4785f 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,15 +1,25 @@ +Program. Program ::= [Bind]; +EId. Exp3 ::= Ident; +EInt. Exp3 ::= Integer; +EAnn. Exp3 ::= "(" Exp ":" Type ")"; +ELet. Exp3 ::= "let" Bind "in" Exp; +EApp. Exp2 ::= Exp2 Exp3; +EAdd. Exp1 ::= Exp1 "+" Exp2; +EAbs. Exp ::= "\\" Ident ":" Type "." Exp; -Program. Program ::= "main" "=" Exp ; +Bind. Bind ::= Ident ":" Type ";" + Ident [Ident] "=" Exp; -EId. Exp3 ::= Ident ; -EInt. Exp3 ::= Integer ; -EApp. Exp2 ::= Exp2 Exp3 ; -EAdd. Exp1 ::= Exp1 "+" Exp2 ; -EAbs. Exp ::= "\\" Ident "->" Exp ; +separator Bind ";"; +separator Ident ""; -coercions Exp 3 ; +coercions Exp 3; -comment "--" ; -comment "{-" "-}" ; +TInt. Type1 ::= "Int" ; +TPol. Type1 ::= Ident ; +TFun. Type ::= Type1 "->" Type ; +coercions Type 1 ; +comment "--"; +comment "{-" "-}"; diff --git a/Makefile b/Makefile index 16b753d..e63a1e6 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY : sdist clean language : src/Grammar/Test - cabal install --installdir=. + cabal install --installdir=. --overwrite-policy=always src/Grammar/Test.hs src/Grammar/Lex.x src/Grammar/Par.y : Grammar.cf bnfc -o src -d $< @@ -22,4 +22,16 @@ clean : rm -r src/Grammar rm language +test : + ./language ./sample-programs/basic-1 + ./language ./sample-programs/basic-2 + ./language ./sample-programs/basic-3 + ./language ./sample-programs/basic-4 + ./language ./sample-programs/basic-5 + ./language ./sample-programs/basic-5 + ./language ./sample-programs/basic-6 + ./language ./sample-programs/basic-7 + ./language ./sample-programs/basic-8 + ./language ./sample-programs/basic-9 + # EOF diff --git a/cabal.project.local b/cabal.project.local new file mode 100644 index 0000000..0432756 --- /dev/null +++ b/cabal.project.local @@ -0,0 +1,2 @@ +ignore-project: False +tests: True diff --git a/language.cabal b/language.cabal index 5734655..8b958a5 100644 --- a/language.cabal +++ b/language.cabal @@ -1,4 +1,4 @@ -cabal-version: 3.0 +cabal-version: 3.4 name: language @@ -12,26 +12,34 @@ build-type: Simple extra-doc-files: CHANGELOG.md + extra-source-files: Grammar.cf common warnings - ghc-options: -Wall + ghc-options: -W executable language import: warnings main-is: Main.hs - + other-modules: Grammar.Abs Grammar.Lex Grammar.Par Grammar.Print Grammar.Skel - Interpreter - + Grammar.ErrM + LambdaLifter + Auxiliary + Renamer + TypeChecker + TypeCheckerIr +-- Interpreter + Compiler + LlvmIr hs-source-dirs: src build-depends: @@ -40,5 +48,5 @@ executable language , containers , either , array - + , extra default-language: GHC2021 diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 new file mode 100644 index 0000000..f0cdcc4 --- /dev/null +++ b/sample-programs/basic-1 @@ -0,0 +1,21 @@ + +-- tripplemagic : Int -> Int -> Int -> Int; +-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; +-- main : Int; +-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 +-- answer: 22 + +-- apply : (Int -> Int) -> Int -> Int; +-- apply f x = f x; +-- main : Int; +-- main = apply (\x : Int . x + 5) 5 +-- answer: 10 + +apply : (Int -> Int -> Int) -> Int -> Int -> Int; +apply f x y = f x y; +krimp: Int -> Int -> Int; +krimp x y = x + y; +main : Int; +main = apply (krimp) 2 3; +-- answer: 5 + diff --git a/shell.nix b/shell.nix index 0c7624a..0af8c7b 100644 --- a/shell.nix +++ b/shell.nix @@ -1,12 +1,12 @@ let - pkgs = import { }; # pin the channel to ensure reproducibility! + pkgs = import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/747927516efcb5e31ba03b7ff32f61f6d47e7d87.zip") { }; # pin the channel to ensure reproducibility! in pkgs.haskellPackages.developPackage { root = ./.; withHoogle = true; modifier = drv: pkgs.haskell.lib.addBuildTools drv ( - (with pkgs; [ hlint haskell-language-server ghc jasmin ]) + (with pkgs; [ hlint haskell-language-server ghc jasmin llvmPackages_15.libllvm]) ++ (with pkgs.haskellPackages; [ cabal-install diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs new file mode 100644 index 0000000..735d804 --- /dev/null +++ b/src/Auxiliary.hs @@ -0,0 +1,21 @@ +{-# LANGUAGE LambdaCase #-} +module Auxiliary (module Auxiliary) where +import Control.Monad.Error.Class (liftEither) +import Control.Monad.Except (MonadError) +import Data.Either.Combinators (maybeToRight) + +snoc :: a -> [a] -> [a] +snoc x xs = xs ++ [x] + +maybeToRightM :: MonadError l m => l -> Maybe r -> m r +maybeToRightM err = liftEither . maybeToRight err + +mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b]) +mapAccumM f = go + where + go acc = \case + [] -> pure (acc, []) + x:xs -> do + (acc', x') <- f acc x + (acc'', xs') <- go acc' xs + pure (acc'', x':xs') diff --git a/src/Compiler.hs b/src/Compiler.hs index e69de29..fd6b6bc 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -0,0 +1,266 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Compiler (compile) where + +import Auxiliary (snoc) +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (dupe, first, second) +import Grammar.ErrM (Err) +import LlvmIr (LLVMIr (..), LLVMType (..), + LLVMValue (..), Visibility (..), + llvmIrToString) +import TypeChecker (partitionType) +import TypeCheckerIr + +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo + , variableCount :: Integer + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t { instructions = snoc l $ instructions t } + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState Integer +getNewVar = increaseVarCount >> getVarCount + +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. +getFunctions :: [Bind] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ map go bs + where + go (Bind id args _) = + (id, FunctionInfo { numArgs=length args, arguments=args }) + + + +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , variableCount = 0 + } + +-- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to lli +compile :: Program -> Err String +compile (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen + +compileScs :: [Bind] -> CompilerState () +compileScs [] = pure () +compileScs (Bind (name, t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define (type2LlvmType t_return) name args' + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs + where + t_return = snd $ partitionType (length args) t + +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (Ident "end") + -- , Label (Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (Ident "end") + -- , Label (Ident "end") + Ret I64 (VInteger 0) + ] + +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] + +compileExp :: Exp -> CompilerState () +compileExp = \case + EInt i -> emitInt i + EAdd t e1 e2 -> emitAdd t e1 e2 + EId (name, _) -> emitIdent name + EApp t e1 e2 -> emitApp t e1 e2 + EAbs t ti e -> emitAbs t ti e + ELet bind e -> emitLet bind e + +--- aux functions --- +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e + +emitLet :: Bind -> Exp -> CompilerState () +emitLet b e = emit . Comment $ concat [ "ELet (" + , show b + , " = " + , show e + , ") is not implemented!" + ] + +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call (type2LlvmType t) visibility name args' + emit $ SetVariable (Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x + +emitIdent :: Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitInt :: Integer -> CompilerState () +emitInt i = do + -- !!this should never happen!! + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) + +-- emitMul :: Exp -> Exp -> CompilerState () +-- emitMul e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Mul I64 v1 v2 + +-- emitMod :: Exp -> Exp -> CompilerState () +-- emitMod e1 e2 = do +-- -- `let m a b = rem (abs $ b + a) b` +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- vadd <- gets variableCount +-- emit $ SetVariable $ Ident $ show vadd +-- emit $ Add I64 v1 v2 +-- +-- increaseVarCount +-- vabs <- gets variableCount +-- emit $ SetVariable $ Ident $ show vabs +-- emit $ Call I64 (Ident "llvm.abs.i64") +-- [ (I64, VIdent (Ident $ show vadd)) +-- , (I1, VInteger 1) +-- ] +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 + +-- emitDiv :: Exp -> Exp -> CompilerState () +-- emitDiv e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Div I64 v1 v2 + +-- emitSub :: Exp -> Exp -> CompilerState () +-- emitSub e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Sub I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + EInt i -> pure $ VInteger i + + EId id@(name, t) -> do + funcs <- gets functions + case Map.lookup id funcs of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ SetVariable (Ident $ show vc) + (Call (type2LlvmType t) Global name []) + pure $ VIdent (Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (Ident $ show v) (getType e) + +type2LlvmType :: Type -> LLVMType +type2LlvmType = \case + TInt -> I64 + TFun t xs -> do + let (t', xs') = function2LLVMType xs [type2LlvmType t] + Function t' xs' + t -> CustomType $ Ident ("\"" ++ show t ++ "\"") + where + function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) + function2LLVMType x s = (type2LlvmType x, s) + +getType :: Exp -> LLVMType +getType (EInt _) = I64 +getType (EAdd t _ _) = type2LlvmType t +getType (EId (_, t)) = type2LlvmType t +getType (EApp t _ _) = type2LlvmType t +getType (EAbs t _ _) = type2LlvmType t +getType (ELet _ e) = getType e + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VFunction _ _ t) = t diff --git a/src/GC/Makefile b/src/GC/Makefile index 354adad..92b02e8 100644 --- a/src/GC/Makefile +++ b/src/GC/Makefile @@ -9,8 +9,8 @@ STDFLAGS = -std=gnu++20 -stdlib=libc++ WFLAGS = -Wall -Wextra DBGFLAGS = -g -test_test: - echo "$(shell pwd)" +advance: + $(CC) $(WFLAGS) $(STDFLAGS) tests/advance.cpp -o tests/advance.out heap: $(CC) $(WFLAGS) $(STDFLAGS) $(LIB_INCL) lib/heap.cpp diff --git a/src/GC/include/heap.hpp b/src/GC/include/heap.hpp index 30f1a7a..c70ee54 100644 --- a/src/GC/include/heap.hpp +++ b/src/GC/include/heap.hpp @@ -2,9 +2,9 @@ #include #include +#include #include #include -#include #include "chunk.hpp" @@ -30,26 +30,10 @@ namespace GC { m_allocated_size = 0; } - void collect(Heap *heap); - void sweep(Heap *heap); - uintptr_t *try_recycle_chunks(Heap *heap, size_t size); - void free(Heap *heap); - void free_overlap(Heap *heap); - void mark(uintptr_t *start, const uintptr_t *end, std::vector worklist); - void print_line(Chunk *chunk); - void print_worklist(std::vector list); - - inline static Heap *m_instance = nullptr; - const char *m_heap; - size_t m_size; - size_t m_allocated_size; - uintptr_t *m_stack_end = nullptr; - - // maybe change to std::list - std::vector m_allocated_chunks; - std::vector m_freed_chunks; - - public: + // BEWARE only for testing, this should be adressed + ~Heap() { + std::free((char *)m_heap); + } static inline Heap *the() { // TODO: make private if (m_instance) // if m_instance is not a nullptr @@ -58,11 +42,35 @@ namespace GC { return m_instance; } - // BEWARE only for testing, this should be adressed - ~Heap() { - std::free((char *)m_heap); + static inline Chunk *getAt(std::list list, size_t n) { + auto iter = list.begin(); + if (!n) + return *iter; + std::advance(iter, n); + return *iter; } + void collect(); + void sweep(Heap *heap); + uintptr_t *try_recycle_chunks(size_t size); + void free(Heap* heap); + void free_overlap(Heap *heap); + void mark(uintptr_t *start, const uintptr_t *end, std::list worklist); + void print_line(Chunk *chunk); + void print_worklist(std::list list); + + inline static Heap *m_instance = nullptr; + const char *m_heap; + size_t m_size; + size_t m_allocated_size; + uintptr_t *m_stack_top = nullptr; + + // maybe change to std::list + std::list m_allocated_chunks; + std::list m_freed_chunks; + + public: + /** * These are the only two functions which are exposed * as the API for LLVM. At the absolute start of the @@ -70,8 +78,9 @@ namespace GC { * that the address of the topmost stack frame is * saved as the limit for scanning the stack in collect. */ - void *alloc(size_t size); // TODO: make static - void init(); // TODO: make static + static void init(); // TODO: make static + static void dispose(); // -||- + static void *alloc(size_t size); // -||- // DEBUG ONLY void collect(uint flags); // conditional collection diff --git a/src/GC/lib/heap.cpp b/src/GC/lib/heap.cpp index 3d37a3c..c17f680 100644 --- a/src/GC/lib/heap.cpp +++ b/src/GC/lib/heap.cpp @@ -14,13 +14,21 @@ namespace GC { /** * Initialises the heap singleton and saves the address - * of the calling stack frame as the stack_end. Presumeably + * of the calling stack frame as the stack_top. Presumeably * this address points to the stack frame of the compiled * LLVM executable after linking. */ void Heap::init() { Heap *heap = Heap::the(); - heap->m_stack_end = reinterpret_cast(__builtin_frame_address(1)); + heap->m_stack_top = reinterpret_cast(__builtin_frame_address(1)); + } + + /** + * Disposes the heap at program exit. + */ + void Heap::dispose() { + Heap *heap = Heap::the(); + delete heap; } /** @@ -43,13 +51,13 @@ namespace GC { } if (heap->m_size + size > HEAP_SIZE) { - collect(heap); + heap->collect(); // If collect failed, crash with OOM error assert(heap->m_size + size <= HEAP_SIZE && "Heap: Out Of Memory"); } // If a chunk was recycled, return the old chunk address - uintptr_t *reused_chunk = try_recycle_chunks(heap, size); + uintptr_t *reused_chunk = heap->try_recycle_chunks(size); if (reused_chunk != nullptr) { return (void *)reused_chunk; } @@ -58,7 +66,7 @@ namespace GC { // then create a new chunk auto new_chunk = new Chunk; new_chunk->size = size; - new_chunk->start = (uintptr_t *)(heap->m_heap + m_size); + new_chunk->start = (uintptr_t *)(heap->m_heap + heap->m_size); heap->m_size += size; @@ -75,8 +83,6 @@ namespace GC { * objects slightly which saves time from malloc'ing * memory from the OS. * - * @param heap Pointer to the singleton Heap instance - * * @param size Amount of bytes needed for the object * which is about to be allocated. * @@ -86,10 +92,14 @@ namespace GC { * nullptr is returned to signify no * chunks were found. */ - uintptr_t *Heap::try_recycle_chunks(Heap *heap, size_t size) { + uintptr_t *Heap::try_recycle_chunks(size_t size) { + auto heap = Heap::the(); // Check if there are any freed chunks large enough for current request for (size_t i = 0; i < heap->m_freed_chunks.size(); i++) { - auto cp = heap->m_freed_chunks.at(i); + // auto cp = heap->m_freed_chunks.at(i); + auto cp = getAt(heap->m_freed_chunks, i); + auto iter = heap->m_freed_chunks.begin(); + advance(iter, i); if (cp->size > size) { // Split the chunk, use one part and add the remaining part to @@ -100,7 +110,7 @@ namespace GC { chunk_complement->size = diff; chunk_complement->start = cp->start + cp->size; - heap->m_freed_chunks.erase(m_freed_chunks.begin() + i); + heap->m_freed_chunks.erase(iter); heap->m_freed_chunks.push_back(chunk_complement); heap->m_allocated_chunks.push_back(cp); @@ -109,7 +119,7 @@ namespace GC { else if (cp->size == size) { // Reuse the whole chunk - heap->m_freed_chunks.erase(m_freed_chunks.begin() + i); + heap->m_freed_chunks.erase(iter); heap->m_allocated_chunks.push_back(cp); return cp->start; } @@ -123,24 +133,23 @@ namespace GC { * left on the heap, a collection is triggered. This * function is private so that the user cannot trigger * a collection unneccessarily. - * - * @param heap Heap singleton instance, only for avoiding - * redundant calls to the singleton get */ - void Heap::collect(Heap *heap) { + void Heap::collect() { + // Get instance + auto heap = Heap::the(); // get current stack - auto stack_start = reinterpret_cast(__builtin_frame_address(0)); + auto stack_bottom = reinterpret_cast(__builtin_frame_address(0)); // fix this block, it's nästy - uintptr_t *stack_end; - if (heap->m_stack_end != nullptr) - stack_end = heap->m_stack_end; + uintptr_t *stack_top; + if (heap->m_stack_top != nullptr) + stack_top = heap->m_stack_top; else - stack_end = (uintptr_t *)0; // temporary + stack_top = (uintptr_t *)0; // temporary auto work_list = heap->m_allocated_chunks; - mark(stack_start, stack_end, work_list); + mark(stack_bottom, stack_top, work_list); sweep(heap); @@ -148,24 +157,26 @@ namespace GC { } /** - * Iterates through the stack, if an element on the stack points to a chunk - * that chunk is marked (i.e. reachable). It only marks element which are directly - * reachable from the chunk, so no chain of pointers from the stack are detected. + * Iterates through the stack, if an element on the stack points to a chunk, + * called a root chunk, that chunk is marked (i.e. reachable). + * Then it recursively follows all chunks which are possibly reachable from + * the root chunk and mark those chunks. * If a chunk is marked it is removed from the worklist, since it's no longer of * concern for this method. * * @param start Pointer to the start of the stack frame. * @param end Pointer to the end of the stack frame. - * @param worklist The currently allocated chunks. + * @param worklist The currently allocated chunks, which haven't been marked. */ - void Heap::mark(uintptr_t *start, const uintptr_t *end, vector worklist) { + void Heap::mark(uintptr_t *start, const uintptr_t *end, list worklist) { int counter = 0; // To find adresses thats in the worklist for (; start < end; start++) { counter++; - // all pointers must be aligned as double words - - for (auto it = worklist.begin(); it != worklist.end();) { + auto it = worklist.begin(); + auto stop = worklist.end(); + // for (auto it = worklist.begin(); it != worklist.end();) { + while (it != stop) { Chunk *chunk = *it; auto c_start = reinterpret_cast(chunk->start); @@ -181,7 +192,9 @@ namespace GC { if (!chunk->marked) { chunk->marked = true; + // Remove the marked chunk from the worklist it = worklist.erase(it); + // Recursively call mark, to see if the reachable chunk further points to another chunk mark((uintptr_t*) c_start, (uintptr_t*) c_end, worklist); } else { @@ -198,24 +211,27 @@ namespace GC { /** * Sweeps the heap, unmarks the marked chunks for the next cycle, - * adds the unmarked nodes to the vector of freed chunks; to be freed. + * adds the unmarked nodes to the list of freed chunks; to be freed. * - * @param heap Pointer to the heap to oporate on. + * @param heap Pointer to the heap singleton instance. */ void Heap::sweep(Heap *heap) { - for (auto it = heap->m_allocated_chunks.begin(); it != heap->m_allocated_chunks.end();) { - Chunk *chunk = *it; + auto iter = heap->m_allocated_chunks.begin(); + auto stop = heap->m_allocated_chunks.end(); + // for (auto it = heap->m_allocated_chunks.begin(); it != heap->m_allocated_chunks.end();) { + while (iter != stop) { + Chunk *chunk = *iter; // Unmark the marked chunks for the next iteration. if (chunk->marked) { chunk->marked = false; - ++it; + ++iter; } else { // Add the unmarked chunks to freed chunks and remove from // the list of allocated chunks heap->m_freed_chunks.push_back(chunk); - it = heap->m_allocated_chunks.erase(it); + iter = heap->m_allocated_chunks.erase(iter); } } } @@ -257,13 +273,15 @@ namespace GC { * larger chunks. */ void Heap::free_overlap(Heap *heap) { - std::vector filtered; + std::list filtered; size_t i = 0; - filtered.push_back(heap->m_freed_chunks.at(i++)); + // filtered.push_back(heap->m_freed_chunks.at(i++)); + filtered.push_back(getAt(heap->m_freed_chunks, i++)); cout << filtered.back()->start << endl; for (; i < heap->m_freed_chunks.size(); i++) { auto prev = filtered.back(); - auto next = heap->m_freed_chunks.at(i); + // auto next = heap->m_freed_chunks.at(i); + auto next = getAt(heap->m_freed_chunks, i); auto p_start = (uintptr_t)(prev->start); auto p_size = (uintptr_t)(prev->size); auto n_start = (uintptr_t)(next->start); @@ -283,9 +301,9 @@ namespace GC { void Heap::check_init() { auto heap = Heap::the(); cout << "Heap addr:\t" << heap << endl; - cout << "GC m_stack_end:\t" << heap->m_stack_end << endl; - auto stack_start = reinterpret_cast(__builtin_frame_address(0)); - cout << "GC stack_start:\t" << stack_start << endl; + cout << "GC m_stack_top:\t" << heap->m_stack_top << endl; + auto stack_bottom = reinterpret_cast(__builtin_frame_address(0)); + cout << "GC stack_bottom:\t" << stack_bottom << endl; } /** @@ -307,20 +325,20 @@ namespace GC { auto heap = Heap::the(); // get the frame adress, whwere local variables and saved registers are located - auto stack_start = reinterpret_cast(__builtin_frame_address(0)); - cout << "Stack start in collect:\t" << stack_start << endl; - uintptr_t *stack_end; + auto stack_bottom = reinterpret_cast(__builtin_frame_address(0)); + cout << "Stack bottom in collect:\t" << stack_bottom << endl; + uintptr_t *stack_top; - if (heap->m_stack_end != nullptr) - stack_end = heap->m_stack_end; + if (heap->m_stack_top != nullptr) + stack_top = heap->m_stack_top; else - stack_end = (uintptr_t *) stack_start - 80; // dummy value + stack_top = (uintptr_t *) stack_bottom + 80; // dummy value - cout << "Stack end in collect:\t " << stack_end << endl; + cout << "Stack end in collect:\t " << stack_top << endl; auto work_list = heap->m_allocated_chunks; if (flags & MARK) { - mark(stack_start, stack_end, work_list); + mark(stack_bottom, stack_top, work_list); } if (flags & SWEEP) { @@ -366,7 +384,7 @@ namespace GC { cout << "Marked: " << chunk->marked << "\nStart adr: " << chunk->start << "\nSize: " << chunk->size << " B\n" << endl; } - void Heap::print_worklist(std::vector list) { + void Heap::print_worklist(std::list list) { for (auto cp : list) { cout << "Chunk at:\t" << cp->start << "\nSize:\t\t" << cp->size << endl; } diff --git a/src/GC/tests/advance.cpp b/src/GC/tests/advance.cpp new file mode 100644 index 0000000..0a8a177 --- /dev/null +++ b/src/GC/tests/advance.cpp @@ -0,0 +1,34 @@ +#include +#include +#include + +using namespace std; + +int main() { + list l; + char c = 'a'; + for (int i = 1; i <= 5; i++) { + l.push_back(c++); + } + + auto iter = l.begin(); + auto stop = l.end(); + + while (iter != stop) { + cout << *iter << " "; + + iter++; + } + cout << endl; + iter = l.begin(); + while (*iter != *stop) { + cout << *iter << " "; + iter++; + } + cout << endl; + + cout << "rebased" << endl; + // cout << "iter: " << *iter << "\nstop: " << *stop << endl; + + return 0; +} \ No newline at end of file diff --git a/src/GC/todo.md b/src/GC/todo.md index dba3eee..f9492da 100644 --- a/src/GC/todo.md +++ b/src/GC/todo.md @@ -1,18 +1,15 @@ # Garbage collection - - ## Project Goal for next week (24/2): -- Debug +- Write more complex tests ## GC TODO: - Merge to main branch -- Switch std::vector to std::list -- Make alloc and init static, move the() to private -- stack_end, stack_start -> stack_top, stack_bottom - Double check m_heap_size functionality and when a collection is triggered +- Kolla vektor vs list complexity ## Tests TODO +- Write complex datastructures for tests with larger programs diff --git a/src/Interpreter.hs b/src/Interpreter.hs index bdbd8d2..37d46a7 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,45 +1,78 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} module Interpreter where +import Auxiliary (maybeToRightM) import Control.Applicative (Applicative) import Control.Monad.Except (Except, MonadError (throwError), liftEither) +import Control.Monad.State (MonadState, StateT, evalStateT) import Data.Either.Combinators (maybeToRight) import Data.Map (Map) import qualified Data.Map as Map +import Data.Maybe (maybe) import Grammar.Abs +import Grammar.ErrM (Err) import Grammar.Print (printTree) -interpret :: Program -> Except String Integer -interpret (Program e) = - eval mempty e >>= \case - VClosure {} -> throwError "main evaluated to a function" - VInt i -> pure i +interpret :: Program -> Err Integer +interpret (Program scs) = do + main <- findMain scs + eval (initCxt scs) main >>= + \case + VClosure {} -> throwError "main evaluated to a function" + VInt i -> pure i +initCxt :: [Bind] -> Cxt +initCxt scs = + Cxt { env = mempty + , sig = foldr insert mempty $ map expandLambdas scs + } + where insert (Bind name _ rhs) = Map.insert name rhs + +expandLambdas :: Bind -> Bind +expandLambdas (Bind name parms rhs) = Bind name [] $ foldr EAbs rhs parms + +findMain :: [Bind] -> Err Exp +findMain [] = throwError "No main!" +findMain (sc:scs) = case sc of + Bind "main" _ rhs -> pure rhs + _ -> findMain scs + data Val = VInt Integer - | VClosure Cxt Ident Exp + | VClosure Env Ident Exp + deriving (Show, Eq) -type Cxt = Map Ident Val +type Env = Map Ident Val +type Sig = Map Ident Exp -eval :: Cxt -> Exp -> Except String Val +data Cxt = Cxt + { env :: Map Ident Val + , sig :: Map Ident Exp + } deriving (Show, Eq) + +eval :: Cxt -> Exp -> Err Val eval cxt = \case - -- ------------ x ∈ γ -- γ ⊢ x ⇓ γ(x) - EId x -> - maybeToRightM - ("Unbound variable:" ++ printTree x) - $ Map.lookup x cxt + EId x -> do + case Map.lookup x cxt.env of + Just e -> pure e + Nothing -> + case Map.lookup x cxt.sig of + Just e -> eval (emptyEnv cxt) e + Nothing -> throwError ("Unbound variable: " ++ printTree x) -- --------- -- γ ⊢ i ⇓ i EInt i -> pure $ VInt i - -- γ ⊢ e ⇓ let δ in λx → f + -- γ ⊢ e ⇓ let δ in λx. f -- γ ⊢ e₁ ⇓ v -- δ,x=v ⊢ f ⇓ v₁ -- ------------------------------ @@ -50,13 +83,15 @@ eval cxt = \case VInt _ -> throwError "Not a function" VClosure delta x f -> do v <- eval cxt e1 - eval (Map.insert x v delta) f + let cxt' = putEnv (Map.insert x v delta) cxt + eval cxt' f + -- -- ----------------------------- - -- γ ⊢ λx → f ⇓ let γ in λx → f + -- γ ⊢ λx. f ⇓ let γ in λx. f - EAbs x e -> pure $ VClosure cxt x e + EAbs par e -> pure $ VClosure cxt.env par e -- γ ⊢ e ⇓ v @@ -71,8 +106,11 @@ eval cxt = \case (VInt i, VInt i1) -> pure $ VInt (i + i1) _ -> throwError "Can't add a function" + ELet _ _ -> throwError "ELet pattern match should never occur!" -maybeToRightM :: MonadError l m => l -> Maybe r -> m r -maybeToRightM err = liftEither . maybeToRight err +emptyEnv :: Cxt -> Cxt +emptyEnv cxt = cxt { env = mempty } +putEnv :: Env -> Cxt -> Cxt +putEnv env cxt = cxt { env = env } diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs new file mode 100644 index 0000000..015e7f3 --- /dev/null +++ b/src/LambdaLifter.hs @@ -0,0 +1,190 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + + +module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where + +import Auxiliary (snoc) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, evalState) +import Data.Set (Set) +import qualified Data.Set as Set +import Prelude hiding (exp) +import Renamer +import TypeCheckerIr + + +-- | Lift lambdas and let expression into supercombinators. +-- Three phases: +-- @freeVars@ annotatss all the free variables. +-- @abstract@ converts lambdas into let expressions. +-- @collectScs@ moves every non-constant let expression to a top-level function. +lambdaLift :: Program -> Program +lambdaLift = collectScs . abstract . freeVars + + +-- | Annotate free variables +freeVars :: Program -> AnnProgram +freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) + | Bind n xs e <- ds + ] + +freeVarsExp :: Set Id -> Exp -> AnnExp +freeVarsExp localVars = \case + EId n | Set.member n localVars -> (Set.singleton n, AId n) + | otherwise -> (mempty, AId n) + + EInt i -> (mempty, AInt i) + + EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') + where + e' = freeVarsExp (Set.insert par localVars) e + + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') + where + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' + + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' + + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars + + +freeVarsOf :: AnnExp -> Set Id +freeVarsOf = fst + +-- AST annotated with free variables +type AnnProgram = [(Id, [Id], AnnExp)] + +type AnnExp = (Set Id, AnnExp') + +data ABind = ABind Id [Id] AnnExp deriving Show + +data AnnExp' = AId Id + | AInt Integer + | ALet ABind AnnExp + | AApp Type AnnExp AnnExp + | AAdd Type AnnExp AnnExp + | AAbs Type Id AnnExp + deriving Show +-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +-- Free variables are @v₁ v₂ .. vₙ@ are bound. +abstract :: AnnProgram -> Program +abstract prog = Program $ evalState (mapM go prog) 0 + where + go :: (Id, [Id], AnnExp) -> State Int Bind + go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + where + (rhs', parms1) = flattenLambdasAnn rhs + + +-- | Flatten nested lambdas and collect the parameters +-- @\x.\y.\z. ae → (ae, [x,y,z])@ +flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) +flattenLambdasAnn ae = go (ae, []) + where + go :: (AnnExp, [Id]) -> (AnnExp, [Id]) + go ((free, e), acc) = + case e of + AAbs _ par (free1, e1) -> + go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) + +abstractExp :: AnnExp -> State Int Exp +abstractExp (free, exp) = case exp of + AId n -> pure $ EId n + AInt i -> pure $ EInt i + AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) + AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) + ALet b e -> liftA2 ELet (go b) (abstractExp e) + where + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' + + skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp + skipLambdas f (free, ae) = case ae of + AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 + _ -> f (free, ae) + + -- Lift lambda into let and bind free variables + AAbs t parm e -> do + i <- nextNumber + rhs <- abstractExp e + + let sc_name = Ident ("sc_" ++ show i) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) + + pure $ foldl (EApp TInt) sc $ map EId freeList + where + freeList = Set.toList free + parms = snoc parm freeList + + +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i + +-- | Collects supercombinators by lifting non-constant let expressions +collectScs :: Program -> Program +collectScs (Program scs) = Program $ concatMap collectFromRhs scs + where + collectFromRhs (Bind name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs + + +collectScsExp :: Exp -> ([Bind], Exp) +collectScsExp = \case + EId n -> ([], EId n) + EInt i -> ([], EInt i) + + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAbs t par e -> (scs, EAbs t par e') + where + (scs, e') = collectScsExp e + + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ e_scs, ELet bind e') + else (bind : rhs_scs ++ e_scs, e') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (e_scs, e') = collectScsExp e + + +-- @\x.\y.\z. e → (e, [x,y,z])@ +flattenLambdas :: Exp -> (Exp, [Id]) +flattenLambdas = go . (, []) + where + go (e, acc) = case e of + EAbs _ par e1 -> go (e1, snoc par acc) + _ -> (e, acc) diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs new file mode 100644 index 0000000..d340ddc --- /dev/null +++ b/src/LlvmIr.hs @@ -0,0 +1,204 @@ +{-# LANGUAGE LambdaCase #-} + +module LlvmIr ( + LLVMType (..), + LLVMIr (..), + llvmIrToString, + LLVMValue (..), + LLVMComp (..), + Visibility (..), +) where + +import Data.List (intercalate) +import TypeCheckerIr + +-- | A datatype which represents some basic LLVM types +data LLVMType + = I1 + | I8 + | I32 + | I64 + | Ptr + | Ref LLVMType + | Function LLVMType [LLVMType] + | Array Int LLVMType + | CustomType Ident + +instance Show LLVMType where + show :: LLVMType -> String + show = \case + I1 -> "i1" + I8 -> "i8" + I32 -> "i32" + I64 -> "i64" + Ptr -> "ptr" + Ref ty -> show ty <> "*" + Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" + Array n ty -> concat ["[", show n, " x ", show ty, "]"] + CustomType (Ident ty) -> ty + +data LLVMComp + = LLEq + | LLNe + | LLUgt + | LLUge + | LLUlt + | LLUle + | LLSgt + | LLSge + | LLSlt + | LLSle +instance Show LLVMComp where + show :: LLVMComp -> String + show = \case + LLEq -> "eq" + LLNe -> "ne" + LLUgt -> "ugt" + LLUge -> "uge" + LLUlt -> "ult" + LLUle -> "ule" + LLSgt -> "sgt" + LLSge -> "sge" + LLSlt -> "slt" + LLSle -> "sle" + +data Visibility = Local | Global +instance Show Visibility where + show :: Visibility -> String + show Local = "%" + show Global = "@" + +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant +data LLVMValue + = VInteger Integer + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType + +instance Show LLVMValue where + show :: LLVMValue -> String + show v = case v of + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> show vis <> n + VConstant s -> "c" <> show s + +type Params = [(Ident, LLVMType)] +type Args = [(LLVMType, LLVMValue)] + +-- | A datatype which represents different instructions in LLVM +data LLVMIr + = Define LLVMType Ident Params + | DefineEnd + | Declare LLVMType Ident Params + | SetVariable Ident LLVMIr + | Variable Ident + | Add LLVMType LLVMValue LLVMValue + | Sub LLVMType LLVMValue LLVMValue + | Div LLVMType LLVMValue LLVMValue + | Mul LLVMType LLVMValue LLVMValue + | Srem LLVMType LLVMValue LLVMValue + | Icmp LLVMComp LLVMType LLVMValue LLVMValue + | Br Ident + | BrCond LLVMValue Ident Ident + | Label Ident + | Call LLVMType Visibility Ident Args + | Alloca LLVMType + | Store LLVMType Ident LLVMType Ident + | Bitcast LLVMType Ident LLVMType + | Ret LLVMType LLVMValue + | Comment String + | UnsafeRaw String -- This should generally be avoided, and proper + -- instructions should be used in its place + deriving (Show) + +-- | Converts a list of LLVMIr instructions to a string +llvmIrToString :: [LLVMIr] -> String +llvmIrToString = go 0 + where + go :: Int -> [LLVMIr] -> String + go _ [] = mempty + go i (x : xs) = do + let (i', n) = case x of + Define{} -> (i + 1, 0) + DefineEnd -> (i - 1, 0) + _ -> (i, i) + insToString n x <> go i' xs + +-- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +insToString :: Int -> LLVMIr -> String +insToString i l = + replicate i '\t' <> case l of + (Define t (Ident i) params) -> + concat + [ "define ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call t vis (Ident i) arg) -> + concat + [ "call ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 (Ident id1) t2 (Ident id2)) -> + concat + [ "store ", show t1, " %", id1 + , ", ", show t2 , " %", id2, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" + (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , "label_", s1, ", ", "label %", "label_", s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id diff --git a/src/Main.hs b/src/Main.hs index ed753f2..1831428 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,30 +1,97 @@ {-# LANGUAGE LambdaCase #-} + module Main where -import Control.Monad.Except (runExcept) -import Grammar.Par (myLexer, pProgram) -import Interpreter (interpret) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) +import Compiler (compile) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) + +-- import Interpreter (interpret) +import LambdaLifter (lambdaLift) +import Renamer (rename) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import TypeChecker (typecheck) main :: IO () -main = getArgs >>= \case - [] -> print "Required file path missing" - (x:_) -> do - file <- readFile x - case pProgram (myLexer file) of - Left err -> do - putStrLn "SYNTAX ERROR" - putStrLn err - exitFailure - Right prg -> case runExcept $ interpret prg of - Left err -> do - putStrLn "INTERPRETER ERROR" - putStrLn err - exitFailure - Right i -> do - print i - exitSuccess +main = + getArgs >>= \case + [] -> print "Required file path missing" + (s : _) -> main' s +main' :: String -> IO () +main' s = do + file <- readFile s + printToErr "-- Parse Tree -- " + parsed <- fromSyntaxErr . pProgram $ myLexer file + printToErr $ printTree parsed + printToErr "\n-- Renamer --" + let renamed = rename parsed + printToErr $ printTree renamed + + printToErr "\n-- TypeChecker --" + typechecked <- fromTypeCheckerErr $ typecheck renamed + printToErr $ printTree typechecked + + printToErr "\n-- Lambda Lifter --" + let lifted = lambdaLift typechecked + printToErr $ printTree lifted + + printToErr "\n -- Printing compiler output to stdout --" + compiled <- fromCompilerErr $ compile lifted + putStrLn compiled + writeFile "llvm.ll" compiled + + -- interpred <- fromInterpreterErr $ interpret lifted + -- putStrLn "\n-- interpret" + -- print interpred + + exitSuccess + +printToErr :: String -> IO () +printToErr = hPutStrLn stderr + +fromCompilerErr :: Err a -> IO a +fromCompilerErr = + either + ( \err -> do + putStrLn "\nCOMPILER ERROR" + putStrLn err + exitFailure + ) + pure + +fromSyntaxErr :: Err a -> IO a +fromSyntaxErr = + either + ( \err -> do + putStrLn "\nSYNTAX ERROR" + putStrLn err + exitFailure + ) + pure + +fromTypeCheckerErr :: Err a -> IO a +fromTypeCheckerErr = + either + ( \err -> do + putStrLn "\nTYPECHECKER ERROR" + putStrLn err + exitFailure + ) + pure + +fromInterpreterErr :: Err a -> IO a +fromInterpreterErr = + either + ( \err -> do + putStrLn "\nINTERPRETER ERROR" + putStrLn err + exitFailure + ) + pure diff --git a/src/Renamer.hs b/src/Renamer.hs new file mode 100644 index 0000000..b284e92 --- /dev/null +++ b/src/Renamer.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE LambdaCase #-} + +module Renamer (module Renamer) where + +import Auxiliary (mapAccumM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) +import Grammar.Abs + + +-- | Rename all variables and local binds +rename :: Program -> Program +rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 + where + initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + renameSc :: Names -> Bind -> Rn Bind + renameSc old_names (Bind name t _ parms rhs) = do + (new_names, parms') <- newNames old_names parms + rhs' <- snd <$> renameExp new_names rhs + pure $ Bind name t name parms' rhs' + + +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn { runRn :: State Int a } + deriving (Functor, Applicative, Monad, MonadState Int) + +-- | Maps old to new name +type Names = Map Ident Ident + +renameLocalBind :: Names -> Bind -> Rn (Names, Bind) +renameLocalBind old_names (Bind name t _ parms rhs) = do + (new_names, name') <- newName old_names name + (new_names', parms') <- newNames new_names parms + (new_names'', rhs') <- renameExp new_names' rhs + pure (new_names'', Bind name' t name' parms' rhs') + +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) + + EInt i1 -> pure (old_names, EInt i1) + + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') + + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') + + ELet b e -> do + (new_names, b) <- renameLocalBind old_names b + (new_names', e') <- renameExp new_names e + pure (new_names', ELet b e') + + EAbs par t e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' t e') + + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) + +-- | Create a new name and add it to name environment. +newName :: Names -> Ident -> Rn (Names, Ident) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) + +-- | Create multiple names and add them to the name environment +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) +newNames = mapAccumM newName + +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ + diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index e69de29..1e44888 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} + +module TypeChecker (typecheck, partitionType) where + +import Auxiliary (maybeToRightM, snoc) +import Control.Monad.Except (throwError, unless) +import Data.Map (Map) +import qualified Data.Map as Map +import Grammar.Abs +import Grammar.ErrM (Err) +import Grammar.Print (Print (prt), concatD, doc, printTree, + render) +import Prelude hiding (exp, id) +import qualified TypeCheckerIr as T + +-- NOTE: this type checker is poorly tested + +-- TODO +-- Coercion +-- Type inference + +data Cxt = Cxt + { env :: Map Ident Type -- ^ Local scope signature + , sig :: Map Ident Type -- ^ Top-level signatures + } + +initCxt :: [Bind] -> Cxt +initCxt sc = Cxt { env = mempty + , sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc + } + +typecheck :: Program -> Err T.Program +typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc + +-- | Check if infered rhs type matches type signature. +checkBind :: Cxt -> Bind -> Err T.Bind +checkBind cxt b = + case expandLambdas b of + Bind name t _ parms rhs -> do + (rhs', t_rhs) <- infer cxt rhs + unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs + pure $ T.Bind (name, t) (zip parms ts_parms) rhs' + where + ts_parms = fst $ partitionType (length parms) t + +-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ +expandLambdas :: Bind -> Bind +expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' + where + rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms + ts_parms = fst $ partitionType (length parms) t + +-- | Infer type of expression. +infer :: Cxt -> Exp -> Err (T.Exp, Type) +infer cxt = \case + EId x -> + case lookupEnv x cxt of + Nothing -> + case lookupSig x cxt of + Nothing -> throwError ("Unbound variable:" ++ printTree x) + Just t -> pure (T.EId (x, t), t) + Just t -> pure (T.EId (x, t), t) + + EInt i -> pure (T.EInt i, T.TInt) + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure (T.EApp t2 e' e1', t2) + _ -> do + throwError ("Not a function: " ++ show e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure (T.EAdd T.TInt e' e1', T.TInt) + + EAbs x t e -> do + (e', t1) <- infer (insertEnv x t cxt) e + let t_abs = TFun t t1 + pure (T.EAbs t_abs (x, t) e', t_abs) + + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + (e', t) <- infer cxt' e + pure (T.ELet b' e', t) + + EAnn e t -> do + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) + +-- | Check infered type matches the supplied type. +check :: Cxt -> Exp -> Type -> Err T.Exp +check cxt exp typ = case exp of + + EId x -> do + t <- case lookupEnv x cxt of + Nothing -> maybeToRightM + ("Unbound variable:" ++ printTree x) + (lookupSig x cxt) + Just t -> pure t + unless (typeEq t typ) . throwError $ typeErr x typ t + pure $ T.EId (x, t) + + EInt i -> do + unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ + pure $ T.EInt i + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure $ T.EApp t2 e' e1' + _ -> throwError ("Not a function 2: " ++ printTree e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure $ T.EAdd T.TInt e' e1' + + EAbs x t e -> do + (e', t_e) <- infer (insertEnv x t cxt) e + let t1 = TFun t t_e + unless (typeEq t1 typ) $ throwError "Wrong lamda type!" + pure $ T.EAbs t1 (x, t) e' + + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + e' <- check cxt' e typ + pure $ T.ELet b' e' + + EAnn e t -> do + unless (typeEq t typ) $ + throwError "Inferred type and type annotation doesn't match" + check cxt e t + +-- | Check if types are equivalent. Doesn't handle coercion or polymorphism. +typeEq :: Type -> Type -> Bool +typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 +typeEq t t1 = t == t1 + +-- | Partion type into types of parameters and return type. +partitionType :: Int -- Number of parameters to apply + -> Type + -> ([Type], Type) +partitionType = go [] + where + go acc 0 t = (acc, t) + go acc i t = case t of + TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t + +lookupEnv :: Ident -> Cxt -> Maybe Type +lookupEnv x = Map.lookup x . env + +insertEnv :: Ident -> Type -> Cxt -> Cxt +insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } + +lookupSig :: Ident -> Cxt -> Maybe Type +lookupSig x = Map.lookup x . sig + +typeErr :: Print a => a -> Type -> Type -> String +typeErr p expected actual = render $ concatD + [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" + , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" + , doc $ showString "Actual: " , prt 0 actual + ] diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs new file mode 100644 index 0000000..f6e3ec6 --- /dev/null +++ b/src/TypeCheckerIr.hs @@ -0,0 +1,100 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeCheckerIr + ( module Grammar.Abs + , module TypeCheckerIr + ) where + +import Grammar.Abs (Ident (..), Type (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) + +newtype Program = Program [Bind] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Exp + = EId Id + | EInt Integer + | ELet Bind Exp + | EApp Type Exp Exp + | EAdd Type Exp Exp + | EAbs Type Id Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id = (Ident, Type) + +data Bind = Bind Id [Id] Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print Bind where + prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD + [ prtId 0 name + , doc $ showString ";" + , prt 0 n + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) + +prtId :: Int -> Id -> Doc +prtId i (name, t) = prPrec i 0 $ concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] + +prtIdP :: Int -> Id -> Doc +prtIdP i (name, t) = prPrec i 0 $ concatD + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + + +instance Print Exp where + prt i = \case + EId n -> prPrec i 3 $ concatD [prtIdP 0 n] + EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] + ELet bs e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + ] + EApp t e1 e2 -> prPrec i 2 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> prPrec i 1 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs t n e -> prPrec i 0 $ concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtIdP 0 n + , doc $ showString "." + , prt 0 e + ] + + diff --git a/test_program b/test_program deleted file mode 100644 index 83f3e9a..0000000 --- a/test_program +++ /dev/null @@ -1,5 +0,0 @@ - - - - -main = (\x -> x + x + 3) ((\x -> x) 2)