Add bidirectional type checker, lambda lifter.

This commit is contained in:
Martin Fredin 2023-02-18 14:49:33 +01:00
parent 2fa30faa87
commit ac3f222753
22 changed files with 2440 additions and 577 deletions

View file

@ -3,94 +3,94 @@
-- * PROGRAM -- * PROGRAM
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
Program. Program ::= [Def] ; Program. Program ::= [Def];
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * TOP-LEVEL -- * TOP-LEVEL
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
DBind. Def ::= Bind ; DBind. Def ::= Bind;
DSig. Def ::= Sig ; DSig. Def ::= Sig;
DData. Def ::= Data ; DData. Def ::= Data;
Sig. Sig ::= LIdent ":" Type ; Sig. Sig ::= LIdent ":" Type;
Bind. Bind ::= LIdent [LIdent] "=" Exp;
Bind. Bind ::= LIdent [LIdent] "=" Exp ;
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * TYPES -- * Types
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
TLit. Type2 ::= UIdent ; TLit. Type1 ::= UIdent; -- τ
TVar. Type2 ::= TVar ; TVar. Type1 ::= TVar; -- α
TAll. Type1 ::= "forall" TVar "." Type ; internal TEVar. Type1 ::= TEVar; -- ά
TData. Type1 ::= UIdent "(" [Type] ")" ; TData. Type1 ::= UIdent "(" [Type] ")"; -- D ()
internal TEVar. Type1 ::= TEVar ; TFun. Type ::= Type1 "->" Type; -- A A
TFun. Type ::= Type1 "->" Type ; TAll. Type ::= "forall" TVar "." Type; -- α. A
MkTVar. TVar ::= LIdent ; MkTVar. TVar ::= LIdent;
internal MkTEVar. TEVar ::= LIdent ; internal MkTEVar. TEVar ::= LIdent;
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * DATA TYPES -- * DATA TYPES
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
Constructor. Constructor ::= UIdent ":" Type ; Data. Data ::= "data" Type "where" "{" [Inj] "}" ;
Data. Data ::= "data" Type "where" "{" [Constructor] "}" ; Inj. Inj ::= UIdent ":" Type ;
separator nonempty Inj " " ;
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * EXPRESSIONS -- * Expressions
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
EAnn. Exp4 ::= "(" Exp ":" Type ")" ; EAnn. Exp4 ::= "(" Exp ":" Type ")";
EVar. Exp3 ::= LIdent ; EVar. Exp3 ::= LIdent;
EInj. Exp3 ::= UIdent ; EInj. Exp3 ::= UIdent;
ELit. Exp3 ::= Lit ; ELit. Exp3 ::= Lit;
EApp. Exp2 ::= Exp2 Exp3 ; EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2 ; EAdd. Exp1 ::= Exp1 "+" Exp2;
ELet. Exp ::= "let" Bind "in" Exp ; ELet. Exp ::= "let" Bind "in" Exp;
EAbs. Exp ::= "\\" LIdent "." Exp ; EAbs. Exp ::= "\\" LIdent "." Exp;
ECase. Exp ::= "case" Exp "of" "{" [Branch] "}"; ECase. Exp ::= "case" Exp "of" "{" [Branch] "}";
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * LITERALS -- * LITERALS
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
LInt. Lit ::= Integer ; LInt. Lit ::= Integer;
LChar. Lit ::= Char ; LChar. Lit ::= Character;
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * CASE -- * PATTERN MATCHING
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
Branch. Branch ::= Pattern "=>" Exp ; Branch. Branch ::= Pattern "=>" Exp ;
PVar. Pattern1 ::= LIdent ; PVar. Pattern1 ::= LIdent;
PLit. Pattern1 ::= Lit ; PLit. Pattern1 ::= Lit;
PCatch. Pattern1 ::= "_" ; PCatch. Pattern1 ::= "_";
PEnum. Pattern1 ::= UIdent ; PEnum. Pattern1 ::= UIdent;
PInj. Pattern ::= UIdent [Pattern1] ; PInj. Pattern ::= UIdent [Pattern1];
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * AUX -- * AUX
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
terminator Def ";" ; terminator Def ";";
separator nonempty Constructor "" ;
separator Type " " ;
separator nonempty Pattern1 " " ;
terminator Branch ";" ; terminator Branch ";" ;
separator Ident " ";
separator LIdent " ";
separator TVar " " ;
coercions Exp 4 ; separator LIdent "";
coercions Type 2 ; separator Type " ";
coercions Pattern 1 ; separator TVar " ";
separator nonempty Pattern1 " ";
coercions Pattern 1;
coercions Exp 4;
coercions Type 1 ;
token Character '\''(char)'\'' ;
token UIdent (upper (letter | digit | '_')*) ; token UIdent (upper (letter | digit | '_')*) ;
token LIdent (lower (letter | digit | '_')*) ; token LIdent (lower (letter | digit | '_')*) ;
comment "--" ; comment "--";
comment "{-" "-}" ; comment "{-" "-}";

View file

@ -31,13 +31,18 @@ executable language
Grammar.Skel Grammar.Skel
Grammar.ErrM Grammar.ErrM
Auxiliary Auxiliary
Renamer.Renamer
TypeChecker.TypeChecker TypeChecker.TypeChecker
TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
TypeChecker.RemoveTEVar
LambdaLifter
Monomorphizer.Monomorphizer Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr
Renamer.Renamer
Codegen.Codegen Codegen.Codegen
Codegen.LlvmIr Codegen.LlvmIr
Compiler
hs-source-dirs: src hs-source-dirs: src
@ -60,6 +65,9 @@ Test-suite language-testsuite
main-is: Tests.hs main-is: Tests.hs
other-modules: other-modules:
TestTypeCheckerBidir
TestTypeCheckerHm
Grammar.Abs Grammar.Abs
Grammar.Lex Grammar.Lex
Grammar.Par Grammar.Par
@ -67,9 +75,11 @@ Test-suite language-testsuite
Grammar.Skel Grammar.Skel
Grammar.ErrM Grammar.ErrM
Auxiliary Auxiliary
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer Renamer.Renamer
TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir
TypeChecker.RemoveTEVar
TypeChecker.TypeCheckerIr
Compiler Compiler
hs-source-dirs: src, tests, tests/TypecheckingHM hs-source-dirs: src, tests, tests/TypecheckingHM
@ -87,3 +97,4 @@ Test-suite language-testsuite
, bytestring , bytestring
default-language: GHC2021 default-language: GHC2021

11
sample-programs/basic-0 Normal file
View file

@ -0,0 +1,11 @@
data forall a. List (a) where {
Nil : List (a)
Cons : a -> List (a) -> List (a)
};
length : forall c. List (c) -> Int;
length = \list. case list of {
Nil => 0;
Cons x xs => 1 + length xs;
Cons x (Cons y Nil) => 2;
};

View file

@ -3,3 +3,4 @@ add x = \y. x+y;
main : Int ; main : Int ;
main = (\z. z+z) ((add 4) 6) ; main = (\z. z+z) ((add 4) 6) ;

121
spec.txt Normal file
View file

@ -0,0 +1,121 @@
---------------------------------------------------------------------------
-- * Parser
---------------------------------------------------------------------------
data Program = Program [Def]
data Def = DSig Ident Type | DBind Bind
data Bind = Bind Ident [Ident] Exp
data Exp
= EId Ident
| ELit Lit
| EAnn Exp Type
| ELet Ident Exp Exp
| EApp Exp Exp
| EAdd Exp Exp
| EAbs Ident Exp
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
| TEVar TEVar -- ά (internal)
data TVar = MkTVar Ident
data TEVar = MkTEVar Ident
---------------------------------------------------------------------------
-- * Type checker
---------------------------------------------------------------------------
-- • Def and DSig are removed in favor on just Bind
-- • Typed expressions
-- • TEVar is removed (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
data TVar = MkTVar Ident
---------------------------------------------------------------------------
-- * Lambda lifter
---------------------------------------------------------------------------
-- • EAbs are removed (NOT IMPLEMENTED)
-- • ELet only allow constant expressions (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Ident ExpT ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
data TVar = MkTVar Ident
---------------------------------------------------------------------------
-- * Monomorpher
---------------------------------------------------------------------------
-- • Polymorphic types are removed (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Ident ExpT ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type = Type Ident

View file

@ -3,6 +3,7 @@ module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither) import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError) import Control.Monad.Except (MonadError)
import Data.Either.Combinators (maybeToRight) import Data.Either.Combinators (maybeToRight)
import TypeChecker.TypeCheckerIr (Type (TFun))
snoc :: a -> [a] -> [a] snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x] snoc x xs = xs ++ [x]
@ -19,3 +20,4 @@ mapAccumM f = go
(acc', x') <- f acc x (acc', x') <- f acc x
(acc'', xs') <- go acc' xs (acc'', xs') <- go acc' xs
pure (acc'', x':xs') pure (acc'', x':xs')

View file

@ -17,6 +17,7 @@ import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace) import Debug.Trace (trace)
import qualified Grammar.Abs as GA import qualified Grammar.Abs as GA
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr (Ident (..))
import Monomorphizer.MonomorphizerIr as MIR import Monomorphizer.MonomorphizerIr as MIR
-- | The record used as the code generator state -- | The record used as the code generator state
@ -57,8 +58,13 @@ getVarCount :: CompilerState Integer
getVarCount = gets variableCount getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state -- | Increases the variable count and returns it from the CodeGenerator state
<<<<<<< HEAD
getNewVar :: CompilerState GA.Ident getNewVar :: CompilerState GA.Ident
getNewVar = GA.Ident . show <$> (increaseVarCount >> getVarCount) getNewVar = GA.Ident . show <$> (increaseVarCount >> getVarCount)
=======
getNewVar :: CompilerState Ident
getNewVar = (Ident . show) <$> (increaseVarCount >> getVarCount)
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
-- | Increses the label count and returns a label from the CodeGenerator state -- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer getNewLabel :: CompilerState Integer
@ -76,10 +82,25 @@ getFunctions bs = Map.fromList $ go bs
go (MIR.DBind (MIR.Bind id args _) : xs) = go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args}) (id, FunctionInfo{numArgs = length args, arguments = args})
: go xs : go xs
<<<<<<< HEAD
go (_ : xs) = go xs go (_ : xs) = go xs
=======
go (MIR.DData (MIR.Data n cons) : xs) =
do map
( \(Inj id xs) ->
( (coerce id, MIR.TLit (extractTypeName n))
, FunctionInfo
{ numArgs = undefined -- TODO
, arguments = createArgs (snd <$> undefined) -- TODO
}
)
)
cons
<> go xs
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
createArgs :: [MIR.Type] -> [Id] createArgs :: [MIR.Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
@ -89,6 +110,7 @@ getConstructors bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (MIR.DData (MIR.Data t cons) : xs) = go (MIR.DData (MIR.Data t cons) : xs) =
<<<<<<< HEAD
fst fst
( foldl ( foldl
( \(acc, i) (Constructor id xs) -> ( \(acc, i) (Constructor id xs) ->
@ -96,6 +118,17 @@ getConstructors bs = Map.fromList $ go bs
, ConstructorInfo , ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs) { numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs) , argumentsCI = createArgs (init . flattenType $ xs)
=======
do
let (Ident n) = extractTypeName t
fst
( foldl
( \(acc, i) (Inj (Ident id) xs) ->
( ( (Ident (n <> "_" <> id), MIR.TLit (coerce n))
, ConstructorInfo
{ numArgsCI = undefined -- TODO
, argumentsCI = createArgs (snd <$> undefined) -- TODO
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
, numCI = i , numCI = i
, returnTypeCI = t --last . flattenType $ xs , returnTypeCI = t --last . flattenType $ xs
} }
@ -133,30 +166,30 @@ test :: Integer -> Program
test v = test v =
Program Program
[ DataType [ DataType
(GA.Ident "Craig") (Ident "Craig")
[ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")] [ Constructor (Ident "Bob") [MIR.Type (Ident "_Int")]
, Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")] , Constructor (Ident "Betty") [MIR.Type (Ident "_Int")]
] ]
, DataType , DataType
(GA.Ident "Alice") (Ident "Alice")
[ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- , [ Constructor (Ident "Eve") [MIR.Type (Ident "_Int")] -- ,
-- (GA.Ident "Alice", [TInt, TInt]) -- (Ident "Alice", [TInt, TInt])
] ]
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) , Bind (Ident "fibonacci", MIR.Type (Ident "_Int")) [(Ident "x", MIR.Type (Ident "_Int"))] (EVar ("x", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig"))
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) [] , Bind (Ident "main", MIR.Type (Ident "_Int")) []
-- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92) -- (EApp (MIR.Type (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))-- (EInt 92)
$ $
eCaseInt eCaseInt
(EApp (MIR.TLit (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.TLit (GA.Ident "Craig")), MIR.TLit (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig")) (EApp (MIR.TLit (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.TLit (Ident "Craig")), MIR.TLit (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int")) [ injectionCons "Craig_Bob" "Craig" [CIdent (Ident "x")] (EVar (Ident "x", MIR.Type (Ident "_Int")), MIR.Type (Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2) , injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (GA.Ident "z")) (int 3) , Injection (CIdent (Ident "z")) (int 3)
, -- , injectionInt 5 (int 6) , -- , injectionInt 5 (int 6)
injectionCatchAll (int 10) injectionCatchAll (int 10)
] ]
] ]
where where
injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs) injectionCons x y xs = Injection (CCons (Ident x, MIR.Type (Ident y)) xs)
injectionInt x = Injection (CLit (LInt x)) injectionInt x = Injection (CLit (LInt x))
injectionCatchAll = Injection CatchAll injectionCatchAll = Injection CatchAll
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int")) eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
@ -206,7 +239,7 @@ compileScs [] = do
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
enumerateOneM_ enumerateOneM_
( \i (GA.Ident arg_n, arg_t) -> do ( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar elemPtr <- getNewVar
@ -222,7 +255,7 @@ compileScs [] = do
I32 I32
(VInteger i) (VInteger i)
) )
emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
) )
(argumentsCI ci) (argumentsCI ci)
@ -255,8 +288,13 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let biggestVariant = 7 + maximum (sum . (\(Constructor _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) let biggestVariant = 7 + maximum (sum . (\(Constructor _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (Ident outer_id) [I8, Array biggestVariant I8] emit $ LIR.Type (Ident outer_id) [I8, Array biggestVariant I8]
mapM_ mapM_
<<<<<<< HEAD
( \(Constructor inner_id fi) -> do ( \(Constructor inner_id fi) -> do
emit $ LIR.Type inner_id (I8 : variantTypes fi) emit $ LIR.Type inner_id (I8 : variantTypes fi)
=======
( \(Inj (Ident inner_id) fi) -> do
emit $ LIR.Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType (snd <$> undefined)) -- TODO
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
) )
ts ts
compileScs xs compileScs xs
@ -282,17 +320,17 @@ mainContent var =
-- " %4 = load i72, ptr %3\n" <> -- " %4 = load i72, ptr %3\n" <>
-- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n" -- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n"
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2") -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (GA.Ident "b_1") -- , Label (Ident "b_1")
-- , UnsafeRaw -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (GA.Ident "end") -- , Br (Ident "end")
-- , Label (GA.Ident "b_2") -- , Label (Ident "b_2")
-- , UnsafeRaw -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (GA.Ident "end") -- , Br (Ident "end")
-- , Label (GA.Ident "end") -- , Label (Ident "end")
Ret I64 (VInteger 0) Ret I64 (VInteger 0)
] ]
@ -310,7 +348,7 @@ compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit,t) = emitLit lit compileExp (MIR.ELit lit,t) = emitLit lit
compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2 compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2 -- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (MIR.EId name,t) = emitIdent name compileExp (MIR.EVar name,t) = emitIdent name
compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2 compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e -- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e) compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e)
@ -328,7 +366,7 @@ emitECased t e cases = do
let rt = type2LlvmType (snd e) let rt = type2LlvmType (snd e)
vs <- exprToValue e vs <- exprToValue e
lbl <- getNewLabel lbl <- getNewLabel
let label = GA.Ident $ "escape_" <> show lbl let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty) emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs mapM_ (emitCases rt ty label stackPtr vs) cs
@ -341,13 +379,13 @@ emitECased t e cases = do
res <- getNewVar res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr) emit $ SetVariable res (Load ty Ptr stackPtr)
where where
emitCases :: LLVMType -> LLVMType -> GA.Ident -> GA.Ident -> LLVMValue -> Branch -> CompilerState () emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, t) exp) = do emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, t) exp) = do
cons <- gets constructors cons <- gets constructors
let r = fromJust $ Map.lookup consId cons let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -397,8 +435,8 @@ emitECased t e cases = do
(MIR.LInt i, _) -> VInteger i (MIR.LInt i, _) -> VInteger i
(MIR.LChar i, _) -> VChar i (MIR.LChar i, _) -> VChar i
ns <- getNewVar ns <- getNewVar
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq ty vs i') emit $ SetVariable ns (Icmp LLEq ty vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos emit $ Label lbl_succPos
@ -444,8 +482,13 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
appEmitter e1 e2 stack = do appEmitter e1 e2 stack = do
let newStack = e2 : stack let newStack = e2 : stack
case e1 of case e1 of
<<<<<<< HEAD
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EId name, t) -> do (MIR.EId name, t) -> do
=======
(MIR.EApp e1' e2', t) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
args <- traverse exprToValue newStack args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
funcs <- gets functions funcs <- gets functions
@ -462,7 +505,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
emit $ SetVariable vs call emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x x -> error $ "The unspeakable happened: " <> show x
emitIdent :: GA.Ident -> CompilerState () emitIdent :: Ident -> CompilerState ()
emitIdent id = do emitIdent id = do
-- !!this should never happen!! -- !!this should never happen!!
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
@ -477,14 +520,14 @@ emitLit i = do
(MIR.LChar i'') -> (VChar i'', I8) (MIR.LChar i'') -> (VChar i'', I8)
varCount <- getNewVar varCount <- getNewVar
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0)) emit $ SetVariable (Ident (show varCount)) (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd t e1 e2 = do emitAdd t e1 e2 = do
v1 <- exprToValue e1 v1 <- exprToValue e1
v2 <- exprToValue e2 v2 <- exprToValue e2
v <- getNewVar v <- getNewVar
emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2) emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitSub t e1 e2 = do emitSub t e1 e2 = do
@ -498,7 +541,7 @@ exprToValue = \case
(MIR.ELit i, t) -> pure $ case i of (MIR.ELit i, t) -> pure $ case i of
(MIR.LInt i) -> VInteger i (MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar i (MIR.LChar i) -> VChar i
(MIR.EId name, t) -> do (MIR.EVar name, t) -> do
funcs <- gets functions funcs <- gets functions
case Map.lookup (name, t) funcs of case Map.lookup (name, t) funcs of
Just fi -> do Just fi -> do
@ -515,7 +558,7 @@ exprToValue = \case
e -> do e -> do
compileExp e compileExp e
v <- getVarCount v <- getVarCount
pure $ VIdent (GA.Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: MIR.Type -> LLVMType type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(Ident name)) = case name of type2LlvmType (MIR.TLit id@(Ident name)) = case name of
@ -558,3 +601,4 @@ typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -12,7 +12,8 @@ module Codegen.LlvmIr (
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import Grammar.Abs (Ident (..)) import Grammar.Abs (Character)
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show
instance ToIr CallingConvention where instance ToIr CallingConvention where
@ -87,7 +88,7 @@ instance ToIr Visibility where
-- or a string contstant -- or a string contstant
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VChar Char | VChar Character
| VIdent Ident LLVMType | VIdent Ident LLVMType
| VConstant String | VConstant String
| VFunction Ident Visibility LLVMType | VFunction Ident Visibility LLVMType

242
src/LambdaLifter.hs Normal file
View file

@ -0,0 +1,242 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.List (partition)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
where
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
(binds, datatypes) = partition isBind defs
isBind = \case
DBind _ -> True
_ -> False
-- | Annotate free variables
freeVars :: [Bind] -> AnnBinds
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
| Bind n xs e <- binds
]
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
freeVarsExp localVars (exp, t) = case exp of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
| otherwise -> (mempty, (AVar n, t))
ELit lit -> (mempty, (ALit lit, t))
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind (name, t_bind) parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
-- AST annotated with free variables
type AnnBinds = [(Id, [Id], AnnExpT)]
type AnnExpT = (Set Ident, AnnExpT')
data ABind = ABind Id [Id] AnnExpT deriving Show
type AnnExpT' = (AnnExp, Type)
data AnnExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet ABind AnnExpT
| AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnBinds -> [Bind]
abstract prog = evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExpT) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
go ((free, (e, t)), acc)
| AAbs par (free1, e1) <- e
, TFun t_par _ <- t
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
| otherwise = ((free, (e, t)), acc)
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ)
ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
skipLambdas f (free, (ae, t)) = case ae of
AAbs par ae1 -> do
ae1' <- skipLambdas f ae1
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
-- Lift lambda into let and bind free variables
AAbs parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
pure $ foldl applyVars sc freeList
where
freeList = Set.toList free
vars = zip names $ getVars typ
names = snoc parm freeList
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
where
(t_var, t_return) = applyVarType t
applyVarType :: Type -> (Type, Type)
applyVarType typ = (t1, foldr ($) t2 foralls)
where
(t1, t2) = case typ' of
TFun t1 t2 -> (t1, t2)
_ -> error "Not a function!"
(foralls, typ') = skipForalls [] typ
skipForalls acc = \case
TAll tvar t -> skipForalls (snoc (TAll tvar) acc) t
t -> (acc, t)
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
else (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
flattenLambdas = go . (, [])
where
go ((e, t), acc) = case e of
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
where t_var = head $ getVars t
_ -> ((e, t), acc)
getVars :: Type -> [Type]
getVars = fst . partitionType
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)

View file

@ -1,66 +1,114 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-}
module Main where module Main where
import Codegen.Codegen (generateCode)
import Data.Bool (bool)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Monomorphizer.Monomorphizer (monomorphize)
import Control.Monad (when) import Control.Monad (when)
import Data.Bool (bool)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
import Data.Maybe (fromJust, isNothing)
import Compiler (compile) import GHC.IO.Handle.Text (hPutStrLn)
import Renamer.Renamer (rename) import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder),
OptDescr (Option), getOpt,
usageInfo)
import System.Directory (createDirectory, doesPathExist, import System.Directory (createDirectory, doesPathExist,
getDirectoryContents, getDirectoryContents,
removeDirectoryRecursive, removeDirectoryRecursive,
setCurrentDirectory) setCurrentDirectory)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (ExitCode, exitFailure, import System.Exit (ExitCode (ExitFailure),
exitSuccess) exitFailure, exitSuccess,
exitWith)
import System.IO (stderr) import System.IO (stderr)
import System.Process.Extra (readCreateProcess, shell,
spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (typecheck) import Codegen.Codegen (generateCode)
import Compiler (compile)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize)
import Renamer.Renamer (rename)
import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
main :: IO () main :: IO ()
main = main = getArgs >>= parseArgs >>= uncurry main'
getArgs >>= \case
[] -> putStrLn "Required file path missing"
["-d", s] -> do
when (".crf" `isSuffixOf` s) (main' True s)
putStrLn $ "File '" ++ s ++ "' is not a churf file"
[s] -> do
when (".crf" `isSuffixOf` s) (main' False s)
putStrLn $ "File '" ++ s ++ "' is not a churf file"
xs -> putStrLn $ "Can't process: " ++ unwords xs
main' :: Bool -> String -> IO () parseArgs :: [String] -> IO (Options, String)
main' debug s = do parseArgs argv = case getOpt RequireOrder flags argv of
(os, f:_, [])
| opts.help || isNothing opts.typechecker -> do
hPutStrLn stderr (usageInfo header flags)
exitSuccess
| otherwise -> pure (opts, f)
where
opts = foldr ($) initOpts os
(_, _, errs) -> do
hPutStrLn stderr (concat errs ++ usageInfo header flags)
exitWith (ExitFailure 1)
where
header = "Usage: language [--help] [-d|--debug] [-t|type-checker bi/hm] FILE \n"
flags :: [OptDescr (Options -> Options)]
flags =
[ Option ['d'] ["debug"] (NoArg enableDebug) "Print debug messages."
, Option ['t'] ["type-checker"] (ReqArg chooseTypechecker "bi/hm") "Choose type checker. Possible options are bi and hm"
, Option [] ["help"] (NoArg enableHelp) "Print this help message"
]
initOpts :: Options
initOpts = Options { help = False
, debug = False
, typechecker = Nothing
}
enableHelp :: Options -> Options
enableHelp opts = opts { help = True }
enableDebug :: Options -> Options
enableDebug opts = opts { debug = True }
chooseTypechecker :: String -> Options -> Options
chooseTypechecker s options = options { typechecker = tc }
where
tc = case s of
"hm" -> pure Hm
"bi" -> pure Bi
_ -> Nothing
data Options = Options
{ help :: Bool
, debug :: Bool
, typechecker :: Maybe TypeChecker
}
main' :: Options -> String -> IO ()
main' opts s = do
file <- readFile s file <- readFile s
printToErr "-- Parse Tree -- " printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file parsed <- fromSyntaxErr . pProgram $ myLexer file
bool (printToErr $ printTree parsed) (printToErr $ show parsed) debug bool (printToErr $ printTree parsed) (printToErr $ show parsed) opts.debug
printToErr "\n-- Renamer --" printToErr "\n-- Renamer --"
renamed <- fromRenamerErr . rename $ parsed renamed <- fromRenamerErr . rename $ parsed
bool (printToErr $ printTree renamed) (printToErr $ show renamed) debug bool (printToErr $ printTree renamed) (printToErr $ show renamed) opts.debug
printToErr "\n-- TypeChecker --" printToErr "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed typechecked <- fromTypeCheckerErr $ typecheck (fromJust opts.typechecker) renamed
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) debug bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) opts.debug
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
-- printToErr "\n-- Lambda Lifter --" -- printToErr "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted -- printToErr $ printTree lifted
-- --
--printToErr "\n -- Compiler --" printToErr "\n -- Compiler --"
generatedCode <- fromCompilerErr $ generateCode (monomorphize typechecked) generatedCode <- fromCompilerErr $ generateCode (monomorphize typechecked)
--putStrLn generatedCode --putStrLn generatedCode

View file

@ -5,8 +5,9 @@ module Monomorphizer.Monomorphizer (monomorphize) where
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Monomorphizer.MonomorphizerIr qualified as M import qualified Monomorphizer.MonomorphizerIr as M
import TypeChecker.TypeCheckerIr qualified as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ident (..))
monomorphize :: T.Program -> M.Program monomorphize :: T.Program -> M.Program
monomorphize (T.Program ds) = M.Program $ monoDefs ds monomorphize (T.Program ds) = M.Program $ monoDefs ds
@ -16,20 +17,20 @@ monoDefs = map monoDef
monoDef :: T.Def -> M.Def monoDef :: T.Def -> M.Def
monoDef (T.DBind bind) = M.DBind $ monoBind bind monoDef (T.DBind bind) = M.DBind $ monoBind bind
monoDef (T.DData d) = M.DData $ monoData d --monoDef (T.DData d) = M.DData $ monoData d
monoBind :: T.Bind -> M.Bind monoBind :: T.Bind -> M.Bind
monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t) monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t)
monoData :: T.Data -> M.Data --monoData :: T.Data -> M.Data
monoData (T.Data (T.Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs) --monoData (T.Data (Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs)
monoConstructor :: T.Constructor -> M.Constructor monoConstructor :: T.Inj -> M.Inj
monoConstructor (T.Constructor (T.Ident i) t) = M.Constructor (M.Ident i) (monoType t) monoConstructor (T.Inj (Ident i) t) = M.Inj (M.Ident i) (monoType t)
monoExpr :: T.Exp -> M.Exp monoExpr :: T.Exp -> M.Exp
monoExpr = \case monoExpr = \case
T.EId (T.Ident i) -> M.EId (M.Ident i) T.EVar (Ident i) -> M.EVar (M.Ident i)
T.ELit lit -> M.ELit $ monoLit lit T.ELit lit -> M.ELit $ monoLit lit
T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt) T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt)
T.EApp expt1 expt2 -> M.EApp (monoexpt expt1) (monoexpt expt2) T.EApp expt1 expt2 -> M.EApp (monoexpt expt1) (monoexpt expt2)
@ -47,9 +48,9 @@ monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES"
monoType :: T.Type -> M.Type monoType :: T.Type -> M.Type
monoType (T.TAll _ t) = monoType t monoType (T.TAll _ t) = monoType t
monoType (T.TVar (T.MkTVar i)) = M.TLit "Int" monoType (T.TVar (T.MkTVar i)) = M.TLit "Int"
monoType (T.TLit (T.Ident i)) = M.TLit (M.Ident i) monoType (T.TLit (Ident i)) = M.TLit (M.Ident i)
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
monoType (T.TData (T.Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t)) monoType (T.TData (Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t))
monoexpt :: T.ExpT -> M.ExpT monoexpt :: T.ExpT -> M.ExpT
monoexpt (e, t) = (monoExpr e, monoType t) monoexpt (e, t) = (monoExpr e, monoType t)
@ -65,12 +66,12 @@ monoInjs :: [T.Branch] -> [M.Branch]
monoInjs = map monoInj monoInjs = map monoInj
monoInj :: T.Branch -> M.Branch monoInj :: T.Branch -> M.Branch
monoInj (T.Branch (init, t) expt) = M.Branch (monoInit init, monoType t) (monoexpt expt) monoInj (T.Branch (patt, t) expt) = M.Branch (monoPattern patt, monoType t) (monoexpt expt)
monoInit :: T.Pattern -> M.Pattern monoPattern :: T.Pattern -> M.Pattern
monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t) monoPattern (T.PVar (id, t)) = M.PVar (id, monoType t)
monoInit (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t) monoPattern (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t)
monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps) monoPattern (T.PInj id ps) = M.PInj (coerce id) (map monoPattern ps)
-- DO NOT DO THIS FOR REAL THOUGH -- DO NOT DO THIS FOR REAL THOUGH
monoInit (T.PEnum (T.Ident i)) = M.PInj (M.Ident i) [] monoPattern (T.PEnum (Ident i)) = M.PInj (M.Ident i) []
monoInit T.PCatch = M.PCatch monoPattern T.PCatch = M.PCatch

View file

@ -11,14 +11,14 @@ newtype Program = Program [Def]
data Def = DBind Bind | DData Data data Def = DBind Bind | DData Data
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Data = Data Type [Constructor] data Data = Data Type [Inj]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind Id [Id] ExpT
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Exp data Exp
= EId Ident = EVar Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind ExpT
| EApp ExpT ExpT | EApp ExpT ExpT
@ -35,12 +35,12 @@ data Branch = Branch (Pattern, Type) ExpT
type ExpT = (Exp, Type) type ExpT = (Exp, Type)
data Constructor = Constructor Ident Type data Inj = Inj Ident Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Lit data Lit
= LInt Integer = LInt Integer
| LChar Char | LChar Character
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Type = TLit Ident | TFun Type Type data Type = TLit Ident | TFun Type Type

View file

@ -1,131 +1,124 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM) import Control.Monad.Except (ExceptT, MonadError (throwError),
import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) runExceptT)
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.State (MonadState, State, evalState, gets,
import Control.Monad.State ( mapAndUnzipM, modify)
MonadState,
StateT,
evalStateT,
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe) import Data.Tuple.Extra (dupe, second)
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM (Err)
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Either String Program rename :: Program -> Err Program
rename (Program defs) = Program <$> renameDefs defs rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def] initCxt :: Cxt
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt initCxt = Cxt 0 0
data Cxt = Cxt { var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map String String
renameDefs :: [Def] -> Err [Def]
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
where where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs] initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs]
renameDef :: Def -> Rn Def renameDef :: Def -> Rn Def
renameDef = \case renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind DBind (Bind name vars rhs) -> do
DData (Data (TData cname types) constrs) -> do (new_names, vars') <- newNamesL initNames vars
tvars_ <- tvars rhs' <- snd <$> renameExp new_names rhs
tvars' <- mapM nextNameTVar tvars_ pure . DBind $ Bind name vars' rhs'
let tvars_lt = zip tvars_ tvars' DData (Data typ injs) -> do
typ' = map (substituteTVar tvars_lt) types tvars <- collectTVars [] typ
constrs' = map (renameConstr tvars_lt) constrs tvars' <- mapM nextNameTVar tvars
pure . DData $ Data (TData cname typ') constrs' let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs'
where where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t TAll tvar t -> collectTVars (tvar:tvars) t
TData _ _ -> return tvars TData _ _ -> pure tvars
-- Should be monad error _ -> throwError ("Bad data type definition: " ++ show typ)
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameConstr new_types (Constructor name typ) = renameInj new_types (Inj name typ) =
Constructor name $ substituteTVar new_types typ Inj name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of substituteTVar new_names typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names -> TVar tvar | Just tvar' <- lookup tvar new_names
TVar tvar' -> TVar tvar'
| otherwise -> | otherwise
typ -> typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names -> TAll tvar t | Just tvar' <- lookup tvar new_names
TAll tvar' $ substitute' t -> TAll tvar' $ substitute' t
| otherwise -> | otherwise
TAll tvar $ substitute' t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ) _ -> error ("Impossible " ++ show typ)
where where
substitute' = substituteTVar new_names substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names) EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n) EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit) ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing -- TODO fix shadowing
ELet bind e -> do ELet (Bind name vars rhs) e -> do
(new_names, bind') <- renameBind old_names bind (new_names, name') <- newNameL old_names name
(new_names', e') <- renameExp new_names e (new_names', vars') <- newNamesL new_names vars
pure (new_names', ELet bind' e') (new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do EAbs par e -> do
(new_names, par') <- newName old_names (coerce par) (new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
t' <- renameTVars t t' <- renameTVars t
@ -137,26 +130,23 @@ renameExp old_names = \case
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch]) renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs (new_names, xs') <- mapAndUnzipM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs') if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch) renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do renameBranch ns (Branch patt e) = do
(new_names, init') <- renamePattern ns init (new_names, patt') <- renamePattern ns patt
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
return (new_names', Branch init' e') return (new_names', Branch patt' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern) renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of renamePattern ns p = case p of
PInj cs ps -> do PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps (ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps) return (ns_new, PInj cs ps')
rest -> return (ns, rest) PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type renameTVars :: Type -> Rn Type
renameTVars typ = case typ of renameTVars typ = case typ of
@ -167,44 +157,57 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ _ -> pure typ
substitute :: substitute :: TVar -- α
TVar -> -- α -> TVar -- α_n
TVar -> -- α_n -> Type -- A
Type -> -- A -> Type -- [α_n/α]A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar' TVar tvar | tvar == tvar1 -> TVar tvar2
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ | otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible" _ -> error "Impossible"
where where
substitute' = substitute tvar1 tvar2 substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent]) newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName newNamesL = mapAccumM newNameL
-- | Create a new name and add it to name environment.
newNameL :: Names -> LIdent -> Rn (Names, LIdent)
newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
-- | Create a new name and add it to name environment.
newNameU :: Names -> UIdent -> Rn (Names, UIdent)
newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent makeName :: String -> Rn String
makeName (LIdent prefix) = do makeName prefix = do
i <- gets var_counter i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter} modify $ \cxt -> cxt { var_counter = succ cxt.var_counter}
pure name pure name
nextNameTVar :: TVar -> Rn TVar nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do nextNameTVar (MkTVar (LIdent s))= do
i <- gets tvar_counter i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter} modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter}
pure tvar pure tvar

206
src/Renamer/RenamerOld.hs Normal file
View file

@ -0,0 +1,206 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM)
import Control.Monad.Except (ExceptT, MonadError, runExceptT,
throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (MonadState, StateT, evalStateT, gets,
modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (TData cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Inj -> Inj
renameConstr new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet bind e -> do
(new_names, bind') <- renameBind old_names bind
(new_names', e') <- renameExp new_names e
pure (new_names', ELet bind' e')
EAbs par e -> do
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do
(new_names, init') <- renamePattern ns init
(new_names', e') <- renameExp new_names e
return (new_names', Branch init' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of
PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps
return (ns_new, PInj cs ps)
rest -> return (ns, rest)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -0,0 +1,73 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.RemoveTEVar where
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Arrow (Arrow (second))
import Control.Monad.Error (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Tuple.Extra (secondM)
import Grammar.Abs
import Grammar.ErrM (Err)
import qualified TypeChecker.TypeCheckerIr as T
class RemoveTEVar a b where
rmTEVar :: a -> Err b
instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where
rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs
instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where
rmTEVar = \case
T.DBind bind -> T.DBind <$> rmTEVar bind
T.DData dat -> T.DData <$> rmTEVar dat
instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where
rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs)
instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where
rmTEVar exp = case exp of
T.EVar name -> pure $ T.EVar name
T.EInj name -> pure $ T.EInj name
T.ELit lit -> pure $ T.ELit lit
T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e)
T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
T.EAdd e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
T.EAbs name e -> T.EAbs name <$> rmTEVar e
T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches)
instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where
rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e)
instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where
rmTEVar = \case
T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t
T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t
T.PCatch -> pure T.PCatch
T.PEnum name -> pure $ T.PEnum name
T.PInj name ps -> T.PInj name <$> rmTEVar ps
instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where
rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs)
instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where
rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ
instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where
rmTEVar = secondM rmTEVar
instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where
rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ)
instance RemoveTEVar a b => RemoveTEVar [a] [b] where
rmTEVar = mapM rmTEVar
instance RemoveTEVar Type T.Type where
rmTEVar = \case
TLit lit -> pure $ T.TLit (coerce lit)
TVar tvar -> pure $ T.TVar tvar
TData name typs -> T.TData (coerce name) <$> rmTEVar typs
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
TAll tvar t -> T.TAll tvar <$> rmTEVar t
TEVar _ -> throwError "NewType TEVar!"

View file

@ -0,0 +1,858 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM,
zipWithM_)
import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (intercalate)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing)
import Data.Sequence (Seq (..))
import qualified Data.Sequence as S
import Data.Tuple.Extra (second, secondM)
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.ErrM
import Grammar.Print (printTree)
import Prelude hiding (exp, id)
import qualified TypeChecker.TypeCheckerIr as T
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α
| EnvTEVar TEVar -- ^ Existential unsolved type variable. ά
| EnvTEVarSolved TEVar Type -- ^ Existential solved type variable. ά = τ
| EnvMark TEVar -- ^ Scoping Marker. ▶ ά
deriving (Eq, Show)
type Env = Seq EnvElem
-- | Ordered context
-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A
data Cxt = Cxt
{ env :: Env -- ^ Local scope context Γ
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K
} deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String)
typecheck :: Program -> Err (T.Program' Type)
typecheck (Program defs) = do
datatypes <- mapM typecheckDataType [ d | DData d <- defs ]
let initCxt = Cxt
{ env = mempty
, sig = Map.fromList [ (name, t)
| DSig' name t <- defs
]
, binds = Map.fromList [ (name, foldr EAbs rhs vars)
| DBind' name vars rhs <- defs
]
, next_tevar = 0
, data_injs = Map.fromList [ (name, typ)
| Data _ injs <- datatypes
, Inj name typ <- injs
]
}
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt;
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds'
where
binds = [ b | DBind b <- defs ]
coerceData = map (\(Data t injs) -> T.Data t $ map
(\(Inj name typ) -> T.Inj (coerce name) typ) injs)
typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do
bind' <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
---------------------------
-- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ
Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) (coerce vars') (rhs', t))
where
vars' = zip vars $ getVars t
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------
-- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ
Nothing -> do
(e, t) <- infer $ foldr EAbs rhs vars
t' <- applyEnv t
e' <- applyEnvExp e
let rhs' = skipLambdas (length vars) e'
vars' = zip vars $ getVars t'
pure (T.Bind (coerce name, t') (coerce vars') (rhs', t'))
env <- gets env
unless (isComplete env) err
putEnv Empty
pure bind'
where
err = throwError $ unlines
[ "Type inference failed: " ++ printTree (Bind name vars rhs)
, "Did you forget to add type annotation to a polymorphic function?"
]
typecheckDataType :: Data -> Err Data
typecheckDataType (Data typ injs) = do
(name, tvars) <- go [] typ
injs' <- mapM (\i -> typecheckInj i name tvars) injs
pure (Data typ injs')
where
go tvars = \case
TAll tvar t -> go (tvar:tvars) t
TData name typs
| Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars'
-> pure (name, tvars')
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj
typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ
= throwError "Unbound type variables"
| TData name' typs <- getReturn inj_typ
, name' == name
, Right tvars' <- mapM toTVar typs
, tvars' == tvars
= pure (Inj inj_name $ foldr TAll inj_typ tvars)
| otherwise
= throwError $ unwords
["Bad type constructor: ", show name
, "\nExpected: ", ppT . TData name $ map TVar tvars
, "\nActual: ", ppT $ getReturn inj_typ
]
where
boundTVars :: [TVar] -> Type -> Bool
boundTVars tvars' = \case
TAll tvar t -> boundTVars (tvar:tvars') t
TFun t1 t2 -> on (&&) (boundTVars tvars') t1 t2
TVar tvar -> elem tvar tvars'
TData _ typs -> all (boundTVars tvars) typs
TLit _ -> True
TEVar _ -> error "TEVar in data type declaration"
---------------------------------------------------------------------------
-- * Subtyping rules
---------------------------------------------------------------------------
-- | Γ ⊢ A <: B ⊣ Δ
-- Under input context Γ, type A is a subtype of B, with output context ∆
subtype :: Type -> Type -> Tc ()
subtype t1 t2 = case (t1, t2) of
(TLit lit1, TLit lit2) | lit1 == lit2 -> pure ()
-- -------------------- <:Var
-- Γ[α] ⊢ α <: α ⊣ Γ[α]
(TVar tvar1, TVar tvar2) | tvar1 == tvar2 -> pure ()
-- -------------------- <:Exvar
-- Γ[ά] ⊢ ά <: ά ⊣ Γ[ά]
(TEVar tevar1, TEVar tevar2) | tevar1 == tevar2 -> pure ()
-- Γ ⊢ B₁ <: A₁ ⊣ Θ Θ ⊢ [Θ]A₂ <: [Θ]B₂ ⊣ Δ
-- ----------------------------------------- <:→
-- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ
(TFun a1 a2, TFun b1 b2) -> do
subtype b1 a1
a2' <- applyEnv a2
b2' <- applyEnv b2
subtype a2' b2'
-- Γ, α ⊢ A <: B ⊣ Δ,α
-- --------------------- <:∀R
-- Γ ⊢ A <: ∀α. B ⊣ Δ
(a, TAll tvar b) -> do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
subtype a b
dropTrailing env_tvar
-- Γ,▶ ά,ά ⊢ [ά/α]A <: B ⊣ Δ,▶ ά,Θ
-- ------------------------------- <:∀L
-- Γ ⊢ ∀α.A <: B ⊣ Δ
(TAll tvar a, b) -> do
tevar <- fresh
let env_marker = EnvMark tevar
env_tevar = EnvTEVar tevar
insertEnv env_marker
insertEnv env_tevar
let a' = substitute tvar tevar a
subtype a' b
dropTrailing env_marker
-- ά ∉ FV(A) Γ[ά] ⊢ ά :=< A ⊣ Δ
-- ------------------------------ <:instantiateL
-- Γ[ά] ⊢ ά <: A ⊣ Δ
(TEVar tevar, typ) | notElem tevar $ frees typ -> instantiateL tevar typ
-- ά ∉ FV(A) Γ[ά] ⊢ A =:< ά ⊣ Δ
-- ------------------------------ <:instantiateR
-- Γ[ά] ⊢ A <: ά ⊣ Δ
(typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar
(TData name1 typs1, TData name2 typs2)
-- D₁ = D₂
-- ----------------
-- Γ ⊢ D₁ () <: D₂ ()
| name1 == name2
, [] <- typs1
, [] <- typs2
-> pure ()
-- Γ ⊢ ά₁ <: έ₁ ⊣ Θ₁
-- ...
-- D₁ = D₂ Θₙ₋₁ ⊢ [Θₙ₋₁]άₙ <: [Θₙ₋₁]έₙ ⊣ Δ
-- -------------------------------------------
-- Γ ⊢ D (ά₁ ‥ άₙ) <: D (έ₁ ‥ έₙ) ⊣ Δ
| name1 == name2
, t1:t1s <- typs1
, t2:t2s <- typs2
-> do
subtype t1 t2
zipWithM_ go t1s t2s
where
go t1' t2' = do
t1'' <- applyEnv t1'
t2'' <- applyEnv t2'
subtype t1'' t2''
_ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"]
---------------------------------------------------------------------------
-- * Instantiation rules
---------------------------------------------------------------------------
-- | Γ ⊢ ά :=< A ⊣ Δ
-- Under input context Γ, instantiate ά such that ά <: A, with output context ∆
instantiateL :: TEVar -> Type -> Tc ()
instantiateL tevar typ = gets env >>= go
where
go env
-- Γ ⊢ τ
-- ----------------------------- InstLSolve
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
| isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
| TEVar tevar' <- typ = instReach tevar tevar'
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ =:< ά₁ ⊣ Θ Θ ⊢ ά₂ :=< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstLArr
-- Γ[ά] ⊢ ά :=< A₁ → A₂ ⊣ Δ
| TFun t1 t2 <- typ = do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar2
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateR t1 tevar1
instantiateL tevar2 =<< applyEnv t2
-- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ'
-- ------------------------- InstLAIIR
-- Γ[ά] ⊢ ά :=< ∀ε.Ε ⊣ Δ
| TAll tvar t <- typ = do
instantiateL tevar t
let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l
| otherwise = error $ "Trying to instantiateL: " ++ ppT (TEVar tevar)
++ " <: " ++ ppT typ
-- | Γ ⊢ A =:< ά ⊣ Δ
-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆
instantiateR :: Type -> TEVar -> Tc ()
instantiateR typ tevar = gets env >>= go
where
go env
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
| isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
| TEVar tevar' <- typ = instReach tevar tevar'
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstRArr
-- Γ[ά] ⊢ ά =:< A₁ → A₂ ⊣ Δ
| TFun t1 t2 <- typ = do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar2
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateL tevar1 t1
t2' <- applyEnv t2
instantiateR t2' tevar2
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
-- ---------------------------------- InstRAIIL
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
| TAll tvar t <- typ = do
tevar' <- fresh
insertEnv $ EnvMark tevar'
insertEnv $ EnvTVar tvar
let t' = substitute tvar tevar' t
instantiateR t' tevar
let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l
| otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: "
++ ppT (TEVar tevar)
-- ----------------------------- InstLReach
-- Γ[ά][έ] ⊢ ά :=< έ ⊣ Γ[ά][έ=ά]
--
-- ----------------------------- InstRReach
-- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά]
instReach :: TEVar -> TEVar -> Tc ()
instReach tevar tevar' = do
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar') . env)
let env_solved = EnvTEVarSolved tevar' $ TEVar tevar
putEnv $ (env_l :|> env_solved) <> env_r
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check exp typ
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
-- Γ ⊢ e ↑ ∀α.A ⊣ Δ
| TAll tvar t <- typ = do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
exp' <- check exp t
(env_l, _) <- gets (splitOn env_tvar . env)
putEnv env_l
pure exp'
-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ
-- --------------------------- →I
-- Γ ⊢ λx.e ↑ A → B ⊣ Δ
| EAbs name e <- exp
, TFun t1 t2 <- typ = do
let env_id = EnvVar name t1
insertEnv env_id
e' <- check e t2
(env_l, _) <- gets (splitOn env_id . env)
putEnv env_l
pure (T.EAbs (coerce name) e', typ)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption
where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
-- -------------------------------------- Sub
-- Γ ⊢ e ↑ B ⊣ Δ
subsumption = do
(exp', t) <- infer exp
exp'' <- applyEnvExp exp'
t' <- applyEnv t
typ' <- applyEnv typ
subtype t' typ'
pure (exp'', t')
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer = \case
ELit lit -> pure (T.ELit lit, inferLit lit)
-- (x : A) ∈ Γ
-- ------------- Var
-- Γ ⊢ x ↓ A ⊣ Γ
EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t
Nothing -> do
e <- maybeToRightM
("Unbound variable " ++ show name)
=<< lookupBind name
snd <$> infer e
pure (T.EVar (coerce name), t)
EInj name -> do
t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name
pure (T.EInj $ coerce name, t)
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- Anno
-- Γ ⊢ (e : A) ↓ A ⊣ Δ
EAnn e t -> do
_ <- gets $ (`wellFormed` t) . env
(e', _) <- check e t
pure (e', t)
-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
EApp e1 e2 -> do
(e1', t) <- infer e1
t' <- applyEnv t
e1'' <- applyEnvExp e1'
(e2', t'') <- apply t' e2
pure (T.EApp (e1'', t) e2', t'')
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ
EAbs name e -> do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVar tevar2
let env_id = EnvVar name (TEVar tevar1)
insertEnv env_id
e' <- check e $ TEVar tevar2
dropTrailing env_id
let t_exp = on TFun TEVar tevar1 tevar2
pure (T.EAbs (coerce name) e', t_exp)
-- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ
-- -------------------------------------------- LetI
-- Γ ⊢ let x=e in e' ↑ C ⊣ Δ
ELet (Bind name [] rhs) e -> do -- TODO vars
(rhs', t_rhs) <- infer rhs
let env_id = EnvVar name t_rhs
insertEnv env_id
(e', t) <- infer e
(env_l, _) <- gets (splitOn env_id . env)
putEnv env_l
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int
-- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> do
cxt <- get
let t = TLit "Int"
e1' <- check e1 t
put cxt
e2' <- check e2 t
pure (T.EAdd e1' e2', t)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
apply :: Type -> Exp -> Tc (T.ExpT' Type, Type)
apply typ exp = case typ of
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ
TAll tvar t -> do
tevar <- fresh
insertEnv $ EnvTEVar tevar
let t' = substitute tvar tevar t
apply t' exp
-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ
-- ------------------------------- άApp
-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ
TEVar tevar -> do
tevar1 <- fresh
tevar2 <- fresh
let env_tevar1 = EnvTEVar tevar1
env_tevar2 = EnvTEVar tevar2
t_fun = on TFun TEVar tevar1 tevar2
env_tevar_solved = EnvTEVarSolved tevar t_fun
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env)
putEnv $
(env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r
expT' <- check exp $ TEVar tevar1
pure (expT', TEVar tevar2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
TFun t1 t2 -> do
expt' <- check exp t1
pure (expt', t2)
_ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
checkBranch (Branch patt exp) t_patt t_exp = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
t_exp' <- applyEnv t_exp
(exp, t_exp) <- check exp t_exp'
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
PVar x -> do
insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, dummy), dummy) -- TODO
PCatch -> pure (T.PCatch, dummy) -- TODO
PLit lit | inferLit lit == t_patt -> let
t = inferLit lit
in
pure (T.PLit (lit, t), t)
| otherwise -> throwError "Literal in pattern have wrong type"
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
pure (T.PEnum (coerce name), dummy) -- TODO
PInj name ps -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
let (t_ps, t_return) = partitionTypeWithForall t
unless (length ps == length t_ps) $
throwError "Wrong number of variables"
subtype t_return t_patt
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
let ps'' = map fst ps' -- TODO
pure (T.PInj (coerce name) ps'', dummy)
---------------------------------------------------------------------------
-- * Auxiliary
---------------------------------------------------------------------------
frees :: Type -> [TEVar]
frees = \case
TLit _ -> []
TVar _ -> []
TEVar tevar -> [tevar]
TFun t1 t2 -> on (++) frees t1 t2
TAll _ t -> frees t
TData _ typs -> concatMap frees typs
-- | [ά/α]A
substitute :: TVar -- α
-> TEVar -- ά
-> Type -- A
-> Type -- [ά/α]A
substitute tvar tevar typ = case typ of
TLit _ -> typ
TVar tvar' | tvar' == tvar -> TEVar tevar
| otherwise -> typ
TEVar _ -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar' t -> TAll tvar' (substitute' t)
TData name typs -> TData name $ map substitute' typs
where
substitute' = substitute tvar tevar
-- | Γ,x,Γ' → (Γ, Γ')
splitOn :: EnvElem -> Env -> (Env, Env)
splitOn x env = second (S.drop 1) $ S.breakl (==x) env
-- | Drop frontmost elements until and including element @x@.
dropTrailing :: EnvElem -> Tc ()
dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyEnvExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do
id <- applyEnvId id
vars' <- mapM applyEnvId vars
rhs' <- applyEnvExpT rhs
exp' <- applyEnvExpT exp
pure $ T.ELet (T.Bind id vars' rhs') exp'
T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2)
T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2)
T.EAbs name e -> T.EAbs name <$> applyEnvExpT e
T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e)
(mapM applyEnvBranch branches)
_ -> pure exp
where
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvId = secondM applyEnv
applyEnvBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
e' <- applyEnvExpT e
pure $ T.Branch pt e'
applyEnvPattern = \case
T.PVar id -> T.PVar <$> applyEnvId id
T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t
T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps
p -> pure p
applyEnv :: Type -> Tc Type
applyEnv t = gets $ (`applyEnv'` t) . env
-- | [Γ]A. Applies context to type until fully applied.
applyEnv' :: Env -> Type -> Type
applyEnv' cxt typ | typ == typ' = typ'
| otherwise = applyEnv' cxt typ'
where
typ' = case typ of
TLit _ -> typ
TData name typs -> TData name $ map (applyEnv' cxt) typs
-- [Γ]α = α
TVar _ -> typ
-- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ
-- [Γ[ά]]ά = [Γ[ά]]ά
TEVar tevar -> fromMaybe typ $ findSolved tevar cxt
-- [Γ](A → B) = [Γ]A → [Γ]B
TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyEnv' cxt t
findSolved :: TEVar -> Env -> Maybe Type
findSolved _ Empty = Nothing
findSolved tevar (xs :|> x) = case x of
EnvTEVarSolved tevar' t | tevar == tevar' -> Just t
_ -> findSolved tevar xs
-- | Γ ⊢ A
-- Under context Γ, type A is well-formed
wellFormed :: Env -> Type -> Err ()
wellFormed env = \case
TLit _ -> pure ()
-- -------- UvarWF
-- Γ[α] ⊢ α
TVar tvar -> unless (EnvTVar tvar `elem` env) $
throwError ("Unbound type variable: " ++ show tvar)
-- Γ ⊢ A Γ ⊢ B
-- ------------- ArrowWF
-- Γ ⊢ A → B
TFun t1 t2 -> do { wellFormed env t1; wellFormed env t2 }
-- Γ,α ⊢ A
-- -------- ForallWF
-- Γ ⊢ ∀α.A
TAll tvar t -> wellFormed (env :|> EnvTVar tvar) t
TEVar tevar
-- ---------- EvarWF
-- Γ[ά] ⊢ ά
| EnvTEVar tevar `elem` env -> pure ()
-- ---------- SolvedEvarWF
-- Γ[ά=τ] ⊢ ά
| Just _ <- findSolved tevar env -> pure ()
| otherwise -> throwError ("Can't find type: " ++ show tevar)
TData _ typs -> mapM_ (wellFormed env) typs
isMono :: Type -> Bool
isMono = \case
TAll{} -> False
TFun t1 t2 -> on (&&) isMono t1 t2
TData _ typs -> all isMono typs
TVar _ -> True
TEVar _ -> True
TLit _ -> True
inferLit :: Lit -> Type
inferLit = \case
LInt _ -> TLit "Int"
LChar _ -> TLit "Char"
fresh :: Tc TEVar
fresh = do
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar)
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar
getVars :: Type -> [Type]
getVars = fst . partitionType
getReturn :: Type -> Type
getReturn = snd . partitionType
-- | Partion type into variable types and return type.
--
-- ∀a.∀b. a → (∀c. c → c) → b
-- ([a, ∀c. c → c], b)
--
-- Unsure if foralls should be added to the return type or not.
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return')
where
t_vars' = map (\t -> foldr applyForall t foralls) t_vars
t_return' = foldr applyForall t_return foralls
applyForall fa t | usesTVar tvar t = fa t
| otherwise = t
where TAll tvar _ = fa t
(t_vars, t_return) = go [] typ'
(foralls, typ') = skipForalls typ
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
usesTVar :: TVar -> Type -> Bool
usesTVar tvar = \case
TLit _ -> False
TVar tvar' | tvar' == tvar -> True
| otherwise -> False
TFun t1 t2 -> on (||) usesTVar' t1 t2
TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar"
| otherwise -> usesTVar' t
TData _ typs -> any usesTVar' typs
_ -> error "Impossible"
where
usesTVar' = usesTVar tvar
skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type
skipLambdas i exp
| i == 0 = exp
| T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e
| otherwise = error "Number of expected lambdas doesn't match expression"
isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar
where
unSolvedTEVar = \case
EnvTEVar _ -> True
_ -> False
toTVar :: Type -> Err TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> throwError "Not a type variable"
insertEnv :: EnvElem -> Tc ()
insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig)
lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env)
where
findId Empty = Nothing
findId (ys :|> y) = case y of
EnvVar x' t | x==x' -> Just t
_ -> findId ys
lookupInj :: UIdent -> Tc (Maybe Type)
lookupInj x = gets (Map.lookup x . data_injs)
putEnv :: Env -> Tc ()
putEnv = modifyEnv . const
modifyEnv :: (Env -> Env) -> Tc ()
modifyEnv f =
modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env }
pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ)
dummy = TLit "Int"
---------------------------------------------------------------------------
-- * Debug
---------------------------------------------------------------------------
traceEnv s = do
env <- gets env
trace (s ++ " " ++ show env) pure ()
traceD s x = trace (s ++ " " ++ show x) pure ()
traceT s x = trace (s ++ " " ++ ppT x) pure ()
traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure ()
ppT = \case
TLit (UIdent s) -> s
TVar (MkTVar (LIdent s)) -> "α_" ++ s
TFun t1 t2 -> ppT t1 ++ "" ++ ppT t2
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
++ " )"
ppEnvElem = \case
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s
EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t
EnvMark (MkTEVar (LIdent s)) -> "" ++ "ά_" ++ s
ppEnv = \case
Empty -> "·"
(xs :|> x) -> ppEnv xs ++ " (" ++ ppEnvElem x ++ ")"

View file

@ -16,21 +16,16 @@ import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import qualified Data.Set as S
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr ( import qualified TypeChecker.TypeCheckerIr as T
Ctx (..), import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Env (..), Subst)
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty
@ -695,3 +690,4 @@ unzip4 =
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d]) (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
) )
([], [], [], []) ([], [], [], [])

View file

@ -1,245 +1,135 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr (
module TypeChecker.TypeCheckerIr,
) where
import Control.Monad.Except module TypeChecker.TypeCheckerIr
import Control.Monad.Reader ( module Grammar.Abs
import Control.Monad.State , module TypeChecker.TypeCheckerIr
import Data.Char (isDigit) ) where
import Data.Functor.Identity (Identity)
import Data.Map (Map) import Data.String (IsString)
import Data.Set (Set) import Grammar.Abs (Character (..), Lit (..), TVar (..))
import Data.String qualified
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Ctx = Ctx {vars :: Map Ident Type} newtype Program' t = Program [Def' t]
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type
, takenTypeVars :: Set Ident
}
deriving (Show)
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Data = Data Ident [Constructor] data Def' t = DBind (Bind' t)
deriving (Show, Eq, Ord, Read) | DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Constructor = Constructor Ident Type
deriving (Show, Eq, Ord, Read)
newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read)
data Type data Type
= TLit Ident = TLit Ident
| TVar TVar | TVar TVar
| TData Ident [Type]
| TFun Type Type | TFun Type Type
| TAll TVar Type | TAll TVar Type
| TData Ident [Type]
deriving (Show, Eq, Ord, Read)
data Exp
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
| ECase ExpT [Branch]
deriving (C.Eq, C.Ord, C.Read, C.Show)
type ExpT = (Exp, Type)
data Branch = Branch (Pattern, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Pattern = PVar Id | PLit (Lit, Type) | PInj Ident [Pattern] | PCatch | PEnum Ident
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def = DBind Bind | DData Data data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type) data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype Ident = Ident String newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, Data.String.IsString) deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
data Lit = LInt Integer | LChar Char data Pattern' t
deriving (Show, Eq, Ord, Read) = PVar (Id' t) -- TODO should be Ident
| PLit (Lit, t) -- TODO should be Lit
| PCatch
| PEnum Ident
| PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Bind = Bind Id [Id] ExpT data Exp' t
= EVar Ident
| EInj Ident
| ELit Lit
| ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Ident where instance Print Ident where
prt _ (Ident str) = doc . showString $ str prt i (Ident s) = prt i s
instance Print [Def] where instance Print t => Print (Program' t) where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Data where
prt i = \case
Data type_ constructors -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 constructors, doc (showString "}")])
instance Print Constructor where
prt i = \case
Constructor uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
instance Print [Constructor] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print t => Print (Bind' t) where
prt i (Bind (name, t) args rhs) = prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD
prPrec i 0 $ [ prtSig sig
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString "\n"
, prt 0 name , prt 0 name
, prtIdPs 0 args , prtIdPs 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
instance Print [Bind] where prtSig :: Print t => Id' t -> Doc
prtSig (name, t) = concatD [ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
instance Print t => Print (ExpT' t) where
prt i (e, t) = concatD [ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs] prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc prtIdPs :: Print t => Int -> [Id' t] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtIdPs i = prPrec i 0 . concatD . map (prt i)
prtId :: Int -> Id -> Doc instance Print t => Print (Id' t) where
prtId i (name, t) = prt i (name, t) = concatD [ doc $ showString "("
prPrec i 0 $ , prt i name
concatD , doc $ showString ","
[ doc $ showString "(" , prt i t
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")" , doc $ showString ")"
] ]
prtIdP :: Int -> Id -> Doc instance Print t => Print (Exp' t) where
prtIdP i (name, t) =
prPrec i 0 $
concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prt 0 n] EVar name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ concatD [prt 0 lit] EInj name -> prPrec i 3 $ prt 0 name
ELet bs e -> ELit lit -> prPrec i 3 $ prt 0 lit
prPrec i 3 $ ELet b e -> prPrec i 3 $ concatD
concatD
[ doc $ showString "let" [ doc $ showString "let"
, prt 0 bs , prt 0 b
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
] ]
EApp e1 e2 -> EApp e1 e2 -> prPrec i 2 $ concatD
prPrec i 2 $
concatD
[ prt 2 e1 [ prt 2 e1
, prt 3 e2 , prt 3 e2
] ]
EAdd e1 e2 -> EAdd e1 e2 -> prPrec i 1 $ concatD
prPrec i 1 $ [ prt 1 e1
concatD
[ doc $ showString "@"
, prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
] ]
EAbs n e -> EAbs v e -> prPrec i 0 $ concatD
prPrec i 0 $ [ doc $ showString "\\"
concatD
[ doc $ showString "λ"
, prt 0 n
, doc $ showString "."
, prt 0 e
]
ECase exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
]
)
instance Print ExpT where
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Branch where
prt i = \case
Branch (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print Pattern where
prt i = \case
PVar lident -> prPrec i 0 (concatD [prtId 0 lident])
PLit (lit, typ) -> prPrec i 0 (concatD [doc $ showString "(", prt 0 lit, doc $ showString ",", prt 0 typ, doc $ showString ")"])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 0 patterns])
PCatch -> prPrec i 0 (concatD [doc (showString "_")])
PEnum p -> prt i p
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print TVar where
prt i (MkTVar id) = prt i id
instance Print Type where
prt i = \case
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar@(MkTVar (Ident iden)) ->
if all isDigit iden
then prPrec i 2 (concatD [prt 0 $ TVar (MkTVar (Ident ("a" <> iden)))])
else prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TData ident types -> prPrec i 1 (concatD [prt 0 ident, prt 0 types])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt n -> prPrec i 0 (concatD [prt 0 n])
LChar c -> prPrec i 0 (concatD [prt 0 c])

View file

@ -0,0 +1,232 @@
{-# LANGUAGE OverloadedStrings #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module TestTypeCheckerBidir (testTypeCheckerBidir) where
import Test.Hspec
import Control.Monad ((<=<))
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Par (myLexer, pProgram)
import Renamer.Renamer (rename)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_id
tc_double
tc_add_lam
tc_const
tc_simple_rank2
tc_rank2
tc_identity
tc_pair
tc_tree
tc_mono_case
tc_pol_case
tc_id = specify "Basic identity function polymorphism" $ run
[ "id : forall a. a -> a;"
, "id x = x;"
, "main = id 4;"
] `shouldSatisfy` ok
tc_double = specify "Addition inference" $ run
["double x = x + x;"] `shouldSatisfy` ok
tc_add_lam = specify "Addition lambda inference" $ run
["four = (\\x. x + x) 2;"] `shouldSatisfy` ok
tc_const = specify "Basic polymorphism with multiple type variables" $ run
[ "const : forall a. forall b. a -> b -> a;"
, "const x y = x;"
, "main = const 'a' 65;"
] `shouldSatisfy` ok
tc_simple_rank2 = specify "Simple rank two polymorphism" $ run
[ "id : forall a. a -> a;"
, "id x = x;"
, "f : forall a. a -> (forall b. b -> b) -> a;"
, "f x g = g x;"
, "main = f 4 id;"
] `shouldSatisfy` ok
tc_rank2 = specify "Rank two polymorphism is ok" $ run
[ "const : forall a. forall b. a -> b -> a;"
, "const x y = x;"
, "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int;"
, "rank2 x f y = f x + f y;"
, "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h';"
] `shouldSatisfy` ok
tc_identity = describe "(∀b. b → b) should only accept the identity function" $ do
specify "identityᵢₙₜ is rejected" $ run (fs ++ id_int) `shouldNotSatisfy` ok
specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok
where
fs =
[ "f : forall a. a -> (forall b. b -> b) -> a;"
, "f x g = g x;"
, "id : forall a. a -> a;"
, "id x = x;"
, "id_int : Int -> Int;"
, "id_int x = x;"
]
id =
[ "main : Int;"
, "main = f 4 id;"
]
id_int =
[ "main : Int;"
, "main = f 4 id_int;"
]
tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
specify "Wrong arguments are rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. forall b. Pair (a b) where {"
, " Pair : a -> b -> Pair (a b)"
, "};"
, "main : Pair (Int Char);"
]
wrong = ["main = Pair 'a' 65;"]
correct = ["main = Pair 65 'a';"]
tc_tree = describe "Tree. Recursive data type" $ do
specify "Wrong tree is rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. Tree (a) where {"
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
, " Leaf : a -> Tree (a)"
, "};"
]
wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3);"]
correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3);"]
tc_mono_case = describe "Monomorphic pattern matching" $ do
specify "First wrong case expression rejected"
$ run wrong1 `shouldNotSatisfy` ok
specify "Second wrong case expression rejected"
$ run wrong2 `shouldNotSatisfy` ok
specify "Third wrong case expression rejected"
$ run wrong3 `shouldNotSatisfy` ok
specify "First correct case expression accepted"
$ run correct1 `shouldSatisfy` ok
specify "Second correct case expression accepted"
$ run correct2 `shouldSatisfy` ok
where
wrong1 =
[ "simple : Int -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => 1;"
, "};"
]
wrong2 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 1 => 1;"
, "};"
]
wrong3 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => '1';"
, "};"
]
correct1 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => 1;"
, "};"
]
correct2 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " _ => 1;"
, "};"
]
tc_pol_case = describe "Polymophic pattern matching" $ do
specify "First wrong case expression rejected"
$ run (fs ++ wrong1) `shouldNotSatisfy` ok
specify "Second wrong case expression rejected"
$ run (fs ++ wrong2) `shouldNotSatisfy` ok
specify "Third wrong case expression rejected"
$ run (fs ++ wrong3) `shouldNotSatisfy` ok
specify "First correct case expression accepted"
$ run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted"
$ run (fs ++ correct2) `shouldSatisfy` ok
where
fs =
[ "data forall a. List (a) where {"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "};"
]
wrong1 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " Cons 6 xs => 1 + length xs;"
, "};"
]
wrong2 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Cons => 0;"
, " Cons x xs => 1 + length xs;"
, "};"
]
wrong3 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " 0 => 0;"
, " Cons x xs => 1 + length xs;"
, "};"
]
correct1 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " Cons x xs => 1 + length xs;"
, " Cons x (Cons y Nil) => 2;"
, "};"
]
correct2 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " non_empty => 1;"
, "};"
]
run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines
ok = \case
Ok _ -> True
Bad _ -> False

113
tests/TestTypeCheckerHm.hs Normal file
View file

@ -0,0 +1,113 @@
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE QualifiedDo #-}
module TestTypeCheckerHm (testTypeCheckerHm) where
import Control.Monad ((<=<))
import qualified DoStrings as D
import Grammar.Par (myLexer, pProgram)
import Prelude (Bool (..), Either (..), IO, fmap,
not, ($), (.))
import Test.Hspec
-- import Test.QuickCheck
import TypeChecker.TypeCheckerHm (typecheck)
testTypeCheckerHm = describe "Hillner Milner type checker test" $ do
ok1
ok2
bad1
bad2
-- bad3
ok1 =
specify "Basic polymorphism with multiple type variables" $
run
( D.do
const
"main = const 'a' 65 ;"
)
`shouldSatisfy` ok
ok2 =
specify "Head with a correct signature is accepted" $
run
( D.do
list
headSig
head
)
`shouldSatisfy` ok
bad1 =
specify "Infinite type unification should not succeed" $
run
( D.do
"main = \\x. x x ;"
)
`shouldSatisfy` bad
bad2 =
specify "Pattern matching using different types should not succeed" $
run
( D.do
list
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
`shouldSatisfy` bad
bad3 =
specify "Using a concrete function on a skolem variable should not succeed" $
run
( D.do
bool
_not
"f : a -> Bool () ;"
" f x = not x ;"
)
`shouldSatisfy` bad
run = typecheck <=< pProgram . myLexer
ok (Right _) = True
ok (Left _) = False
bad = not . ok
-- FUNCTIONS
const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
list = D.do
"data List (a) where"
" {"
" Nil : List (a)"
" Cons : a -> List (a) -> List (a)"
" };"
headSig = D.do
"head : List (a) -> a ;"
head = D.do
"head xs = "
" case xs of {"
" Cons x xs => x ;"
" };"
bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
"};"
_not = D.do
"not : Bool () -> Bool () ;"
"not x = case x of {"
" True => False ;"
" False => True ;"
"};"

10
tests/Tests.hs Normal file
View file

@ -0,0 +1,10 @@
module Main where
import Test.Hspec
import TestTypeCheckerBidir (testTypeCheckerBidir)
import TestTypeCheckerHm (testTypeCheckerHm)
main = hspec $ do
testTypeCheckerBidir
testTypeCheckerHm