Merge branch 'g-collection' of https://github.com/bachelor-group-66-systemf/language into g-collection

This commit is contained in:
valtermiari 2023-02-24 14:25:08 +01:00
commit 87f5d7fe74
22 changed files with 1406 additions and 149 deletions

3
.gitignore vendored
View file

@ -3,6 +3,9 @@ dist-newstyle
*.x
*.bak
src/Grammar
language
llvm.ll
/language
.vscode/

View file

@ -1,15 +1,25 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
EInt. Exp3 ::= Integer;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
Program. Program ::= "main" "=" Exp ;
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp;
EId. Exp3 ::= Ident ;
EInt. Exp3 ::= Integer ;
EApp. Exp2 ::= Exp2 Exp3 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
EAbs. Exp ::= "\\" Ident "->" Exp ;
separator Bind ";";
separator Ident "";
coercions Exp 3 ;
coercions Exp 3;
comment "--" ;
comment "{-" "-}" ;
TInt. Type1 ::= "Int" ;
TPol. Type1 ::= Ident ;
TFun. Type ::= Type1 "->" Type ;
coercions Type 1 ;
comment "--";
comment "{-" "-}";

View file

@ -1,7 +1,7 @@
.PHONY : sdist clean
language : src/Grammar/Test
cabal install --installdir=.
cabal install --installdir=. --overwrite-policy=always
src/Grammar/Test.hs src/Grammar/Lex.x src/Grammar/Par.y : Grammar.cf
bnfc -o src -d $<
@ -22,4 +22,16 @@ clean :
rm -r src/Grammar
rm language
test :
./language ./sample-programs/basic-1
./language ./sample-programs/basic-2
./language ./sample-programs/basic-3
./language ./sample-programs/basic-4
./language ./sample-programs/basic-5
./language ./sample-programs/basic-5
./language ./sample-programs/basic-6
./language ./sample-programs/basic-7
./language ./sample-programs/basic-8
./language ./sample-programs/basic-9
# EOF

2
cabal.project.local Normal file
View file

@ -0,0 +1,2 @@
ignore-project: False
tests: True

View file

@ -1,4 +1,4 @@
cabal-version: 3.0
cabal-version: 3.4
name: language
@ -12,26 +12,34 @@ build-type: Simple
extra-doc-files: CHANGELOG.md
extra-source-files:
Grammar.cf
common warnings
ghc-options: -Wall
ghc-options: -W
executable language
import: warnings
main-is: Main.hs
other-modules:
Grammar.Abs
Grammar.Lex
Grammar.Par
Grammar.Print
Grammar.Skel
Interpreter
Grammar.ErrM
LambdaLifter
Auxiliary
Renamer
TypeChecker
TypeCheckerIr
-- Interpreter
Compiler
LlvmIr
hs-source-dirs: src
build-depends:
@ -40,5 +48,5 @@ executable language
, containers
, either
, array
, extra
default-language: GHC2021

21
sample-programs/basic-1 Normal file
View file

@ -0,0 +1,21 @@
-- tripplemagic : Int -> Int -> Int -> Int;
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
-- main : Int;
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x y = f x y;
krimp: Int -> Int -> Int;
krimp x y = x + y;
main : Int;
main = apply (krimp) 2 3;
-- answer: 5

View file

@ -1,12 +1,12 @@
let
pkgs = import <nixpkgs> { }; # pin the channel to ensure reproducibility!
pkgs = import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/747927516efcb5e31ba03b7ff32f61f6d47e7d87.zip") { }; # pin the channel to ensure reproducibility!
in
pkgs.haskellPackages.developPackage {
root = ./.;
withHoogle = true;
modifier = drv:
pkgs.haskell.lib.addBuildTools drv (
(with pkgs; [ hlint haskell-language-server ghc jasmin ])
(with pkgs; [ hlint haskell-language-server ghc jasmin llvmPackages_15.libllvm])
++
(with pkgs.haskellPackages; [
cabal-install

21
src/Auxiliary.hs Normal file
View file

@ -0,0 +1,21 @@
{-# LANGUAGE LambdaCase #-}
module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError)
import Data.Either.Combinators (maybeToRight)
snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x]
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err
mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b])
mapAccumM f = go
where
go acc = \case
[] -> pure (acc, [])
x:xs -> do
(acc', x') <- f acc x
(acc'', xs') <- go acc' xs
pure (acc'', x':xs')

View file

@ -0,0 +1,266 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Compiler (compile) where
import Auxiliary (snoc)
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..),
llvmIrToString)
import TypeChecker (partitionType)
import TypeCheckerIr
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, variableCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ map go bs
where
go (Bind id args _) =
(id, FunctionInfo { numArgs=length args, arguments=args })
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, variableCount = 0
}
-- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to
-- Simply pipe it to lli
compile :: Program -> Err String
compile (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where
t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
compileExp :: Exp -> CompilerState ()
compileExp = \case
EInt i -> emitInt i
EAdd t e1 e2 -> emitAdd t e1 e2
EId (name, _) -> emitIdent name
EApp t e1 e2 -> emitApp t e1 e2
EAbs t ti e -> emitAbs t ti e
ELet bind e -> emitLet bind e
--- aux functions ---
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet b e = emit . Comment $ concat [ "ELet ("
, show b
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 []
where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc)
(Call (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
TInt -> I64
TFun t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
t -> CustomType $ Ident ("\"" ++ show t ++ "\"")
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t

View file

@ -9,8 +9,8 @@ STDFLAGS = -std=gnu++20 -stdlib=libc++
WFLAGS = -Wall -Wextra
DBGFLAGS = -g
test_test:
echo "$(shell pwd)"
advance:
$(CC) $(WFLAGS) $(STDFLAGS) tests/advance.cpp -o tests/advance.out
heap:
$(CC) $(WFLAGS) $(STDFLAGS) $(LIB_INCL) lib/heap.cpp

View file

@ -2,9 +2,9 @@
#include <assert.h>
#include <iostream>
#include <list>
#include <setjmp.h>
#include <stdlib.h>
#include <vector>
#include "chunk.hpp"
@ -30,26 +30,10 @@ namespace GC {
m_allocated_size = 0;
}
void collect(Heap *heap);
void sweep(Heap *heap);
uintptr_t *try_recycle_chunks(Heap *heap, size_t size);
void free(Heap *heap);
void free_overlap(Heap *heap);
void mark(uintptr_t *start, const uintptr_t *end, std::vector<Chunk *> worklist);
void print_line(Chunk *chunk);
void print_worklist(std::vector<Chunk *> list);
inline static Heap *m_instance = nullptr;
const char *m_heap;
size_t m_size;
size_t m_allocated_size;
uintptr_t *m_stack_end = nullptr;
// maybe change to std::list
std::vector<Chunk *> m_allocated_chunks;
std::vector<Chunk *> m_freed_chunks;
public:
// BEWARE only for testing, this should be adressed
~Heap() {
std::free((char *)m_heap);
}
static inline Heap *the() { // TODO: make private
if (m_instance) // if m_instance is not a nullptr
@ -58,11 +42,35 @@ namespace GC {
return m_instance;
}
// BEWARE only for testing, this should be adressed
~Heap() {
std::free((char *)m_heap);
static inline Chunk *getAt(std::list<Chunk *> list, size_t n) {
auto iter = list.begin();
if (!n)
return *iter;
std::advance(iter, n);
return *iter;
}
void collect();
void sweep(Heap *heap);
uintptr_t *try_recycle_chunks(size_t size);
void free(Heap* heap);
void free_overlap(Heap *heap);
void mark(uintptr_t *start, const uintptr_t *end, std::list<Chunk *> worklist);
void print_line(Chunk *chunk);
void print_worklist(std::list<Chunk *> list);
inline static Heap *m_instance = nullptr;
const char *m_heap;
size_t m_size;
size_t m_allocated_size;
uintptr_t *m_stack_top = nullptr;
// maybe change to std::list
std::list<Chunk *> m_allocated_chunks;
std::list<Chunk *> m_freed_chunks;
public:
/**
* These are the only two functions which are exposed
* as the API for LLVM. At the absolute start of the
@ -70,8 +78,9 @@ namespace GC {
* that the address of the topmost stack frame is
* saved as the limit for scanning the stack in collect.
*/
void *alloc(size_t size); // TODO: make static
void init(); // TODO: make static
static void init(); // TODO: make static
static void dispose(); // -||-
static void *alloc(size_t size); // -||-
// DEBUG ONLY
void collect(uint flags); // conditional collection

View file

@ -14,13 +14,21 @@ namespace GC {
/**
* Initialises the heap singleton and saves the address
* of the calling stack frame as the stack_end. Presumeably
* of the calling stack frame as the stack_top. Presumeably
* this address points to the stack frame of the compiled
* LLVM executable after linking.
*/
void Heap::init() {
Heap *heap = Heap::the();
heap->m_stack_end = reinterpret_cast<uintptr_t *>(__builtin_frame_address(1));
heap->m_stack_top = reinterpret_cast<uintptr_t *>(__builtin_frame_address(1));
}
/**
* Disposes the heap at program exit.
*/
void Heap::dispose() {
Heap *heap = Heap::the();
delete heap;
}
/**
@ -43,13 +51,13 @@ namespace GC {
}
if (heap->m_size + size > HEAP_SIZE) {
collect(heap);
heap->collect();
// If collect failed, crash with OOM error
assert(heap->m_size + size <= HEAP_SIZE && "Heap: Out Of Memory");
}
// If a chunk was recycled, return the old chunk address
uintptr_t *reused_chunk = try_recycle_chunks(heap, size);
uintptr_t *reused_chunk = heap->try_recycle_chunks(size);
if (reused_chunk != nullptr) {
return (void *)reused_chunk;
}
@ -58,7 +66,7 @@ namespace GC {
// then create a new chunk
auto new_chunk = new Chunk;
new_chunk->size = size;
new_chunk->start = (uintptr_t *)(heap->m_heap + m_size);
new_chunk->start = (uintptr_t *)(heap->m_heap + heap->m_size);
heap->m_size += size;
@ -75,8 +83,6 @@ namespace GC {
* objects slightly which saves time from malloc'ing
* memory from the OS.
*
* @param heap Pointer to the singleton Heap instance
*
* @param size Amount of bytes needed for the object
* which is about to be allocated.
*
@ -86,10 +92,14 @@ namespace GC {
* nullptr is returned to signify no
* chunks were found.
*/
uintptr_t *Heap::try_recycle_chunks(Heap *heap, size_t size) {
uintptr_t *Heap::try_recycle_chunks(size_t size) {
auto heap = Heap::the();
// Check if there are any freed chunks large enough for current request
for (size_t i = 0; i < heap->m_freed_chunks.size(); i++) {
auto cp = heap->m_freed_chunks.at(i);
// auto cp = heap->m_freed_chunks.at(i);
auto cp = getAt(heap->m_freed_chunks, i);
auto iter = heap->m_freed_chunks.begin();
advance(iter, i);
if (cp->size > size)
{
// Split the chunk, use one part and add the remaining part to
@ -100,7 +110,7 @@ namespace GC {
chunk_complement->size = diff;
chunk_complement->start = cp->start + cp->size;
heap->m_freed_chunks.erase(m_freed_chunks.begin() + i);
heap->m_freed_chunks.erase(iter);
heap->m_freed_chunks.push_back(chunk_complement);
heap->m_allocated_chunks.push_back(cp);
@ -109,7 +119,7 @@ namespace GC {
else if (cp->size == size)
{
// Reuse the whole chunk
heap->m_freed_chunks.erase(m_freed_chunks.begin() + i);
heap->m_freed_chunks.erase(iter);
heap->m_allocated_chunks.push_back(cp);
return cp->start;
}
@ -123,24 +133,23 @@ namespace GC {
* left on the heap, a collection is triggered. This
* function is private so that the user cannot trigger
* a collection unneccessarily.
*
* @param heap Heap singleton instance, only for avoiding
* redundant calls to the singleton get
*/
void Heap::collect(Heap *heap) {
void Heap::collect() {
// Get instance
auto heap = Heap::the();
// get current stack
auto stack_start = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
auto stack_bottom = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
// fix this block, it's nästy
uintptr_t *stack_end;
if (heap->m_stack_end != nullptr)
stack_end = heap->m_stack_end;
uintptr_t *stack_top;
if (heap->m_stack_top != nullptr)
stack_top = heap->m_stack_top;
else
stack_end = (uintptr_t *)0; // temporary
stack_top = (uintptr_t *)0; // temporary
auto work_list = heap->m_allocated_chunks;
mark(stack_start, stack_end, work_list);
mark(stack_bottom, stack_top, work_list);
sweep(heap);
@ -148,24 +157,26 @@ namespace GC {
}
/**
* Iterates through the stack, if an element on the stack points to a chunk
* that chunk is marked (i.e. reachable). It only marks element which are directly
* reachable from the chunk, so no chain of pointers from the stack are detected.
* Iterates through the stack, if an element on the stack points to a chunk,
* called a root chunk, that chunk is marked (i.e. reachable).
* Then it recursively follows all chunks which are possibly reachable from
* the root chunk and mark those chunks.
* If a chunk is marked it is removed from the worklist, since it's no longer of
* concern for this method.
*
* @param start Pointer to the start of the stack frame.
* @param end Pointer to the end of the stack frame.
* @param worklist The currently allocated chunks.
* @param worklist The currently allocated chunks, which haven't been marked.
*/
void Heap::mark(uintptr_t *start, const uintptr_t *end, vector<Chunk*> worklist) {
void Heap::mark(uintptr_t *start, const uintptr_t *end, list<Chunk*> worklist) {
int counter = 0;
// To find adresses thats in the worklist
for (; start < end; start++) {
counter++;
// all pointers must be aligned as double words
for (auto it = worklist.begin(); it != worklist.end();) {
auto it = worklist.begin();
auto stop = worklist.end();
// for (auto it = worklist.begin(); it != worklist.end();) {
while (it != stop) {
Chunk *chunk = *it;
auto c_start = reinterpret_cast<uintptr_t>(chunk->start);
@ -181,7 +192,9 @@ namespace GC {
if (!chunk->marked) {
chunk->marked = true;
// Remove the marked chunk from the worklist
it = worklist.erase(it);
// Recursively call mark, to see if the reachable chunk further points to another chunk
mark((uintptr_t*) c_start, (uintptr_t*) c_end, worklist);
}
else {
@ -198,24 +211,27 @@ namespace GC {
/**
* Sweeps the heap, unmarks the marked chunks for the next cycle,
* adds the unmarked nodes to the vector of freed chunks; to be freed.
* adds the unmarked nodes to the list of freed chunks; to be freed.
*
* @param heap Pointer to the heap to oporate on.
* @param heap Pointer to the heap singleton instance.
*/
void Heap::sweep(Heap *heap) {
for (auto it = heap->m_allocated_chunks.begin(); it != heap->m_allocated_chunks.end();) {
Chunk *chunk = *it;
auto iter = heap->m_allocated_chunks.begin();
auto stop = heap->m_allocated_chunks.end();
// for (auto it = heap->m_allocated_chunks.begin(); it != heap->m_allocated_chunks.end();) {
while (iter != stop) {
Chunk *chunk = *iter;
// Unmark the marked chunks for the next iteration.
if (chunk->marked) {
chunk->marked = false;
++it;
++iter;
}
else {
// Add the unmarked chunks to freed chunks and remove from
// the list of allocated chunks
heap->m_freed_chunks.push_back(chunk);
it = heap->m_allocated_chunks.erase(it);
iter = heap->m_allocated_chunks.erase(iter);
}
}
}
@ -257,13 +273,15 @@ namespace GC {
* larger chunks.
*/
void Heap::free_overlap(Heap *heap) {
std::vector<Chunk *> filtered;
std::list<Chunk *> filtered;
size_t i = 0;
filtered.push_back(heap->m_freed_chunks.at(i++));
// filtered.push_back(heap->m_freed_chunks.at(i++));
filtered.push_back(getAt(heap->m_freed_chunks, i++));
cout << filtered.back()->start << endl;
for (; i < heap->m_freed_chunks.size(); i++) {
auto prev = filtered.back();
auto next = heap->m_freed_chunks.at(i);
// auto next = heap->m_freed_chunks.at(i);
auto next = getAt(heap->m_freed_chunks, i);
auto p_start = (uintptr_t)(prev->start);
auto p_size = (uintptr_t)(prev->size);
auto n_start = (uintptr_t)(next->start);
@ -283,9 +301,9 @@ namespace GC {
void Heap::check_init() {
auto heap = Heap::the();
cout << "Heap addr:\t" << heap << endl;
cout << "GC m_stack_end:\t" << heap->m_stack_end << endl;
auto stack_start = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
cout << "GC stack_start:\t" << stack_start << endl;
cout << "GC m_stack_top:\t" << heap->m_stack_top << endl;
auto stack_bottom = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
cout << "GC stack_bottom:\t" << stack_bottom << endl;
}
/**
@ -307,20 +325,20 @@ namespace GC {
auto heap = Heap::the();
// get the frame adress, whwere local variables and saved registers are located
auto stack_start = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
cout << "Stack start in collect:\t" << stack_start << endl;
uintptr_t *stack_end;
auto stack_bottom = reinterpret_cast<uintptr_t *>(__builtin_frame_address(0));
cout << "Stack bottom in collect:\t" << stack_bottom << endl;
uintptr_t *stack_top;
if (heap->m_stack_end != nullptr)
stack_end = heap->m_stack_end;
if (heap->m_stack_top != nullptr)
stack_top = heap->m_stack_top;
else
stack_end = (uintptr_t *) stack_start - 80; // dummy value
stack_top = (uintptr_t *) stack_bottom + 80; // dummy value
cout << "Stack end in collect:\t " << stack_end << endl;
cout << "Stack end in collect:\t " << stack_top << endl;
auto work_list = heap->m_allocated_chunks;
if (flags & MARK) {
mark(stack_start, stack_end, work_list);
mark(stack_bottom, stack_top, work_list);
}
if (flags & SWEEP) {
@ -366,7 +384,7 @@ namespace GC {
cout << "Marked: " << chunk->marked << "\nStart adr: " << chunk->start << "\nSize: " << chunk->size << " B\n" << endl;
}
void Heap::print_worklist(std::vector<Chunk *> list) {
void Heap::print_worklist(std::list<Chunk *> list) {
for (auto cp : list) {
cout << "Chunk at:\t" << cp->start << "\nSize:\t\t" << cp->size << endl;
}

34
src/GC/tests/advance.cpp Normal file
View file

@ -0,0 +1,34 @@
#include <iostream>
#include <list>
#include <stdlib.h>
using namespace std;
int main() {
list<char> l;
char c = 'a';
for (int i = 1; i <= 5; i++) {
l.push_back(c++);
}
auto iter = l.begin();
auto stop = l.end();
while (iter != stop) {
cout << *iter << " ";
iter++;
}
cout << endl;
iter = l.begin();
while (*iter != *stop) {
cout << *iter << " ";
iter++;
}
cout << endl;
cout << "rebased" << endl;
// cout << "iter: " << *iter << "\nstop: " << *stop << endl;
return 0;
}

View file

@ -1,18 +1,15 @@
# Garbage collection
## Project
Goal for next week (24/2):
- Debug
- Write more complex tests
## GC TODO:
- Merge to main branch
- Switch std::vector to std::list
- Make alloc and init static, move the() to private
- stack_end, stack_start -> stack_top, stack_bottom
- Double check m_heap_size functionality and when a collection is triggered
- Kolla vektor vs list complexity
## Tests TODO
- Write complex datastructures for tests with larger programs

View file

@ -1,45 +1,78 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Interpreter where
import Auxiliary (maybeToRightM)
import Control.Applicative (Applicative)
import Control.Monad.Except (Except, MonadError (throwError),
liftEither)
import Control.Monad.State (MonadState, StateT, evalStateT)
import Data.Either.Combinators (maybeToRight)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (maybe)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
interpret :: Program -> Except String Integer
interpret (Program e) =
eval mempty e >>= \case
VClosure {} -> throwError "main evaluated to a function"
VInt i -> pure i
interpret :: Program -> Err Integer
interpret (Program scs) = do
main <- findMain scs
eval (initCxt scs) main >>=
\case
VClosure {} -> throwError "main evaluated to a function"
VInt i -> pure i
initCxt :: [Bind] -> Cxt
initCxt scs =
Cxt { env = mempty
, sig = foldr insert mempty $ map expandLambdas scs
}
where insert (Bind name _ rhs) = Map.insert name rhs
expandLambdas :: Bind -> Bind
expandLambdas (Bind name parms rhs) = Bind name [] $ foldr EAbs rhs parms
findMain :: [Bind] -> Err Exp
findMain [] = throwError "No main!"
findMain (sc:scs) = case sc of
Bind "main" _ rhs -> pure rhs
_ -> findMain scs
data Val = VInt Integer
| VClosure Cxt Ident Exp
| VClosure Env Ident Exp
deriving (Show, Eq)
type Cxt = Map Ident Val
type Env = Map Ident Val
type Sig = Map Ident Exp
eval :: Cxt -> Exp -> Except String Val
data Cxt = Cxt
{ env :: Map Ident Val
, sig :: Map Ident Exp
} deriving (Show, Eq)
eval :: Cxt -> Exp -> Err Val
eval cxt = \case
-- ------------ x ∈ γ
-- γ ⊢ x ⇓ γ(x)
EId x ->
maybeToRightM
("Unbound variable:" ++ printTree x)
$ Map.lookup x cxt
EId x -> do
case Map.lookup x cxt.env of
Just e -> pure e
Nothing ->
case Map.lookup x cxt.sig of
Just e -> eval (emptyEnv cxt) e
Nothing -> throwError ("Unbound variable: " ++ printTree x)
-- ---------
-- γ ⊢ i ⇓ i
EInt i -> pure $ VInt i
-- γ ⊢ e ⇓ let δ in λx f
-- γ ⊢ e ⇓ let δ in λx. f
-- γ ⊢ e₁ ⇓ v
-- δ,x=v ⊢ f ⇓ v₁
-- ------------------------------
@ -50,13 +83,15 @@ eval cxt = \case
VInt _ -> throwError "Not a function"
VClosure delta x f -> do
v <- eval cxt e1
eval (Map.insert x v delta) f
let cxt' = putEnv (Map.insert x v delta) cxt
eval cxt' f
--
-- -----------------------------
-- γ ⊢ λx → f ⇓ let γ in λx → f
-- γ ⊢ λx. f ⇓ let γ in λx. f
EAbs x e -> pure $ VClosure cxt x e
EAbs par e -> pure $ VClosure cxt.env par e
-- γ ⊢ e ⇓ v
@ -71,8 +106,11 @@ eval cxt = \case
(VInt i, VInt i1) -> pure $ VInt (i + i1)
_ -> throwError "Can't add a function"
ELet _ _ -> throwError "ELet pattern match should never occur!"
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err
emptyEnv :: Cxt -> Cxt
emptyEnv cxt = cxt { env = mempty }
putEnv :: Env -> Cxt -> Cxt
putEnv env cxt = cxt { env = env }

190
src/LambdaLifter.hs Normal file
View file

@ -0,0 +1,190 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import Renamer
import TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of
AId n -> pure $ EId n
AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
-- Lift lambda into let and bind free variables
AAbs t parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)

204
src/LlvmIr.hs Normal file
View file

@ -0,0 +1,204 @@
{-# LANGUAGE LambdaCase #-}
module LlvmIr (
LLVMType (..),
LLVMIr (..),
llvmIrToString,
LLVMValue (..),
LLVMComp (..),
Visibility (..),
) where
import Data.List (intercalate)
import TypeCheckerIr
-- | A datatype which represents some basic LLVM types
data LLVMType
= I1
| I8
| I32
| I64
| Ptr
| Ref LLVMType
| Function LLVMType [LLVMType]
| Array Int LLVMType
| CustomType Ident
instance Show LLVMType where
show :: LLVMType -> String
show = \case
I1 -> "i1"
I8 -> "i8"
I32 -> "i32"
I64 -> "i64"
Ptr -> "ptr"
Ref ty -> show ty <> "*"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
CustomType (Ident ty) -> ty
data LLVMComp
= LLEq
| LLNe
| LLUgt
| LLUge
| LLUlt
| LLUle
| LLSgt
| LLSge
| LLSlt
| LLSle
instance Show LLVMComp where
show :: LLVMComp -> String
show = \case
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
LLUle -> "ule"
LLSgt -> "sgt"
LLSge -> "sge"
LLSlt -> "slt"
LLSle -> "sle"
data Visibility = Local | Global
instance Show Visibility where
show :: Visibility -> String
show Local = "%"
show Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant
data LLVMValue
= VInteger Integer
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
instance Show LLVMValue where
show :: LLVMValue -> String
show v = case v of
VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM
data LLVMIr
= Define LLVMType Ident Params
| DefineEnd
| Declare LLVMType Ident Params
| SetVariable Ident LLVMIr
| Variable Ident
| Add LLVMType LLVMValue LLVMValue
| Sub LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
| Br Ident
| BrCond LLVMValue Ident Ident
| Label Ident
| Call LLVMType Visibility Ident Args
| Alloca LLVMType
| Store LLVMType Ident LLVMType Ident
| Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue
| Comment String
| UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place
deriving (Show)
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
where
go :: Int -> [LLVMIr] -> String
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
insToString n x <> go i' xs
-- | Converts a LLVM inststruction to a String, allowing for printing etc.
-- The integer represents the indentation
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of
(Define t (Ident i) params) ->
concat
[ "define ", show t, " @", i
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
, ") {\n"
]
DefineEnd -> "}\n"
(Declare _t (Ident _i) _params) -> undefined
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
(Add t v1 v2) ->
concat
[ "add ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Sub t v1 v2) ->
concat
[ "sub ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Div t v1 v2) ->
concat
[ "sdiv ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Mul t v1 v2) ->
concat
[ "mul ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Srem t v1 v2) ->
concat
[ "srem ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Call t vis (Ident i) arg) ->
concat
[ "call ", show t, " ", show vis, i, "("
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
, ")\n"
]
(Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) ->
concat
[ "store ", show t1, " %", id1
, ", ", show t2 , " %", id2, "\n"
]
(Bitcast t1 (Ident i) t2) ->
concat
[ "bitcast ", show t1, " %"
, i, " to ", show t2, "\n"
]
(Icmp comp t v1 v2) ->
concat
[ "icmp ", show comp, " ", show t
, " ", show v1, ", ", show v2, "\n"
]
(Ret t v) ->
concat
[ "ret ", show t, " "
, show v, "\n"
]
(UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) ->
concat
[ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n"
]
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id

View file

@ -1,30 +1,97 @@
{-# LANGUAGE LambdaCase #-}
module Main where
import Control.Monad.Except (runExcept)
import Grammar.Par (myLexer, pProgram)
import Interpreter (interpret)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import Compiler (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
-- import Interpreter (interpret)
import LambdaLifter (lambdaLift)
import Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker (typecheck)
main :: IO ()
main = getArgs >>= \case
[] -> print "Required file path missing"
(x:_) -> do
file <- readFile x
case pProgram (myLexer file) of
Left err -> do
putStrLn "SYNTAX ERROR"
putStrLn err
exitFailure
Right prg -> case runExcept $ interpret prg of
Left err -> do
putStrLn "INTERPRETER ERROR"
putStrLn err
exitFailure
Right i -> do
print i
exitSuccess
main =
getArgs >>= \case
[] -> print "Required file path missing"
(s : _) -> main' s
main' :: String -> IO ()
main' s = do
file <- readFile s
printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file
printToErr $ printTree parsed
printToErr "\n-- Renamer --"
let renamed = rename parsed
printToErr $ printTree renamed
printToErr "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ compile lifted
putStrLn compiled
writeFile "llvm.ll" compiled
-- interpred <- fromInterpreterErr $ interpret lifted
-- putStrLn "\n-- interpret"
-- print interpred
exitSuccess
printToErr :: String -> IO ()
printToErr = hPutStrLn stderr
fromCompilerErr :: Err a -> IO a
fromCompilerErr =
either
( \err -> do
putStrLn "\nCOMPILER ERROR"
putStrLn err
exitFailure
)
pure
fromSyntaxErr :: Err a -> IO a
fromSyntaxErr =
either
( \err -> do
putStrLn "\nSYNTAX ERROR"
putStrLn err
exitFailure
)
pure
fromTypeCheckerErr :: Err a -> IO a
fromTypeCheckerErr =
either
( \err -> do
putStrLn "\nTYPECHECKER ERROR"
putStrLn err
exitFailure
)
pure
fromInterpreterErr :: Err a -> IO a
fromInterpreterErr =
either
( \err -> do
putStrLn "\nINTERPRETER ERROR"
putStrLn err
exitFailure
)
pure

84
src/Renamer.hs Normal file
View file

@ -0,0 +1,84 @@
{-# LANGUAGE LambdaCase #-}
module Renamer (module Renamer) where
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind
renameSc old_names (Bind name t _ parms rhs) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs'
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }
deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
EInt i1 -> pure (old_names, EInt i1)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ELet b e -> do
(new_names, b) <- renameLocalBind old_names b
(new_names', e') <- renameExp new_names e
pure (new_names', ELet b e')
EAbs par t e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -0,0 +1,178 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module TypeChecker (typecheck, partitionType) where
import Auxiliary (maybeToRightM, snoc)
import Control.Monad.Except (throwError, unless)
import Data.Map (Map)
import qualified Data.Map as Map
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (Print (prt), concatD, doc, printTree,
render)
import Prelude hiding (exp, id)
import qualified TypeCheckerIr as T
-- NOTE: this type checker is poorly tested
-- TODO
-- Coercion
-- Type inference
data Cxt = Cxt
{ env :: Map Ident Type -- ^ Local scope signature
, sig :: Map Ident Type -- ^ Top-level signatures
}
initCxt :: [Bind] -> Cxt
initCxt sc = Cxt { env = mempty
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
}
typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
-- | Check if infered rhs type matches type signature.
checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t
-- | Infer type of expression.
infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EInt i -> pure (T.EInt i, T.TInt)
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.EAdd T.TInt e' e1', T.TInt)
EAbs x t e -> do
(e', t1) <- infer (insertEnv x t cxt) e
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e
pure (T.ELet b' e', t)
EAnn e t -> do
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
EId x -> do
t <- case lookupEnv x cxt of
Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
EAbs x t e -> do
(e', t_e) <- infer (insertEnv x t cxt) e
let t1 = TFun t t_e
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env
insertEnv :: Ident -> Type -> Cxt -> Cxt
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
lookupSig :: Ident -> Cxt -> Maybe Type
lookupSig x = Map.lookup x . sig
typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual
]

100
src/TypeCheckerIr.hs Normal file
View file

@ -0,0 +1,100 @@
{-# LANGUAGE LambdaCase #-}
module TypeCheckerIr
( module Grammar.Abs
, module TypeCheckerIr
) where
import Grammar.Abs (Ident (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Id
| EInt Integer
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name
, doc $ showString ";"
, prt 0 n
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet bs e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtIdP 0 n
, doc $ showString "."
, prt 0 e
]

View file

@ -1,5 +0,0 @@
main = (\x -> x + x + 3) ((\x -> x) 2)