diff --git a/.gitignore b/.gitignore index 0b0f588..897dce2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ dist-newstyle *.y *.x *.bak +Grammar.tex src/Grammar language @@ -13,4 +14,10 @@ src/GC/lib/*.o src/GC/lib/*.so src/GC/lib/*.a src/GC/tests/*.out -src/GC/tests/logs \ No newline at end of file +src/GC/tests/logs +test_program_result +output/ +*.o +*.out +*.aux +*.log diff --git a/Grammar.cf b/Grammar.cf index 0b4785f..2db8b14 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,25 +1,103 @@ -Program. Program ::= [Bind]; +------------------------------------------------------------------------------- +-- * TOP-LEVEL +------------------------------------------------------------------------------- -EId. Exp3 ::= Ident; -EInt. Exp3 ::= Integer; -EAnn. Exp3 ::= "(" Exp ":" Type ")"; -ELet. Exp3 ::= "let" Bind "in" Exp; -EApp. Exp2 ::= Exp2 Exp3; -EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident ":" Type "." Exp; +DBind. Def ::= Bind; +DSig. Def ::= Sig; +DData. Def ::= Data; -Bind. Bind ::= Ident ":" Type ";" - Ident [Ident] "=" Exp; +internal Sig. Sig ::= LIdent ":" Type; + SigS. Sig ::= VarName ":" Type; +internal Bind. Bind ::= LIdent [LIdent] "=" Exp; + BindS. Bind ::= VarName [LIdent] "=" Exp; -separator Bind ";"; -separator Ident ""; +------------------------------------------------------------------------------- +-- * Types +------------------------------------------------------------------------------- -coercions Exp 3; +internal TLit. Type3 ::= UIdent; -- τ + TIdent. Type3 ::= UIdent; + TVar. Type3 ::= TVar; -- α + TApp. Type2 ::= Type2 Type3 ; + TFun. Type1 ::= Type1 "->" Type; -- A → A + TAll. Type ::= "forall" TVar "." Type; -- ∀α. A +internal TEVar. Type1 ::= TEVar; -- ά +internal TData. Type1 ::= UIdent "(" [Type] ")"; -- D () -TInt. Type1 ::= "Int" ; -TPol. Type1 ::= Ident ; -TFun. Type ::= Type1 "->" Type ; -coercions Type 1 ; + MkTVar. TVar ::= LIdent; +internal MkTEVar. TEVar ::= LIdent; + +------------------------------------------------------------------------------- +-- * DATA TYPES +------------------------------------------------------------------------------- + +Data. Data ::= "data" Type "where" "{" [Inj] "}" ; + +Inj. Inj ::= UIdent ":" Type ; + +------------------------------------------------------------------------------- +-- * PATTERN MATCHING +------------------------------------------------------------------------------- + +Branch. Branch ::= Pattern "=>" Exp ; + +PVar. Pattern1 ::= LIdent; +PLit. Pattern1 ::= Lit; +PCatch. Pattern1 ::= "_"; +PEnum. Pattern1 ::= UIdent; +PInj. Pattern ::= UIdent [Pattern1]; + +------------------------------------------------------------------------------- +-- * Expressions +------------------------------------------------------------------------------- + +internal EVar. Exp4 ::= LIdent; + EVarS. Exp4 ::= VarName ; + EInj. Exp4 ::= UIdent; + ELit. Exp4 ::= Lit; + EApp. Exp3 ::= Exp3 Exp4; + EAdd. Exp2 ::= Exp2 "+" Exp3; + ELet. Exp1 ::= "let" Bind "in" Exp1; + EAbs. Exp1 ::= "\\" LIdent "." Exp1; + ECase. Exp1 ::= "case" Exp "of" "{" [Branch] "}"; + EAnn. Exp ::= Exp1 ":" Type; + +VSymbol. VarName ::= "." Symbol; +VIdent. VarName ::= LIdent; + +infixSymbol. Exp2 ::= Exp2 Symbol Exp3 ; +define infixSymbol e1 vn e3 = EApp (EApp (EVarS (VSymbol vn)) e1) e3; + +------------------------------------------------------------------------------- +-- * LITERALS +------------------------------------------------------------------------------- + +LInt. Lit ::= Integer; +LChar. Lit ::= Char; + +------------------------------------------------------------------------------- +-- * AUX +------------------------------------------------------------------------------- + +layout "of", "where"; +layout toplevel; + +separator Def ";"; +separator Branch ";" ; +separator Inj ";"; + +separator LIdent ""; +separator Type " "; +separator TVar " "; +separator nonempty Pattern1 " "; + +coercions Pattern 1; +coercions Exp 4; +coercions Type 3 ; + +token UIdent (upper (letter | digit | '_')*) ; +token LIdent (lower (letter | digit | '_')*) ; +token Symbol (["@#%^&*_-+=|?/<>,•:[]"]+) ; comment "--"; comment "{-" "-}"; diff --git a/Grammar.pdf b/Grammar.pdf new file mode 100644 index 0000000..f7f7a70 Binary files /dev/null and b/Grammar.pdf differ diff --git a/Justfile b/Justfile new file mode 100644 index 0000000..a7acacd --- /dev/null +++ b/Justfile @@ -0,0 +1,35 @@ +# build from scratch +build: + bnfc -o src -d Grammar.cf + cabal install --installdir=. --overwrite-policy=always + +# clean the generated directories +clean: + rm -r src/Grammar + rm language + rm -r dist-newstyle/ + +# run all tests +test: + cabal test + +debug FILE: + cabal run language -- -d {{FILE}} + +hm FILE: + cabal run language -- -t hm {{FILE}} + +bi FILE: + cabal run language -- -t bi {{FILE}} + +hmd FILE: + cabal run language -- -d -t hm {{FILE}} + +bid FILE: + cabal run language -- -d -t bi {{FILE}} + +hmdm FILE: + cabal run language -- -d -t hm -m {{FILE}} + +bidm FILE: + cabal run language -- -d -t bi -m {{FILE}} \ No newline at end of file diff --git a/Makefile b/Makefile index e63a1e6..6c1ebde 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ language : src/Grammar/Test cabal install --installdir=. --overwrite-policy=always -src/Grammar/Test.hs src/Grammar/Lex.x src/Grammar/Par.y : Grammar.cf +src/Grammar/Test.hs src/Grammar/Lex.x src/Grammar/Par.y src/Grammar/Layout : Grammar.cf bnfc -o src -d $< src/Grammar/Par.hs : src/Grammar/Par.y @@ -15,23 +15,25 @@ src/Grammar/Lex.hs : src/Grammar/Lex.x src/Grammar/%.y : Grammar.cf bnfc -o src -d $< -src/Grammar/Test : src/Grammar/Test.hs src/Grammar/Par.hs src/Grammar/Lex.hs - ghc src/Grammar/Test.hs src/Grammar/Par.hs src/Grammar/Lex.hs src/Grammar/Abs.hs src/Grammar/Skel.hs src/Grammar/Print.hs -o src/Grammar/test +src/Grammar/Test : src/Grammar/Test.hs src/Grammar/Par.hs src/Grammar/Lex.hs src/Grammar/Layout + ghc src/Grammar/Test.hs src/Grammar/Par.hs src/Grammar/Lex.hs src/Grammar/Abs.hs src/Grammar/Skel.hs src/Grammar/Print.hs src/Grammar/Layout -o src/Grammar/test + +Grammar.tex : + bnfc --latex Grammar.cf + +Grammar.pdf : Grammar.tex + pdflatex Grammar.tex + rm Grammar.aux Grammar.log + +pdf : Grammar.pdf clean : rm -r src/Grammar rm language + rm -rf dist-newstyles + rm Grammar.aux Grammar.fdb_latexmk Grammar.fls Grammar.log Grammar.synctex.gz Grammar.tex test : - ./language ./sample-programs/basic-1 - ./language ./sample-programs/basic-2 - ./language ./sample-programs/basic-3 - ./language ./sample-programs/basic-4 - ./language ./sample-programs/basic-5 - ./language ./sample-programs/basic-5 - ./language ./sample-programs/basic-6 - ./language ./sample-programs/basic-7 - ./language ./sample-programs/basic-8 - ./language ./sample-programs/basic-9 + cabal v2-test # EOF diff --git a/README.md b/README.md index 1cfb72a..7cb234e 100644 --- a/README.md +++ b/README.md @@ -1 +1,244 @@ -# language \ No newline at end of file +# Build +First generate the parser using [BNFC](https://bnfc.digitalgrammars.com/), +this is done using the command `bnfc -o src -d Grammar.cf` + +Churf can then be built using `cabal install` + +Using the tool [make](https://www.gnu.org/software/make/) the entire thing can be built by running `make` +or using [just](https://github.com/casey/just), `just build` + +# Dependencies +If you have Nix installed, simply run `nix-shell --pure shell.nix` to get into an environment +with the right versions of packages. Then run `make` and the compiler should build. + +# Compiling a program + +Using the Hindley-Milner type checker: `./language -t hm example.crf` + +Using the bidirectional type checker: `./language -t bi example.crf` + +The program to compile has to have the file extension `.crf` +# Syntax and quirks + +See Grammar.pdf for the full syntax. + +The syntactic requirements differ a bit using the different type checkers. +The bidirectional type checker require explicit `forall` everywhere a type +forall quantified type variable is declared. In the Hindley-Milner type checker +all type variables are assumed to be forall quantified. + +Currently for the code generator and monomorphizer to work correctly it is +expected that the function `main` exist with either explicitly given type `Int` +or inferrable. + +Single line comments are written using `--` +Multi line comments are written using `{-` and `-}` + +Braches and semicolons are optional. + +## Program + +A program is a list of defs separated by semicolons, which in turn is either a bind, a signature, or a data types +`Program ::= [Def]` + +```hs +data Test () where + Test : Test () +test : Int +test = 0 +``` + +## Bind + +A bind is a name followed by a white space separated list of arguments, then an equal sign followed by an expression. +Both name and arguments have to start with lower case letters + +`Bind ::= LIdent [LIdent] "=" Exp` + +```hs +example x y = x + y +``` + +## Signature +A signature is a name followed by a colon and then the type +The name has to start with a lowe case letter + +`Sig ::= LIdent ":" Type` + +```hs +const : a -> b -> a +``` + +## Data type +A data type is declared as follows + +`Data ::= "data" Type "where" "{" [Inj] "}"` + +The words in quotes are necessary keywords +The type can be any type for parsing, but only `TData` will type check. + +The list of Inj is separated by white space. Using new lines is recommended for ones own sanity. + +```hs +data Maybe (a) where + Nothing : Maybe (a) + Just : a -> Maybe (a) +``` +The parens are necessary for every data type to make the grammar unambiguous. +Thus in `data Bool () where ...` the parens *do* *not* represent Unit + +### Inj +An inj is a constructor for the data type + +It is declared like a signature, except the name has to start with a lower case letter. +The return type of the constructor also has match the type of the data type to type check. + +`Inj ::= UIdent ":" Type` + +## Type + +A type can be either a type literal, type variable, function type, explicit forall quantified type or a type representing a data type +A type literal have to start with an upper case letter, type variables have to start with a lower case letter, +data types have to start with an upper case letter, a function type is two types separated by an arrow (arrows right associative), +and foralls take one type variable followed by a type. + +`TLit ::= UIdent` + +`TVar ::= LIdent` + +`TData ::= UIdent "(" [Type] ")"` + +`TFun ::= Type "->" Type` + +`TAll ::= "forall" LIdent "." Type` + +```hs +exampleLit : Int +exampleVar : a +exampleData : Maybe (a) +exampleFun : Int -> a +exampleAll : forall a. forall b. a -> b +``` + +## Expressions + +There are a couple different expressions, probably best explained by their rules + +Type annotated expression + +`EAnn ::= "(" Exp ":" Type ")"` + +Variable + +`EVar ::= LIdent` +```hs +x +``` + +Constructor + +`EInj ::= UIdent` +```hs +Just +``` + +Literal + +`ELit ::= Lit` +```hs +0 +``` + +Function application + +`EApp ::= Exp2 Exp3` +```hs +f 0 +``` + +Addition + +`EAdd ::= Exp1 "+" Exp2` +```hs +3 + 5 +``` + +Let expression + +`ELet ::= "let" Bind "in" Exp ` +```hs +let f x = x in f 0 +``` + +Abstraction, known as lambda or closure + +`EAbs ::= "\\" LIdent "." Exp` +```hs +\x. x +``` + +Case expression consist of a list semicolon separated list of Branches + +`ECase ::= "case" Exp "of" "{" [Branch] "}"` + +```hs +case xs of + Cons x xs => 1 + Nil => 0 +``` + +### Branch +A branch is a pattern followed by the fat arrow and then an expression + +`Branch ::= Pattern "=>" Exp` + +### Pattern +A pattern can be either a variable, literal, a wildcard represented by `_`, an enum constructor (constructor with zero arguments) +, or a constructor followed by a recursive list of patterns. + +Variable match + +`PVar ::= LIdent` + +The x in the following example +```hs +x => 0 +``` +Literal match + +`PLit ::= Lit` + +The 1 in the following example +```hs +1 => 0 +``` +A wildcard match + +`PCatch ::= "_"` + +The underscore in the following example +```hs +_ => 0 +``` +A constructor without arguments + +`PEnum ::= UIdent` + +The Nothing in the following example +```hs +Nothing => 0 +``` +The recursive match on a constructor + +`PInj ::= UIdent [Pattern1]` + +The outer Just represents the UIdent and the rest is the recursive match +```hs +Just (Just 0) => 1 +``` + +For simplicity sake a user does not need to consider these last two cases as different in parsing. +We allow arbitrarily deep pattern matching. + +## Literal +We currently allow two different literals: Integer and Char diff --git a/benchmark.py b/benchmark.py new file mode 100755 index 0000000..40f0a15 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,21 @@ +#!/bin/env/python3 + +import sys +import os +import time + +if __name__ == "__main__": + args = sys.argv + if len(args) == 1: + print ("first arg is number of loops second is exe") + else: + total = 0 + iter = int(args[1]) + for i in range(iter): + time_pre = time.time() + os.system("./" + args[2] + "> /dev/null") + time_post = time.time() + calc = time_post - time_pre + total += calc + + print ("File: " + args[2] + ", " + str(iter) + " runs gave average: " + str(total / iter) + "s") diff --git a/benchmark.txt b/benchmark.txt new file mode 100644 index 0000000..c12461e --- /dev/null +++ b/benchmark.txt @@ -0,0 +1,9 @@ +# Full optimization Churf +File: output/hello_world, 100 runs gave average: 0.025261127948760988s + +# O2 Haskell +File: ./Bench, 100 runs gave average: 0.05629507303237915s + +# 03 Haskell +File: ./Bench, 100 runs gave average: 0.05490849256515503s +File: ./Bench, 100 runs gave average: 0.05323728561401367s diff --git a/cabal.project.local b/cabal.project.local deleted file mode 100644 index 0432756..0000000 --- a/cabal.project.local +++ /dev/null @@ -1,2 +0,0 @@ -ignore-project: False -tests: True diff --git a/fourmolu.yaml b/fourmolu.yaml index f15300e..8b96b58 100644 --- a/fourmolu.yaml +++ b/fourmolu.yaml @@ -1,14 +1 @@ -indentation: 4 -function-arrows: trailing -comma-style: leading -import-export-style: diff-friendly indent-wheres: false -record-brace-space: false -newlines-between-decls: 1 -haddock-style: multi-line -haddock-style-module: -let-style: auto -in-style: right-align -respectful: true -fixities: [] -unicode: never diff --git a/language.cabal b/language.cabal index 8b958a5..af7178c 100644 --- a/language.cabal +++ b/language.cabal @@ -12,11 +12,9 @@ build-type: Simple extra-doc-files: CHANGELOG.md - extra-source-files: Grammar.cf - common warnings ghc-options: -W @@ -32,14 +30,33 @@ executable language Grammar.Print Grammar.Skel Grammar.ErrM - LambdaLifter + Grammar.ErrM + Grammar.Layout Auxiliary - Renamer - TypeChecker - TypeCheckerIr --- Interpreter + Renamer.Renamer + TypeChecker.TypeChecker + AnnForall + OrderDefs + TypeChecker.TypeCheckerHm + TypeChecker.TypeCheckerBidir + TypeChecker.TypeCheckerIr + TypeChecker.ReportTEVar + TypeChecker.RemoveForall + LambdaLifter + Monomorphizer.Monomorphizer + Monomorphizer.MonomorphizerIr + Monomorphizer.MorbIr + Monomorphizer.DataTypeRemover + Codegen.Codegen + Codegen.LlvmIr + Codegen.Auxillary + Codegen.CompilerState + Codegen.Emits Compiler - LlvmIr + Renamer.Renamer + TreeConverter + Desugar.Desugar + hs-source-dirs: src build-depends: @@ -47,6 +64,65 @@ executable language , mtl , containers , either - , array , extra + , array + , hspec + , QuickCheck + , directory + , process + default-language: GHC2021 + +Test-suite language-testsuite + type: exitcode-stdio-1.0 + main-is: Main.hs + + other-modules: + TestTypeCheckerBidir + TestTypeCheckerHm + TestAnnForall + TestReportForall + TestRenamer + TestLambdaLifter + DoStrings + + Grammar.Abs + Grammar.Lex + Grammar.Par + Grammar.Print + Grammar.Skel + Grammar.ErrM + Grammar.Layout + OrderDefs + Auxiliary + Monomorphizer.Monomorphizer + Monomorphizer.MonomorphizerIr + Renamer.Renamer + TypeChecker.TypeChecker + AnnForall + ReportForall + TypeChecker.TypeCheckerHm + TypeChecker.TypeCheckerBidir + TypeChecker.ReportTEVar + TypeChecker.RemoveForall + TypeChecker.TypeCheckerIr + Compiler + + hs-source-dirs: src, tests + + build-depends: + base >=4.16 + , mtl + , containers + , either + , extra + , array + , hspec + , QuickCheck + , process + , bytestring + , hspec + , directory + + default-language: GHC2021 + diff --git a/pipeline.txt b/pipeline.txt new file mode 100644 index 0000000..1872562 --- /dev/null +++ b/pipeline.txt @@ -0,0 +1,27 @@ + + Parser + | + ReportForall Report unnecessary foralls. Hm: report rank>2 foralls + | + AnnotateForall Annotate all unbound type variables with foralls + | + Renamer Rename type variables and term variables + | + / \ + / \ + TypeCheckHm TypeCheckBi + \ / + \ / + | + ReportTEVar Report type existential variables and change type AST + | + RemoveForall RemoveForall and change type AST + | + Monomorpher + | + Desugar + | + CodeGen + + + diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 deleted file mode 100644 index f0cdcc4..0000000 --- a/sample-programs/basic-1 +++ /dev/null @@ -1,21 +0,0 @@ - --- tripplemagic : Int -> Int -> Int -> Int; --- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; --- main : Int; --- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 --- answer: 22 - --- apply : (Int -> Int) -> Int -> Int; --- apply f x = f x; --- main : Int; --- main = apply (\x : Int . x + 5) 5 --- answer: 10 - -apply : (Int -> Int -> Int) -> Int -> Int -> Int; -apply f x y = f x y; -krimp: Int -> Int -> Int; -krimp x y = x + y; -main : Int; -main = apply (krimp) 2 3; --- answer: 5 - diff --git a/sample-programs/bubble-sort.chrf b/sample-programs/bubble-sort.chrf new file mode 100644 index 0000000..59e6598 --- /dev/null +++ b/sample-programs/bubble-sort.chrf @@ -0,0 +1,11 @@ +data List (a) where + Cons : a -> List (a) -> List (a) + Nil : List (a) + +bubblesort : List (a) -> List (a) +bubblesort xs = case xs of + Nil => Nil + Cons x => case x of + Nil => Cons x Nil + Cons y => + diff --git a/sample-programs/insertion-sort.chrf b/sample-programs/insertion-sort.chrf new file mode 100644 index 0000000..fc61691 --- /dev/null +++ b/sample-programs/insertion-sort.chrf @@ -0,0 +1,30 @@ +data List (a) where + Nil : List (a) + Cons : a -> List (a) -> List (a) + +insert : Int -> List (Int) -> List (Int) +insert x xs = case xs of + Cons z zs => case (lt x z) of + True => Cons x (Cons z zs) + False => Cons z (insert x zs) + Nil => Cons x Nil + +insertionSort : List (Int) -> List (Int) +insertionSort xs = case xs of + Cons y ys => case ys of + _ => insert y (insertionSort ys) + Nil => xs + Nil => Nil + +main = head (insertionSort (revRange 1250)) + +head xs = case xs of + Cons x _ => x + +revRange x = case x of + 0 => Cons x Nil + x => Cons x (revRange (x + minusOne)) + +-- represents minus one :) +minusOne : Int ; +minusOne = 9223372036854775807 + 9223372036854775807 + 1; \ No newline at end of file diff --git a/sample-programs/loop.crf b/sample-programs/loop.crf new file mode 100644 index 0000000..e3c3c38 --- /dev/null +++ b/sample-programs/loop.crf @@ -0,0 +1,18 @@ +main = for 0 1000 + +for x n = case n of + 0 => 0 + n => for (revRange 1000) (n + minusOne) + +data List (a) where + Nil : List (a) + Cons : a -> List (a) -> List (a) + +-- create a list of x to 0 +revRange x = case x of + 0 => Cons x Nil + x => Cons x (revRange (x + minusOne)) + +-- represents minus one :) +minusOne : Int ; +minusOne = 9223372036854775807 + 9223372036854775807 + 1; \ No newline at end of file diff --git a/sample-programs/lt_testing.crf b/sample-programs/lt_testing.crf new file mode 100644 index 0000000..5edc1c9 --- /dev/null +++ b/sample-programs/lt_testing.crf @@ -0,0 +1,3 @@ +main = case (lt 3 5) of + True => 1 + False => 0 diff --git a/sample-programs/mono-1.crf b/sample-programs/mono-1.crf new file mode 100644 index 0000000..c41e9b6 --- /dev/null +++ b/sample-programs/mono-1.crf @@ -0,0 +1,8 @@ +const2 : a -> b -> a +const2 x y = x + +f : a -> a +f x = (const2 x 'c') + +main = f 5 + diff --git a/sample-programs/mono-2.crf b/sample-programs/mono-2.crf new file mode 100644 index 0000000..76a92c2 --- /dev/null +++ b/sample-programs/mono-2.crf @@ -0,0 +1,17 @@ +data Either (a b) where + Left : a -> Either (a b) + Right : b -> Either (a b) + +unwrapLeft : Either (a b) -> a +unwrapLeft x = case x of + Left y => y + +unwrapRight : Either (a b) -> b +unwrapRight x = case x of + Right y => y + +wow : Either (Int Char) +wow = Left 5 + +main = unwrapLeft wow + diff --git a/sample-programs/mono-3.crf b/sample-programs/mono-3.crf new file mode 100644 index 0000000..a51df2c --- /dev/null +++ b/sample-programs/mono-3.crf @@ -0,0 +1,11 @@ +data Number() where + One: Number () + Two: Number () + +numberToInt : Number () -> Int +numberToInt n = case n of + One => 1 + Two => 2 + +main = numberToInt One + diff --git a/sample-programs/mono-4.chrf b/sample-programs/mono-4.chrf new file mode 100644 index 0000000..79d1495 --- /dev/null +++ b/sample-programs/mono-4.chrf @@ -0,0 +1,12 @@ +data Either (a b) where + Left : a -> Either (a b) + Right : b -> Either (a b) + +unwrap : Either (a a) -> a +unwrap x = case x of + Left y => y + Right y => y + +main : Int +main = unwrap (Left 3) + diff --git a/shell.nix b/shell.nix index 0af8c7b..c8cc7a8 100644 --- a/shell.nix +++ b/shell.nix @@ -6,15 +6,20 @@ pkgs.haskellPackages.developPackage { withHoogle = true; modifier = drv: pkgs.haskell.lib.addBuildTools drv ( - (with pkgs; [ hlint haskell-language-server ghc jasmin llvmPackages_15.libllvm]) + (with pkgs; [ hlint + haskell-language-server + ghc + jasmin + llvmPackages_15.libllvm +# texlive.combined.scheme-full + graphviz + ]) ++ - (with pkgs.haskellPackages; [ - cabal-install - stylish-haskell - BNFC - alex - happy - ]) - ); + (with pkgs.haskellPackages; [ cabal-install + stylish-haskell + BNFC + alex + happy + ])); } diff --git a/spec.txt b/spec.txt new file mode 100644 index 0000000..2273846 --- /dev/null +++ b/spec.txt @@ -0,0 +1,121 @@ +--------------------------------------------------------------------------- +-- * Parser +--------------------------------------------------------------------------- + +data Program = Program [Def] + +data Def = DSig Ident Type | DBind Bind + +data Bind = Bind Ident [Ident] Exp + +data Exp + = EId Ident + | ELit Lit + | EAnn Exp Type + | ELet Ident Exp Exp + | EApp Exp Exp + | EAdd Exp Exp + | EAbs Ident Exp + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + | TEVar TEVar -- ά (internal) + +data TVar = MkTVar Ident +data TEVar = MkTEVar Ident + +--------------------------------------------------------------------------- +-- * Type checker +--------------------------------------------------------------------------- + +-- • Def and DSig are removed in favor on just Bind +-- • Typed expressions +-- • TEVar is removed (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Bind ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + | EAbs Ident ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + +data TVar = MkTVar Ident + +--------------------------------------------------------------------------- +-- * Lambda lifter +--------------------------------------------------------------------------- +-- • EAbs are removed (NOT IMPLEMENTED) +-- • ELet only allow constant expressions (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Ident ExpT ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + +data TVar = MkTVar Ident + +--------------------------------------------------------------------------- +-- * Monomorpher +--------------------------------------------------------------------------- +-- • Polymorphic types are removed (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Ident ExpT ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + +data Lit = LInt Integer + | LChar Character + +data Type = Type Ident diff --git a/src/AnnForall.hs b/src/AnnForall.hs new file mode 100644 index 0000000..16222bd --- /dev/null +++ b/src/AnnForall.hs @@ -0,0 +1,100 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} + +module AnnForall (annotateForall) where + +import Auxiliary (partitionDefs) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.Except (throwError) +import Data.Function (on) +import Data.Set (Set) +import qualified Data.Set as Set +import Grammar.Abs +import Grammar.ErrM (Err) + +annotateForall :: Program -> Err Program +annotateForall (Program defs) = do + ds' <- mapM (fmap DData . annData) ds + bs' <- mapM (fmap DBind . annBind) bs + pure $ Program (ds' ++ ss' ++ bs') + where + ss' = map (DSig . annSig) ss + (ds, ss, bs) = partitionDefs defs + + +annData :: Data -> Err Data +annData (Data typ injs) = do + (typ', tvars) <- annTyp typ + pure (Data typ' $ map (annInj tvars) injs) + + where + annTyp typ = do + (bounded, ts) <- boundedTVars mempty typ + unbounded <- Set.fromList <$> mapM assertTVar ts + let diff = unbounded Set.\\ bounded + typ' = foldr TAll typ diff + (typ', ) . fst <$> boundedTVars mempty typ' + where + boundedTVars tvars typ = case typ of + TAll tvar t -> boundedTVars (Set.insert tvar tvars) t + TData _ ts -> pure (tvars, ts) + _ -> throwError "Misformed data declaration" + + assertTVar typ = case typ of + TVar tvar -> pure tvar + _ -> throwError $ unwords [ "Misformed data declaration:" + , "Non type variable argument" + ] + annInj tvars (Inj n t) = + Inj n $ foldr TAll t (unboundedTVars t Set.\\ tvars) + +annSig :: Sig -> Sig +annSig (Sig name typ) = Sig name $ annType typ + +annBind :: Bind -> Err Bind +annBind (Bind name vars exp) = Bind name vars <$> annExp exp + where + annExp = \case + EAnn e t -> flip EAnn (annType t) <$> annExp e + EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2) + EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2) + ELet bind e -> liftA2 ELet (annBind bind) (annExp e) + EAbs x e -> EAbs x <$> annExp e + ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs) + e -> pure e + annBranch (Branch p e) = Branch p <$> annExp e + +annType :: Type -> Type +annType typ = go $ unboundedTVars typ + where + go us + | null us = typ + | otherwise = foldr TAll typ us + +unboundedTVars :: Type -> Set TVar +unboundedTVars = unboundedTVars' mempty + +unboundedTVars' :: Set TVar -> Type -> Set TVar +unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded + where + tvars = gatherTVars typ + gatherTVars = \case + TAll tvar t -> TVars { bounded = Set.singleton tvar + , unbounded = unboundedTVars' (Set.insert tvar bs) t + } + TVar tvar -> uTVars $ Set.singleton tvar + TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2 + TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs + _ -> TVars { bounded = mempty, unbounded = mempty } + +data TVars = TVars + { bounded :: Set TVar + , unbounded :: Set TVar + } deriving (Eq, Show, Ord) + +uTVars :: Set TVar -> TVars +uTVars us = TVars + { bounded = mempty + , unbounded = us + } + diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index 735d804..22095aa 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -1,8 +1,18 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE Rank2Types #-} + module Auxiliary (module Auxiliary) where + import Control.Monad.Error.Class (liftEither) -import Control.Monad.Except (MonadError) +import Control.Monad.Except (MonadError, liftM2) import Data.Either.Combinators (maybeToRight) +import Data.List (foldl') +import Grammar.Abs +import Prelude hiding ((>>), (>>=)) + +(>>) a b = a ++ " " ++ b +(>>=) a f = f a snoc :: a -> [a] -> [a] snoc x xs = xs ++ [x] @@ -14,8 +24,52 @@ mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b]) mapAccumM f = go where go acc = \case - [] -> pure (acc, []) - x:xs -> do - (acc', x') <- f acc x - (acc'', xs') <- go acc' xs - pure (acc'', x':xs') + [] -> pure (acc, []) + x : xs -> do + (acc', x') <- f acc x + (acc'', xs') <- go acc' xs + pure (acc'', x' : xs') + +onMM :: Monad m => (b -> b -> m c) -> (a -> m b) -> a -> a -> m c +onMM f g x y = liftMM2 f (g x) (g y) + +onM :: Monad m => (b -> b -> c) -> (a -> m b) -> a -> a -> m c +onM f g x y = liftM2 f (g x) (g y) + +unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) +unzip4 = + foldl' + ( \(as, bs, cs, ds) (a, b, c, d) -> + (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d]) + ) + ([], [], [], []) + +liftMM2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c +liftMM2 f m1 m2 = do + x1 <- m1 + x2 <- m2 + f x1 x2 + +litType :: Lit -> Type +litType (LInt _) = int +litType (LChar _) = char + +int = TLit "Int" +char = TLit "Char" + +tupSequence :: Monad m => (m a, b) -> m (a, b) +tupSequence (ma, b) = (,b) <$> ma + +fst_ :: (a, b, c) -> a +snd_ :: (a, b, c) -> b +trd_ :: (a, b, c) -> c +snd_ (_, a, _) = a +fst_ (a, _, _) = a +trd_ (_, _, a) = a + +partitionDefs :: [Def] -> ([Data], [Sig], [Bind]) +partitionDefs defs = (datas, sigs, binds) + where + datas = [ d | DData d <- defs ] + sigs = [ s | DSig s <- defs ] + binds = [ b | DBind b <- defs ] diff --git a/src/CaseDesugar/CaseDesugar.hs b/src/CaseDesugar/CaseDesugar.hs new file mode 100644 index 0000000..e1db55e --- /dev/null +++ b/src/CaseDesugar/CaseDesugar.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE LambdaCase #-} + +module CaseDesugar.CaseDesugar (desuga) where + +import CaseDesugar.CaseDesugarIr qualified as CIR +import TypeChecker.TypeCheckerIr qualified as TIR + +desuga :: TIR.Program -> CIR.Program +desuga (TIR.Program x) = CIR.Program $ desugaDef <$> x + +desugaDef :: TIR.Def -> CIR.Def +desugaDef (TIR.DBind bin@TIR.Bind{}) = CIR.DBind $ desugaBind bin +desugaDef (TIR.DData dat@TIR.Data{}) = CIR.DData $ desugaData dat + +desugaData :: TIR.Data -> CIR.Data +desugaData (TIR.Data t injs) = CIR.Data (desugaType t) (desugaInj <$> injs) + +desugaType :: TIR.Type -> CIR.Type +desugaType (TIR.TLit (TIR.Ident s)) = CIR.TLit (CIR.Ident s) +desugaType (TIR.TVar tv) = CIR.TVar (desugaTVar tv) +desugaType (TIR.TData (TIR.Ident s) ts) = CIR.TData (CIR.Ident s) (desugaType <$> ts) +desugaType (TIR.TFun t1 t2) = CIR.TFun (desugaType t1) (desugaType t2) +desugaType (TIR.TAll _ t1) = desugaType t1 + +desugaTVar :: TIR.TVar -> CIR.TVar +desugaTVar (TIR.MkTVar (TIR.Ident s)) = CIR.MkTVar (CIR.Ident s) + +desugaInj :: TIR.Inj -> CIR.Inj +desugaInj (TIR.Inj (TIR.Ident s) t) = CIR.Inj (CIR.Ident s) (desugaType t) + +desugaId :: TIR.Id -> CIR.Id +desugaId (TIR.Ident s, t) = (CIR.Ident s, desugaType t) + +desugaBind :: TIR.Bind -> CIR.Bind +desugaBind (TIR.Bind id args exp) = + CIR.Bind (desugaId id) (desugaId <$> args) (desugaExpT exp) + +desugaExpT :: TIR.ExpT -> CIR.ExpT +desugaExpT (exp, t) = (desugaExp exp, desugaType t) + +desugaExp :: TIR.Exp -> CIR.Exp +desugaExp (TIR.EVar (TIR.Ident s)) = CIR.EVar (CIR.Ident s) +desugaExp (TIR.EInj (TIR.Ident s)) = CIR.EInj (CIR.Ident s) +desugaExp (TIR.ELit lit) = CIR.ELit lit +desugaExp (TIR.ELet b e) = CIR.ELet (desugaBind b) (desugaExpT e) +desugaExp (TIR.EApp e1 e2) = CIR.EApp (desugaExpT e1) (desugaExpT e2) +desugaExp (TIR.EAdd e1 e2) = CIR.EAdd (desugaExpT e1) (desugaExpT e2) +desugaExp (TIR.EAbs (TIR.Ident s) e) = CIR.EAbs (CIR.Ident s) (desugaExpT e) +desugaExp (TIR.ECase e branches) = CIR.ECase (desugaExpT e) (desugaBranches branches) + +desugaBranches :: [TIR.Branch] -> [CIR.Branch] +desugaBranches bs = do + let injections = filter (\case (TIR.Branch (TIR.PInj{}, _) _) -> True; _ -> False) bs + let patterns = filter (\case (TIR.Branch (TIR.PInj{}, _) _) -> True; _ -> False) bs + undefined + +desugaBranch :: TIR.Branch -> CIR.Branch +desugaBranch (TIR.Branch (TIR.PInj (TIR.Ident s) ps, pt) e) = do + undefined +desugaBranch (TIR.Branch (p, pt) e) = do + CIR.Branch + ( case p of + TIR.PVar id -> (CIR.PVar (desugaId id), desugaType pt) + TIR.PLit (lit, t) -> (CIR.PLit (lit, desugaType t), desugaType pt) + TIR.PCatch -> (CIR.PCatch, desugaType pt) + TIR.PEnum (TIR.Ident s) -> (CIR.PEnum (CIR.Ident s), desugaType pt) + ) + (desugaExpT e) + +{- +case (Tupli 5 5) of + Tupli 6 5 => 1 + Tupli _ x => 3 + x => 1 +=== +case (Tupli 5 5) of + Tupli x y => case x of + 6 => case y of + 5 => 1 + x => 3 + _ => case y of + x => 3 +-} \ No newline at end of file diff --git a/src/CaseDesugar/CaseDesugarIr.hs b/src/CaseDesugar/CaseDesugarIr.hs new file mode 100644 index 0000000..dd9864f --- /dev/null +++ b/src/CaseDesugar/CaseDesugarIr.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module CaseDesugar.CaseDesugarIr ( + module Grammar.Abs, + module CaseDesugar.CaseDesugarIr, +) where + +import Data.String (IsString) +import Grammar.Abs (Lit (..)) +import Grammar.Print +import Prelude +import Prelude qualified as C (Eq, Ord, Read, Show) + +newtype Program' t = Program [Def' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Def' t + = DBind (Bind' t) + | DData (Data' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Type + = TLit Ident + | TVar TVar + | TData Ident [Type] + | TFun Type Type + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Data' t = Data t [Inj' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Inj' t = Inj Ident t + deriving (C.Eq, C.Ord, C.Show, C.Read) + +newtype Ident = Ident String + deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) + +data Pattern' t + = PVar (Id' t) -- TODO should be Ident + | PLit (Lit, t) -- TODO should be Lit + | PCatch + | PEnum Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Exp' t + = EVar Ident + | EInj Ident + | ELit Lit + | ELet (Bind' t) (ExpT' t) + | EApp (ExpT' t) (ExpT' t) + | EAdd (ExpT' t) (ExpT' t) + | EAbs Ident (ExpT' t) + | ECase (ExpT' t) [Branch' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +newtype TVar = MkTVar Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id' t = (Ident, t) +type ExpT' t = (Exp' t, t) + +data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Branch' t = Branch (Pattern' t, t) (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Ident where + prt _ (Ident s) = doc $ showString s + +instance Print t => Print (Program' t) where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print t => Print (Bind' t) where + prt i (Bind sig@(name, _) parms rhs) = + prPrec i 0 $ + concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +prtSig :: Print t => Id' t -> Doc +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] + +instance Print t => Print (ExpT' t) where + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print t => Print [Bind' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Print t => Int -> [Id' t] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prt i) + +instance Print t => Print (Id' t) where + prt i (name, t) = + concatD + [ doc $ showString "(" + , prt i name + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print t => Print (Exp' t) where + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + EInj name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> + prPrec i 1 $ + concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs v e -> + prPrec i 0 $ + concatD + [ doc $ showString "\\" + , prt 0 v + , doc $ showString "." + , prt 0 e + ] + ECase e branches -> + prPrec i 0 $ + concatD + [ doc $ showString "case" + , prt 0 e + , doc $ showString "of" + , doc $ showString "{" + , prt 0 branches + , doc $ showString "}" + ] + +instance Print t => Print (Branch' t) where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print t => Print [Branch' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print t => Print (Def' t) where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print t => Print (Data' t) where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print t => Print (Inj' t) where + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + +instance Print t => Print (Pattern' t) where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + +instance Print t => Print [Def' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print [Type] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + +instance Print Type where + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) + TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + +instance Print TVar where + prt i (MkTVar ident) = prt i ident + +type Program = Program' Type +type Def = Def' Type +type Data = Data' Type +type Bind = Bind' Type +type Branch = Branch' Type +type Pattern = Pattern' Type +type Inj = Inj' Type +type Exp = Exp' Type +type ExpT = ExpT' Type +type Id = Id' Type +pattern DBind' id vars expt = DBind (Bind id vars expt) +pattern DData' typ injs = DData (Data typ injs) diff --git a/src/Codegen/Auxillary.hs b/src/Codegen/Auxillary.hs new file mode 100644 index 0000000..c95be39 --- /dev/null +++ b/src/Codegen/Auxillary.hs @@ -0,0 +1,51 @@ +module Codegen.Auxillary where + +import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) +import Control.Monad (foldM_) +import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..)) +import TypeChecker.TypeCheckerIr qualified as TIR + +type2LlvmType :: MIR.Type -> LLVMType +type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of + "Int" -> I64 + "Char" -> I8 + "Bool" -> I1 + _ -> CustomType id +type2LlvmType (MIR.TFun t xs) = do + let (t', xs') = function2LLVMType xs [type2LlvmType t] + Function t' xs' + where + function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) + function2LLVMType x s = (type2LlvmType x, s) + +getType :: ExpT -> LLVMType +getType (_, t) = type2LlvmType t + +extractTypeName :: MIR.Type -> TIR.Ident +extractTypeName (MIR.TLit id) = id +extractTypeName (MIR.TFun t xs) = + let (TIR.Ident i) = extractTypeName t + (TIR.Ident is) = extractTypeName xs + in TIR.Ident $ i <> "_$_" <> is + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VChar _) = I8 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +valueGetType (VFunction _ _ t) = t + +typeByteSize :: LLVMType -> Integer +typeByteSize I1 = 1 +typeByteSize I8 = 1 +typeByteSize I32 = 4 +typeByteSize I64 = 8 +typeByteSize Ptr = 8 +typeByteSize (Ref _) = 8 +typeByteSize (Function _ _) = 8 +typeByteSize (Array n t) = n * typeByteSize t +typeByteSize (CustomType _) = 8 + +enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () +enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs new file mode 100644 index 0000000..be92a35 --- /dev/null +++ b/src/Codegen/Codegen.hs @@ -0,0 +1,35 @@ +module Codegen.Codegen (generateCode) where + +import Codegen.CompilerState ( + CodeGenerator (instructions), + initCodeGenerator, + ) +import Codegen.Emits (compileScs) +import Codegen.LlvmIr as LIR (llvmIrToString) +import Control.Monad.State ( + execStateT, + ) +import Data.List (sortBy) +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit)) +import TypeChecker.TypeCheckerIr (Ident (..)) + +{- | Compiles an AST and produces a LLVM Ir string. + An easy way to actually "compile" this output is to + Simply pipe it to LLI +-} +generateCode :: MIR.Program -> Bool -> Err String +generateCode (MIR.Program scs) addGc = do + let tree = filter (not . detectPrelude) (sortBy lowData scs) + let codegen = initCodeGenerator addGc tree + llvmIrToString . instructions <$> execStateT (compileScs tree) codegen + +detectPrelude :: Def -> Bool +detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True +detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True +detectPrelude _ = False + +lowData :: Def -> Def -> Ordering +lowData (DData _) (DBind _) = LT +lowData (DBind _) (DData _) = GT +lowData _ _ = EQ \ No newline at end of file diff --git a/src/Codegen/CompilerState.hs b/src/Codegen/CompilerState.hs new file mode 100644 index 0000000..523cc54 --- /dev/null +++ b/src/Codegen/CompilerState.hs @@ -0,0 +1,147 @@ +module Codegen.CompilerState where + +import Auxiliary (snoc) +import Codegen.Auxillary (type2LlvmType, typeByteSize) +import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), + LLVMType) +import Control.Monad.State (StateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR +import qualified TypeChecker.TypeCheckerIr as TIR + +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map MIR.Id FunctionInfo + , customTypes :: Map LLVMType Integer + , constructors :: Map TIR.Ident ConstructorInfo + , variableCount :: Integer + , labelCount :: Integer + , gcEnabled :: Bool + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } + deriving (Show) +data ConstructorInfo = ConstructorInfo + { numArgsCI :: Int + , argumentsCI :: [Id] + , numCI :: Integer + , returnTypeCI :: MIR.Type + } + deriving (Show) + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1} + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState TIR.Ident +getNewVar = TIR.Ident . show <$> (increaseVarCount >> getVarCount) + +-- | Increses the label count and returns a label from the CodeGenerator state +getNewLabel :: CompilerState Integer +getNewLabel = do + modify (\t -> t{labelCount = labelCount t + 1}) + gets labelCount + +{- | Produces a map of functions infos from a list of binds, + which contains useful data for code generation. +-} +getFunctions :: [MIR.Def] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DBind (MIR.Bind id args _) : xs) = + (id, FunctionInfo{numArgs = length args, arguments = args}) + : go xs + go (_ : xs) = go xs + +createArgs :: [MIR.Type] -> [Id] +createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs + +{- | Produces a map of functions infos from a list of binds, + which contains useful data for code generation. +-} +getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo +getConstructors bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DData (MIR.Data t cons) : xs) = + fst + ( foldl + ( \(acc, i) (Inj id xs) -> + ( ( id + , ConstructorInfo + { numArgsCI = length (init . flattenType $ xs) + , argumentsCI = createArgs (init . flattenType $ xs) + , numCI = i + , returnTypeCI = t -- last . flattenType $ xs + } + ) + : acc + , i + 1 + ) + ) + ([], 0) + cons + ) + <> go xs + go (_ : xs) = go xs + +getTypes :: [MIR.Def] -> Map LLVMType Integer +getTypes bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs + go (_ : xs) = go xs + variantTypes fi = init $ map type2LlvmType (flattenType fi) + biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) + +initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator +initCodeGenerator addGc scs = + CodeGenerator + { instructions = defaultStart <> if addGc then gcStart else [] + , functions = getFunctions scs + , constructors = getConstructors scs + , customTypes = getTypes scs + , variableCount = 0 + , labelCount = 0 + , gcEnabled = addGc + } + +defaultStart :: [LLVMIr] +defaultStart = + [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" + , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + , UnsafeRaw "declare i32 @exit(i32 noundef)\n" + , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" + ] + +gcStart :: [LLVMIr] +gcStart = + [ UnsafeRaw "declare external void @cheap_init()\n" + , UnsafeRaw "declare external ptr @cheap_alloc(i64)\n" + , UnsafeRaw "declare external void @cheap_dispose()\n" + , UnsafeRaw "declare external ptr @cheap_the()\n" + , UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n" + , UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n" + ] diff --git a/src/Codegen/Emits.hs b/src/Codegen/Emits.hs new file mode 100644 index 0000000..bc19f87 --- /dev/null +++ b/src/Codegen/Emits.hs @@ -0,0 +1,392 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Codegen.Emits where + +import Codegen.Auxillary +import Codegen.CompilerState +import Codegen.LlvmIr as LIR +import Control.Applicative ((<|>)) +import Control.Monad (when) +import Control.Monad.State (gets, modify) +import Data.Bifunctor qualified as BI +import Data.Char (ord) +import Data.Coerce (coerce) +import Data.Map qualified as Map +import Data.Maybe (fromJust, fromMaybe, isNothing) +import Data.Tuple.Extra (dupe, first, second) +import Debug.Trace (trace, traceShow) +import Grammar.Print +import Monomorphizer.MonomorphizerIr as MIR +import TypeChecker.TypeCheckerIr qualified as TIR + +compileScs :: [MIR.Def] -> CompilerState () +compileScs [] = do + emit $ UnsafeRaw "\n" + -- as a last step create all the constructors + -- //TODO maybe merge this with the data type match? + c <- gets (Map.toList . constructors) + mapM_ + ( \(id, ci) -> do + let t = returnTypeCI ci + let t' = type2LlvmType t + let x = BI.second type2LlvmType <$> argumentsCI ci + emit $ Define FastCC t' id x + top <- getNewVar + ptr <- getNewVar + -- allocated the primary type + emit $ SetVariable top (Alloca t') + + -- set the first byte to the index of the constructor + emit $ + SetVariable ptr $ + GetElementPtr + t' + (Ref t') + (VIdent top I8) + I64 + (VInteger 0) + I32 + (VInteger 0) + emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr + + -- get a pointer of the correct type + ptr' <- getNewVar + emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) + cTypes <- gets customTypes + + enumerateOneM_ + ( \i (TIR.Ident arg_n, arg_t) -> do + let arg_t' = type2LlvmType arg_t + emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) + elemPtr <- getNewVar + emit $ + SetVariable + elemPtr + ( GetElementPtr + (CustomType id) + (Ref (CustomType id)) + (VIdent ptr' Ptr) + I64 + (VInteger 0) + I32 + (VInteger i) + ) + case Map.lookup arg_t' cTypes of + Just s -> do + emit $ Comment "Malloc and store" + heapPtr <- getNewVar + useGc <- gets gcEnabled + emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s) + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr + emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr + Nothing -> do + emit $ Comment "Just store" + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr + ) + (argumentsCI ci) + + -- load and return the constructed value + emit $ Comment "Return the newly constructed value" + load <- getNewVar + emit $ SetVariable load (Load t' Ptr top) + emit $ Ret t' (VIdent load t') + emit DefineEnd + emit $ UnsafeRaw "\n" + + modify $ \s -> s{variableCount = 0} + ) + c +compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do + let t_return = type2LlvmType . last . flattenType $ t + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define FastCC t_return name args' + useGc <- gets gcEnabled + when (name == "main") (mapM_ emit (firstMainContent useGc)) + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ lastMainContent useGc functionBody + else emit $ Ret t_return functionBody + emit DefineEnd + modify $ \s -> s{variableCount = 0} + compileScs xs +compileScs (MIR.DData (MIR.Data typ ts) : xs) = do + let (TIR.Ident outer_id) = extractTypeName typ + -- //TODO this could be extracted from the customTypes map + let variantTypes fi = init $ map type2LlvmType (flattenType fi) + let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) + emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] + typeSets <- gets customTypes + mapM_ + ( \(Inj inner_id fi) -> do + let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi + emit $ LIR.Type inner_id (I8 : types) + ) + ts + compileScs xs + +firstMainContent :: Bool -> [LLVMIr] +firstMainContent True = + [ UnsafeRaw "%prof = call ptr @cheap_the()\n" + , UnsafeRaw "call void @cheap_set_profiler(ptr %prof, i1 true)\n" + , UnsafeRaw "call void @cheap_profiler_log_options(ptr %prof, i64 30)\n" + , UnsafeRaw "call void @cheap_init()\n" + ] +firstMainContent False = [] + +lastMainContent :: Bool -> LLVMValue -> [LLVMIr] +lastMainContent True var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" + , UnsafeRaw "call void @cheap_dispose()\n" + , Ret I64 (VInteger 0) + ] +lastMainContent False var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" + , Ret I64 (VInteger 0) + ] + +compileExp :: ExpT -> CompilerState () +compileExp (MIR.ELit lit, _t) = emitLit lit +compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 +compileExp (MIR.EVar name, _t) = emitIdent name +compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 +compileExp (MIR.ELet bind e, _) = emitLet bind e +compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) + +emitLet :: MIR.Bind -> ExpT -> CompilerState () +emitLet (MIR.Bind id [] innerExp) e = do + evaled <- exprToValue innerExp + tempVar <- getNewVar + let t = type2LlvmType . snd $ innerExp + emit $ SetVariable tempVar (Alloca t) + emit $ Store (type2LlvmType . snd $ innerExp) evaled Ptr tempVar + emit $ SetVariable (fst id) (Load t Ptr tempVar) + compileExp e +emitLet b _ = error $ "Non empty argument list in let-bind " <> show b + +emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () +emitECased t e cases = do + let cs = snd <$> cases + let ty = type2LlvmType t + let rt = type2LlvmType (snd e) + vs <- exprToValue e + lbl <- getNewLabel + let label = TIR.Ident $ "escape_" <> show lbl + stackPtr <- getNewVar + emit $ SetVariable stackPtr (Alloca ty) + mapM_ (emitCases rt ty label stackPtr vs) cs + -- crashLbl <- TIR.Ident . ("crash_" <>) . show <$> getNewLabel + -- emit $ Label crashLbl + var_num <- getVarCount + emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef " <> show var_num <> ", i64 noundef 6)\n" + useGc <- gets gcEnabled + when useGc (emit . UnsafeRaw $ "call void @cheap_dispose()\n") + emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n" + mapM_ (const increaseVarCount) [0 .. 1] + emit $ Br label + emit $ Label label + res <- getNewVar + emit $ SetVariable res (Load ty Ptr stackPtr) + where + emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () + emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do + emit $ Comment "Inj" + cons <- gets constructors + let r = fromJust $ Map.lookup consId cons + + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + + consVal <- getNewVar + emit $ SetVariable consVal (ExtractValue rt vs 0) + + consCheck <- getNewVar + emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r)) + emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + + castPtr <- getNewVar + casted <- getNewVar + emit $ SetVariable castPtr (Alloca rt) + emit $ Store rt vs Ptr castPtr + emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) + enumerateOneM_ + ( \i c -> do + case c of + PVar (x, topT) -> do + let topT' = type2LlvmType topT + let botT' = CustomType (coerce consId) + emit . Comment $ "ident " <> toIr topT' + cTypes <- gets customTypes + if Map.member topT' cTypes + then do + deref <- getNewVar + emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) + emit $ SetVariable x (Load topT' Ptr deref) + else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) + PLit (_l, _t) -> error "Nested pattern matching to be implemented" + PInj _id _ps -> error "Nested pattern matching to be implemented" + PCatch -> pure () + PEnum _id -> error "Nested pattern matching to be implemented" + ) + cs + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + emit $ Label lbl_failPos + emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do + emit $ Comment "Plit" + let i' = case i of + MIR.LInt i -> VInteger i + MIR.LChar i -> VChar (ord i) + ns <- getNewVar + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i') + emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + emit $ Label lbl_failPos + emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do + emit $ Comment "Pvar" + -- //TODO this is pretty disgusting and would heavily benefit from a rewrite + valPtr <- getNewVar + emit $ SetVariable valPtr (Alloca rt) + emit $ Store rt vs Ptr valPtr + emit $ SetVariable id (Load rt Ptr valPtr) + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + emit $ Label lbl_failPos + emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do + emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp) + emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do + emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp) + emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do + emit $ Comment "Penum" + cons <- gets constructors + let r = Map.lookup consId cons + when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n") + + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + + consVal <- getNewVar + emit $ SetVariable consVal (ExtractValue rt vs 0) + + consCheck <- getNewVar + emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI (fromJust r))) + emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + + castPtr <- getNewVar + casted <- getNewVar + emit $ SetVariable castPtr (Alloca rt) + emit $ Store rt vs Ptr castPtr + emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + emit $ Label lbl_failPos + emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do + emit $ Comment "Pcatch" + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + emit $ Label lbl_failPos + +emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () +emitApp rt e1 e2 = appEmitter e1 e2 [] + where + appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState () + appEmitter e1 e2 stack = do + let newStack = e2 : stack + case e1 of + (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack + (MIR.EVar name, t) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + consts <- gets constructors + let visibility = + fromMaybe Local $ + Global <$ Map.lookup name consts + <|> Global <$ Map.lookup (name, t) funcs + -- this piece of code could probably be improved, i.e remove the double `const Global` + args' = map (first valueGetType . dupe) args + let call = + case name of + TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1)) + TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1)) + _ -> Call FastCC (type2LlvmType rt) visibility name args' + emit $ Comment $ show rt + emit $ SetVariable vs call + x -> error $ "The unspeakable happened: " <> show x + +emitIdent :: TIR.Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitLit :: MIR.Lit -> CompilerState () +emitLit i = do + -- !!this should never happen!! + let (i', t) = case i of + (MIR.LInt i'') -> (VInteger i'', I64) + (MIR.LChar i'') -> (VChar $ ord i'', I8) + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable varCount (Add t i' (VInteger 0)) + +emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable v (Add (type2LlvmType t) v1 v2) + +exprToValue :: ExpT -> CompilerState LLVMValue +exprToValue = \case + (MIR.ELit i, _t) -> pure $ case i of + (MIR.LInt i) -> VInteger i + (MIR.LChar i) -> VChar $ ord i + (MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1 + (MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0 + (MIR.EVar name, t) -> do + funcs <- gets functions + cons <- gets constructors + let res = + Map.lookup (name, t) funcs + <|> ( \c -> + FunctionInfo + { numArgs = numArgsCI c + , arguments = argumentsCI c + } + ) + <$> Map.lookup name cons + case res of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ + SetVariable + vc + (Call FastCC (type2LlvmType t) Global name []) + pure $ VIdent vc (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (TIR.Ident $ show v) (getType e) diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs new file mode 100644 index 0000000..cc77cf9 --- /dev/null +++ b/src/Codegen/LlvmIr.hs @@ -0,0 +1,271 @@ +{-# LANGUAGE LambdaCase #-} + +module Codegen.LlvmIr ( + LLVMType (..), + LLVMIr (..), + llvmIrToString, + LLVMValue (..), + LLVMComp (..), + Visibility (..), + CallingConvention (..), + ToIr (..), +) where + +import Data.List (intercalate) +import TypeChecker.TypeCheckerIr (Ident (..)) + +data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord) +instance ToIr CallingConvention where + toIr :: CallingConvention -> String + toIr TailCC = "tailcc" + toIr FastCC = "fastcc" + toIr CCC = "ccc" + toIr ColdCC = "coldcc" + +-- | A datatype which represents some basic LLVM types +data LLVMType + = I1 + | I8 + | I32 + | I64 + | Ptr + | Ref LLVMType + | Function LLVMType [LLVMType] + | Array Integer LLVMType + | CustomType Ident + deriving (Show, Eq, Ord) + +class ToIr a where + toIr :: a -> String + +instance ToIr LLVMType where + toIr :: LLVMType -> String + toIr = \case + I1 -> "i1" + I8 -> "i8" + I32 -> "i32" + I64 -> "i64" + Ptr -> "ptr" + Ref ty -> toIr ty <> "*" + Function t xs -> toIr t <> " (" <> intercalate ", " (map toIr xs) <> ")*" + Array n ty -> concat ["[", show n, " x ", toIr ty, "]"] + CustomType (Ident ty) -> "%" <> ty + +data LLVMComp + = LLEq + | LLNe + | LLUgt + | LLUge + | LLUlt + | LLUle + | LLSgt + | LLSge + | LLSlt + | LLSle + deriving (Show, Eq, Ord) +instance ToIr LLVMComp where + toIr :: LLVMComp -> String + toIr = \case + LLEq -> "eq" + LLNe -> "ne" + LLUgt -> "ugt" + LLUge -> "uge" + LLUlt -> "ult" + LLUle -> "ule" + LLSgt -> "sgt" + LLSge -> "sge" + LLSlt -> "slt" + LLSle -> "sle" + +data Visibility = Local | Global deriving (Show, Eq, Ord) +instance ToIr Visibility where + toIr :: Visibility -> String + toIr Local = "%" + toIr Global = "@" + +{- | Represents a LLVM "value", as in an integer, a register variable, +or a string contstant +-} +data LLVMValue + = VInteger Integer + | VChar Int + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType + deriving (Show, Eq, Ord) + +instance ToIr LLVMValue where + toIr :: LLVMValue -> String + toIr v = case v of + VInteger i -> show i + VChar i -> show i + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> toIr vis <> n + VConstant s -> "c" <> show s + +type Params = [(Ident, LLVMType)] +type Args = [(LLVMType, LLVMValue)] + +-- | A datatype which represents different instructions in LLVM +data LLVMIr + = Type Ident [LLVMType] + | Define CallingConvention LLVMType Ident Params + | DefineEnd + | Declare LLVMType Ident Params + | SetVariable Ident LLVMIr + | Variable Ident + | ExtractValue LLVMType LLVMValue Integer + | GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue + | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue + | Add LLVMType LLVMValue LLVMValue + | Sub LLVMType LLVMValue LLVMValue + | Div LLVMType LLVMValue LLVMValue + | Mul LLVMType LLVMValue LLVMValue + | Srem LLVMType LLVMValue LLVMValue + | Icmp LLVMComp LLVMType LLVMValue LLVMValue + | Br Ident + | BrCond LLVMValue Ident Ident + | Label Ident + | Call CallingConvention LLVMType Visibility Ident Args + | Alloca LLVMType + | Store LLVMType LLVMValue LLVMType Ident + | Load LLVMType LLVMType Ident + | Bitcast LLVMType LLVMValue LLVMType + | Ret LLVMType LLVMValue + | Comment String + | Malloc Integer + | GcMalloc Integer + | UnsafeRaw String -- This should generally be avoided, and proper + -- instructions should be used in its place + deriving (Show, Eq, Ord) + +-- | Converts a list of LLVMIr instructions to a string +llvmIrToString :: [LLVMIr] -> String +llvmIrToString = go 0 + where + go :: Int -> [LLVMIr] -> String + go _ [] = mempty + go i (x : xs) = do + let (i', n) = case x of + Define{} -> (i + 1, 0) + DefineEnd -> (i - 1, 0) + _ -> (i, i) + insToString n x <> go i' xs + +-- \| Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +-- +{- FOURMOLU_DISABLE -} + insToString :: Int -> LLVMIr -> String + insToString i l = + replicate i '\t' <> case l of + (GetElementPtr t1 t2 p t3 v1 t4 v2) -> do + -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 + concat + [ "getelementptr ", toIr t1, ", " , toIr t2 + , " ", toIr p, ", ", toIr t3, " ", toIr v1 + , ", ", toIr t4, " ", toIr v2, "\n" + ] + (ExtractValue t1 v i) -> do + concat + [ "extractvalue ", toIr t1, " " + , toIr v, ", ", show i, "\n" + ] + (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do + -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 + concat + [ "getelementptr inbounds ", toIr t1, ", " , toIr t2 + , " ", toIr p, ", ", toIr t3, " ", toIr v1, + ", ", toIr t4, " ", toIr v2, "\n" ] + (Type (Ident n) types) -> + concat + [ "%", n, " = type { " + , intercalate ", " (map toIr types) + , " }\n" + ] + (Define c t (Ident i) params) -> + concat + [ "define ", toIr c, " ", toIr t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [toIr x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", toIr t, " ", toIr v1 + , ", ", toIr v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", toIr t, " ", toIr v1, ", " + , toIr v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", toIr t, " ", toIr v1, ", " + , toIr v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", toIr t, " ", toIr v1 + , ", ", toIr v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", toIr t, " ", toIr v1, ", " + , toIr v2, "\n" + ] + (Call c t vis (Ident i) arg) -> + concat + [ "call ", toIr c, " ", toIr t, " ", toIr vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> toIr x <> " " <> toIr y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", toIr t, "\n"] + (Malloc t) -> + concat + [ "call ptr @malloc(i64 ", show t, ")\n"] + (GcMalloc t) -> + concat + [ "call ptr @cheap_alloc(i64 ", show t, ")\n"] + (Store t1 val t2 (Ident id2)) -> + concat + [ "store ", toIr t1, " ", toIr val + , ", ", toIr t2 , " %", id2, "\n" + ] + (Load t1 t2 (Ident addr)) -> + concat + [ "load ", toIr t1, ", " + , toIr t2, " %", addr, "\n" + ] + (Bitcast t1 v t2) -> + concat + [ "bitcast ", toIr t1, " " + , toIr v, " to ", toIr t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", toIr comp, " ", toIr t + , " ", toIr v1, ", ", toIr v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", toIr t, " " + , toIr v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" + (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", toIr val, ", ", "label %" + , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id +{- FOURMOLU_ENABLE -} + +lblPfx :: String +lblPfx = "lbl_" diff --git a/src/Compiler.hs b/src/Compiler.hs index fd6b6bc..3fb1fe1 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -1,266 +1,43 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} - module Compiler (compile) where -import Auxiliary (snoc) -import Control.Monad.State (StateT, execStateT, gets, modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (dupe, first, second) -import Grammar.ErrM (Err) -import LlvmIr (LLVMIr (..), LLVMType (..), - LLVMValue (..), Visibility (..), - llvmIrToString) -import TypeChecker (partitionType) -import TypeCheckerIr +import System.Process.Extra (readCreateProcess, shell) --- | The record used as the code generator state -data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo - , variableCount :: Integer - } +-- spawnWait s = spawnCommand s >>= \s >>= waitForProcess --- | A state type synonym -type CompilerState a = StateT CodeGenerator Err a +optimize :: String -> IO String +optimize = readCreateProcess (shell "opt --O3 --tailcallopt -S") -data FunctionInfo = FunctionInfo - { numArgs :: Int - , arguments :: [Id] - } +compileClang :: Bool -> String -> IO String +compileClang False = + readCreateProcess . shell $ + unwords + [ "clang++" -- , "-Lsrc/GC/lib/", "-l:libgcoll.a" + , "-fno-rtti" + , "-x" + , "ir" -- , "-Lsrc/GC/lib -l:gcoll.a" + , "-o" + , "output/hello_world" + , "-" + ] +compileClang True = + readCreateProcess . shell $ + unwords + [ "clang++" -- , "-Lsrc/GC/lib/", "-l:libgcoll.a" + , "-fno-rtti" + , "src/GC/lib/cheap.cpp" + , "src/GC/lib/event.cpp" + , "src/GC/lib/heap.cpp" + , "src/GC/lib/profiler.cpp" + , "-Wall -Wextra -g -std=gnu++20 -stdlib=libstdc++" + , "-O3" + --, "-tailcallopt" + , "-Isrc/GC/include" + , "-x" + , "ir" -- , "-Lsrc/GC/lib -l:gcoll.a" + , "-o" + , "output/hello_world" + , "-" + ] --- | Adds a instruction to the CodeGenerator state -emit :: LLVMIr -> CompilerState () -emit l = modify $ \t -> t { instructions = snoc l $ instructions t } - --- | Increases the variable counter in the CodeGenerator state -increaseVarCount :: CompilerState () -increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } - --- | Returns the variable count from the CodeGenerator state -getVarCount :: CompilerState Integer -getVarCount = gets variableCount - --- | Increases the variable count and returns it from the CodeGenerator state -getNewVar :: CompilerState Integer -getNewVar = increaseVarCount >> getVarCount - --- | Produces a map of functions infos from a list of binds, --- which contains useful data for code generation. -getFunctions :: [Bind] -> Map Id FunctionInfo -getFunctions bs = Map.fromList $ map go bs - where - go (Bind id args _) = - (id, FunctionInfo { numArgs=length args, arguments=args }) - - - -initCodeGenerator :: [Bind] -> CodeGenerator -initCodeGenerator scs = CodeGenerator { instructions = defaultStart - , functions = getFunctions scs - , variableCount = 0 - } - --- | Compiles an AST and produces a LLVM Ir string. --- An easy way to actually "compile" this output is to --- Simply pipe it to lli -compile :: Program -> Err String -compile (Program scs) = do - let codegen = initCodeGenerator scs - llvmIrToString . instructions <$> execStateT (compileScs scs) codegen - -compileScs :: [Bind] -> CompilerState () -compileScs [] = pure () -compileScs (Bind (name, t) args exp : xs) = do - emit $ UnsafeRaw "\n" - emit . Comment $ show name <> ": " <> show exp - let args' = map (second type2LlvmType) args - emit $ Define (type2LlvmType t_return) name args' - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit $ mainContent functionBody - else emit $ Ret I64 functionBody - emit DefineEnd - modify $ \s -> s { variableCount = 0 } - compileScs xs - where - t_return = snd $ partitionType (length args) t - -mainContent :: LLVMValue -> [LLVMIr] -mainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" - , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") - -- , Label (Ident "b_1") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (Ident "end") - -- , Label (Ident "b_2") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (Ident "end") - -- , Label (Ident "end") - Ret I64 (VInteger 0) - ] - -defaultStart :: [LLVMIr] -defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - ] - -compileExp :: Exp -> CompilerState () -compileExp = \case - EInt i -> emitInt i - EAdd t e1 e2 -> emitAdd t e1 e2 - EId (name, _) -> emitIdent name - EApp t e1 e2 -> emitApp t e1 e2 - EAbs t ti e -> emitAbs t ti e - ELet bind e -> emitLet bind e - ---- aux functions --- -emitAbs :: Type -> Id -> Exp -> CompilerState () -emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - -emitLet :: Bind -> Exp -> CompilerState () -emitLet b e = emit . Comment $ concat [ "ELet (" - , show b - , " = " - , show e - , ") is not implemented!" - ] - -emitApp :: Type -> Exp -> Exp -> CompilerState () -emitApp t e1 e2 = appEmitter t e1 e2 [] - where - appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter t e1 e2 stack = do - let newStack = e2 : stack - case e1 of - EApp _ e1' e2' -> appEmitter t e1' e2' newStack - EId id@(name, _) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - let visibility = maybe Local (const Global) $ Map.lookup id funcs - args' = map (first valueGetType . dupe) args - call = Call (type2LlvmType t) visibility name args' - emit $ SetVariable (Ident $ show vs) call - x -> do - emit . Comment $ "The unspeakable happened: " - emit . Comment $ show x - -emitIdent :: Ident -> CompilerState () -emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" - -emitInt :: Integer -> CompilerState () -emitInt i = do - -- !!this should never happen!! - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) - -emitAdd :: Type -> Exp -> Exp -> CompilerState () -emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) - --- emitMul :: Exp -> Exp -> CompilerState () --- emitMul e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Mul I64 v1 v2 - --- emitMod :: Exp -> Exp -> CompilerState () --- emitMod e1 e2 = do --- -- `let m a b = rem (abs $ b + a) b` --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- vadd <- gets variableCount --- emit $ SetVariable $ Ident $ show vadd --- emit $ Add I64 v1 v2 --- --- increaseVarCount --- vabs <- gets variableCount --- emit $ SetVariable $ Ident $ show vabs --- emit $ Call I64 (Ident "llvm.abs.i64") --- [ (I64, VIdent (Ident $ show vadd)) --- , (I1, VInteger 1) --- ] --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 - --- emitDiv :: Exp -> Exp -> CompilerState () --- emitDiv e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Div I64 v1 v2 - --- emitSub :: Exp -> Exp -> CompilerState () --- emitSub e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Sub I64 v1 v2 - -exprToValue :: Exp -> CompilerState LLVMValue -exprToValue = \case - EInt i -> pure $ VInteger i - - EId id@(name, t) -> do - funcs <- gets functions - case Map.lookup id funcs of - Just fi -> do - if numArgs fi == 0 - then do - vc <- getNewVar - emit $ SetVariable (Ident $ show vc) - (Call (type2LlvmType t) Global name []) - pure $ VIdent (Ident $ show vc) (type2LlvmType t) - else pure $ VFunction name Global (type2LlvmType t) - Nothing -> pure $ VIdent name (type2LlvmType t) - - e -> do - compileExp e - v <- getVarCount - pure $ VIdent (Ident $ show v) (getType e) - -type2LlvmType :: Type -> LLVMType -type2LlvmType = \case - TInt -> I64 - TFun t xs -> do - let (t', xs') = function2LLVMType xs [type2LlvmType t] - Function t' xs' - t -> CustomType $ Ident ("\"" ++ show t ++ "\"") - where - function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) - function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) - -getType :: Exp -> LLVMType -getType (EInt _) = I64 -getType (EAdd t _ _) = type2LlvmType t -getType (EId (_, t)) = type2LlvmType t -getType (EApp t _ _) = type2LlvmType t -getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e - -valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (length s) I8 -valueGetType (VFunction _ _ t) = t +compile :: String -> Bool -> IO String +compile s addGc = optimize s >>= compileClang addGc diff --git a/src/Desugar/Desugar.hs b/src/Desugar/Desugar.hs new file mode 100644 index 0000000..550d7c3 --- /dev/null +++ b/src/Desugar/Desugar.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Desugar.Desugar (desugar) where + +import Grammar.Abs + +{- + +The entire module should never have any catch all pattern matches as that +will disble warnings for when the grammar is expanded. + +-} + +desugar :: Program -> Program +desugar (Program defs) = Program (map desugarDef defs) + +desugarVarName :: VarName -> LIdent +desugarVarName (VSymbol (Symbol i)) = LIdent $ fixName i +desugarVarName (VIdent i) = i + +desugarDef :: Def -> Def +desugarDef = \case + DBind b -> DBind (desugarBind b) + DSig sig -> DSig (desugarSig sig) + DData d -> DData (desugarData d) + +desugarBind :: Bind -> Bind +desugarBind (BindS name args e) = Bind (desugarVarName name) args (desugarExp e) +desugarBind (Bind name args e) = Bind name args (desugarExp e) + +desugarSig :: Sig -> Sig +desugarSig (SigS ident typ) = Sig (desugarVarName ident) (desugarType typ) +desugarSig (Sig ident typ) = Sig ident (desugarType typ) + +desugarData :: Data -> Data +desugarData (Data typ injs) = Data (desugarType typ) (map desugarInj injs) + +desugarType :: Type -> Type +desugarType = \case + TIdent (UIdent "Int") -> TLit "Int" + TIdent (UIdent "Char") -> TLit "Char" + TIdent ident -> TData ident [] + TApp t1 t2 -> + let (name : tvars) = flatten t1 ++ [t2] + in case name of + TIdent ident -> TData ident (map desugarType tvars) + _ -> error "desugarType is not implemented correctly" + TLit l -> TLit l + TVar v -> TVar v + (TAll i t) -> TAll i (desugarType t) + TFun t1 t2 -> TFun (desugarType t1) (desugarType t2) + TEVar v -> TEVar v + TData ident typ -> TData ident (map desugarType typ) + where + flatten :: Type -> [Type] + flatten (TApp a b) = flatten a <> flatten b + flatten a = [a] + +desugarInj :: Inj -> Inj +desugarInj (Inj ident typ) = Inj ident (desugarType typ) + +desugarExp :: Exp -> Exp +desugarExp = \case + EApp e1 e2 -> EApp (desugarExp e1) (desugarExp e2) + EAdd e1 e2 -> EAdd (desugarExp e1) (desugarExp e2) + EAbs i e -> EAbs i (desugarExp e) + ELet b e -> ELet (desugarBind b) (desugarExp e) + ECase e br -> ECase (desugarExp e) (map desugarBranch br) + EAnn e t -> EAnn (desugarExp e) t + EVarS (VSymbol (Symbol symb)) -> EVar (LIdent $ fixName symb) + EVarS (VIdent (LIdent ident)) -> EVar $ LIdent $ fixName ident + EVar i -> EVar i + ELit l -> ELit l + EInj i -> EInj i + +desugarBranch :: Branch -> Branch +desugarBranch (Branch p e) = Branch (desugarPattern p) (desugarExp e) + +desugarPattern :: Pattern -> Pattern +desugarPattern = \case + PVar ident -> PVar ident + PLit lit -> PLit (desugarLit lit) + PCatch -> PCatch + PEnum ident -> PEnum ident + PInj ident patterns -> PInj ident (map desugarPattern patterns) + +desugarLit :: Lit -> Lit +desugarLit (LInt i) = LInt i +desugarLit (LChar c) = LChar c + +fixName :: String -> String +fixName = concatMap mapSymbols + where + mapSymbols :: Char -> String + mapSymbols c = case c of + '@' -> "$at$" + '#' -> "$octothorpe$" + '%' -> "$percent$" + '^' -> "$hat$" + '&' -> "$and$" + '*' -> "$star$" + '_' -> "$underscore$" + '-' -> "$minus$" + '+' -> "$plus$" + '=' -> "$equals$" + '|' -> "$pipe$" + '?' -> "$questionmark$" + '/' -> "$fslash$" + '<' -> "$langle$" + '>' -> "$rangle$" + ',' -> "$comma$" + '•' -> "$bullet$" + ':' -> "$semicolon$" + '[' -> "$lbracket$" + ']' -> "$rbracket$" + c -> c : "" diff --git a/src/Interpreter.hs b/src/Interpreter.hs deleted file mode 100644 index 37d46a7..0000000 --- a/src/Interpreter.hs +++ /dev/null @@ -1,116 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedRecordDot #-} -{-# LANGUAGE OverloadedStrings #-} -module Interpreter where - -import Auxiliary (maybeToRightM) -import Control.Applicative (Applicative) -import Control.Monad.Except (Except, MonadError (throwError), - liftEither) -import Control.Monad.State (MonadState, StateT, evalStateT) -import Data.Either.Combinators (maybeToRight) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (maybe) -import Grammar.Abs -import Grammar.ErrM (Err) -import Grammar.Print (printTree) - -interpret :: Program -> Err Integer -interpret (Program scs) = do - main <- findMain scs - eval (initCxt scs) main >>= - \case - VClosure {} -> throwError "main evaluated to a function" - VInt i -> pure i - - -initCxt :: [Bind] -> Cxt -initCxt scs = - Cxt { env = mempty - , sig = foldr insert mempty $ map expandLambdas scs - } - where insert (Bind name _ rhs) = Map.insert name rhs - -expandLambdas :: Bind -> Bind -expandLambdas (Bind name parms rhs) = Bind name [] $ foldr EAbs rhs parms - -findMain :: [Bind] -> Err Exp -findMain [] = throwError "No main!" -findMain (sc:scs) = case sc of - Bind "main" _ rhs -> pure rhs - _ -> findMain scs - -data Val = VInt Integer - | VClosure Env Ident Exp - deriving (Show, Eq) - -type Env = Map Ident Val -type Sig = Map Ident Exp - -data Cxt = Cxt - { env :: Map Ident Val - , sig :: Map Ident Exp - } deriving (Show, Eq) - -eval :: Cxt -> Exp -> Err Val -eval cxt = \case - - -- ------------ x ∈ γ - -- γ ⊢ x ⇓ γ(x) - - EId x -> do - case Map.lookup x cxt.env of - Just e -> pure e - Nothing -> - case Map.lookup x cxt.sig of - Just e -> eval (emptyEnv cxt) e - Nothing -> throwError ("Unbound variable: " ++ printTree x) - - -- --------- - -- γ ⊢ i ⇓ i - - EInt i -> pure $ VInt i - - -- γ ⊢ e ⇓ let δ in λx. f - -- γ ⊢ e₁ ⇓ v - -- δ,x=v ⊢ f ⇓ v₁ - -- ------------------------------ - -- γ ⊢ e e₁ ⇓ v₁ - - EApp e e1 -> - eval cxt e >>= \case - VInt _ -> throwError "Not a function" - VClosure delta x f -> do - v <- eval cxt e1 - let cxt' = putEnv (Map.insert x v delta) cxt - eval cxt' f - - - -- - -- ----------------------------- - -- γ ⊢ λx. f ⇓ let γ in λx. f - - EAbs par e -> pure $ VClosure cxt.env par e - - - -- γ ⊢ e ⇓ v - -- γ ⊢ e₁ ⇓ v₁ - -- ------------------ - -- γ ⊢ e e₁ ⇓ v + v₁ - - EAdd e e1 -> do - v <- eval cxt e - v1 <- eval cxt e1 - case (v, v1) of - (VInt i, VInt i1) -> pure $ VInt (i + i1) - _ -> throwError "Can't add a function" - - ELet _ _ -> throwError "ELet pattern match should never occur!" - - -emptyEnv :: Cxt -> Cxt -emptyEnv cxt = cxt { env = mempty } - -putEnv :: Env -> Cxt -> Cxt -putEnv env cxt = cxt { env = env } diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 015e7f3..5581814 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -1,138 +1,249 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} -module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where +module LambdaLifter where -import Auxiliary (snoc) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad.State (MonadState (get, put), State, evalState) -import Data.Set (Set) -import qualified Data.Set as Set -import Prelude hiding (exp) -import Renamer -import TypeCheckerIr +import Auxiliary (onM, snoc) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, + evalState) +import Data.Function (on) +import Data.List (delete, mapAccumL, (\\)) +import Prelude hiding (exp) +import TypeChecker.TypeCheckerIr -- | Lift lambdas and let expression into supercombinators. -- Three phases: --- @freeVars@ annotatss all the free variables. +-- @freeVars@ annotates all the free variables. -- @abstract@ converts lambdas into let expressions. -- @collectScs@ moves every non-constant let expression to a top-level function. +-- lambdaLift :: Program -> Program -lambdaLift = collectScs . abstract . freeVars - +lambdaLift (Program ds) = Program (datatypes ++ binds) + where + datatypes = flip filter ds $ \case DData _ -> True + _ -> False + binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] -- | Annotate free variables -freeVars :: Program -> AnnProgram -freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) - | Bind n xs e <- ds - ] +freeVars :: [Bind] -> [ABind] +freeVars binds = [ let ae = freeVarsExp [] e + ae' = ae { frees = ae.frees \\ xs } + in ABind n xs ae' + | Bind n xs e <- binds + ] -freeVarsExp :: Set Id -> Exp -> AnnExp -freeVarsExp localVars = \case - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) +freeVarsExp :: Frees -> ExpT -> Ann AExpT +freeVarsExp localVars (ae, t) = case ae of + EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)] + , term = (AVar n, t) + } + | otherwise -> Ann { frees = [] + , term = (AVar n, t) + } - EInt i -> (mempty, AInt i) + EInj n -> Ann { frees = [], term = (AInj n, t) } - EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') + ELit lit -> Ann { frees = [], term = (ALit lit, t) } + + EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees + , term = (AApp annae1 annae2, t) + } where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2 - EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') + EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees + , term = (AAdd annae1 annae2, t) + } where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2 - EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') + + EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees + , term = (AAbs x annae, t) } where - e' = freeVarsExp (Set.insert par localVars) e + annae = freeVarsExp (localVars <| (x,t_x)) e + t_x = case t of TFun t _ -> t + _ -> error "Impossible" -- Sum free variables present in bind and the expression - ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') + -- let f x = x + y in f 5 + z → frees: y, z + ELet bind@(Bind n _ _) e -> + Ann { frees = delete n annae.frees <|| annbind.frees + , term = (ALet annbind annae, t) + } where - binders_frees = Set.delete name $ freeVarsOf rhs' - e_free = Set.delete name $ freeVarsOf e' + annae = freeVarsExp (localVars <| n) e + annbind = freeVarsBind localVars bind - rhs' = freeVarsExp e_localVars rhs - new_bind = ABind name parms rhs' - - e' = freeVarsExp e_localVars e - e_localVars = Set.insert name localVars + ECase e branches -> + Ann { frees = foldl (<||) annae.frees (map frees annbranches) + , term = (ACase annae annbranches, t) + } + where + annae = freeVarsExp localVars e + annbranches = map (freeVarsBranch localVars) branches -freeVarsOf :: AnnExp -> Set Id -freeVarsOf = fst +freeVarsBind :: Frees -> Bind -> Ann ABind +freeVarsBind localVars (Bind name vars e) = + Ann { frees = annae.frees \\ vars + , term = ABind name vars annae + } + where + annae = freeVarsExp (localVars <|| vars) e + + +freeVarsBranch :: Frees -> Branch -> Ann ABranch +freeVarsBranch localVars (Branch pt e) = + Ann { frees = annae.frees \\ varsInPattern + , term = ABranch pt annae + } + where + annae = freeVarsExp localVars e + varsInPattern = go [] pt + where + go acc (p, t) = case p of + PVar n -> acc <| (n, t) + PInj _ ps -> foldl go acc ps + _ -> [] + -- AST annotated with free variables -type AnnProgram = [(Id, [Id], AnnExp)] -type AnnExp = (Set Id, AnnExp') +type Frees = [(Ident, Type)] -data ABind = ABind Id [Id] AnnExp deriving Show +data Ann a = Ann + { frees :: Frees + , term :: a + } deriving (Show, Eq) -data AnnExp' = AId Id - | AInt Integer - | ALet ABind AnnExp - | AApp Type AnnExp AnnExp - | AAdd Type AnnExp AnnExp - | AAbs Type Id AnnExp - deriving Show --- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. --- Free variables are @v₁ v₂ .. vₙ@ are bound. -abstract :: AnnProgram -> Program -abstract prog = Program $ evalState (mapM go prog) 0 +data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq) +data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq) + +type AExpT = (AExp, Type) + +data AExp = AVar Ident + | AInj Ident + | ALit Lit + | ALet (Ann ABind) (Ann AExpT) + | AApp (Ann AExpT) (Ann AExpT) + | AAdd (Ann AExpT) (Ann AExpT) + | AAbs Ident (Ann AExpT) + | ACase (Ann AExpT) [Ann ABranch] + deriving (Show, Eq) + +abstract :: [ABind] -> [Bind] +abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0 + +abstractAnnBind :: Ann ABind -> State Int Bind +abstractAnnBind Ann { term = ABind name vars annae } = + Bind name (vars' <|| vars) <$> abstractAnnExp annae' where - go :: (Id, [Id], AnnExp) -> State Int Bind - go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + (annae', vars') = go [] annae where - (rhs', parms1) = flattenLambdasAnn rhs + go acc = \case + Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae + ae -> (ae, acc) +abstractAnnExp :: Ann AExpT -> State Int ExpT +abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of + AVar n -> pure (EVar n, typ) + AInj n -> pure (EInj n, typ) + ALit lit -> pure (ELit lit, typ) + AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2 + AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2 --- | Flatten nested lambdas and collect the parameters --- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) -flattenLambdasAnn ae = go (ae, []) - where - go :: (AnnExp, [Id]) -> (AnnExp, [Id]) - go ((free, e), acc) = - case e of - AAbs _ par (free1, e1) -> - go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) - -abstractExp :: AnnExp -> State Int Exp -abstractExp (free, exp) = case exp of - AId n -> pure $ EId n - AInt i -> pure $ EInt i - AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) - AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ALet b e -> liftA2 ELet (go b) (abstractExp e) - where - go (ABind name parms rhs) = do - (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs - pure $ Bind name (parms ++ parms1) rhs' - - skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp - skipLambdas f (free, ae) = case ae of - AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 - _ -> f (free, ae) - - -- Lift lambda into let and bind free variables - AAbs t parm e -> do + -- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc + AAbs x annae' -> do i <- nextNumber - rhs <- abstractExp e + rhs <- abstractAnnExp annae'' + let sc_name = Ident ("sc_" ++ show i) + e@(_, t) = foldl applyFree (EVar sc_name, typ) frees + pure (ELet (Bind (sc_name, typ) vars rhs) e ,t) - let sc_name = Ident ("sc_" ++ show i) - sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) - - pure $ foldl (EApp TInt) sc $ map EId freeList where - freeList = Set.toList free - parms = snoc parm freeList + vars = frees <| (x, t_x) <|| ys + t_x = case typ of TFun t _ -> t + _ -> error "Impossible" + (annae'', ys) = go [] annae' + where + go acc = \case + Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae + ae -> (ae, acc) + + + applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type) + applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e') + where + t_e' = case t_e of TFun _ t -> t + _ -> error "Impossible" + + ACase annae' bs -> do + bs <- mapM go bs + e <- abstractAnnExp annae' + pure (ECase e bs, typ) + where + go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae + + ALet b annae' -> + (, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae') + + +-- | Collects supercombinators by lifting non-constant let expressions +collectScs :: [Bind] -> [Bind] +collectScs = concatMap collectFromRhs + where + collectFromRhs (Bind name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs + + +collectScsExp :: ExpT -> ([Bind], ExpT) +collectScsExp expT@(exp, typ) = case exp of + EVar _ -> ([], expT) + EInj _ -> ([], expT) + ELit _ -> ([], expT) + + EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ)) + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAbs par e -> (scs, (EAbs par e', typ)) + where + (scs, e') = collectScsExp e + + ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ)) + where + (scs, branches') = mapAccumL f [] branches + (scs_e, e') = collectScsExp e + f acc b = (acc ++ acc', b') + where (acc', b') = collectScsBranch b + + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e + | null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et')) + | otherwise -> (bind : rhs_scs ++ et_scs, et') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (et_scs, et') = collectScsExp e + +collectScsBranch (Branch patt exp) = (scs, Branch patt exp') + where (scs, exp') = collectScsExp exp nextNumber :: State Int Int nextNumber = do @@ -140,51 +251,11 @@ nextNumber = do put $ succ i pure i --- | Collects supercombinators by lifting non-constant let expressions -collectScs :: Program -> Program -collectScs (Program scs) = Program $ concatMap collectFromRhs scs - where - collectFromRhs (Bind name parms rhs) = - let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs +(<|) :: Eq a => [a] -> a -> [a] +xs <| x | elem x xs = xs + | otherwise = snoc x xs -collectScsExp :: Exp -> ([Bind], Exp) -collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) +(<||) :: Eq a => [a] -> [a] -> [a] +xs <|| ys = foldl (<|) xs ys - EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 - - EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 - - EAbs t par e -> (scs, EAbs t par e') - where - (scs, e') = collectScsExp e - - -- Collect supercombinators from bind, the rhss, and the expression. - -- - -- > f = let sc x y = rhs in e - -- - ELet (Bind name parms rhs) e -> if null parms - then ( rhs_scs ++ e_scs, ELet bind e') - else (bind : rhs_scs ++ e_scs, e') - where - bind = Bind name parms rhs' - (rhs_scs, rhs') = collectScsExp rhs - (e_scs, e') = collectScsExp e - - --- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: Exp -> (Exp, [Id]) -flattenLambdas = go . (, []) - where - go (e, acc) = case e of - EAbs _ par e1 -> go (e1, snoc par acc) - _ -> (e, acc) diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs deleted file mode 100644 index d340ddc..0000000 --- a/src/LlvmIr.hs +++ /dev/null @@ -1,204 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module LlvmIr ( - LLVMType (..), - LLVMIr (..), - llvmIrToString, - LLVMValue (..), - LLVMComp (..), - Visibility (..), -) where - -import Data.List (intercalate) -import TypeCheckerIr - --- | A datatype which represents some basic LLVM types -data LLVMType - = I1 - | I8 - | I32 - | I64 - | Ptr - | Ref LLVMType - | Function LLVMType [LLVMType] - | Array Int LLVMType - | CustomType Ident - -instance Show LLVMType where - show :: LLVMType -> String - show = \case - I1 -> "i1" - I8 -> "i8" - I32 -> "i32" - I64 -> "i64" - Ptr -> "ptr" - Ref ty -> show ty <> "*" - Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" - Array n ty -> concat ["[", show n, " x ", show ty, "]"] - CustomType (Ident ty) -> ty - -data LLVMComp - = LLEq - | LLNe - | LLUgt - | LLUge - | LLUlt - | LLUle - | LLSgt - | LLSge - | LLSlt - | LLSle -instance Show LLVMComp where - show :: LLVMComp -> String - show = \case - LLEq -> "eq" - LLNe -> "ne" - LLUgt -> "ugt" - LLUge -> "uge" - LLUlt -> "ult" - LLUle -> "ule" - LLSgt -> "sgt" - LLSge -> "sge" - LLSlt -> "slt" - LLSle -> "sle" - -data Visibility = Local | Global -instance Show Visibility where - show :: Visibility -> String - show Local = "%" - show Global = "@" - --- | Represents a LLVM "value", as in an integer, a register variable, --- or a string contstant -data LLVMValue - = VInteger Integer - | VIdent Ident LLVMType - | VConstant String - | VFunction Ident Visibility LLVMType - -instance Show LLVMValue where - show :: LLVMValue -> String - show v = case v of - VInteger i -> show i - VIdent (Ident n) _ -> "%" <> n - VFunction (Ident n) vis _ -> show vis <> n - VConstant s -> "c" <> show s - -type Params = [(Ident, LLVMType)] -type Args = [(LLVMType, LLVMValue)] - --- | A datatype which represents different instructions in LLVM -data LLVMIr - = Define LLVMType Ident Params - | DefineEnd - | Declare LLVMType Ident Params - | SetVariable Ident LLVMIr - | Variable Ident - | Add LLVMType LLVMValue LLVMValue - | Sub LLVMType LLVMValue LLVMValue - | Div LLVMType LLVMValue LLVMValue - | Mul LLVMType LLVMValue LLVMValue - | Srem LLVMType LLVMValue LLVMValue - | Icmp LLVMComp LLVMType LLVMValue LLVMValue - | Br Ident - | BrCond LLVMValue Ident Ident - | Label Ident - | Call LLVMType Visibility Ident Args - | Alloca LLVMType - | Store LLVMType Ident LLVMType Ident - | Bitcast LLVMType Ident LLVMType - | Ret LLVMType LLVMValue - | Comment String - | UnsafeRaw String -- This should generally be avoided, and proper - -- instructions should be used in its place - deriving (Show) - --- | Converts a list of LLVMIr instructions to a string -llvmIrToString :: [LLVMIr] -> String -llvmIrToString = go 0 - where - go :: Int -> [LLVMIr] -> String - go _ [] = mempty - go i (x : xs) = do - let (i', n) = case x of - Define{} -> (i + 1, 0) - DefineEnd -> (i - 1, 0) - _ -> (i, i) - insToString n x <> go i' xs - --- | Converts a LLVM inststruction to a String, allowing for printing etc. --- The integer represents the indentation -insToString :: Int -> LLVMIr -> String -insToString i l = - replicate i '\t' <> case l of - (Define t (Ident i) params) -> - concat - [ "define ", show t, " @", i - , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) - , ") {\n" - ] - DefineEnd -> "}\n" - (Declare _t (Ident _i) _params) -> undefined - (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] - (Add t v1 v2) -> - concat - [ "add ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Sub t v1 v2) -> - concat - [ "sub ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Div t v1 v2) -> - concat - [ "sdiv ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Mul t v1 v2) -> - concat - [ "mul ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Srem t v1 v2) -> - concat - [ "srem ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Call t vis (Ident i) arg) -> - concat - [ "call ", show t, " ", show vis, i, "(" - , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg - , ")\n" - ] - (Alloca t) -> unwords ["alloca", show t, "\n"] - (Store t1 (Ident id1) t2 (Ident id2)) -> - concat - [ "store ", show t1, " %", id1 - , ", ", show t2 , " %", id2, "\n" - ] - (Bitcast t1 (Ident i) t2) -> - concat - [ "bitcast ", show t1, " %" - , i, " to ", show t2, "\n" - ] - (Icmp comp t v1 v2) -> - concat - [ "icmp ", show comp, " ", show t - , " ", show v1, ", ", show v2, "\n" - ] - (Ret t v) -> - concat - [ "ret ", show t, " " - , show v, "\n" - ] - (UnsafeRaw s) -> s - (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" - (Br (Ident s)) -> "br label %label_" <> s <> "\n" - (BrCond val (Ident s1) (Ident s2)) -> - concat - [ "br i1 ", show val, ", ", "label %" - , "label_", s1, ", ", "label %", "label_", s2, "\n" - ] - (Comment s) -> "; " <> s <> "\n" - (Variable (Ident id)) -> "%" <> id diff --git a/src/Main.hs b/src/Main.hs index 1831428..b487222 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,97 +1,196 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} module Main where +import AnnForall (annotateForall) +import Codegen.Codegen (generateCode) import Compiler (compile) +import Control.Monad (when, (<=<)) +import Data.List.Extra (isSuffixOf) +import Data.Maybe (fromJust, isNothing) +import Desugar.Desugar (desugar) import GHC.IO.Handle.Text (hPutStrLn) import Grammar.ErrM (Err) +import Grammar.Layout (resolveLayout) import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) - --- import Interpreter (interpret) +import Grammar.Print (Print, printTree) import LambdaLifter (lambdaLift) -import Renamer (rename) +import Monomorphizer.Monomorphizer (monomorphize) +import OrderDefs (orderDefs) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import System.Console.GetOpt ( + ArgDescr (NoArg, ReqArg), + ArgOrder (RequireOrder), + OptDescr (Option), + getOpt, + usageInfo, + ) +import System.Directory ( + createDirectory, + doesPathExist, + getDirectoryContents, + removeDirectoryRecursive, + setCurrentDirectory, + ) import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) +import System.Exit ( + ExitCode (ExitFailure), + exitFailure, + exitSuccess, + exitWith, + ) import System.IO (stderr) -import TypeChecker (typecheck) +import System.Process (spawnCommand, waitForProcess) +import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck) main :: IO () -main = - getArgs >>= \case - [] -> print "Required file path missing" - (s : _) -> main' s +main = getArgs >>= parseArgs >>= uncurry main' -main' :: String -> IO () -main' s = do - file <- readFile s +parseArgs :: [String] -> IO (Options, String) +parseArgs argv = case getOpt RequireOrder flags argv of + (os, f : _, []) + | opts.help || isNothing opts.typechecker -> do + hPutStrLn stderr (usageInfo header flags) + exitSuccess + | otherwise -> pure (opts, f) + where + opts = foldr ($) initOpts os + (_, _, errs) -> do + hPutStrLn stderr (concat errs ++ usageInfo header flags) + exitWith (ExitFailure 1) + where + header = "Usage: language [--help] [-d|--debug] [-t|type-checker bi/hm] FILE \n" - printToErr "-- Parse Tree -- " - parsed <- fromSyntaxErr . pProgram $ myLexer file - printToErr $ printTree parsed +flags :: [OptDescr (Options -> Options)] +flags = + [ Option ['d'] ["debug"] (NoArg enableDebug) "Print debug messages." + , Option ['t'] ["type-checker"] (ReqArg chooseTypechecker "bi/hm") "Choose type checker. Possible options are bi and hm" + , Option ['m'] ["disable-gc"] (NoArg disableGC) "Disables the garbage collector and uses malloc instead." + , Option [] ["help"] (NoArg enableHelp) "Print this help message" + ] - printToErr "\n-- Renamer --" - let renamed = rename parsed - printToErr $ printTree renamed +initOpts :: Options +initOpts = + Options + { help = False + , debug = False + , gc = True + , typechecker = Nothing + } - printToErr "\n-- TypeChecker --" - typechecked <- fromTypeCheckerErr $ typecheck renamed - printToErr $ printTree typechecked +enableHelp :: Options -> Options +enableHelp opts = opts{help = True} - printToErr "\n-- Lambda Lifter --" - let lifted = lambdaLift typechecked - printToErr $ printTree lifted +enableDebug :: Options -> Options +enableDebug opts = opts{debug = True} - printToErr "\n -- Printing compiler output to stdout --" - compiled <- fromCompilerErr $ compile lifted - putStrLn compiled - writeFile "llvm.ll" compiled +disableGC :: Options -> Options +disableGC opts = opts{gc = False} - -- interpred <- fromInterpreterErr $ interpret lifted - -- putStrLn "\n-- interpret" - -- print interpred +chooseTypechecker :: String -> Options -> Options +chooseTypechecker s options = options{typechecker = tc} + where + tc = case s of + "hm" -> pure Hm + "bi" -> pure Bi + _ -> Nothing - exitSuccess +data Options = Options + { help :: Bool + , debug :: Bool + , gc :: Bool + , typechecker :: Maybe TypeChecker + } + +main' :: Options -> String -> IO () +main' opts s = + let + log :: (Print a, Show a) => a -> IO () + log = printToErr . if opts.debug then show else printTree + in + do + file <- readFile s + + printToErr "-- Parse Tree -- " + parsed <- fromErr . pProgram . resolveLayout True $ myLexer (file ++ prelude) + log parsed + + printToErr "-- Desugar --" + let desugared = desugar parsed + log desugared + + printToErr "\n-- Renamer --" + _ <- fromErr $ reportForall (fromJust opts.typechecker) desugared + renamed <- fromErr $ (rename <=< annotateForall) desugared + log renamed + + printToErr "\n-- TypeChecker --" + typechecked <- fromErr $ typecheck (fromJust opts.typechecker) (orderDefs renamed) + log typechecked + + printToErr "\n-- Lambda Lifter --" + let lifted = lambdaLift typechecked + log lifted + + printToErr "\n -- Monomorphizer --" + let monomorphized = monomorphize lifted + log monomorphized + + printToErr "\n -- Compiler --" + -- generatedCode <- fromErr $ generateCode monomorphized (gc opts) + generatedCode <- fromErr $ generateCode monomorphized False + + check <- doesPathExist "output" + when check (removeDirectoryRecursive "output") + createDirectory "output" + createDirectory "output/logs" + when opts.debug $ do + writeFile "output/llvm.ll" generatedCode + debugDotViz + + -- compile generatedCode (gc opts) + compile generatedCode False + printToErr "Compilation done!" + printToErr "\n-- Program output --" + print =<< spawnWait "./output/hello_world" + + exitSuccess + +debugDotViz :: IO () +debugDotViz = do + setCurrentDirectory "output" + spawnWait "opt -dot-cfg llvm.ll -disable-output" + content <- filter (isSuffixOf ".dot") <$> getDirectoryContents "." + let commands = (\p -> "dot " <> p <> " -Tpng -o" <> p <> ".png") <$> content + mapM_ spawnWait commands + setCurrentDirectory ".." + return () + +spawnWait :: String -> IO ExitCode +spawnWait s = spawnCommand s >>= waitForProcess printToErr :: String -> IO () printToErr = hPutStrLn stderr -fromCompilerErr :: Err a -> IO a -fromCompilerErr = - either - ( \err -> do - putStrLn "\nCOMPILER ERROR" - putStrLn err - exitFailure - ) - pure +fromErr :: Err a -> IO a +fromErr = either (\s -> printToErr s >> exitFailure) pure -fromSyntaxErr :: Err a -> IO a -fromSyntaxErr = - either - ( \err -> do - putStrLn "\nSYNTAX ERROR" - putStrLn err - exitFailure - ) - pure - -fromTypeCheckerErr :: Err a -> IO a -fromTypeCheckerErr = - either - ( \err -> do - putStrLn "\nTYPECHECKER ERROR" - putStrLn err - exitFailure - ) - pure - -fromInterpreterErr :: Err a -> IO a -fromInterpreterErr = - either - ( \err -> do - putStrLn "\nINTERPRETER ERROR" - putStrLn err - exitFailure - ) - pure +prelude :: String +prelude = + unlines + [ "\n" + , "data Bool where" + , " False : Bool" + , " True : Bool" + , -- The function body of lt is replaced during code gen. It exists here for type checking purposes. + "lt : Int -> Int -> Bool" + , "lt x y = case x of" + , " _ => True" + , " _ => False" + , "\n" + , -- The function body of - is replaced during code gen. It exists here for type checking purposes. + ".- : Int -> Int -> Int" + , ".- x y = 0" + , "\n" + ] diff --git a/src/Monomorphizer/DataTypeRemover.hs b/src/Monomorphizer/DataTypeRemover.hs new file mode 100644 index 0000000..e4caef0 --- /dev/null +++ b/src/Monomorphizer/DataTypeRemover.hs @@ -0,0 +1,60 @@ +module Monomorphizer.DataTypeRemover (removeDataTypes) where + +import Monomorphizer.MonomorphizerIr qualified as M2 +import Monomorphizer.MorbIr qualified as M1 +import TypeChecker.TypeCheckerIr (Ident (Ident)) + +removeDataTypes :: M1.Program -> M2.Program +removeDataTypes (M1.Program defs) = M2.Program (map pDef defs) + +pDef :: M1.Def -> M2.Def +pDef (M1.DBind b) = M2.DBind (pBind b) +pDef (M1.DData d) = M2.DData (pData d) + +pData :: M1.Data -> M2.Data +pData (M1.Data t cs) = M2.Data (pType t) (map pCons cs) + +pCons :: M1.Inj -> M2.Inj +pCons (M1.Inj ident t) = M2.Inj ident (pType t) + +pType :: M1.Type -> M2.Type +pType (M1.TLit ident) = M2.TLit ident +pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2) +pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool") +pType d = M2.TLit (Ident (newName d)) -- This is the step + +newName :: M1.Type -> String +newName (M1.TLit (Ident str)) = str +newName (M1.TFun t1 t2) = newName t1 ++ newName t2 +newName (M1.TData (Ident str) args) = str ++ concatMap newName args + +pBind :: M1.Bind -> M2.Bind +pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt) + +pId :: (Ident, M1.Type) -> (Ident, M2.Type) +pId (ident, t) = (ident, pType t) + +pExpT :: M1.ExpT -> M2.ExpT +pExpT (exp, t) = (pExp exp, pType t) + +pExp :: M1.Exp -> M2.Exp +pExp (M1.EVar ident) = M2.EVar ident +pExp (M1.ELit lit) = M2.ELit (pLit lit) +pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) +pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) +pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) +pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches) + +pBranch :: M1.Branch -> M2.Branch +pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt) + +pPattern :: M1.Pattern -> M2.Pattern +pPattern (M1.PVar id) = M2.PVar (pId id) +pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t) +pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts) +pPattern M1.PCatch = M2.PCatch +pPattern (M1.PEnum ident) = M2.PEnum ident + +pLit :: M1.Lit -> M2.Lit +pLit (M1.LInt v) = M2.LInt v +pLit (M1.LChar c) = M2.LChar c diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs new file mode 100644 index 0000000..3a8bd9e --- /dev/null +++ b/src/Monomorphizer/Monomorphizer.hs @@ -0,0 +1,419 @@ +{-# LANGUAGE LambdaCase #-} + +{- | For now, converts polymorphic functions to concrete ones based on usage. +Assumes lambdas are lifted. + +This step of compilation is as follows: + +Split all function bindings into monomorphic and polymorphic binds. The +monomorphic bindings will be part of this compilation step. +Apply the following monomorphization function on all monomorphic binds, with +their type as an additional argument. + +The function that transforms Binds operates on both monomorphic and +polymorphic functions, creates a context in which all possible polymorphic types +are mapped to concrete types, created using the additional argument. +Expressions are then recursively processed. The type of these expressions +are changed to using the mapped generic types. The expected type provided +in the recursion is changed depending on the different nodes. + +When an external bind is encountered (with EId), it is checked whether it +exists in outputed binds or not. If it does, nothing further is evaluated. +If not, the bind transformer function is called on it with the +expected type in this context. The result of this computation (a monomorphic +bind) is added to the resulting set of binds. +-} +module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where + +import Monomorphizer.DataTypeRemover (removeDataTypes) +import Monomorphizer.MonomorphizerIr qualified as O +import Monomorphizer.MorbIr qualified as M +import TypeChecker.TypeCheckerIr (Ident (Ident)) +import TypeChecker.TypeCheckerIr qualified as T + +import Control.Monad.Reader ( + MonadReader (ask, local), + Reader, + asks, + runReader, + ) +import Control.Monad.State ( + MonadState (get), + StateT (runStateT), + gets, + modify, + ) +import Data.Coerce (coerce) +import Data.Map qualified as Map +import Data.Maybe (catMaybes) +import Data.Set qualified as Set +import Grammar.Print (printTree) +import Debug.Trace (trace) + +{- | EnvM is the monad containing the read-only state as well as the +output state containing monomorphized functions and to-be monomorphized +data type declarations. +-} +newtype EnvM a = EnvM (StateT Output (Reader Env) a) + deriving (Functor, Applicative, Monad, MonadState Output, MonadReader Env) + +type Output = Map.Map Ident Outputted + +{- | Data structure describing outputted top-level information, that is +Binds, Polymorphic Data types (monomorphized in a later step) and +Marked bind, which means that it is in the process of monomorphization +and should not be monomorphized again. +-} +data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show) + +-- | Static environment. +data Env = Env + { input :: Map.Map Ident T.Bind + -- ^ All binds in the program. + , dataDefs :: Map.Map Ident T.Data + -- ^ All constructors mapped to their respective polymorphic data def + -- which includes all other constructors. + , polys :: Map.Map Ident M.Type + -- ^ Maps polymorphic identifiers with concrete types. + , locals :: Set.Set Ident + -- ^ Local variables. + } + +-- | Determines if the identifier describes a local variable in the given context. +localExists :: Ident -> EnvM Bool +localExists ident = asks (Set.member ident . locals) + +-- | Gets a polymorphic bind from an id. +getInputBind :: Ident -> EnvM (Maybe T.Bind) +getInputBind ident = asks (Map.lookup ident . input) + +-- | Add monomorphic function derived from a polymorphic one, to env. +addOutputBind :: M.Bind -> EnvM () +addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b)) + +{- | Marks a global bind as being processed, meaning that when encountered again, +it should not be recursively processed. +-} +markBind :: Ident -> EnvM () +markBind ident = modify (Map.insert ident Marked) + +-- | Check if bind has been touched or not. +isBindMarked :: Ident -> EnvM Bool +isBindMarked ident = gets (Map.member ident) + +-- | Checks if constructor is outputted. +isConsMarked :: Ident -> EnvM Bool +isConsMarked ident = gets (Map.member ident) + +-- | Finds main bind. +getMain :: EnvM T.Bind +getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of + Just mainBind -> mainBind + Nothing -> error "main not found in monomorphizer!" + ) + +{- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime +error when encountering different structures between the two arguments. Debug: +First argument is the name of the bind. +-} +mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)] +mapTypes _ident (T.TLit _) (M.TLit _) = [] +mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] +mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) = + mapTypes ident pt1 mt1 + ++ mapTypes ident pt2 mt2 +mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) = + if tIdent /= mIdent + then error "the data type names of monomorphic and polymorphic data types does not match" + else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs) +mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++ + "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" + +-- | Gets the mapped monomorphic type of a polymorphic type in the current context. +getMonoFromPoly :: T.Type -> EnvM M.Type +getMonoFromPoly t = do + env <- ask + return $ getMono (polys env) t + where + getMono :: Map.Map Ident M.Type -> T.Type -> M.Type + getMono polys t = case t of + (T.TLit ident) -> M.TLit (coerce ident) + (T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) + (T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of + Just concrete -> concrete + Nothing -> M.TLit (Ident "void") + -- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" + (T.TData ident args) -> M.TData ident (map (getMono polys) args) + +{- | If ident not already in env's output, morphed bind to output +(and all referenced binds within this bind). +Returns the annotated bind name. +-} +morphBind :: M.Type -> T.Bind -> EnvM Ident +morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do + -- The "new name" is used to find out if it is already marked or not. + let name' = newFuncName expectedType b + bindMarked <- isBindMarked (coerce name') + local + ( \env -> + env + { locals = Set.fromList (map fst args) + , polys = Map.fromList (mapTypes ident btype expectedType) + } + ) + $ do + -- Return with right name if already marked + if bindMarked + then return name' + else do + -- Mark so that this bind will not be processed in recursive or cyclic + -- function calls + markBind (coerce name') + expt' <- getMonoFromPoly expt + exp' <- morphExp expt' exp + -- Get monomorphic type sof args + args' <- mapM morphArg args + addOutputBind $ + M.Bind + (coerce name', expectedType) + args' + (exp', expt') + return name' + +-- | Monomorphizes arguments of a bind. +morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type) +morphArg (ident, t) = do + t' <- getMonoFromPoly t + return (ident, t') + +-- | Gets the data bind from the name of a constructor. +getInputData :: Ident -> EnvM (Maybe T.Data) +getInputData ident = do + env <- ask + return $ Map.lookup ident (dataDefs env) + +{- | Monomorphize a constructor using it's global name. Constructors may +appear as expressions in the tree, or as patterns in case-expressions. +'newIdent' has a unique name while 'ident' has a general name. +-} +morphCons :: M.Type -> Ident -> Ident -> EnvM () +morphCons expectedType ident newIdent = do + --trace ("Tjofras:" ++ show (newName expectedType ident)) $ return () + maybeD <- getInputData ident + case maybeD of + Nothing -> error $ "identifier '" ++ show ident ++ "' not found" + Just d -> do + modify (\output -> Map.insert newIdent (Data expectedType d) output) + +-- | Converts literals from input to output tree. +convertLit :: T.Lit -> M.Lit +convertLit (T.LInt v) = M.LInt v +convertLit (T.LChar v) = M.LChar v + +-- | Monomorphizes an expression, given an expected type. +morphExp :: M.Type -> T.Exp -> EnvM M.Exp +morphExp expectedType exp = case exp of + T.ELit lit -> return $ M.ELit (convertLit lit) + -- Constructor + T.EInj ident -> do + let ident' = newName (getDataType expectedType) ident + morphCons expectedType ident ident' + return $ M.EVar ident' + T.EApp (e1, _t1) (e2, t2) -> do + t2' <- getMonoFromPoly t2 + e2' <- morphExp t2' e2 + e1' <- morphExp (M.TFun t2' expectedType) e1 + return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2') + T.EAdd (e1, t1) (e2, t2) -> do + t1' <- getMonoFromPoly t1 + t2' <- getMonoFromPoly t2 + e1' <- morphExp t1' e1 + e2' <- morphExp t2' e2 + return $ M.EAdd (e1', expectedType) (e2', expectedType) + T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do + t' <- getMonoFromPoly t + morphExp t' exp + T.ECase (exp, t) bs -> do + t' <- getMonoFromPoly t + exp' <- morphExp t' exp + bs' <- mapM morphBranch bs + return $ M.ECase (exp', t') (catMaybes bs') + -- Ideally constructors should be EInj, though this code handles them + -- as well. + T.EVar ident -> do + isLocal <- localExists ident + if isLocal + then do + return $ M.EVar (coerce ident) + else do + bind <- getInputBind ident + case bind of + Nothing -> error $ "unbound variable: '" ++ printTree ident ++ "'" + Just bind' -> do + -- New bind to process + newBindName <- morphBind expectedType bind' + return $ M.EVar (coerce newBindName) + T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) -> + if length args > 0 then error "only constants in lets allowed" + else do + tB' <- getMonoFromPoly tB + tExpB' <- getMonoFromPoly tExpB + tExp' <- getMonoFromPoly tExp + expB' <- morphExp tExpB' expB + exp' <- morphExp tExp' exp + return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp') + +-- | Monomorphizes case-of branches. +morphBranch :: T.Branch -> EnvM (Maybe M.Branch) +morphBranch (T.Branch (p, pt) (e, et)) = do + pt' <- getMonoFromPoly pt + et' <- getMonoFromPoly et + env <- ask + maybeMorphedPattern <- morphPattern p pt' + case maybeMorphedPattern of + Nothing -> return Nothing + Just (p', newLocals) -> + local (const env { locals = Set.union (locals env) newLocals }) $ do + e' <- morphExp et' e + return $ Just (M.Branch (p', pt') (e', et')) + +morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident)) +morphPattern p expectedType = case p of + T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident) + T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty) + T.PCatch -> return $ Just (M.PCatch, Set.empty) + T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty) + T.PInj ident pts -> do let newIdent = newName expectedType ident + outEnv <- get + trace ("WOW: " ++ show (newName expectedType ident)) $ return () + trace ("WOW2: " ++ show (outEnv)) $ return () + isMarked <- isConsMarked newIdent + if isMarked + then do + trace ("WOW3") $ return () + ts' <- mapM (getMonoFromPoly . snd) pts + let pts' = zip (map fst pts) ts' + psSets <- mapM (uncurry morphPattern) pts' + let maybePsSets = sequence psSets + case maybePsSets of + Nothing -> return Nothing + Just psSets' -> return $ Just + (M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets') + else return Nothing + +-- | Creates a new identifier for a function with an assigned type. +newFuncName :: M.Type -> T.Bind -> Ident +newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = + if bindName == "main" + then Ident bindName + else newName t ident + +newName :: M.Type -> Ident -> Ident +newName t (Ident str) = Ident $ str ++ "$" ++ newName' t + where + newName' :: M.Type -> String + newName' (M.TLit (Ident str)) = str + newName' (M.TFun t1 t2) = newName' t1 ++ "_" ++ newName' t2 + newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts + +-- | Monomorphization step. +monomorphize :: T.Program -> O.Program +monomorphize (T.Program defs) = + removeDataTypes $ + M.Program + ( getDefsFromOutput + (runEnvM Map.empty (createEnv defs) monomorphize') + ) + where + monomorphize' :: EnvM () + monomorphize' = do + main <- getMain + morphBind (M.TLit $ Ident "Int") main + return () + +-- | Runs and gives the output binds. +runEnvM :: Output -> Env -> EnvM () -> Output +runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env + +-- | Creates the environment based on the input binds. +createEnv :: [T.Def] -> Env +createEnv defs = + Env + { input = Map.fromList bindPairs + , dataDefs = Map.fromList dataPairs + , polys = Map.empty + , locals = Set.empty + } + where + bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs + dataPairs :: [(Ident, T.Data)] + dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs + +-- | Gets a top-lefel function name. +getBindName :: T.Bind -> Ident +getBindName (T.Bind (ident, _) _ _) = ident + +-- Helper functions +-- Gets custom data declarations form defs. +getDataFromDefs :: [T.Def] -> [T.Data] +getDataFromDefs = + foldl + ( \bs -> \case + T.DBind _ -> bs + T.DData d -> d : bs + ) + [] + +getConsName :: T.Inj -> Ident +getConsName (T.Inj ident _) = ident + +getBindsFromDefs :: [T.Def] -> [T.Bind] +getBindsFromDefs = + foldl + ( \bs -> \case + T.DBind b -> b : bs + T.DData _ -> bs + ) + [] + +getDefsFromOutput :: Output -> [M.Def] +getDefsFromOutput o = + map M.DBind binds + ++ (map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty) + where + (binds, dataInput) = splitBindsAndData o + +-- | Splits the output into binds and data declaration components (used in createNewData) +splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)]) +splitBindsAndData output = + foldl + ( \(oBinds, oData) (ident, o) -> case o of + Marked -> error "internal bug in monomorphizer" + Complete b -> (b : oBinds, oData) + Data t d -> (oBinds, (ident, t, d) : oData) + ) + ([], []) + (Map.toList output) + +-- | Converts all found constructors to monomorphic data declarations. +createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data +createNewData [] o = o +createNewData ((consIdent, consType, polyData) : input) o = + createNewData input $ + Map.insertWith + (\_ (M.Data _ cs) -> M.Data newDataType (newCons : cs)) + newDataName + (M.Data newDataType [newCons]) + o + where + T.Data (T.TData polyDataIdent _) _ = polyData + newDataType = getDataType consType + newDataName = newName newDataType polyDataIdent + newCons = M.Inj consIdent consType + +-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a). +getDataType :: M.Type -> M.Type +getDataType (M.TFun _t1 t2) = getDataType t2 +getDataType tData@(M.TData _ _) = tData +getDataType _ = error "???" + diff --git a/src/Monomorphizer/MonomorphizerIr.hs b/src/Monomorphizer/MonomorphizerIr.hs new file mode 100644 index 0000000..052cdc1 --- /dev/null +++ b/src/Monomorphizer/MonomorphizerIr.hs @@ -0,0 +1,182 @@ +{-# LANGUAGE LambdaCase #-} + +module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where + +import Grammar.Print +import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) + +type Id = (TIR.Ident, Type) + +newtype Program = Program [Def] + deriving (Show, Ord, Eq) + +data Def = DBind Bind | DData Data + deriving (Show, Ord, Eq) + +data Data = Data Type [Inj] + deriving (Show, Ord, Eq) + +data Bind = Bind Id [Id] ExpT + deriving (Show, Ord, Eq) + +data Exp + = EVar TIR.Ident + | ELit Lit + | ELet Bind ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + | ECase ExpT [Branch] + deriving (Show, Ord, Eq) + +data Pattern + = PVar Id + | PLit (Lit, Type) + | PInj TIR.Ident [Pattern] + | PCatch + | PEnum TIR.Ident + deriving (Eq, Ord, Show) + +data Branch = Branch (Pattern, Type) ExpT + deriving (Eq, Ord, Show) + +type ExpT = (Exp, Type) + +data Inj = Inj TIR.Ident Type + deriving (Show, Ord, Eq) + +data Lit + = LInt Integer + | LChar Char + deriving (Show, Ord, Eq) + +data Type = TLit TIR.Ident | TFun Type Type + deriving (Show, Ord, Eq) + +flattenType :: Type -> [Type] +flattenType (TFun t1 t2) = t1 : flattenType t2 +flattenType x = [x] + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print (Bind) where + prt i (Bind sig@(name, _) parms rhs) = + prPrec i 0 $ + concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +prtSig :: Id -> Doc +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] + +instance Print (ExpT) where + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prt i) + +instance Print Exp where + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> + prPrec i 1 $ + concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + ECase e branches -> + prPrec i 0 $ + concatD + [ doc $ showString "case" + , prt 0 e + , doc $ showString "of" + , doc $ showString "{" + , prt 0 branches + , doc $ showString "}" + ] + +instance Print Branch where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print [Branch] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print Def where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print Data where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print Inj where + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + +instance Print Pattern where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) + +instance Print [Def] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print [Type] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + +instance Print Type where + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + +instance Print Lit where + prt i = \case + LInt int -> prt i int + LChar char -> prt i char diff --git a/src/Monomorphizer/MorbIr.hs b/src/Monomorphizer/MorbIr.hs new file mode 100644 index 0000000..3e5db6b --- /dev/null +++ b/src/Monomorphizer/MorbIr.hs @@ -0,0 +1,184 @@ +{-# LANGUAGE LambdaCase #-} +module Monomorphizer.MorbIr where + +import Grammar.Print +import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) + +type Id = (TIR.Ident, Type) + +newtype Program = Program [Def] + deriving (Show, Ord, Eq) + +data Def = DBind Bind | DData Data + deriving (Show, Ord, Eq) + +data Data = Data Type [Inj] + deriving (Show, Ord, Eq) + +data Bind = Bind Id [Id] ExpT + deriving (Show, Ord, Eq) + +data Exp + = EVar TIR.Ident + | ELit Lit + | ELet Bind ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + | ECase ExpT [Branch] + deriving (Show, Ord, Eq) + +data Pattern + = PVar Id + | PLit (Lit, Type) + | PInj TIR.Ident [Pattern] + | PCatch + | PEnum TIR.Ident + deriving (Eq, Ord, Show) + +data Branch = Branch (Pattern, Type) ExpT + deriving (Eq, Ord, Show) + +type ExpT = (Exp, Type) + +data Inj = Inj TIR.Ident Type + deriving (Show, Ord, Eq) + +data Lit + = LInt Integer + | LChar Char + deriving (Show, Ord, Eq) + +data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type] + + deriving (Show, Ord, Eq) + +flattenType :: Type -> [Type] +flattenType (TFun t1 t2) = t1 : flattenType t2 +flattenType x = [x] + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print (Bind) where + prt i (Bind sig@(name, _) parms rhs) = + prPrec i 0 $ + concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +prtSig :: Id -> Doc +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] + +instance Print (ExpT) where + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prt i) + +instance Print Exp where + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> + prPrec i 1 $ + concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + ECase e branches -> + prPrec i 0 $ + concatD + [ doc $ showString "case" + , prt 0 e + , doc $ showString "of" + , doc $ showString "{" + , prt 0 branches + , doc $ showString "}" + ] + +instance Print Branch where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print [Branch] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print Def where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print Data where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print Inj where + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + +instance Print Pattern where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) + +instance Print [Def] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print [Type] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + +instance Print Type where + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) + +instance Print Lit where + prt i = \case + LInt int -> prt i int + LChar char -> prt i char + diff --git a/src/OrderDefs.hs b/src/OrderDefs.hs new file mode 100644 index 0000000..079512b --- /dev/null +++ b/src/OrderDefs.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE LambdaCase #-} + +module OrderDefs where + +import Control.Monad.State (State, execState, get, modify, when) +import Data.Function (on) +import Data.List (partition, sortBy) +import Data.Set (Set) +import qualified Data.Set as Set +import Grammar.Abs + +orderDefs :: Program -> Program +orderDefs (Program defs) = + Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig) + + where + (has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names) + [ b | DBind b <- defs] + sig_names = [ n | DSig (Sig n _) <- defs ] + not_binds = flip filter defs $ \case DBind _ -> False + _ -> True + +orderBinds :: [Bind] -> [Bind] +orderBinds binds = sortBy (on compare countUniqueCalls) binds + where + bind_names = [ n | Bind n _ _ <- binds] + + countUniqueCalls :: Bind -> Int + countUniqueCalls (Bind n _ e) = + Set.size $ execState (go e) (Set.singleton n) + where + go :: Exp -> State (Set LIdent) () + go exp = get >>= \called -> case exp of + EVar x -> when (Set.notMember x called && elem x bind_names) $ + modify (Set.insert x) + EApp e1 e2 -> on (>>) go e1 e2 + EAdd e1 e2 -> on (>>) go e1 e2 + ELet (Bind _ _ e) e' -> on (>>) go e e' + EAbs _ e -> go e + ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs + EAnn e _ -> go e + EInj _ -> pure () + ELit _ -> pure () diff --git a/src/Renamer.hs b/src/Renamer.hs deleted file mode 100644 index b284e92..0000000 --- a/src/Renamer.hs +++ /dev/null @@ -1,84 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module Renamer (module Renamer) where - -import Auxiliary (mapAccumM) -import Control.Monad.State (MonadState, State, evalState, gets, - modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) -import Data.Tuple.Extra (dupe) -import Grammar.Abs - - --- | Rename all variables and local binds -rename :: Program -> Program -rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 - where - initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs - renameSc :: Names -> Bind -> Rn Bind - renameSc old_names (Bind name t _ parms rhs) = do - (new_names, parms') <- newNames old_names parms - rhs' <- snd <$> renameExp new_names rhs - pure $ Bind name t name parms' rhs' - - --- | Rename monad. State holds the number of renamed names. -newtype Rn a = Rn { runRn :: State Int a } - deriving (Functor, Applicative, Monad, MonadState Int) - --- | Maps old to new name -type Names = Map Ident Ident - -renameLocalBind :: Names -> Bind -> Rn (Names, Bind) -renameLocalBind old_names (Bind name t _ parms rhs) = do - (new_names, name') <- newName old_names name - (new_names', parms') <- newNames new_names parms - (new_names'', rhs') <- renameExp new_names' rhs - pure (new_names'', Bind name' t name' parms' rhs') - -renameExp :: Names -> Exp -> Rn (Names, Exp) -renameExp old_names = \case - EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) - - EInt i1 -> pure (old_names, EInt i1) - - EApp e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EApp e1' e2') - - EAdd e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EAdd e1' e2') - - ELet b e -> do - (new_names, b) <- renameLocalBind old_names b - (new_names', e') <- renameExp new_names e - pure (new_names', ELet b e') - - EAbs par t e -> do - (new_names, par') <- newName old_names par - (new_names', e') <- renameExp new_names e - pure (new_names', EAbs par' t e') - - EAnn e t -> do - (new_names, e') <- renameExp old_names e - pure (new_names, EAnn e' t) - --- | Create a new name and add it to name environment. -newName :: Names -> Ident -> Rn (Names, Ident) -newName env old_name = do - new_name <- makeName old_name - pure (Map.insert old_name new_name env, new_name) - --- | Create multiple names and add them to the name environment -newNames :: Names -> [Ident] -> Rn (Names, [Ident]) -newNames = mapAccumM newName - --- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -makeName :: Ident -> Rn Ident -makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ - diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs new file mode 100644 index 0000000..1eee3f0 --- /dev/null +++ b/src/Renamer/Renamer.hs @@ -0,0 +1,112 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} + +module Renamer.Renamer (rename) where + +import Auxiliary (maybeToRightM, onM, partitionDefs) +import Control.Applicative (liftA2) +import Control.Monad.Except (ExceptT, MonadError, runExceptT) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (dupe) +import Grammar.Abs +import Grammar.ErrM (Err) +import Grammar.Print (printTree) + +-- | Rename all variables and local binds +rename :: Program -> Err Program +rename (Program defs) = rename' $ do + ds' <- mapM (fmap DData . rnData) ds + ss' <- mapM (fmap DSig . rnSig) ss + bs' <- mapM (fmap DBind . rnTopBind) bs + pure $ Program (ds' ++ ss' ++ bs') + where + (ds, ss, bs) = partitionDefs defs + rename' = flip evalState initCxt + . runExceptT + . runRn + initCxt = Cxt + { counter = 0 + , names = Map.fromList $ [ dupe n | Sig n _ <- ss ] + ++ [ dupe n | Bind n _ _ <- bs ] + } +rnData :: Data -> Rn Data +rnData (Data typ injs) = liftA2 Data (rnType typ) (mapM rnInj injs) + where + rnInj (Inj name t) = Inj name <$> rnType t + +rnSig :: Sig -> Rn Sig +rnSig (Sig name typ) = liftA2 Sig (getName name) (rnType typ) + +rnType :: Type -> Rn Type +rnType = \case + TVar (MkTVar name) -> TVar . MkTVar <$> getName name + TData name ts -> TData name <$> localNames (mapM rnType ts) + TFun t1 t2 -> onM TFun (localNames . rnType) t1 t2 + TAll (MkTVar name) t -> liftA2 (TAll . MkTVar) (newName name) (rnType t) + typ -> pure typ + +rnTopBind :: Bind -> Rn Bind +rnTopBind = rnBind' False + +rnLocalBind :: Bind -> Rn Bind +rnLocalBind = rnBind' True + +rnBind' :: Bool -> Bind -> Rn Bind +rnBind' isLocal (Bind name vars rhs) = do + name' <- if isLocal then newName name else getName name + (vars', rhs') <- localNames $ liftA2 (,) (mapM newName vars) (rnExp rhs) + pure (Bind name' vars' rhs') + +rnExp :: Exp -> Rn Exp +rnExp = \case + EVar x -> EVar <$> getName x + EInj x -> pure (EInj x) + ELit lit -> pure (ELit lit) + EApp e1 e2 -> onM EApp (localNames . rnExp) e1 e2 + EAdd e1 e2 -> onM EAdd (localNames . rnExp) e1 e2 + ELet bind e -> liftA2 ELet (rnLocalBind bind) (rnExp e) + EAbs x e -> liftA2 EAbs (newName x) (rnExp e) + EAnn e t -> liftA2 EAnn (rnExp e) (rnType t) + ECase e bs -> liftA2 ECase (rnExp e) (mapM (localNames . rnBranch) bs) + +rnBranch :: Branch -> Rn Branch +rnBranch (Branch p e) = liftA2 Branch (rnPattern p) (rnExp e) + +rnPattern :: Pattern -> Rn Pattern +rnPattern = \case + PVar x -> PVar <$> newName x + PLit lit -> pure (PLit lit) + PCatch -> pure PCatch + PEnum name -> pure (PEnum name) + PInj name ps -> PInj name <$> mapM rnPattern ps + +data Cxt = Cxt + { counter :: Int + , names :: Map LIdent LIdent + } + +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a} + deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) + +getName :: LIdent -> Rn LIdent +getName name = maybeToRightM err =<< gets (Map.lookup name . names) + where err = "Can't find new name " ++ printTree name + +newName :: LIdent -> Rn LIdent +newName name = do + name' <- gets (mk name . counter) + modify $ \cxt -> cxt { counter = succ cxt.counter + , names = Map.insert name name' cxt.names + } + pure name' + where + mk (LIdent name) i = LIdent ("$" ++ show i ++ name) + +localNames :: MonadState Cxt m => m b -> m b +localNames m = do + old_names <- gets names + m <* modify ( \cxt' -> cxt' { names = old_names }) diff --git a/src/ReportForall.hs b/src/ReportForall.hs new file mode 100644 index 0000000..8b5e9db --- /dev/null +++ b/src/ReportForall.hs @@ -0,0 +1,68 @@ +{-# LANGUAGE LambdaCase #-} + +module ReportForall (reportForall) where + +import Auxiliary (partitionDefs) +import Control.Monad (unless, void, when) +import Control.Monad.Except (MonadError (throwError)) +import Data.Either.Combinators (mapRight) +import Data.Foldable (foldlM) +import Data.Function (on) +import Data.List (delete) +import Grammar.Abs +import Grammar.ErrM (Err) +import TypeChecker.TypeChecker (TypeChecker (Bi, Hm)) + +reportForall :: TypeChecker -> Program -> Err () +reportForall tc p = do + when (tc == Hm) $ rpProgram rpaType p + rpProgram rpuType p + +rpuType :: Type -> Err () +rpuType typ = do + tvars <- go [] typ + unless (null tvars) $ throwError "Unused forall" + where + go tvars = \case + TAll tvar t + | tvar `elem` tvars -> throwError "Unused forall" + | otherwise -> go (tvar : tvars) t + TVar tvar -> pure (delete tvar tvars) + TFun t1 t2 -> go tvars t1 >>= (`go` t2) + TData _ typs -> foldlM go tvars typs + _ -> pure tvars + + +rpaType :: Type -> Err () +rpaType = rpForall . skipForall + where + skipForall = \case + TAll _ t -> skipForall t + t -> t + rpForall = \case + TAll {} -> throwError "Higher rank forall not allowed" + TFun t1 t2 -> on (>>) rpForall t1 t2 + TData _ typs -> mapM_ rpForall typs + _ -> pure () + +rpProgram :: (Type -> Err ()) -> Program -> Err () +rpProgram rf (Program defs) = do + mapM_ rpuBind bs + mapM_ rpuData ds + mapM_ rpuSig ss + where + (ds, ss, bs) = partitionDefs defs + rpuSig (Sig _ typ) = rf typ + rpuData (Data typ injs) = rf typ >> mapM rpuInj injs + rpuInj (Inj _ typ) = rf typ + rpuBind (Bind _ _ rhs) = rpuExp rhs + rpuBranch (Branch _ e) = rpuExp e + rpuExp = \case + EAnn e t -> rpuExp e >> rf t + EApp e1 e2 -> on (>>) rpuExp e1 e2 + EAdd e1 e2 -> on (>>) rpuExp e1 e2 + ELet bind e -> rpuBind bind >> rpuExp e + EAbs _ e -> rpuExp e + ECase e bs -> rpuExp e >> mapM_ rpuBranch bs + _ -> pure () + diff --git a/src/TreeConverter.hs b/src/TreeConverter.hs new file mode 100644 index 0000000..2dfa7d2 --- /dev/null +++ b/src/TreeConverter.hs @@ -0,0 +1,13 @@ +module TreeConverter where + +--import qualified Grammar.Abs as G +--import qualified TypeChecker.TypeCheckerIr as T +-- +--convertToTypecheckerIR :: G.Program -> Either String T.Program +--convertToTypecheckerIR (G.Program defs) = T.Program (map convertDef defs) +-- +--convertDef :: G.Bind -> T.Bind +--convertDef (G.Bind name t _ args exp) = T.Bind (name, t) (map (\i -> (i, T.TMono "Int"))) (convertExp exp) +-- +-- + diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs deleted file mode 100644 index 1e44888..0000000 --- a/src/TypeChecker.hs +++ /dev/null @@ -1,178 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedRecordDot #-} - -module TypeChecker (typecheck, partitionType) where - -import Auxiliary (maybeToRightM, snoc) -import Control.Monad.Except (throwError, unless) -import Data.Map (Map) -import qualified Data.Map as Map -import Grammar.Abs -import Grammar.ErrM (Err) -import Grammar.Print (Print (prt), concatD, doc, printTree, - render) -import Prelude hiding (exp, id) -import qualified TypeCheckerIr as T - --- NOTE: this type checker is poorly tested - --- TODO --- Coercion --- Type inference - -data Cxt = Cxt - { env :: Map Ident Type -- ^ Local scope signature - , sig :: Map Ident Type -- ^ Top-level signatures - } - -initCxt :: [Bind] -> Cxt -initCxt sc = Cxt { env = mempty - , sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc - } - -typecheck :: Program -> Err T.Program -typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc - --- | Check if infered rhs type matches type signature. -checkBind :: Cxt -> Bind -> Err T.Bind -checkBind cxt b = - case expandLambdas b of - Bind name t _ parms rhs -> do - (rhs', t_rhs) <- infer cxt rhs - unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs - pure $ T.Bind (name, t) (zip parms ts_parms) rhs' - where - ts_parms = fst $ partitionType (length parms) t - --- | @ f x y = rhs ⇒ f = \x.\y. rhs @ -expandLambdas :: Bind -> Bind -expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' - where - rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms - ts_parms = fst $ partitionType (length parms) t - --- | Infer type of expression. -infer :: Cxt -> Exp -> Err (T.Exp, Type) -infer cxt = \case - EId x -> - case lookupEnv x cxt of - Nothing -> - case lookupSig x cxt of - Nothing -> throwError ("Unbound variable:" ++ printTree x) - Just t -> pure (T.EId (x, t), t) - Just t -> pure (T.EId (x, t), t) - - EInt i -> pure (T.EInt i, T.TInt) - - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure (T.EApp t2 e' e1', t2) - _ -> do - throwError ("Not a function: " ++ show e) - - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure (T.EAdd T.TInt e' e1', T.TInt) - - EAbs x t e -> do - (e', t1) <- infer (insertEnv x t cxt) e - let t_abs = TFun t t1 - pure (T.EAbs t_abs (x, t) e', t_abs) - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - (e', t) <- infer cxt' e - pure (T.ELet b' e', t) - - EAnn e t -> do - (e', t1) <- infer cxt e - unless (typeEq t t1) $ - throwError "Inferred type and type annotation doesn't match" - pure (e', t1) - --- | Check infered type matches the supplied type. -check :: Cxt -> Exp -> Type -> Err T.Exp -check cxt exp typ = case exp of - - EId x -> do - t <- case lookupEnv x cxt of - Nothing -> maybeToRightM - ("Unbound variable:" ++ printTree x) - (lookupSig x cxt) - Just t -> pure t - unless (typeEq t typ) . throwError $ typeErr x typ t - pure $ T.EId (x, t) - - EInt i -> do - unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ - pure $ T.EInt i - - EApp e e1 -> do - (e', t) <- infer cxt e - case t of - TFun t1 t2 -> do - e1' <- check cxt e1 t1 - pure $ T.EApp t2 e' e1' - _ -> throwError ("Not a function 2: " ++ printTree e) - - EAdd e e1 -> do - e' <- check cxt e T.TInt - e1' <- check cxt e1 T.TInt - pure $ T.EAdd T.TInt e' e1' - - EAbs x t e -> do - (e', t_e) <- infer (insertEnv x t cxt) e - let t1 = TFun t t_e - unless (typeEq t1 typ) $ throwError "Wrong lamda type!" - pure $ T.EAbs t1 (x, t) e' - - ELet b e -> do - let cxt' = insertBind b cxt - b' <- checkBind cxt' b - e' <- check cxt' e typ - pure $ T.ELet b' e' - - EAnn e t -> do - unless (typeEq t typ) $ - throwError "Inferred type and type annotation doesn't match" - check cxt e t - --- | Check if types are equivalent. Doesn't handle coercion or polymorphism. -typeEq :: Type -> Type -> Bool -typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 -typeEq t t1 = t == t1 - --- | Partion type into types of parameters and return type. -partitionType :: Int -- Number of parameters to apply - -> Type - -> ([Type], Type) -partitionType = go [] - where - go acc 0 t = (acc, t) - go acc i t = case t of - TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" - -insertBind :: Bind -> Cxt -> Cxt -insertBind (Bind n t _ _ _) = insertEnv n t - -lookupEnv :: Ident -> Cxt -> Maybe Type -lookupEnv x = Map.lookup x . env - -insertEnv :: Ident -> Type -> Cxt -> Cxt -insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } - -lookupSig :: Ident -> Cxt -> Maybe Type -lookupSig x = Map.lookup x . sig - -typeErr :: Print a => a -> Type -> Type -> String -typeErr p expected actual = render $ concatD - [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" - , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" - , doc $ showString "Actual: " , prt 0 actual - ] diff --git a/src/TypeChecker/Bugs.md b/src/TypeChecker/Bugs.md new file mode 100644 index 0000000..d1fd70d --- /dev/null +++ b/src/TypeChecker/Bugs.md @@ -0,0 +1,48 @@ +# Bugs + +## Using uninstantiated type variables + +Program below should not type check + +```hs +data Test (a) where { + Test : b -> Test (a) + }; +``` + +## Duplicate definitions of functions + +Program below should not type check + +```hs +id x = x ; +id x = x ; +``` + +## What? + +Program below should not type check + +```hs +main : a -> b ; +main x = x; +``` +## Pattern match on functions + +Program below should not type check + +```hs +main = case \x. x of { + _ => 0; +}; +``` + +# Inference should not depend on order + +This one is really tough, strangely +Spent many hours on this so far + +```hs +main = id 0 ; +id x = x; +``` diff --git a/src/TypeChecker/RemoveForall.hs b/src/TypeChecker/RemoveForall.hs new file mode 100644 index 0000000..886ecb0 --- /dev/null +++ b/src/TypeChecker/RemoveForall.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeChecker.RemoveForall (removeForall) where + +import Auxiliary (onM) +import Control.Applicative (Applicative (liftA2)) +import Data.Function (on) +import Data.List (partition) +import Data.Tuple.Extra (second) +import Grammar.ErrM (Err) +import qualified TypeChecker.ReportTEVar as R +import TypeChecker.TypeCheckerIr + +removeForall :: Program' R.Type -> Program +removeForall (Program defs) = Program $ map (DData . rfData) ds + ++ map (DBind . rfBind) bs + where + (ds, bs) = ([d | DData d <- defs ], [ b | DBind b <- defs ]) + rfData (Data typ injs) = Data (rfType typ) (map rfInj injs) + rfInj (Inj name typ) = Inj name (rfType typ) + rfBind (Bind name vars rhs) = Bind (rfId name) (map rfId vars) (rfExpT rhs) + rfId = second rfType + rfExpT (e, t) = (rfExp e, rfType t) + rfExp = \case + EApp e1 e2 -> on EApp rfExpT e1 e2 + EAdd e1 e2 -> on EAdd rfExpT e1 e2 + ELet bind e -> ELet (rfBind bind) (rfExpT e) + EAbs name e -> EAbs name (rfExpT e) + ECase e bs -> ECase (rfExpT e) (map rfBranch bs) + ELit lit -> ELit lit + EVar name -> EVar name + EInj name -> EInj name + rfBranch (Branch p e) = Branch (rfPatternT p) (rfExpT e) + rfPatternT (p, t) = (rfPattern p, rfType t) + rfPattern = \case + PVar name -> PVar name + PLit lit -> PLit lit + PCatch -> PCatch + PEnum name -> PEnum name + PInj name ps -> PInj name (map rfPatternT ps) + +rfType :: R.Type -> Type +rfType = \case + R.TAll _ t -> rfType t + R.TFun t1 t2 -> on TFun rfType t1 t2 + R.TData name ts -> TData name (map rfType ts) + R.TLit lit -> TLit lit + R.TVar tvar -> TVar tvar + diff --git a/src/TypeChecker/ReportTEVar.hs b/src/TypeChecker/ReportTEVar.hs new file mode 100644 index 0000000..62cd301 --- /dev/null +++ b/src/TypeChecker/ReportTEVar.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeChecker.ReportTEVar where + +import Auxiliary (onM) +import Control.Applicative (Applicative (liftA2), liftA3) +import Control.Monad.Except (MonadError (throwError)) +import Data.Coerce (coerce) +import Data.Tuple.Extra (secondM) +import Grammar.Abs qualified as G +import Grammar.ErrM (Err) +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr hiding (Type (..)) + +data Type + = TLit Ident + | TVar TVar + | TData Ident [Type] + | TFun Type Type + | TAll TVar Type + deriving (Eq, Ord, Show, Read) + +class ReportTEVar a b where + reportTEVar :: a -> Err b + +instance ReportTEVar (Program' G.Type) (Program' Type) where + reportTEVar (Program defs) = Program <$> reportTEVar defs + +instance ReportTEVar (Def' G.Type) (Def' Type) where + reportTEVar = \case + DBind bind -> DBind <$> reportTEVar bind + DData dat -> DData <$> reportTEVar dat + +instance ReportTEVar (Bind' G.Type) (Bind' Type) where + reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs) + +instance ReportTEVar (Exp' G.Type) (Exp' Type) where + reportTEVar exp = case exp of + EVar name -> pure $ EVar name + EInj name -> pure $ EInj name + ELit lit -> pure $ ELit lit + ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e) + EApp e1 e2 -> onM EApp reportTEVar e1 e2 + EAdd e1 e2 -> onM EAdd reportTEVar e1 e2 + EAbs name e -> EAbs name <$> reportTEVar e + ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches) + +instance ReportTEVar (Branch' G.Type) (Branch' Type) where + reportTEVar (Branch (patt, t_patt) e) = liftA2 Branch (liftA2 (,) (reportTEVar patt) (reportTEVar t_patt)) (reportTEVar e) + +instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where + reportTEVar (p, t) = liftA2 (,) (reportTEVar p) (reportTEVar t) + +instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where + reportTEVar = \case + PVar name -> pure $ PVar name + PLit lit -> pure $ PLit lit + PCatch -> pure PCatch + PEnum name -> pure $ PEnum name + PInj name ps -> PInj name <$> reportTEVar ps + +instance ReportTEVar (Data' G.Type) (Data' Type) where + reportTEVar (Data typ injs) = liftA2 Data (reportTEVar typ) (reportTEVar injs) + +instance ReportTEVar (Inj' G.Type) (Inj' Type) where + reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ + +instance ReportTEVar (Id' G.Type) (Id' Type) where + reportTEVar = secondM reportTEVar + +instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where + reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ) + +instance ReportTEVar a b => ReportTEVar [a] [b] where + reportTEVar = mapM reportTEVar + +instance ReportTEVar G.Type Type where + reportTEVar = \case + G.TLit lit -> pure $ TLit (coerce lit) + G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i) + G.TData name typs -> TData (coerce name) <$> reportTEVar typs + G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) + G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t + G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs new file mode 100644 index 0000000..7f3d67a --- /dev/null +++ b/src/TypeChecker/TypeChecker.hs @@ -0,0 +1,20 @@ +module TypeChecker.TypeChecker (typecheck, TypeChecker (..)) where + +import Control.Monad ((<=<)) +import qualified Grammar.Abs as G +import Grammar.ErrM (Err) +import TypeChecker.RemoveForall (removeForall) +import qualified TypeChecker.ReportTEVar as R +import TypeChecker.ReportTEVar (reportTEVar) +import qualified TypeChecker.TypeCheckerBidir as Bi +import qualified TypeChecker.TypeCheckerHm as Hm +import TypeChecker.TypeCheckerIr + +data TypeChecker = Bi | Hm deriving Eq + +typecheck :: TypeChecker -> G.Program -> Err Program +typecheck tc = fmap removeForall . (reportTEVar <=< f) + where + f = case tc of + Bi -> Bi.typecheck + Hm -> fmap fst . Hm.typecheck diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs new file mode 100644 index 0000000..04a8d91 --- /dev/null +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -0,0 +1,858 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module TypeChecker.TypeCheckerBidir (typecheck) where + +import Auxiliary (int, litType, maybeToRightM, snoc) +import Control.Applicative (Applicative (liftA2), (<|>)) +import Control.Monad.Except (ExceptT, MonadError (throwError), + forM, runExceptT, unless, zipWithM, + zipWithM_) +import Control.Monad.Extra (fromMaybeM, ifM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Coerce (coerce) +import Data.Foldable (foldlM) +import Data.Function (on) +import Data.List (intercalate) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe, isNothing) +import Data.Sequence (Seq (..)) +import qualified Data.Sequence as S +import qualified Data.Set as Set +import Data.Tuple.Extra (second) +import Debug.Trace (trace) +import Grammar.Abs +import Grammar.ErrM +import Grammar.Print (printTree) +import Prelude hiding (exp) +import qualified TypeChecker.TypeCheckerIr as T + +-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) +-- https://doi.org/10.1145/2500365.2500582 +-- +-- TODO +-- • Fix problems with types in Pattern/Branch in TypeCheckerIr +-- • Remove EAdd +-- • Add kinds!! + +data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A + | EnvTVar TVar -- ^ Universal type variable. α + | EnvTEVar TEVar -- ^ Existential unsolved type variable. ά + | EnvTEVarSolved TEVar Type -- ^ Existential solved type variable. ά = τ + | EnvMark TEVar -- ^ Scoping Marker. ▶ ά + deriving (Eq, Show) + +type Env = Seq EnvElem + +-- | Ordered context +-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A +data Cxt = Cxt + { env :: Env -- ^ Local scope context Γ + , sig :: Map LIdent Type -- ^ Top-level signatures x : A + , binds :: Map LIdent Exp -- ^ Top-level binds x : e + , next_tevar :: Int -- ^ Counter to distinguish ά + , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A + , currentBind :: LIdent -- ^ Used for recursive functions + } deriving (Show, Eq) + +newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } + deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) + + +initCxt :: [Def] -> Cxt +initCxt defs = Cxt + { env = mempty + , sig = Map.fromList [ (name, t) + | DSig' name t <- defs + ] + , binds = Map.fromList [ (name, foldr EAbs rhs vars) + | DBind' name vars rhs <- defs + ] + , next_tevar = 0 + , data_injs = Map.fromList [ (name, foldr TAll t $ unboundedTVars t) + | DData (Data _ injs) <- defs + , Inj name t <- injs + ] + , currentBind = "" + } + where + unboundedTVars = uncurry (Set.\\) . go (mempty, mempty) + where + go (unbounded, bounded) = \case + TAll tvar t -> go (unbounded, Set.insert tvar bounded) t + TVar tvar -> (Set.insert tvar unbounded, bounded) + TFun t1 t2 -> foldl go (unbounded, bounded) [t1, t2] + TData _ typs -> foldl go (unbounded, bounded) typs + _ -> (unbounded, bounded) + +typecheck :: Program -> Err (T.Program' Type) +typecheck (Program defs) = do + dataTypes' <- mapM typecheckDataType [ d | DData d <- defs ] + binds' <- typecheckBinds (initCxt defs) [b | DBind b <- defs] + pure . T.Program $ map T.DData dataTypes' ++ map T.DBind binds' + +typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type] +typecheckBinds cxt = flip evalState cxt + . runExceptT + . runTc + . mapM typecheckBind + +typecheckBind :: Bind -> Tc (T.Bind' Type) +typecheckBind (Bind name vars rhs) = do + modify $ \cxt -> cxt { currentBind = name } + bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case + Just t -> do + (rhs', _) <- check (foldr EAbs rhs vars) t + pure (T.Bind (coerce name, t) [] (rhs', t)) + Nothing -> do + (e, t) <- apply =<< infer (foldr EAbs rhs vars) + pure (T.Bind (coerce name, t) [] (e, t)) + env <- gets env + unless (isComplete env) err + insertSig (coerce name) typ + putEnv Empty + pure bind' + where + err = throwError $ unlines + [ "Type inference failed: " ++ printTree (Bind name vars rhs) + , "Did you forget to add type annotation to a polymorphic function?" + ] + +-- TODO remove some checks +typecheckDataType :: Data -> Err (T.Data' Type) +typecheckDataType (Data typ injs) = do + (name, tvars) <- go [] typ + injs' <- mapM (\i -> typecheckInj i name tvars) injs + pure (T.Data typ injs') + where + go tvars = \case + TAll tvar t -> go (tvar:tvars) t + TData name typs + | Right tvars' <- mapM toTVar typs + , all (`elem` tvars) tvars' + -> pure (name, tvars') + _ -> throwError $ unwords ["Bad data type definition: ", ppT typ] + +-- TODO remove some checks +typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type) +typecheckInj (Inj inj_name inj_typ) name tvars + | not $ boundTVars tvars inj_typ + = throwError "Unbound type variables" + | TData name' typs <- getDataId inj_typ + , name' == name + , Right tvars' <- mapM toTVar typs + , all (`elem` tvars) tvars' + = pure $ T.Inj (coerce inj_name) (foldr TAll inj_typ tvars') + | otherwise + = throwError $ unwords + ["Bad type constructor: ", show name + , "\nExpected: ", ppT . TData name $ map TVar tvars + , "\nActual: ", ppT $ getDataId inj_typ + ] + where + boundTVars :: [TVar] -> Type -> Bool + boundTVars tvars' = \case + TAll tvar t -> boundTVars (tvar:tvars') t + TFun t1 t2 -> on (&&) (boundTVars tvars') t1 t2 + TVar tvar -> elem tvar tvars' + TData _ typs -> all (boundTVars tvars) typs + TLit _ -> True + TEVar _ -> error "TEVar in data type declaration" + + + +--------------------------------------------------------------------------- +-- * Typing rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ e ↑ A ⊣ Δ +-- Under input context Γ, e checks against input type A, with output context ∆ +check :: Exp -> Type -> Tc (T.ExpT' Type) + +-- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ +-- ------------------- ∀I +-- Γ ⊢ e ↑ ∀α.A ⊣ Δ +check e (TAll alpha a) = do + let env_tvar = EnvTVar alpha + insertEnv env_tvar + e' <- check e a + (env_l, _) <- gets (splitOn env_tvar . env) + putEnv env_l + apply e' + +-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ +-- --------------------------- →I +-- Γ ⊢ λx.e ↑ A → B ⊣ Δ +check (EAbs x e) (TFun a b) = do + let env_var = EnvVar x a + insertEnv env_var + e' <- check e b + (env_l, _) <- gets (splitOn env_var . env) + putEnv env_l + apply (T.EAbs (coerce x) e', TFun a b) + + --FIXME +-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ +-- ------------------------------------ Case +-- Γ ⊢ case e of Π ↓ C ⊣ Δ +check (ECase scrut pi) c = do + (scrut', a) <- infer scrut + case pi of + [] -> do + subtype a c + apply (T.ECase (scrut', a) [], a) + _ -> do + pi' <- forM pi $ \(Branch p e) -> do + p' <- checkPattern p =<< apply a + e' <- check e c + pure (T.Branch p' e') + apply (T.ECase (scrut', a) pi', c) + where + go (pi, b) (Branch p e) = do + p' <- checkPattern p =<< apply a + e'@(_, b') <- infer e + subtype b' b + apply (T.Branch p' e' : pi, b') + + +-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ +-- -------------------------------------- Sub +-- Γ ⊢ e ↑ B ⊣ Δ +check e b = do + (e', a) <- infer e + b' <- apply b + subtype a b' + apply (e', b) + + + + +checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) +checkPattern patt t_patt = case patt of + + -- ------------------- + -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) + PVar x -> do + insertEnv $ EnvVar x t_patt + apply (T.PVar (coerce x), t_patt) + + -- ------------- + -- Γ ⊢ _ ↑ A ⊣ Γ + PCatch -> apply (T.PCatch, t_patt) + + -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ + -- ------------------------------ + -- Γ ⊢ τ ↑ B ⊣ Δ + PLit lit -> do + subtype (litType lit) t_patt + apply (T.PLit lit, t_patt) + + -- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ + -- --------------------------- + -- Γ ⊢ K ↑ B ⊣ Δ + PEnum name -> do + t <- maybeToRightM ("Unknown constructor " ++ show name) + =<< lookupInj name + subtype t t_patt + apply (T.PEnum (coerce name), t_patt) + + -- Example + -- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs + -- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁ + -- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂ + -- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ + -- --------------------------- + -- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ + PInj name ps -> do + t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name + let ts = getArgs t_inj + unless (length ts == length ps) + $ throwError "Wrong number of arguments!" + + -- [ά/α] + sub <- substituteTVarsOf t_inj + subtype (sub $ getDataId t_inj) t_patt + let check p t = checkPattern p =<< apply (sub t) + ps' <- zipWithM check ps ts + apply (T.PInj (coerce name) ps', t_patt) + where + substituteTVarsOf = \case + TAll tvar t -> do + tevar <- fresh + (substitute tvar tevar .) <$> substituteTVarsOf t + _ -> pure id + + getArgs = \case + TAll _ t -> getArgs t + t -> go [] t + where + go acc = \case + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> acc + +-- | Γ ⊢ e ↓ A ⊣ Δ +-- Under input context Γ, e infers output type A, with output context ∆ +infer :: Exp -> Tc (T.ExpT' Type) +infer (ELit lit) = apply (T.ELit lit, litType lit) + +-- Γ ∋ (x : A) Γ ⊢ rec(x) +-- ------------- Var --------------------- VarRec +-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,(x : ά) +infer (EVar x) = do + a <- ifM (gets $ (x==) . currentBind) varRec var + apply (T.EVar (coerce x), a) + where + var = maybeToRightM "Can't infer" =<< + liftA2 (<|>) (lookupEnv x) (lookupSig x) + varRec = do + alpha <- TEVar <$> fresh + insertEnv (EnvVar x alpha) + pure alpha + +infer (EInj kappa) = do + t <- maybeToRightM ("Unknown constructor: " ++ show kappa) + =<< lookupInj kappa + apply (T.EInj $ coerce kappa, t) + +-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ +-- --------------------- Anno +-- Γ ⊢ (e : A) ↓ A ⊣ Δ +infer (EAnn e a) = do + _ <- gets $ (`wellFormed` a) . env + (e', _) <- check e a + apply (e', a) + +-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ +-- ----------------------------------- →E +-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ +infer (EApp e1 e2) = do + e1'@(_, a) <- infer e1 + (e2', c) <- applyInfer a e2 + apply (T.EApp e1' e2', c) + +-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ +-- ------------------------------- →I +-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ +infer (EAbs name e) = do + alpha <- fresh + epsilon <- fresh + insertEnv $ EnvTEVar alpha + insertEnv $ EnvTEVar epsilon + let env_var = EnvVar name (TEVar alpha) + insertEnv env_var + e' <- check e $ TEVar epsilon + dropTrailing env_var + apply (T.EAbs (coerce name) e', on TFun TEVar alpha epsilon) + +-- Γ ⊢ rhs ↓ A ⊣ Θ Θ,(x:A) ⊢ e ↑ C ⊣ Δ,(x:A),Θ +-- -------------------------------------------- LetI +-- Γ ⊢ let x = rhs in e ↑ C ⊣ Δ +infer (ELet (Bind x vars rhs) e) = do + (rhs', a) <- infer $ foldr EAbs rhs vars + let env_var = EnvVar x a + insertEnv env_var + e'@(_, c) <- infer e + (env_l, _) <- gets (splitOn env_var . env) + putEnv env_l + apply (T.ELet (T.Bind (coerce x, a) [] (rhs', a)) e', c) + +-- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int +-- --------------------------- +I +-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ +infer (EAdd e1 e2) = do + e1' <- check e1 int + e2' <- check e2 int + apply (T.EAdd e1' e2', int) + + --FIXME +-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↑ C ⊣ Δ +-- ------------------------------------ Case +-- Γ ⊢ case e of Π ↓ C ⊣ Δ +infer (ECase scrut pi) = do + (scrut', a) <- infer scrut + case pi of + [] -> apply (T.ECase (scrut', a) [], a) + (Branch _ e):_ -> do + (_, b)<- infer e + (pi', b') <- foldlM go ([], b) pi + apply (T.ECase (scrut', a) pi', b') + where + go (pi, b) (Branch p e) = do + p' <- checkPattern p =<< apply a + e'@(_, b') <- infer e + subtype b' b + apply (T.Branch p' e' : pi, b') + +-- | Γ ⊢ A • e ⇓ C ⊣ Δ +-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ +-- Instantiate existential type variables until there is an arrow type. +applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type) + +-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ +-- ------------------------ ∀App +-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ +applyInfer (TAll alpha a) e = do + alpha' <- fresh + insertEnv $ EnvTEVar alpha' + applyInfer (substitute alpha alpha' a) e + +-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ +-- ------------------------------- άApp +-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ +applyInfer (TEVar alpha) e = do + alpha1 <- fresh + alpha2 <- fresh + (env_l, env_r) <- gets (splitOn (EnvTEVar alpha) . env) + putEnv $ (env_l + :|> EnvTEVar alpha2 + :|> EnvTEVar alpha1 + :|> EnvTEVarSolved alpha (on TFun TEVar alpha1 alpha2) + ) <> env_r + e' <- check e $ TEVar alpha1 + apply (e', TEVar alpha2) + +-- Γ ⊢ e ↑ A ⊣ Δ +-- --------------------- →App +-- Γ ⊢ A → C • e ⇓ C ⊣ Δ +applyInfer (TFun a c) e = do + exp' <- check e a + apply (exp', c) + +applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e) + +--------------------------------------------------------------------------- +-- * Subtyping rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ A <: B ⊣ Δ +-- Under input context Γ, type A is a subtype of B, with output context ∆ +subtype :: Type -> Type -> Tc () +subtype (TLit lit1) (TLit lit2) | lit1 == lit2 = pure () + +-- -------------------- <:Var +-- Γ[α] ⊢ α <: α ⊣ Γ[α] +subtype (TVar alpha) (TVar alpha') | alpha == alpha' = pure () + +-- -------------------- <:Exvar +-- Γ[ά] ⊢ ά <: ά ⊣ Γ[ά] +subtype (TEVar alpha) (TEVar alpha') | alpha == alpha' = pure () + +-- Γ ⊢ B₁ <: A₁ ⊣ Θ Θ ⊢ [Θ]A₂ <: [Θ]B₂ ⊣ Δ +-- ----------------------------------------- <:→ +-- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ +subtype (TFun a1 a2) (TFun b1 b2) = do + subtype b1 a1 + a2' <- apply a2 + b2' <- apply b2 + subtype a2' b2' + +-- Γ, α ⊢ A <: B ⊣ Δ,α,Θ +-- --------------------- <:∀R +-- Γ ⊢ A <: ∀α. B ⊣ Δ +subtype a (TAll alpha b) = do + let env_tvar = EnvTVar alpha + insertEnv env_tvar + subtype a b + dropTrailing env_tvar + +-- Γ,▶ ά,ά ⊢ [ά/α]A <: B ⊣ Δ,▶ ά,Θ +-- ------------------------------- <:∀L +-- Γ ⊢ ∀α.A <: B ⊣ Δ +subtype (TAll alpha a) b = do + alpha' <- fresh + let env_marker = EnvMark alpha' + insertEnv env_marker + insertEnv $ EnvTEVar alpha' + let a' = substitute alpha alpha' a + subtype a' b + dropTrailing env_marker + +-- ά ∉ FV(A) Γ[ά] ⊢ ά :=< A ⊣ Δ +-- ------------------------------ <:instantiateL +-- Γ[ά] ⊢ ά <: A ⊣ Δ +subtype (TEVar alpha) a | notElem alpha $ frees a = instantiateL alpha a + +-- ά ∉ FV(A) Γ[ά] ⊢ A =:< ά ⊣ Δ +-- ------------------------------ <:instantiateR +-- Γ[ά] ⊢ A <: ά ⊣ Δ +subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha + + +subtype (TData name1 typs1) (TData name2 typs2) + + -- D₁ = D₂ + -- ---------------- + -- Γ ⊢ D₁ () <: D₂ () + | name1 == name2 + , [] <- typs1 + , [] <- typs2 + = pure () + + -- Γ ⊢ ά₁ <: έ₁ ⊣ Θ₁ + -- ... + -- D₁ = D₂ Θₙ₋₁ ⊢ [Θₙ₋₁]άₙ <: [Θₙ₋₁]έₙ ⊣ Δ + -- ------------------------------------------- + -- Γ ⊢ D (ά₁ ‥ άₙ) <: D (έ₁ ‥ έₙ) ⊣ Δ + | name1 == name2 + , t1:t1s <- typs1 + , t2:t2s <- typs2 + = do + subtype t1 t2 + zipWithM_ go t1s t2s + where + go t1' t2' = do + t1'' <- apply t1' + t2'' <- apply t2' + subtype t1'' t2'' + +subtype (TIdent t1) (TIdent t2) | t1 == t2 = pure () + +subtype t1 t2 = throwError $ unwords ["Types", show t1, "and", show t2, "doesn't match!"] + +--------------------------------------------------------------------------- +-- * Instantiation rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ ά :=< A ⊣ Δ +-- Under input context Γ, instantiate ά such that ά <: A, with output context ∆ +instantiateL :: TEVar -> Type -> Tc () +instantiateL alpha a = gets env >>= \env -> go env alpha a + where + go env alpha tau + | isMono tau + , (env_l, env_r) <- splitOn (EnvTEVar alpha) env + , Right _ <- wellFormed env_l tau + = putEnv $ (env_l :|> EnvTEVarSolved alpha tau) <> env_r + + -- Γ ⊢ τ + -- ----------------------------- InstLSolve + -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' + go env alpha tau + | isMono tau + , (env_l, env_r) <- splitOn (EnvTEVar alpha) env + , Right _ <- wellFormed env_l tau + = putEnv $ (env_l :|> EnvTEVarSolved alpha tau) <> env_r + + -- ----------------------------- InstLReach + -- Γ[ά][έ] ⊢ ά :=< έ ⊣ Γ[ά][έ=ά] + go env alpha (TEVar epsilon) = do + let (env_l, env_r) = splitOn (EnvTEVar epsilon) env + putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r + + -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ =:< ά₁ ⊣ Θ Θ ⊢ ά₂ :=< [Θ]A₂ ⊣ Δ + -- ------------------------------------------------------- InstLArr + -- Γ[ά] ⊢ ά :=< A₁ → A₂ ⊣ Δ + go _ alpha (TFun a1 a2) = do + alpha1 <- fresh + alpha2 <- fresh + insertEnv $ EnvTEVar alpha2 + insertEnv $ EnvTEVar alpha1 + insertEnv $ EnvTEVarSolved alpha (on TFun TEVar alpha1 alpha2) + instantiateR a1 alpha1 + instantiateL alpha2 =<< apply a2 + + -- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ' + -- ------------------------- InstLAIIR + -- Γ[ά] ⊢ ά :=< ∀ε.Ε ⊣ Δ + go env tevar (TAll tvar t) = do + instantiateL tevar t + let (env_l, _) = splitOn (EnvTVar tvar) env + putEnv env_l + + go _ alpha a = error $ "Trying to instantiateL: " ++ ppT (TEVar alpha) + ++ " <: " ++ ppT a + +-- | Γ ⊢ A =:< ά ⊣ Δ +-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆ +instantiateR :: Type -> TEVar -> Tc () +instantiateR a alpha = gets env >>= \env -> go env a alpha + where + -- Γ ⊢ τ + -- ----------------------------- InstRSolve + -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' + go env tau alpha + | isMono tau + , (env_l, env_r) <- splitOn (EnvTEVar alpha) env + , Right _ <- wellFormed env_l tau + = putEnv $ (env_l :|> EnvTEVarSolved alpha tau) <> env_r + + -- + -- ----------------------------- InstRReach + -- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά] + go env (TEVar epsilon) alpha = do + let (env_l, env_r) = splitOn (EnvTEVar epsilon) env + putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r + + -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ + -- ------------------------------------------------------- InstRArr + -- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ + go _ (TFun a1 a2) alpha = do + alpha1 <- fresh + alpha2 <- fresh + insertEnv $ EnvTEVar alpha2 + insertEnv $ EnvTEVar alpha1 + insertEnv $ EnvTEVarSolved alpha (on TFun TEVar alpha1 alpha2) + instantiateL alpha1 a1 + a2' <- apply a2 + instantiateR a2' alpha2 + + -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' + -- ---------------------------------- InstRAIIL + -- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ + go env (TAll epsilon e) alpha = do + epsilon' <- fresh + insertEnv $ EnvMark epsilon' + insertEnv $ EnvTVar epsilon + instantiateR (substitute epsilon epsilon' e) alpha + let (env_l, _) = splitOn (EnvMark epsilon') env + putEnv env_l + + go _ a alpha = throwError $ "Trying to instantiateR: " ++ ppT a ++ " <: " + ++ ppT (TEVar alpha) + +--------------------------------------------------------------------------- +-- * Auxiliary +--------------------------------------------------------------------------- + +frees :: Type -> [TEVar] +frees = \case + TLit _ -> [] + TVar _ -> [] + TEVar tevar -> [tevar] + TFun t1 t2 -> on (++) frees t1 t2 + TAll _ t -> frees t + TData _ typs -> concatMap frees typs + +-- | [ά/α]A +substitute :: TVar -- α + -> TEVar -- ά + -> Type -- A + -> Type -- [ά/α]A +substitute tvar tevar typ = case typ of + TLit _ -> typ + TVar tvar' | tvar' == tvar -> TEVar tevar + | otherwise -> typ + TEVar _ -> typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar' t -> TAll tvar' (substitute' t) + TData name typs -> TData name $ map substitute' typs + where + substitute' = substitute tvar tevar + +-- | Γ,x,Γ' → (Γ, Γ') +splitOn :: EnvElem -> Env -> (Env, Env) +splitOn x env = second (S.drop 1) $ S.breakl (==x) env + +-- | Drop frontmost elements until and including element @x@. +dropTrailing :: EnvElem -> Tc () +dropTrailing x = modifyEnv $ S.takeWhileL (/= x) + + +findSolved :: TEVar -> Env -> Maybe Type +findSolved _ Empty = Nothing +findSolved tevar (xs :|> x) = case x of + EnvTEVarSolved tevar' t | tevar == tevar' -> Just t + _ -> findSolved tevar xs + +-- | Γ ⊢ A +-- Under context Γ, type A is well-formed +wellFormed :: Env -> Type -> Err () +wellFormed env = \case + TLit _ -> pure () + + -- -------- UvarWF + -- Γ[α] ⊢ α + TVar tvar -> unless (EnvTVar tvar `elem` env) $ + throwError ("Unbound type variable: " ++ show tvar) + -- Γ ⊢ A Γ ⊢ B + -- ------------- ArrowWF + -- Γ ⊢ A → B + TFun t1 t2 -> do { wellFormed env t1; wellFormed env t2 } + + -- Γ,α ⊢ A + -- -------- ForallWF + -- Γ ⊢ ∀α.A + TAll tvar t -> wellFormed (env :|> EnvTVar tvar) t + + TEVar tevar + -- ---------- EvarWF + -- Γ[ά] ⊢ ά + | EnvTEVar tevar `elem` env -> pure () + + -- ---------- SolvedEvarWF + -- Γ[ά=τ] ⊢ ά + | Just _ <- findSolved tevar env -> pure () + | otherwise -> throwError ("Can't find type: " ++ show tevar) + + TData _ typs -> mapM_ (wellFormed env) typs + +isMono :: Type -> Bool +isMono = \case + TAll{} -> False + TFun t1 t2 -> on (&&) isMono t1 t2 + TData _ typs -> all isMono typs + TVar _ -> True + TEVar _ -> True + TLit _ -> True + +fresh :: Tc TEVar +fresh = do + tevar <- gets (MkTEVar . LIdent . show . next_tevar) + modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } + pure tevar + + +isComplete :: Env -> Bool +isComplete = isNothing . S.findIndexL unSolvedTEVar + where + unSolvedTEVar = \case + EnvTEVar _ -> True + _ -> False + +getDataId :: Type -> Type +getDataId typ = case typ of + TAll _ t -> getDataId t + TFun _ t -> getDataId t + TData {} -> typ + +toTVar :: Type -> Err TVar +toTVar = \case + TVar tvar -> pure tvar + _ -> throwError "Not a type variable" + +insertEnv :: EnvElem -> Tc () +insertEnv x = modifyEnv (:|> x) + +lookupSig :: LIdent -> Tc (Maybe Type) +lookupSig x = gets (Map.lookup x . sig) + +insertSig :: LIdent -> Type -> Tc () +insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig } + +lookupEnv :: LIdent -> Tc (Maybe Type) +lookupEnv x = gets (findId . env) + where + findId Empty = Nothing + findId (ys :|> y) = case y of + EnvVar x' t | x==x' -> Just t + _ -> findId ys + +lookupInj :: UIdent -> Tc (Maybe Type) +lookupInj x = gets (Map.lookup x . data_injs) + +putEnv :: Env -> Tc () +putEnv = modifyEnv . const + +modifyEnv :: (Env -> Env) -> Tc () +modifyEnv f = + modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env } + +pattern DBind' name vars exp = DBind (Bind name vars exp) +pattern DSig' name typ = DSig (Sig name typ) + +--------------------------------------------------------------------------- +-- * Apply +--------------------------------------------------------------------------- + +class Apply a where + apply :: a -> Tc a + +instance Apply Type where apply = applyType +instance Apply (T.Exp' Type) where apply = applyExp +instance Apply (T.Branch' Type) where apply = applyBranch +instance Apply (T.Pattern' Type) where apply = applyPattern +instance Apply a => Apply [a] where apply = mapM apply +instance (Apply a, Apply b) => Apply (a, b) where apply = applyPair +instance Apply T.Ident where apply = pure + +applyType :: Type -> Tc Type +applyType t = gets $ (`applyType'` t) . env + +-- | [Γ]A. Applies context to type until fully applied. +applyType' :: Env -> Type -> Type +applyType' cxt typ | typ == typ' = typ' + | otherwise = applyType' cxt typ' + where + typ' = case typ of + TLit _ -> typ + TData name typs -> TData name $ map (applyType' cxt) typs + -- [Γ]α = α + TVar _ -> typ + -- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ + -- [Γ[ά]]ά = [Γ[ά]]ά + TEVar tevar -> fromMaybe typ $ findSolved tevar cxt + -- [Γ](A → B) = [Γ]A → [Γ]B + TFun t1 t2 -> on TFun (applyType' cxt) t1 t2 + -- [Γ](∀α. A) = (∀α. [Γ]A) + TAll tvar t -> TAll tvar $ applyType' cxt t + TIdent t -> typ + +applyExp :: T.Exp' Type -> Tc (T.Exp' Type) +applyExp exp = case exp of + T.ELet (T.Bind id vars rhs) exp -> do + id <- apply id + vars' <- mapM apply vars + rhs' <- apply rhs + exp' <- apply exp + pure $ T.ELet (T.Bind id vars' rhs') exp' + T.EApp e1 e2 -> liftA2 T.EApp (apply e1) (apply e2) + T.EAdd e1 e2 -> liftA2 T.EAdd (apply e1) (apply e2) + T.EAbs name e -> T.EAbs name <$> apply e + T.ECase e branches -> liftA2 T.ECase (apply e) + (mapM apply branches) + _ -> pure exp + +applyBranch :: T.Branch' Type -> Tc (T.Branch' Type) +applyBranch (T.Branch (p, t) e) = do + pt <- liftA2 (,) (apply p) (apply t) + e' <- apply e + pure $ T.Branch pt e' + +applyPattern :: T.Pattern' Type -> Tc (T.Pattern' Type) +applyPattern = \case + T.PVar id -> T.PVar <$> apply id + T.PInj name ps -> T.PInj name <$> apply ps + p -> pure p + +applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b) +applyPair (x, y) = liftA2 (,) (apply x) (apply y) + +--------------------------------------------------------------------------- +-- * Debug +--------------------------------------------------------------------------- + +traceEnv s = do + env <- gets env + trace (s ++ " " ++ ppEnv env) pure () + +traceD s x = trace (s ++ " " ++ show x) pure () + +traceT s x = trace (s ++ " : " ++ ppT x) pure () + +traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure () + +ppT = \case + TLit (UIdent s) -> s + TVar (MkTVar (LIdent s)) -> "tvar_" ++ s + TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2 + TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t + TEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s + TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs) + ++ " )" + TIdent (UIdent name) -> name + +ppEnvElem = \case + EnvVar (LIdent s) t -> s ++ ":" ++ ppT t + EnvTVar (MkTVar (LIdent s)) -> "tvar_" ++ s + EnvTEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s + EnvTEVarSolved (MkTEVar (LIdent s)) t -> "tevar_" ++ s ++ "=" ++ ppT t + EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "tevar_" ++ s + +ppEnv = \case + Empty -> "·" + (xs :|> x) -> ppEnv xs ++ " (" ++ ppEnvElem x ++ ")" diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs new file mode 100644 index 0000000..f4ec70a --- /dev/null +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -0,0 +1,945 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +-- | A module for type checking and inference using algorithm W, Hindley-Milner +module TypeChecker.TypeCheckerHm where + +import Auxiliary (int, litType, maybeToRightM, unzip4) +import Auxiliary qualified as Aux +import Control.Monad.Except +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer +import Data.Coerce (coerce) +import Data.Function (on) +import Data.List (foldl', nub, sortOn) +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace (trace, traceShow) +import Grammar.Abs +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr qualified as T + +{- +TODO +Prettifying the types of generated variables does only need to be done when +presenting the types to the user, i.e, when the user has made a mistake. +For succesfully typed programs the types only need to match. + +-} + +-- | Type check a program +typecheck :: Program -> Either String (T.Program' Type, [Warning]) +typecheck = onLeft msg . run . checkPrg + where + onLeft :: (Error -> String) -> Either Error a -> Either String a + onLeft f (Left x) = Left $ f x + onLeft _ (Right x) = Right x + +checkPrg :: Program -> Infer (T.Program' Type) +checkPrg (Program bs) = do + preRun bs + -- sgs <- gets sigs + bs <- map snd . sortOn fst <$> bindCount bs + bs <- checkDef bs + -- return . prettify sgs . T.Program $ bs + return . T.Program $ bs + +-- | Send the map of user declared signatures to not rename stuff the user defined +prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type +prettify s (T.Program defs) = T.Program $ map (go s) defs + where + go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type + go _ (T.DData d) = T.DData d + go m b@(T.DBind (T.Bind (name, t) args (e, et))) + | Just (Just _) <- M.lookup name m = b + | otherwise = + let fvs = nub $ freeOrdered t + m = M.fromList $ zip fvs letters + in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et) + +replace :: Map T.Ident T.Ident -> Type -> Type +replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of + Just t -> TVar . MkTVar . LIdent $ coerce t + Nothing -> def +replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 +replace m (TData name ts) = TData name (map (replace m) ts) +replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of + Just found -> TAll (MkTVar $ coerce found) (replace m t) + Nothing -> def +replace _ t = t + +bindCount :: [Def] -> Infer [(Int, Def)] +bindCount [] = return [] +bindCount (x : xs) = do + (o, d) <- go x + b <- bindCount xs + return $ (o, d) : b + where + go :: Def -> Infer (Int, Def) + go b@(DBind (Bind _ _ e)) = do + db <- gets declaredBinds + let n = runIdentity $ evalStateT (countBinds db e) mempty + return (n, b) + go (DSig sig) = pure (0, DSig sig) + go (DData data_) = pure (-1, DData data_) + + countBinds :: Set T.Ident -> Exp -> StateT (Set T.Ident) Identity Int + countBinds declared = \case + EVar i -> do + found <- get + if coerce i `S.member` declared && not (coerce i `S.member` found) + then put (S.insert (coerce i) found) >> return 1 + else return 0 + ELet _ e -> countBinds declared e + EApp e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2 + EAdd e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2 + EAbs _ e -> countBinds declared e + ECase e1 brnchs -> do + let f (Branch _ e2) = countBinds declared e2 + (+) . sum <$> mapM f brnchs <*> countBinds declared e1 + _ -> return 0 + +preRun :: [Def] -> Infer () +preRun [] = return () +preRun (x : xs) = case x of + DSig (Sig n t) -> do + collect (collectTVars t) + s <- gets (M.keys . sigs) + duplicateDecl n s $ Aux.do + "Multiple signatures of function" + quote $ printTree n + insertSig (coerce n) (Just t) >> preRun xs + DBind (Bind n _ e) -> do + s <- gets (S.toList . declaredBinds) + duplicateDecl n s $ Aux.do + "Multiple declarations of function" + quote $ printTree n + collect (collectTVars e) + insertBind $ coerce n + s <- gets sigs + case M.lookup (coerce n) s of + Nothing -> insertSig (coerce n) Nothing >> preRun xs + Just _ -> preRun xs + DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs + where + -- Check if function body / signature has been declared already + duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) + +checkDef :: [Def] -> Infer [T.Def' Type] +checkDef [] = return [] +checkDef (x : xs) = case x of + (DBind b) -> do + b' <- checkBind b + xs' <- checkDef xs + return $ T.DBind b' : xs' + (DData d) -> do + xs' <- checkDef xs + return $ T.DData (coerceData d) : xs' + (DSig _) -> checkDef xs + where + coerceData (Data t injs) = + T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs + +freeOrdered :: Type -> [T.Ident] +freeOrdered (TVar (MkTVar a)) = return (coerce a) +freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t +freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b +freeOrdered (TData _ a) = concatMap freeOrdered a +freeOrdered _ = mempty + +-- Much cleaner implementation, unfortunately one minor bug +-- checkBind :: Bind -> Infer (T.Bind' Type) +-- checkBind (Bind name args expr) = do +-- fr <- fresh +-- let lambda = makeLambda expr (reverse (coerce args)) +-- withBinding (coerce name) fr $ do +-- (sub, (e, infSig)) <- algoW lambda +-- env <- asks vars +-- let genInfSig = generalize (apply sub env) infSig +-- maybeSig <- gets (join . M.lookup (coerce name) . sigs) +-- case maybeSig of +-- Just typSig -> do +-- unless +-- (genInfSig <<= typSig) +-- ( throwError $ +-- Error +-- ( Aux.do +-- "Inferred type" +-- quote $ printTree infSig +-- "doesn't match given type" +-- quote $ printTree typSig +-- ) +-- False +-- ) +-- return $ T.Bind (coerce name, typSig) [] (apply sub e, typSig) +-- _ -> do +-- insertSig (coerce name) (Just genInfSig) +-- return $ T.Bind (coerce name, genInfSig) [] (apply sub e, genInfSig) + +checkBind :: Bind -> Infer (T.Bind' Type) +checkBind (Bind name args e) = do + let lambda = makeLambda e (reverse (coerce args)) + (e, infSig) <- inferExp lambda + s <- gets sigs + case M.lookup (coerce name) s of + Just (Just typSig) -> do + env <- asks vars + trace ("ENV IN CHECKBIND: " ++ show env) pure () + let genInfSig = generalize mempty infSig + sub <- genInfSig `unify` typSig + unless + (genInfSig <<= typSig) + ( throwError $ + Error + ( Aux.do + "Inferred type" + quote $ printTree infSig + "doesn't match given type" + quote $ printTree typSig + ) + False + ) + -- Applying sub to typSig will worsen error messages. + -- Unfortunately I do not know a better solution at the moment. + return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig) + _ -> do + insertSig (coerce name) (Just infSig) + return (T.Bind (coerce name, infSig) [] (e, infSig)) + +checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () +checkData err@(Data typ injs) = do + (name, tvars) <- go (skipForalls typ) + dataErr (mapM_ (\i -> checkInj i name tvars) injs) err + where + go = \case + TData name typs + | Right tvars' <- mapM toTVar typs -> + pure (name, tvars') + _ -> + uncatchableErr $ + unwords ["Bad data type definition: ", printTree typ] + +checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m () +checkInj (Inj c inj_typ) name tvars + | TData name' typs <- returnType inj_typ + , Right tvars' <- mapM toTVar typs + , name' == name + , tvars' == tvars = do + exist <- existInj (coerce c) + case exist of + Just t -> uncatchableErr $ Aux.do + "Constructor" + quote $ coerce name + "with type" + quote $ printTree t + "already exist" + Nothing -> insertInj (coerce c) inj_typ + | otherwise = + uncatchableErr $ + unwords + [ "Bad type constructor: " + , show name + , "\nExpected: " + , printTree . TData name $ map TVar tvars + , "\nActual: " + , printTree $ returnType inj_typ + ] + +toTVar :: Type -> Either Error TVar +toTVar = \case + TVar tvar -> pure tvar + _ -> uncatchableErr "Not a type variable" + +returnType :: Type -> Type +returnType (TFun _ t2) = returnType t2 +returnType a = a + +inferExp :: Exp -> Infer (T.ExpT' Type) +inferExp e = do + (s, (e', t)) <- algoW e + let subbed = apply s t + return (e', subbed) + +class CollectTVars a where + collectTVars :: a -> Set T.Ident + +instance CollectTVars Exp where + collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e + collectTVars _ = S.empty + +instance CollectTVars Type where + collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) + collectTVars (TAll _ t) = collectTVars t + collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2 + collectTVars (TData _ ts) = + foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts + collectTVars _ = S.empty + +collect :: Set T.Ident -> Infer () +collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) + +algoW :: Exp -> Infer (Subst, T.ExpT' Type) +algoW = \case + err@(EAnn e t) -> do + (sub0, (e', t')) <- exprErr (algoW e) err + sub1 <- unify t t' + sub2 <- unify t' t + unless + (apply sub1 t <<= apply sub2 t') + ( uncatchableErr $ Aux.do + "Annotated type" + quote $ printTree t + "does not match inferred type" + quote $ printTree t' + ) + let comp = sub2 `compose` sub1 `compose` sub0 + return (comp, (apply comp e', t)) + + -- \| ------------------ + -- \| Γ ⊢ i : Int, ∅ + + ELit lit -> return (nullSubst, (T.ELit lit, litType lit)) + -- \| x : σ ∈ Γ   τ = inst(σ) + -- \| ---------------------- + -- \| Γ ⊢ x : τ, ∅ + EVar (LIdent i) -> do + var <- asks vars + case M.lookup (coerce i) var of + Just t -> + inst t >>= \x -> + return (nullSubst, (T.EVar $ coerce i, x)) + Nothing -> do + sig <- gets sigs + case M.lookup (coerce i) sig of + Just (Just t) -> do + t <- freshen t + return (nullSubst, (T.EVar $ coerce i, t)) + Just Nothing -> do + fr <- fresh + return (nullSubst, (T.EVar $ coerce i, fr)) + Nothing -> + uncatchableErr $ + "Unbound variable: " + <> printTree i + EInj i -> do + constr <- gets injections + case M.lookup (coerce i) constr of + Just t -> do + t <- freshen t + return (nullSubst, (T.EInj $ coerce i, t)) + Nothing -> + uncatchableErr $ Aux.do + "Constructor:" + quote $ printTree i + "is not defined" + + -- \| τ = newvar Γ, x : τ ⊢ e : τ', S + -- \| --------------------------------- + -- \| Γ ⊢ w λx. e : Sτ → τ', S + + err@(EAbs name e) -> do + fr <- fresh + withBinding (coerce name) fr $ do + (s1, (e', t')) <- exprErr (algoW e) err + let varType = apply s1 fr + let newArr = TFun varType t' + return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ + -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) + -- \| ------------------------------------------ + -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ + -- This might be wrong + + err@(EAdd e0 e1) -> do + (s1, (e0', t0)) <- algoW e0 + (s2, (e1', t1)) <- algoW e1 + s3 <- exprErr (unify t0 int) err + s4 <- exprErr (unify t1 int) err + let comp = s4 `compose` s3 `compose` s2 `compose` s1 + return + ( comp + , apply comp (T.EAdd (e0', t0) (e1', t1), int) + ) + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 + -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') + -- \| -------------------------------------- + -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ + + EApp e0 e1 -> do + fr <- fresh + (s0, (e0', t0)) <- algoW e0 + applySt s0 $ do + (s1, (e1', t1)) <- algoW e1 + s2 <- unify (apply s1 t0) (TFun t1 fr) + let t = apply s2 fr + let comp = s2 `compose` s1 `compose` s0 + return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) + + -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ + -- \| ---------------------------------------------- + -- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ + + -- The bar over S₀ and Γ means "generalize" + + ELet (Bind name args e) e1 -> do + fr <- fresh + withBinding (coerce name) fr $ do + (s1, e@(_, t0)) <- algoW (makeLambda e (coerce args)) + env <- asks vars + let t' = generalize (apply s1 env) t0 + withBinding (coerce name) t' $ do + (s2, (e1', t2)) <- algoW e1 + let comp = s2 `compose` s1 + return + ( comp + , apply + comp + (T.ELet (T.Bind (coerce name, t0) [] e) (e1', t2), t2) + ) + ECase caseExpr injs -> do + (sub, (e', t)) <- algoW caseExpr + (subst, injs, ret_t) <- checkCase t injs + let comp = subst `compose` sub + return (comp, apply comp (T.ECase (e', t) injs, ret_t)) + +checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) +checkCase _ [] = do + fr <- fresh + return (nullSubst, [], fr) +checkCase expT brnchs = do + (subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs + let sub0 = composeAll subs + (sub1, _) <- + foldM + ( \(sub, acc) x -> + (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc + ) + (nullSubst, expT) + branchTs + (sub2, returns_type) <- + foldM + ( \(sub, acc) x -> + (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc + ) + (nullSubst, head returns) + (tail returns) + let comp = sub2 `compose` sub1 `compose` sub0 + return (comp, apply comp injs, apply comp returns_type) + +inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type) +inferBranch err@(Branch pat expr) = do + pat@(_, branchT) <- inferPattern pat + (sub, newExp@(_, exprT)) <- catchError (withPattern pat (algoW expr)) (\x -> throwError Error{msg = x.msg <> " in pattern '" <> printTree err <> "'", catchable = False}) + return + ( sub + , apply sub branchT + , T.Branch (apply sub pat) (apply sub newExp) + , apply sub exprT + ) + +inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) +inferPattern = \case + PLit lit -> let lt = litType lit in return (T.PLit lit, lt) + PCatch -> (T.PCatch,) <$> fresh + PVar x -> do + fr <- fresh + let pvar = T.PVar (coerce x) + return (pvar, fr) + PEnum p -> do + t <- gets (M.lookup (coerce p) . injections) + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree p + "does not exist" + ) + True + ) + t + unless + (typeLength t == 1) + ( catchableErr $ Aux.do + "The constructor" + quote $ printTree p + " should have " + show (typeLength t - 1) + " arguments but has been given 0" + ) + let (TData _data _ts) = t -- nasty nasty + frs <- mapM (const fresh) _ts + return (T.PEnum $ coerce p, TData _data frs) + PInj constr patterns -> do + t <- gets (M.lookup (coerce constr) . injections) + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree constr + "does not exist" + ) + True + ) + t + let numArgs = typeLength t - 1 + let (vs, ret) = fromJust (unsnoc $ flattenType t) + patterns <- mapM inferPattern patterns + unless + (length patterns == numArgs) + ( catchableErr $ Aux.do + "The constructor" + quote $ printTree constr + " should have " + show numArgs + " arguments but has been given " + show (length patterns) + ) + sub <- composeAll <$> zipWithM unify vs (map snd patterns) + return + ( T.PInj (coerce constr) (apply sub patterns) + , apply sub ret + ) + +-- | Unify two types producing a new substitution +unify :: Type -> Type -> Infer Subst +unify t0 t1 = + let fvs = S.toList $ free t0 `S.union` free t1 + m = M.fromList $ zip fvs letters + in case (t0, t1) of + (TFun a b, TFun c d) -> do + s1 <- unify a c + s2 <- unify (apply s1 b) (apply s1 d) + return $ s2 `compose` s1 + (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t + (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t + (TVar (MkTVar a), t) -> occurs (coerce a) t + (t, TVar (MkTVar b)) -> occurs (coerce b) t + -- Forall unification should change + (TAll _ t, b) -> unify t b + (a, TAll _ t) -> unify a t + (TLit a, TLit b) -> + if a == b + then return M.empty + else catchableErr $ + Aux.do + "Can not unify" + quote $ printTree (TLit a) + "with" + quote $ printTree (TLit b) + (TData name t, TData name' t') -> + if name == name' && length t == length t' + then do + xs <- zipWithM unify t t' + return $ foldr compose nullSubst xs + else catchableErr $ + Aux.do + "Type constructor:" + printTree name + quote $ printTree $ map (replace m) t + "does not match with:" + printTree name' + quote $ printTree $ map (replace m) t' + (TEVar a, TEVar b) -> + if a == b + then return M.empty + else catchableErr $ + Aux.do + "Can not unify" + quote $ printTree (TEVar a) + "with" + quote $ printTree (TEVar b) + (a, b) -> do + catchableErr $ + Aux.do + "Can not unify" + quote $ printTree $ replace m a + "with" + quote $ printTree $ replace m b + +{- | Check if a type is contained in another type. +I.E. { a = a -> b } is an unsolvable constraint since there is no substitution +where these are equal +-} +occurs :: T.Ident -> Type -> Infer Subst +occurs i t@(TEVar _) = return (M.singleton i t) +occurs i t@(TVar _) = return (M.singleton i t) +occurs i t = + let fvs = S.toList $ free t + m = M.fromList $ zip fvs letters + in if S.member i (free t) + then + catchableErr + ( Aux.do + "Occurs check failed, can't unify" + quote $ printTree $ replace m (TVar $ MkTVar (coerce i)) + "with" + quote $ printTree $ replace m t + ) + else return $ M.singleton i t + +{- | Generalize a type over all free variables in the substitution set + Used for let bindings to allow expression that do not type check in + equivalent lambda expressions: + Type checks: let f = \x. x in (f True, f 'a') + Does not type check: (\f. (f True, f 'a')) (\x. x) +-} +generalize :: Map T.Ident Type -> Type -> Type +generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) + where + go :: [T.Ident] -> Type -> Type + go [] t = t + go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) + removeForalls :: Type -> Type + removeForalls (TAll _ t) = removeForalls t + removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) + removeForalls t = t + +{- | Instantiate a polymorphic type. The free type variables are substituted +with fresh ones. +-} +inst :: Type -> Infer Type +inst = \case + TAll (MkTVar bound) t -> do + fr <- fresh + let s = M.singleton (coerce bound) fr + apply s <$> inst t + TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 + rest -> return rest + +-- | Generate a new fresh variable +fresh :: Infer Type +fresh = do + n <- gets count + modify (\st -> st{count = succ (count st)}) + return $ TVar $ MkTVar $ LIdent $ show n + +-- Is the left a subtype of the right +(<<=) :: Type -> Type -> Bool +(<<=) a b = case (a, b) of + (TVar _, _) -> True + (TFun a b, TFun c d) -> a <<= c && b <<= d + (TAll tvar1 t1, TAll tvar2 t2) -> ungo [tvar1, tvar2] t1 t2 + (TAll tvar t1, t2) -> ungo [tvar] t1 t2 + (t1, TAll tvar t2) -> ungo [tvar] t1 t2 + (TData n1 ts1, TData n2 ts2) -> + n1 == n2 + && length ts1 == length ts2 + && and (zipWith (<<=) ts1 ts2) + (t1, t2) -> t1 == t2 + where + ungo :: [TVar] -> Type -> Type -> Bool + ungo tvars t1 t2 = case run (go tvars t1 t2) of + Right (b, _) -> b + _ -> False + -- TODO: Fix the following + -- Maybe locally using the Infer monad can cause trouble. + -- Since the fresh count starts from zero + go :: [TVar] -> Type -> Type -> Infer Bool + go tvars t1 t2 = do + fr <- fresh + let sub = M.fromList [(coerce x, fr) | (MkTVar x) <- tvars] + return (apply sub t1 <<= apply sub t2) + +skipForalls :: Type -> Type +skipForalls = \case + TAll _ t -> skipForalls t + t -> t + +freshen :: Type -> Infer Type +freshen t = do + let frees = S.toList (free t) + xs <- mapM (const fresh) frees + let sub = M.fromList $ zip frees xs + return $ apply sub t + +{- + +a = TVar $ MkTVar "a" +single = TData "single" [a] +arr = a `TFun` single + +-} + +-- | A class for substitutions +class SubstType t where + -- | Apply a substitution to t + apply :: Subst -> t -> t + +class FreeVars t where + -- | Get all free variables from t + free :: t -> Set T.Ident + +instance FreeVars (T.Bind' Type) where + free (T.Bind (_, t) _ _) = free t + +instance FreeVars Type where + free :: Type -> Set T.Ident + free (TVar (MkTVar a)) = S.singleton (coerce a) + free (TAll (MkTVar bound) t) = + S.singleton (coerce bound) `S.intersection` free t + free (TLit _) = mempty + free (TFun a b) = free a `S.union` free b + free (TData _ a) = free a + free (TEVar _) = S.empty + +instance FreeVars a => FreeVars [a] where + free = let f acc x = acc `S.union` free x in foldl' f S.empty + +instance SubstType Type where + apply :: Subst -> Type -> Type + apply sub t = do + case t of + TLit _ -> t + TVar (MkTVar a) -> case M.lookup (coerce a) sub of + Nothing -> TVar (MkTVar $ coerce a) + Just t -> t + TAll (MkTVar i) t -> case M.lookup (coerce i) sub of + Nothing -> TAll (MkTVar i) (apply sub t) + Just _ -> apply sub t + TFun a b -> TFun (apply sub a) (apply sub b) + TData name a -> TData name (apply sub a) + TEVar (MkTEVar _) -> t + +instance FreeVars (Map T.Ident Type) where + free :: Map T.Ident Type -> Set T.Ident + free = free . M.elems + +instance SubstType (Map T.Ident Type) where + apply :: Subst -> Map T.Ident Type -> Map T.Ident Type + apply = M.map . apply + +instance SubstType (Map T.Ident (Maybe Type)) where + apply s = M.map (fmap $ apply s) + +instance SubstType (T.ExpT' Type) where + apply s (e, t) = (apply s e, apply s t) + +instance SubstType (T.Exp' Type) where + apply s = \case + T.EVar i -> T.EVar i + T.ELit lit -> T.ELit lit + T.ELet (T.Bind (ident, t1) args e1) e2 -> + T.ELet + (T.Bind (ident, apply s t1) args (apply s e1)) + (apply s e2) + T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2) + T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) + T.EAbs ident e -> T.EAbs ident (apply s e) + T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) + T.EInj i -> T.EInj i + +instance SubstType (T.Def' Type) where + apply s = \case + T.DBind (T.Bind name args e) -> + T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) + d -> d + +instance SubstType (T.Branch' Type) where + apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) + +instance SubstType (T.Pattern' Type) where + apply s = \case + T.PVar iden -> T.PVar iden + T.PLit lit -> T.PLit lit + T.PInj i ps -> T.PInj i $ apply s ps + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i + +instance SubstType (T.Pattern' Type, Type) where + apply s (p, t) = (apply s p, apply s t) + +instance SubstType a => SubstType [a] where + apply s = map (apply s) + +instance SubstType (T.Id' Type) where + apply s (name, t) = (name, apply s t) + +-- | Represents the empty substition set +nullSubst :: Subst +nullSubst = mempty + +-- | Compose two substitution sets +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 + +-- | Compose a list of substitution sets into one +composeAll :: [Subst] -> Subst +composeAll = foldl' compose nullSubst + +{- | Convert a function with arguments to its pointfree version +> makeLambda (add x y = x + y) = add = \x. \y. x + y +-} +makeLambda :: Exp -> [T.Ident] -> Exp +makeLambda = foldl (flip (EAbs . coerce)) + +-- | Run the monadic action with an additional binding +withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a +withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) + +-- | Run the monadic action with several additional bindings +withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a +withBindings xs = + local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) + +-- | Run the monadic action with a pattern +withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a +withPattern (p, t) ma = case p of + T.PVar x -> withBinding x t ma + T.PInj _ ps -> foldl' (flip withPattern) ma ps + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma + +-- | Insert a function signature into the environment +insertSig :: T.Ident -> Maybe Type -> Infer () +insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) + +insertBind :: T.Ident -> Infer () +insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds}) + +-- | Insert a constructor into the start with its type +insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m () +insertInj i t = + modify (\st -> st{injections = M.insert i t (injections st)}) + +applySt :: Subst -> Infer a -> Infer a +applySt s = local (\st -> st{vars = apply s st.vars}) + +{- | Check if an injection (constructor of data type) +with an equivalent name has been declared already +-} +existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type) +existInj n = gets (M.lookup n . injections) + +flattenType :: Type -> [Type] +flattenType (TFun a b) = flattenType a <> flattenType b +flattenType a = [a] + +typeLength :: Type -> Int +typeLength (TFun _ b) = 1 + typeLength b +typeLength _ = 1 + +{- | Catch an error if possible and add the given +expression as addition to the error message +-} +exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a +exprErr ma exp = + catchError + ma + ( \err -> + if err.catchable + then + throwError + ( err + { msg = + err.msg + <> " in expression: \n" + <> printTree exp + , catchable = False + } + ) + else throwError err + ) + +bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a +bindErr ma bind = + catchError + ma + ( \err -> + if err.catchable + then + throwError + ( err + { msg = + err.msg + <> " in function: \n" + <> printTree bind + , catchable = False + } + ) + else throwError err + ) + +{- | Catch an error if possible and add the given +data as addition to the error message +-} +dataErr :: (MonadError Error m, Monad m) => m a -> Data -> m a +dataErr ma d = + catchError + ma + ( \err -> + if err.catchable + then + throwError + ( err + { msg = + err.msg + <> " in data: \n" + <> printTree d + } + ) + else throwError (err{catchable = False}) + ) + +initCtx = Ctx mempty +initEnv = Env 0 'a' mempty mempty mempty mempty + +run :: Infer a -> Either Error (a, [Warning]) +run = run' initEnv initCtx + +run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning]) +run' e c = + runIdentity + . runExceptT + . runWriterT + . flip runReaderT c + . flip evalStateT e + . runInfer + +newtype Ctx = Ctx {vars :: Map T.Ident Type} + deriving (Show) + +data Env = Env + { count :: Int + , nextChar :: Char + , sigs :: Map T.Ident (Maybe Type) + , takenTypeVars :: Set T.Ident + , injections :: Map T.Ident Type + , declaredBinds :: Set T.Ident + } + deriving (Show) + +data Error = Error {msg :: String, catchable :: Bool} + deriving (Show) +type Subst = Map T.Ident Type + +newtype Warning = NonExhaustive String + deriving (Show) + +newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a} + deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) + +catchableErr :: MonadError Error m => String -> m a +catchableErr msg = throwError $ Error msg True + +uncatchableErr :: MonadError Error m => String -> m a +uncatchableErr msg = throwError $ Error msg False + +quote :: String -> String +quote s = "'" ++ s ++ "'" + +letters :: [T.Ident] +letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z'] diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs new file mode 100644 index 0000000..a956ff3 --- /dev/null +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -0,0 +1,196 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module TypeChecker.TypeCheckerIr ( + module Grammar.Abs, + module TypeChecker.TypeCheckerIr, +) where + +import Data.String (IsString) +import Grammar.Abs (Lit (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) + +newtype Program' t = Program [Def' t] + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +data Def' t + = DBind (Bind' t) + | DData (Data' t) + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +data Type + = TLit Ident + | TVar TVar + | TData Ident [Type] + | TFun Type Type + deriving (Eq, Ord, Show, Read) + +data Data' t = Data t [Inj' t] + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +data Inj' t = Inj Ident t + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +newtype Ident = Ident String + deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) + +data Pattern' t + = PVar Ident + | PLit Lit + | PCatch + | PEnum Ident + | PInj Ident [(Pattern' t, t)] + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +data Exp' t + = EVar Ident + | EInj Ident + | ELit Lit + | ELet (Bind' t) (ExpT' t) + | EApp (ExpT' t) (ExpT' t) + | EAdd (ExpT' t) (ExpT' t) + | EAbs Ident (ExpT' t) + | ECase (ExpT' t) [Branch' t] + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +newtype TVar = MkTVar Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id' t = (Ident, t) +type ExpT' t = (Exp' t, t) + +data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +data Branch' t = Branch (Pattern' t, t) (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + +instance Print Ident where + prt _ (Ident s) = doc $ showString s + +instance Print t => Print (Program' t) where + prt i (Program sc) = prt i sc + +instance Print t => Print (Bind' t) where + prt i (Bind sig parms rhs) = concatD + [ prtSig sig + , prt i parms + , doc $ showString "=" + , prt i rhs + ] + +prtSig :: Print t => Id' t -> Doc +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] + +instance Print t => Print (ExpT' t) where + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + +instance Print t => Print [Bind' t] where + prt _ [] = concatD [] + prt i [x] = concatD [prt i x] + prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs] + +instance Print t => Print (Id' t) where + prt i (name, t) = + concatD + [ doc $ showString "(" + , prt i name + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print t => Print (Exp' t) where + prt i = \case + EVar lident -> prPrec i 3 (concatD [prt 0 lident]) + EInj uident -> prPrec i 3 (concatD [prt 0 uident]) + ELit lit -> prPrec i 3 (concatD [prt 0 lit]) + EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2]) + EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2]) + ELet bind exp -> prPrec i 0 (concatD [doc (showString "let"), prt 0 bind, doc (showString "in"), prt 0 exp]) + EAbs lident exp -> prPrec i 0 (concatD [doc (showString "\\"), prt 0 lident, doc (showString "."), prt 0 exp]) + ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")]) + +instance Print t => Print (Branch' t) where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print t => Print [Branch' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print t => Print (Def' t) where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print t => Print (Data' t) where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print t => Print (Inj' t) where + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + +instance Print t => Print [Inj' t] where + prt _ [] = concatD [] + prt i [x] = prt i x + prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] + +instance Print t => Print (Pattern' t, t) where + prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t]) + +instance Print t => Print (Pattern' t) where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit lit -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) + +instance Print t => Print [Def' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print [Type] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + +instance Print Type where + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) + TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + +instance Print TVar where + prt i (MkTVar ident) = prt i ident + +type Program = Program' Type +type Def = Def' Type +type Data = Data' Type +type Bind = Bind' Type +type Branch = Branch' Type +type Pattern = Pattern' Type +type Inj = Inj' Type +type Exp = Exp' Type +type ExpT = ExpT' Type +type Id = Id' Type +pattern TVar' s = TVar (MkTVar s) +pattern DBind' id vars expt = DBind (Bind id vars expt) +pattern DData' typ injs = DData (Data typ injs) diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs deleted file mode 100644 index f6e3ec6..0000000 --- a/src/TypeCheckerIr.hs +++ /dev/null @@ -1,100 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module TypeCheckerIr - ( module Grammar.Abs - , module TypeCheckerIr - ) where - -import Grammar.Abs (Ident (..), Type (..)) -import Grammar.Print -import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) - -newtype Program = Program [Bind] - deriving (C.Eq, C.Ord, C.Show, C.Read) - -data Exp - = EId Id - | EInt Integer - | ELet Bind Exp - | EApp Type Exp Exp - | EAdd Type Exp Exp - | EAbs Type Id Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) - -type Id = (Ident, Type) - -data Bind = Bind Id [Id] Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) - -instance Print Program where - prt i (Program sc) = prPrec i 0 $ prt 0 sc - -instance Print Bind where - prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD - [ prtId 0 name - , doc $ showString ";" - , prt 0 n - , prtIdPs 0 parms - , doc $ showString "=" - , prt 0 rhs - ] - -instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] - -prtIdPs :: Int -> [Id] -> Doc -prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) - -prtId :: Int -> Id -> Doc -prtId i (name, t) = prPrec i 0 $ concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - ] - -prtIdP :: Int -> Id -> Doc -prtIdP i (name, t) = prPrec i 0 $ concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] - - -instance Print Exp where - prt i = \case - EId n -> prPrec i 3 $ concatD [prtIdP 0 n] - EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] - ELet bs e -> prPrec i 3 $ concatD - [ doc $ showString "let" - , prt 0 bs - , doc $ showString "in" - , prt 0 e - ] - EApp t e1 e2 -> prPrec i 2 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 2 e1 - , prt 3 e2 - ] - EAdd t e1 e2 -> prPrec i 1 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - EAbs t n e -> prPrec i 0 $ concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "\\" - , prtIdP 0 n - , doc $ showString "." - , prt 0 e - ] - - diff --git a/test_program.crf b/test_program.crf new file mode 100644 index 0000000..6e528dc --- /dev/null +++ b/test_program.crf @@ -0,0 +1,30 @@ +-- Peano naturals +data Nat where + Zero : Nat + Succ : Nat -> Nat + +toInt : Nat -> Int +toInt a = case a of + Succ n => 1 + toInt n + Zero => 0 + +fromInt a = case a of + 0 => Zero + n => Succ (fromInt (a - 1)) + +-- Peano arithmetic -- + +-- Peano addition +add : Nat -> Nat -> Nat +add left right = case left of + Zero => right + Succ n => Succ (add n right) + +-- Peano multiplication +mul : Nat -> Nat -> Nat +mul left right = case right of + Zero => Zero + Succ n => add left (mul left n) + +-- Returns 10_000 +main = toInt (mul (fromInt 100) (fromInt 100)) diff --git a/tests/DoStrings.hs b/tests/DoStrings.hs new file mode 100644 index 0000000..73580f8 --- /dev/null +++ b/tests/DoStrings.hs @@ -0,0 +1,4 @@ +module DoStrings where + +(>>) str1 str2 = str1 ++ "\n" ++ str2 +(>>=) str1 f = f str1 diff --git a/tests/Main.hs b/tests/Main.hs new file mode 100644 index 0000000..da4acf7 --- /dev/null +++ b/tests/Main.hs @@ -0,0 +1,16 @@ +module Main where + +import Test.Hspec +import TestAnnForall (testAnnForall) +import TestRenamer (testRenamer) +import TestReportForall (testReportForall) +import TestTypeCheckerBidir (testTypeCheckerBidir) +import TestTypeCheckerHm (testTypeCheckerHm) + +main = hspec $ do + testReportForall + testAnnForall + testRenamer + testTypeCheckerBidir + testTypeCheckerHm + diff --git a/tests/TestAnnForall.hs b/tests/TestAnnForall.hs new file mode 100644 index 0000000..9280f33 --- /dev/null +++ b/tests/TestAnnForall.hs @@ -0,0 +1,128 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QualifiedDo #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +module TestAnnForall (testAnnForall, test) where + +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Desugar.Desugar (desugar) +import DoStrings qualified as D +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import Test.Hspec ( + describe, + hspec, + shouldBe, + shouldNotSatisfy, + shouldSatisfy, + shouldThrow, + specify, + ) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeChecker (TypeChecker (Bi, Hm)) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr qualified as T + +test = hspec testAnnForall + +testAnnForall = describe "Test AnnForall" $ do + ann_data1 + ann_data2 + ann_bad_data1 + ann_bad_data2 + ann_bad_data3 + ann_sig1 + ann_sig2 + ann_bind + +ann_data1 = + specify "Annotate data type" $ + D.do + "data Either a b where" + " Left : a -> Either a b" + " Right : b -> Either a b" + `shouldBePrg` D.do + "data forall a. forall b. Either a b where" + " Left : a -> Either a b" + " Right : b -> Either a b" + +ann_data2 = + specify "Annotate constructor with additional type variable" $ + D.do + "data forall a. forall b. Either a b where" + " Left : c -> a -> Either a b" + " Right : b -> Either a b" + `shouldBePrg` D.do + "data forall a. forall b. Either a b where" + " Left : forall c. c -> a -> Either a b" + " Right : b -> Either a b" + +ann_bad_data1 = + specify "Bad data type variables" $ + D.do + "data Either Int b where" + " Left : a -> Either a b" + " Right : b -> Either a b" + `shouldBeErr` "Misformed data declaration: Non type variable argument" + +ann_bad_data2 = + specify "Bad data identifer" $ + D.do + "data Int -> Either a b where" + " Left : a -> Either a b" + " Right : b -> Either a b" + `shouldBeErr` "Misformed data declaration" + +ann_bad_data3 = + specify "Constructor forall duplicate" $ + D.do + "data Int -> Either a b where" + " Left : forall a. a -> Either a b" + " Right : b -> Either a b" + `shouldBeErr` "Misformed data declaration" + +ann_sig1 = + specify "Annotate signature" $ + "f : a -> b -> (forall a. a -> a) -> a" + `shouldBePrg` "f : forall a. forall b. a -> b -> (forall a. a -> a) -> a" + +ann_sig2 = + specify "Annotate signature 2" $ + D.do + "const : forall a. forall b. a -> b -> a" + "const x y = x" + "main = const 'a' 65" + `shouldBePrg` D.do + "const : forall a. forall b. a -> b -> a" + "const x y = x" + "main = const 'a' 65" + +ann_bind = + specify "Annotate bind" $ + "f = (\\x.\\y. x : a -> b -> a) 4" + `shouldBePrg` "f = (\\x.\\y. x : forall a. forall b. a -> b -> a) 4" + +shouldBeErr s err = run s `shouldBe` Bad err + +shouldBePrg s1 s2 + | Ok p2 <- run' s2 = run s1 `shouldBe` Ok p2 + | otherwise = error ("Faulty expectation \n" ++ show (run' s2)) + +run = annotateForall <=< run' +run' s = do + p <- run'' s + reportForall Bi p + pure p +run'' = fmap desugar . pProgram . resolveLayout True . myLexer + +runPrint = (putStrLn . either show printTree . run) $ + D.do + "data forall a. forall b. Either a b where" + " Left : c -> a -> Either a b" + " Right : b -> Either a b" diff --git a/tests/TestLambdaLifter.hs b/tests/TestLambdaLifter.hs new file mode 100644 index 0000000..d10e7ee --- /dev/null +++ b/tests/TestLambdaLifter.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QualifiedDo #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +module TestLambdaLifter where + +import Test.Hspec + +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Control.Monad.Error.Class (liftEither) +import Control.Monad.Extra (eitherM) +import Desugar.Desugar (desugar) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import LambdaLifter +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import TypeChecker.RemoveForall (removeForall) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeChecker (TypeChecker (Bi)) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr + +test = hspec testLambdaLifter + +testLambdaLifter = describe "Test Lambda Lifter" $ do + undefined + +-- frees_exp1 + +-- frees_exp1 = specify "Free variables 1" $ +-- freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a") +-- `shouldBe` answer +-- where +-- answer = Ann { frees = [] +-- , term = (AAbs (Ident "x") (Ann { frees = [Ident "x"] +-- , term = (AVar (Ident "x"),TVar (MkTVar (Ident "a"))) +-- } +-- ),TVar (MkTVar (Ident "a"))) +-- } + +abs_1 = undefined + where + input = + unlines + [ "data List a where" + , " Nil : List a" + , " Cons : a -> List a -> List a" + , "map : (a -> b) -> List a -> List b" + , "add : Int -> Int -> Int" + , "f : List Int" + , "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))" + ] + +runFreeVars = either putStrLn print (runFree s2) +runAbstract = either putStrLn (putStrLn . printTree) (runAbs s2) +runCollect = either putStrLn (putStrLn . printTree) (run s2) + +s1 = + unlines + [ "add : Int -> Int -> Int" + , "f : Int -> Int -> Int" + , "f x y = add x y" + , "f = \\x. (\\y. add x y)" + ] + +s2 = + unlines + [ "data List a where" + , " Nil : List (a)" + , " Cons : a -> List a -> List a" + , "add : Int -> Int -> Int" + , "map : (a -> b) -> List a -> List b" + , -- , "map f xs = case xs of" + -- , " Nil => Nil" + -- , " Cons x xs => Cons (f x) (map f xs)" + + "f : List Int" + , "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))" + ] + +s3 = "main = (\\plussq. (\\f. f (f 0)) (plussq 3)) (\\x. \\y. y + x + x)" + +run = fmap collectScs . runAbs + +runAbs = fmap abstract . runFree + +runFree s = do + Program ds <- run' s + pure $ freeVars [b | DBind b <- ds] + +run' = + fmap removeForall + . reportTEVar + <=< typecheck + <=< run'' + +run'' s = do + p <- (fmap desugar . pProgram . resolveLayout True . myLexer) s + reportForall Bi p + (rename <=< annotateForall) p diff --git a/tests/TestRenamer.hs b/tests/TestRenamer.hs new file mode 100644 index 0000000..dc71d38 --- /dev/null +++ b/tests/TestRenamer.hs @@ -0,0 +1,114 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QualifiedDo #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +module TestRenamer (testRenamer, test, runPrint) where + +import AnnForall (annotateForall) +import Control.Exception ( + ErrorCall (ErrorCall), + Exception (displayException), + SomeException (SomeException), + evaluate, + try, + ) +import Control.Exception.Extra (try_) +import Control.Monad (unless, (<=<)) +import Control.Monad.Except (throwError) +import Data.Either.Extra (fromEither) +import Desugar.Desugar (desugar) +import DoStrings qualified as D +import GHC.Generics (Generic, Generic1) +import Grammar.Abs (Program (Program)) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import System.IO.Error (catchIOError, tryIOError) +import Test.Hspec ( + anyErrorCall, + anyException, + describe, + hspec, + shouldBe, + shouldNotSatisfy, + shouldReturn, + shouldSatisfy, + shouldThrow, + specify, + ) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr qualified as T + +-- FIXME tests sucks + +test = hspec testRenamer + +testRenamer = describe "Test Renamer" $ do + rn_data1 + rn_data2 + rn_sig + rn_bind1 + rn_bind2 + +rn_data1 = specify "Rename data type" . shouldSatisfyOk $ + D.do + "data forall a. forall b. Either a b where" + " Left : a -> Either a b" + " Right : b -> Either a b" + +rn_data2 = specify "Rename data type forall in constructor " . shouldSatisfyOk $ + D.do + "data forall a. forall b. Either a b where" + " Left : forall c. c -> a -> Either a b" + " Right : b -> Either a b" + +rn_sig = + specify "Rename signature" $ + shouldSatisfyOk + "f : forall a. forall b. a -> b -> (forall a. a -> a) -> a" + +rn_bind1 = + specify "Rename simple bind" $ + shouldSatisfyOk + "f x = (\\y. let y2 = y + 1 in y2) (x + 1)" + +rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $ + D.do + "data forall a. List a where" + " Nil : List a " + " Cons : a -> List a -> List a" + + "length : forall a. List a -> Int" + "length list = case list of" + " Nil => 0" + " Cons x Nil => 1" + " Cons x (Cons y ys) => 2 + length ys" + +runPrint = putStrLn . either show printTree . run $ + D.do + "data forall a. List a where" + " Nil : List a " + " Cons : a -> List a -> List a" + + "length : forall a. List a -> Int" + "length list = case list of" + " Nil => 0" + " Cons x Nil => 1" + " Cons x (Cons y ys) => 2 + length ys" + +shouldSatisfyOk s = run s `shouldSatisfy` ok + +ok = \case + Ok !_ -> True + Bad !_ -> False + +shouldBeErr s err = run s `shouldBe` Bad err + +run = rename <=< run' +run' = fmap desugar . pProgram . resolveLayout True . myLexer diff --git a/tests/TestReportForall.hs b/tests/TestReportForall.hs new file mode 100644 index 0000000..2d3371d --- /dev/null +++ b/tests/TestReportForall.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +module TestReportForall (testReportForall, test) where + +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Desugar.Desugar (desugar) +import DoStrings qualified as D +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import Test.Hspec ( + describe, + hspec, + shouldBe, + shouldNotSatisfy, + shouldSatisfy, + shouldThrow, + specify, + ) +import TypeChecker.TypeChecker (TypeChecker (Bi, Hm)) + +testReportForall = describe "Test ReportForall" $ do + rp_unused1 + rp_unused2 + rp_forall + +test = hspec testReportForall + +rp_unused1 = + specify "Unused forall 1" $ + "g : forall a. forall a. a -> (forall a. a -> a) -> a" + `shouldBeErrBi` "Unused forall" + +rp_unused2 = + specify "Unused forall 2" $ + "g : forall a. (forall a. a -> a) -> Int" + `shouldBeErrBi` "Unused forall" + +rp_forall = + specify "Rank2 forall with Hm" $ + "f : a -> b -> (forall a. a -> a) -> a" + `shouldBeErrHm` "Higher rank forall not allowed" + +shouldBeErrBi = shouldBeErr Bi +shouldBeErrHm = shouldBeErr Hm +shouldBeErr tc s err = run tc s `shouldBe` Bad err + +run tc = reportForall tc <=< fmap desugar . pProgram . resolveLayout True . myLexer diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs new file mode 100644 index 0000000..15e0c1f --- /dev/null +++ b/tests/TestTypeCheckerBidir.hs @@ -0,0 +1,333 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +module TestTypeCheckerBidir (test, testTypeCheckerBidir) where + +import Test.Hspec + +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Desugar.Desugar (desugar) +import Grammar.Abs (Program) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import TypeChecker.RemoveForall (removeForall) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeChecker (TypeChecker (Bi)) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr qualified as T + +test = hspec testTypeCheckerBidir + +testTypeCheckerBidir = describe "Test Bidirectional type checker" $ do + tc_id + tc_double + tc_add_lam + tc_const + tc_simple_rank2 + tc_rank2 + tc_identity + tc_pair + tc_tree + tc_mono_case + tc_pol_case + tc_infer_case + tc_rec1 + tc_rec2 + +tc_id = + specify "Basic identity function polymorphism" $ + run + [ "id : a -> a" + , "id x = x" + , "main = id 4" + ] + `shouldSatisfy` ok + +tc_double = + specify "Addition inference" $ + run + ["double x = x + x"] + `shouldSatisfy` ok + +tc_add_lam = + specify "Addition lambda inference" $ + run + ["four = (\\x. x + x) 2"] + `shouldSatisfy` ok + +tc_const = + specify "Basic polymorphism with multiple type variables" $ + run + [ "const : a -> b -> a" + , "const x y = x" + , "main = const 'a' 65" + ] + `shouldSatisfy` ok + +tc_simple_rank2 = + specify "Simple rank two polymorphism" $ + run + [ "id : a -> a" + , "id x = x" + , "f : a -> (forall b. b -> b) -> a" + , "f x g = g x" + , "main = f 4 id" + ] + `shouldSatisfy` ok + +tc_rank2 = + specify "Rank two polymorphism is ok" $ + run + [ "const : a -> b -> a" + , "const x y = x" + , "rank2 : a -> (forall c. c -> Int) -> b -> Int" + , "rank2 x f y = f x + f y" + , "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'" + ] + `shouldSatisfy` ok + +tc_identity = describe "(∀b. b → b) should only accept the identity function" $ do + specify "identityᵢₙₜ is rejected" $ run (fs ++ id_int) `shouldNotSatisfy` ok + specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok + where + fs = + [ "f : a -> (forall b. b -> b) -> a" + , "f x g = g x" + , "id : a -> a" + , "id x = x" + , "id_int : Int -> Int" + , "id_int x = x" + ] + id = + [ "main : Int" + , "main = f 4 id" + ] + id_int = + [ "main : Int" + , "main = f 4 id_int" + ] + +tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do + specify "Wrong arguments are rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data Pair a b where" + , " Pair : a -> b -> Pair a b" + , "main : Pair Int Char" + ] + wrong = ["main = Pair 'a' 65"] + correct = ["main = Pair 65 'a'"] + +tc_tree = describe "Tree. Recursive data type" $ do + specify "Wrong tree is rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data Tree a where" + , " Node : a -> Tree a -> Tree a -> Tree a" + , " Leaf : a -> Tree a" + ] + wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3)"] + correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3)"] + +tc_mono_case = describe "Monomorphic pattern matching" $ do + specify "First wrong case expression rejected" $ + run wrong1 `shouldNotSatisfy` ok + specify "Second wrong case expression rejected" $ + run wrong2 `shouldNotSatisfy` ok + specify "Third wrong case expression rejected" $ + run wrong3 `shouldNotSatisfy` ok + specify "First correct case expression accepted" $ + run correct1 `shouldSatisfy` ok + specify "Second correct case expression accepted" $ + run correct2 `shouldSatisfy` ok + where + wrong1 = + [ "simple : Int -> Int" + , "simple c = case c of" + , " 'F' => 0" + , " 'T' => 1" + ] + wrong2 = + [ "simple : Char -> Int" + , "simple c = case c of" + , " 'F' => 0" + , " 1 => 1" + ] + wrong3 = + [ "simple : Char -> Int" + , "simple c = case c of" + , " 'F' => 0" + , " 'T' => '1'" + ] + correct1 = + [ "simple : Char -> Int" + , "simple c = case c of" + , " 'F' => 0" + , " 'T' => 1" + ] + correct2 = + [ "simple : Char -> Int" + , "simple c = case c of" + , " 'F' => 0" + , " _ => 1" + ] + +tc_pol_case = describe "Polymophic and recursive pattern matching" $ do + specify "First wrong case expression rejected" $ + run (fs ++ wrong1) `shouldNotSatisfy` ok + specify "Second wrong case expression rejected" $ + run (fs ++ wrong2) `shouldNotSatisfy` ok + specify "Third wrong case expression rejected" $ + run (fs ++ wrong3) `shouldNotSatisfy` ok + specify "Forth wrong case expression rejected" $ + run (fs ++ wrong4) `shouldNotSatisfy` ok + specify "First correct case expression accepted" $ + run (fs ++ correct1) `shouldSatisfy` ok + specify "Second correct case expression accepted" $ + run (fs ++ correct2) `shouldSatisfy` ok + specify "Third correct case expression accepted" $ + run (fs ++ correct3) `shouldSatisfy` ok + specify "Forth correct case expression accepted" $ + run (fs ++ correct4) `shouldSatisfy` ok + where + fs = + [ "data List a where" + , " Nil : List a" + , " Cons : a -> List a -> List a" + ] + wrong1 = + [ "length : List c -> Int" + , "length = \\list. case list of" + , " Nil => 0" + , " Cons 6 xs => 1 + length xs" + ] + wrong2 = + [ "length : List c -> Int" + , "length = \\list. case list of" + , " Cons => 0" + , " Cons x xs => 1 + length xs" + ] + wrong3 = + [ "length : List c -> Int" + , "length = \\list. case list of" + , " 0 => 0" + , " Cons x xs => 1 + length xs" + ] + wrong4 = + [ "elems : List (List c) -> Int" + , "elems = \\list. case list of" + , " Nil => 0" + , " Cons Nil Nil => 0" + , " Cons Nil xs => elems xs" + , " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs)" + ] + correct1 = + [ "length : List c -> Int" + , "length = \\list. case list of" + , " Nil => 0" + , " Cons x xs => 1 + length xs" + , " Cons x (Cons y Nil) => 2" + ] + correct2 = + [ "length : List c -> Int" + , "length = \\list. case list of" + , " Nil => 0" + , " non_empty => 1" + ] + correct3 = + [ "length : List Int -> Int" + , "length = \\list. case list of" + , " Nil => 0" + , " Cons 1 Nil => 1" + , " Cons x (Cons 2 xs) => 2 + length xs" + ] + correct4 = + [ "elems : List (List c) -> Int" + , "elems = \\list. case list of" + , " Nil => 0" + , " Cons Nil Nil => 0" + , " Cons Nil xs => elems xs" + , " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" + ] + +tc_if = specify "Test if else case expression" $ do + run + [ "data Bool where" + , " True : Bool" + , " False : Bool" + , "ifThenElse : Bool -> a -> a -> a" + , "ifThenElse b if else = case b of" + , " True => if" + , " False => else" + ] + `shouldSatisfy` ok + +tc_infer_case = describe "Infer case expression" $ do + specify "Wrong case expression rejected" $ + run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct case expression accepted" $ + run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data Bool where" + , " True : Bool" + , " False : Bool" + ] + + correct = + [ "toBool = case 0 of" + , " 0 => False" + , " _ => True" + ] + + wrong = + [ "toBool = case 0 of" + , " 0 => False" + , " _ => 1" + ] + +tc_rec1 = + specify "Infer simple recursive definition" $ + run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok + +tc_rec2 = + specify "Infer recursive definition with pattern matching" $ + run + [ "data Bool where" + , " False : Bool" + , " True : Bool" + , "test = \\x. case x of" + , " 10 => True" + , " _ => test (x+1)" + ] + `shouldSatisfy` ok + +run :: [String] -> Err T.Program +run = + fmap removeForall + . reportTEVar + <=< typecheck + <=< run' + +run' s = do + p <- (fmap desugar . pProgram . resolveLayout True . myLexer . unlines) s + reportForall Bi p + (rename <=< annotateForall) p + +runPrint = + (putStrLn . either show printTree . run') + ["double x = x + x"] + +ok = \case + Ok _ -> True + Bad _ -> False diff --git a/tests/TestTypeCheckerHm.hs b/tests/TestTypeCheckerHm.hs new file mode 100644 index 0000000..8137937 --- /dev/null +++ b/tests/TestTypeCheckerHm.hs @@ -0,0 +1,265 @@ +{-# LANGUAGE QualifiedDo #-} + +module TestTypeCheckerHm where + +import Control.Monad (sequence_, (<=<)) +import Test.Hspec + +import AnnForall (annotateForall) +import Desugar.Desugar (desugar) +import DoStrings qualified as D +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import TypeChecker.TypeChecker (TypeChecker (Hm)) +import TypeChecker.TypeCheckerHm (typecheck) +import TypeChecker.TypeCheckerIr (Program) + +testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do + sequence_ goods + sequence_ bads + sequence_ bes + +goods = + [ testSatisfy + "Basic polymorphism with multiple type variables" + ( D.do + _const + "main = const 'a' 65 ;" + ) + ok + , testSatisfy + "Head with a correct signature is accepted" + ( D.do + _List + _headSig + _head + ) + ok + , testSatisfy + "Most simple inference possible" + ( D.do + _id + ) + ok + , testSatisfy + "Pattern matching on a nested list" + ( D.do + _List + "main : List (List a) -> Int;" + "main xs = case xs of {" + " Cons Nil _ => 1;" + " _ => 0;" + "};" + ) + ok + ] + +bads = + [ testSatisfy + "Infinite type unification should not succeed" + ( D.do + "main = \\x. x x ;" + ) + bad + , testSatisfy + "Pattern matching using different types should not succeed" + ( D.do + _List + "bad xs = case xs of {" + " 1 => 0 ;" + " Nil => 0 ;" + "};" + ) + bad + , testSatisfy + "Using a concrete function (data type) on a skolem variable should not succeed" + ( D.do + _Bool + _not + "f : a -> Bool ;" + "f x = not x ;" + ) + bad + , testSatisfy + "Using a concrete function (primitive type) on a skolem variable should not succeed" + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "f : a -> Int ;" + "f x = plusOne x ;" + ) + bad + , testSatisfy + "A function without signature used in an incompatible context should not succeed" + ( D.do + "main = _id 1 2 ;" + "_id x = x ;" + ) + bad + , testSatisfy + "Pattern matching on literal and _List should not succeed" + ( D.do + _List + "length : List c -> Int;" + "length _List = case _List of {" + " 0 => 0;" + " Cons x xs => 1 + length xs;" + "};" + ) + bad + , testSatisfy + "List of function Int -> Int functions should not be usable on Char" + ( D.do + _List + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 'a' ;" + " Nil => 0 ;" + " };" + ) + bad + -- FIXME FAILING TEST + -- , testSatisfy + -- "id with incorrect signature" + -- ( D.do + -- "id : a -> b;" + -- "id x = x;" + -- ) + -- bad + -- FIXME FAILING TEST + -- , testSatisfy + -- "incorrect signature on const" + -- ( D.do + -- "const : a -> b -> b;" + -- "const x y = x" + -- ) + -- bad + -- FIXME FAILING TEST + -- , testSatisfy + -- "incorrect type signature on id lambda" + -- ( D.do + -- "id = ((\\x. x) : a -> b);" + -- ) + -- bad + ] + +bes = + [ testBe + "A basic arithmetic function should be able to be inferred" + ( D.do + "plusOne x = x + 1 ;" + "main x = plusOne x ;" + ) + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "main : Int -> Int ;" + "main x = plusOne x ;" + ) + , testBe + "A basic arithmetic function should be able to be inferred" + ( D.do + "plusOne x = x + 1 ;" + ) + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + ) + , testBe + "List of function Int -> Int functions should be inferred corretly" + ( D.do + _List + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) + ( D.do + _List + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) + , testBe + "length function on int list infers correct signature" + ( D.do + "data List where " + " Nil : List" + " Cons : Int -> List -> List" + + "length xs = case xs of" + " Nil => 0" + " Cons _ xs => 1 + length xs" + ) + ( D.do + "data List where" + " Nil : List" + " Cons : Int -> List -> List" + + "length : List -> Int" + "length xs = case xs of" + " Nil => 0" + " Cons _ xs => 1 + length xs" + ) + ] + +testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction +testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe + +run = fmap (printTree . fst) . typecheck <=< fmap desugar . pProgram . myLexer + +ok (Right _) = True +ok (Left _) = False + +bad = not . ok + +-- FUNCTIONS + +_const = D.do + "const : a -> b -> a ;" + "const x y = x ;" +_List = D.do + "data List a where {" + " Nil : List a;" + " Cons : a -> List a -> List a;" + "};" + +_headSig = D.do + "head : List a -> a ;" + +_head = D.do + "head xs = " + " case xs of {" + " Cons x xs => x ;" + " };" + +_Bool = D.do + "data Bool where {" + " True : Bool" + " False : Bool" + "};" + +_not = D.do + "not : Bool -> Bool ;" + "not x = case x of {" + " True => False ;" + " False => True ;" + "};" +_id = "id x = x ;" + +_Maybe = D.do + "data Maybe a where {" + " Nothing : Maybe a" + " Just : a -> Maybe a" + " };" + +_fmap = D.do + "fmap f ma = case ma of {" + " Nothing => Nothing ;" + " Just a => Just (f a) ;" + "};"