Add bidirectional type checker, lambda lifter.

This commit is contained in:
Martin Fredin 2023-02-18 14:49:33 +01:00
parent 2fa30faa87
commit ac3f222753
22 changed files with 2440 additions and 577 deletions

View file

@ -3,94 +3,94 @@
-- * PROGRAM
-------------------------------------------------------------------------------
Program. Program ::= [Def] ;
Program. Program ::= [Def];
-------------------------------------------------------------------------------
-- * TOP-LEVEL
-------------------------------------------------------------------------------
DBind. Def ::= Bind ;
DSig. Def ::= Sig ;
DData. Def ::= Data ;
DBind. Def ::= Bind;
DSig. Def ::= Sig;
DData. Def ::= Data;
Sig. Sig ::= LIdent ":" Type ;
Bind. Bind ::= LIdent [LIdent] "=" Exp ;
Sig. Sig ::= LIdent ":" Type;
Bind. Bind ::= LIdent [LIdent] "=" Exp;
-------------------------------------------------------------------------------
-- * TYPES
-- * Types
-------------------------------------------------------------------------------
TLit. Type2 ::= UIdent ;
TVar. Type2 ::= TVar ;
TAll. Type1 ::= "forall" TVar "." Type ;
TData. Type1 ::= UIdent "(" [Type] ")" ;
internal TEVar. Type1 ::= TEVar ;
TFun. Type ::= Type1 "->" Type ;
TLit. Type1 ::= UIdent; -- τ
TVar. Type1 ::= TVar; -- α
internal TEVar. Type1 ::= TEVar; -- ά
TData. Type1 ::= UIdent "(" [Type] ")"; -- D ()
TFun. Type ::= Type1 "->" Type; -- A A
TAll. Type ::= "forall" TVar "." Type; -- α. A
MkTVar. TVar ::= LIdent ;
internal MkTEVar. TEVar ::= LIdent ;
MkTVar. TVar ::= LIdent;
internal MkTEVar. TEVar ::= LIdent;
-------------------------------------------------------------------------------
-- * DATA TYPES
-------------------------------------------------------------------------------
Constructor. Constructor ::= UIdent ":" Type ;
Data. Data ::= "data" Type "where" "{" [Inj] "}" ;
Data. Data ::= "data" Type "where" "{" [Constructor] "}" ;
Inj. Inj ::= UIdent ":" Type ;
separator nonempty Inj " " ;
-------------------------------------------------------------------------------
-- * EXPRESSIONS
-- * Expressions
-------------------------------------------------------------------------------
EAnn. Exp4 ::= "(" Exp ":" Type ")" ;
EVar. Exp3 ::= LIdent ;
EInj. Exp3 ::= UIdent ;
ELit. Exp3 ::= Lit ;
EApp. Exp2 ::= Exp2 Exp3 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ELet. Exp ::= "let" Bind "in" Exp ;
EAbs. Exp ::= "\\" LIdent "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Branch] "}";
EAnn. Exp4 ::= "(" Exp ":" Type ")";
EVar. Exp3 ::= LIdent;
EInj. Exp3 ::= UIdent;
ELit. Exp3 ::= Lit;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
ELet. Exp ::= "let" Bind "in" Exp;
EAbs. Exp ::= "\\" LIdent "." Exp;
ECase. Exp ::= "case" Exp "of" "{" [Branch] "}";
-------------------------------------------------------------------------------
-- * LITERALS
-------------------------------------------------------------------------------
LInt. Lit ::= Integer ;
LChar. Lit ::= Char ;
LInt. Lit ::= Integer;
LChar. Lit ::= Character;
-------------------------------------------------------------------------------
-- * CASE
-- * PATTERN MATCHING
-------------------------------------------------------------------------------
Branch. Branch ::= Pattern "=>" Exp ;
PVar. Pattern1 ::= LIdent ;
PLit. Pattern1 ::= Lit ;
PCatch. Pattern1 ::= "_" ;
PEnum. Pattern1 ::= UIdent ;
PInj. Pattern ::= UIdent [Pattern1] ;
PVar. Pattern1 ::= LIdent;
PLit. Pattern1 ::= Lit;
PCatch. Pattern1 ::= "_";
PEnum. Pattern1 ::= UIdent;
PInj. Pattern ::= UIdent [Pattern1];
-------------------------------------------------------------------------------
-- * AUX
-------------------------------------------------------------------------------
terminator Def ";" ;
separator nonempty Constructor "" ;
separator Type " " ;
separator nonempty Pattern1 " " ;
terminator Def ";";
terminator Branch ";" ;
separator Ident " ";
separator LIdent " ";
separator TVar " " ;
coercions Exp 4 ;
coercions Type 2 ;
coercions Pattern 1 ;
separator LIdent "";
separator Type " ";
separator TVar " ";
separator nonempty Pattern1 " ";
coercions Pattern 1;
coercions Exp 4;
coercions Type 1 ;
token Character '\''(char)'\'' ;
token UIdent (upper (letter | digit | '_')*) ;
token LIdent (lower (letter | digit | '_')*) ;
comment "--" ;
comment "{-" "-}" ;
comment "--";
comment "{-" "-}";

View file

@ -31,13 +31,18 @@ executable language
Grammar.Skel
Grammar.ErrM
Auxiliary
Renamer.Renamer
TypeChecker.TypeChecker
TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir
TypeChecker.TypeCheckerIr
TypeChecker.RemoveTEVar
LambdaLifter
Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr
Renamer.Renamer
Codegen.Codegen
Codegen.LlvmIr
Compiler
hs-source-dirs: src
@ -60,6 +65,9 @@ Test-suite language-testsuite
main-is: Tests.hs
other-modules:
TestTypeCheckerBidir
TestTypeCheckerHm
Grammar.Abs
Grammar.Lex
Grammar.Par
@ -67,9 +75,11 @@ Test-suite language-testsuite
Grammar.Skel
Grammar.ErrM
Auxiliary
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir
TypeChecker.RemoveTEVar
TypeChecker.TypeCheckerIr
Compiler
hs-source-dirs: src, tests, tests/TypecheckingHM
@ -87,3 +97,4 @@ Test-suite language-testsuite
, bytestring
default-language: GHC2021

11
sample-programs/basic-0 Normal file
View file

@ -0,0 +1,11 @@
data forall a. List (a) where {
Nil : List (a)
Cons : a -> List (a) -> List (a)
};
length : forall c. List (c) -> Int;
length = \list. case list of {
Nil => 0;
Cons x xs => 1 + length xs;
Cons x (Cons y Nil) => 2;
};

View file

@ -3,3 +3,4 @@ add x = \y. x+y;
main : Int ;
main = (\z. z+z) ((add 4) 6) ;

121
spec.txt Normal file
View file

@ -0,0 +1,121 @@
---------------------------------------------------------------------------
-- * Parser
---------------------------------------------------------------------------
data Program = Program [Def]
data Def = DSig Ident Type | DBind Bind
data Bind = Bind Ident [Ident] Exp
data Exp
= EId Ident
| ELit Lit
| EAnn Exp Type
| ELet Ident Exp Exp
| EApp Exp Exp
| EAdd Exp Exp
| EAbs Ident Exp
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
| TEVar TEVar -- ά (internal)
data TVar = MkTVar Ident
data TEVar = MkTEVar Ident
---------------------------------------------------------------------------
-- * Type checker
---------------------------------------------------------------------------
-- • Def and DSig are removed in favor on just Bind
-- • Typed expressions
-- • TEVar is removed (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
data TVar = MkTVar Ident
---------------------------------------------------------------------------
-- * Lambda lifter
---------------------------------------------------------------------------
-- • EAbs are removed (NOT IMPLEMENTED)
-- • ELet only allow constant expressions (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Ident ExpT ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type
= TLit Ident -- τ
| TVar TVar -- α
| TFun Type Type -- A → A
| TAll TVar Type -- ∀α. A
data TVar = MkTVar Ident
---------------------------------------------------------------------------
-- * Monomorpher
---------------------------------------------------------------------------
-- • Polymorphic types are removed (NOT IMPLEMENTED)
newtype Program = Program [Bind]
data Bind = Bind Id [Id] ExpT
data Exp
= EId Ident
| ELit Lit
| ELet Ident ExpT ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
type Id = (Ident, Type)
type ExpT = (Exp, Type)
data Lit = LInt Integer
| LChar Character
data Type = Type Ident

View file

@ -3,6 +3,7 @@ module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError)
import Data.Either.Combinators (maybeToRight)
import TypeChecker.TypeCheckerIr (Type (TFun))
snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x]
@ -19,3 +20,4 @@ mapAccumM f = go
(acc', x') <- f acc x
(acc'', xs') <- go acc' xs
pure (acc'', x':xs')

View file

@ -17,6 +17,7 @@ import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace)
import qualified Grammar.Abs as GA
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr (Ident (..))
import Monomorphizer.MonomorphizerIr as MIR
-- | The record used as the code generator state
@ -57,8 +58,13 @@ getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
<<<<<<< HEAD
getNewVar :: CompilerState GA.Ident
getNewVar = GA.Ident . show <$> (increaseVarCount >> getVarCount)
=======
getNewVar :: CompilerState Ident
getNewVar = (Ident . show) <$> (increaseVarCount >> getVarCount)
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
@ -76,10 +82,25 @@ getFunctions bs = Map.fromList $ go bs
go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args})
: go xs
<<<<<<< HEAD
go (_ : xs) = go xs
=======
go (MIR.DData (MIR.Data n cons) : xs) =
do map
( \(Inj id xs) ->
( (coerce id, MIR.TLit (extractTypeName n))
, FunctionInfo
{ numArgs = undefined -- TODO
, arguments = createArgs (snd <$> undefined) -- TODO
}
)
)
cons
<> go xs
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
createArgs :: [MIR.Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
@ -89,6 +110,7 @@ getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t cons) : xs) =
<<<<<<< HEAD
fst
( foldl
( \(acc, i) (Constructor id xs) ->
@ -96,6 +118,17 @@ getConstructors bs = Map.fromList $ go bs
, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs)
=======
do
let (Ident n) = extractTypeName t
fst
( foldl
( \(acc, i) (Inj (Ident id) xs) ->
( ( (Ident (n <> "_" <> id), MIR.TLit (coerce n))
, ConstructorInfo
{ numArgsCI = undefined -- TODO
, argumentsCI = createArgs (snd <$> undefined) -- TODO
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
, numCI = i
, returnTypeCI = t --last . flattenType $ xs
}
@ -133,30 +166,30 @@ test :: Integer -> Program
test v =
Program
[ DataType
(GA.Ident "Craig")
[ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")]
, Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")]
(Ident "Craig")
[ Constructor (Ident "Bob") [MIR.Type (Ident "_Int")]
, Constructor (Ident "Betty") [MIR.Type (Ident "_Int")]
]
, DataType
(GA.Ident "Alice")
[ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- ,
-- (GA.Ident "Alice", [TInt, TInt])
(Ident "Alice")
[ Constructor (Ident "Eve") [MIR.Type (Ident "_Int")] -- ,
-- (Ident "Alice", [TInt, TInt])
]
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig"))
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) []
-- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92)
, Bind (Ident "fibonacci", MIR.Type (Ident "_Int")) [(Ident "x", MIR.Type (Ident "_Int"))] (EVar ("x", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig"))
, Bind (Ident "main", MIR.Type (Ident "_Int")) []
-- (EApp (MIR.Type (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))-- (EInt 92)
$
eCaseInt
(EApp (MIR.TLit (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.TLit (GA.Ident "Craig")), MIR.TLit (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int"))
(EApp (MIR.TLit (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.TLit (Ident "Craig")), MIR.TLit (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (Ident "x")] (EVar (Ident "x", MIR.Type (Ident "_Int")), MIR.Type (Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (GA.Ident "z")) (int 3)
, Injection (CIdent (Ident "z")) (int 3)
, -- , injectionInt 5 (int 6)
injectionCatchAll (int 10)
]
]
where
injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs)
injectionCons x y xs = Injection (CCons (Ident x, MIR.Type (Ident y)) xs)
injectionInt x = Injection (CLit (LInt x))
injectionCatchAll = Injection CatchAll
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
@ -206,7 +239,7 @@ compileScs [] = do
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
enumerateOneM_
( \i (GA.Ident arg_n, arg_t) -> do
( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar
@ -222,7 +255,7 @@ compileScs [] = do
I32
(VInteger i)
)
emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
@ -255,8 +288,13 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let biggestVariant = 7 + maximum (sum . (\(Constructor _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (Ident outer_id) [I8, Array biggestVariant I8]
mapM_
<<<<<<< HEAD
( \(Constructor inner_id fi) -> do
emit $ LIR.Type inner_id (I8 : variantTypes fi)
=======
( \(Inj (Ident inner_id) fi) -> do
emit $ LIR.Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType (snd <$> undefined)) -- TODO
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
)
ts
compileScs xs
@ -282,17 +320,17 @@ mainContent var =
-- " %4 = load i72, ptr %3\n" <>
-- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n"
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2")
-- , Label (GA.Ident "b_1")
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (GA.Ident "end")
-- , Label (GA.Ident "b_2")
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (GA.Ident "end")
-- , Label (GA.Ident "end")
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
@ -310,7 +348,7 @@ compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit,t) = emitLit lit
compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (MIR.EId name,t) = emitIdent name
compileExp (MIR.EVar name,t) = emitIdent name
compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e)
@ -328,7 +366,7 @@ emitECased t e cases = do
let rt = type2LlvmType (snd e)
vs <- exprToValue e
lbl <- getNewLabel
let label = GA.Ident $ "escape_" <> show lbl
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs
@ -341,13 +379,13 @@ emitECased t e cases = do
res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr)
where
emitCases :: LLVMType -> LLVMType -> GA.Ident -> GA.Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, t) exp) = do
cons <- gets constructors
let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -397,8 +435,8 @@ emitECased t e cases = do
(MIR.LInt i, _) -> VInteger i
(MIR.LChar i, _) -> VChar i
ns <- getNewVar
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq ty vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
@ -444,8 +482,13 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
<<<<<<< HEAD
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EId name, t) -> do
=======
(MIR.EApp e1' e2', t) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
@ -462,7 +505,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitIdent :: GA.Ident -> CompilerState ()
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
@ -477,14 +520,14 @@ emitLit i = do
(MIR.LChar i'') -> (VChar i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0))
emit $ SetVariable (Ident (show varCount)) (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2)
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitSub t e1 e2 = do
@ -498,7 +541,7 @@ exprToValue = \case
(MIR.ELit i, t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar i
(MIR.EId name, t) -> do
(MIR.EVar name, t) -> do
funcs <- gets functions
case Map.lookup (name, t) funcs of
Just fi -> do
@ -515,7 +558,7 @@ exprToValue = \case
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (GA.Ident $ show v) (getType e)
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(Ident name)) = case name of
@ -558,3 +601,4 @@ typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -11,8 +11,9 @@ module Codegen.LlvmIr (
ToIr(..)
) where
import Data.List (intercalate)
import Grammar.Abs (Ident (..))
import Data.List (intercalate)
import Grammar.Abs (Character)
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show
instance ToIr CallingConvention where
@ -87,7 +88,7 @@ instance ToIr Visibility where
-- or a string contstant
data LLVMValue
= VInteger Integer
| VChar Char
| VChar Character
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType

242
src/LambdaLifter.hs Normal file
View file

@ -0,0 +1,242 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.List (partition)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
where
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
(binds, datatypes) = partition isBind defs
isBind = \case
DBind _ -> True
_ -> False
-- | Annotate free variables
freeVars :: [Bind] -> AnnBinds
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
| Bind n xs e <- binds
]
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
freeVarsExp localVars (exp, t) = case exp of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
| otherwise -> (mempty, (AVar n, t))
ELit lit -> (mempty, (ALit lit, t))
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind (name, t_bind) parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
-- AST annotated with free variables
type AnnBinds = [(Id, [Id], AnnExpT)]
type AnnExpT = (Set Ident, AnnExpT')
data ABind = ABind Id [Id] AnnExpT deriving Show
type AnnExpT' = (AnnExp, Type)
data AnnExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet ABind AnnExpT
| AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnBinds -> [Bind]
abstract prog = evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExpT) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
go ((free, (e, t)), acc)
| AAbs par (free1, e1) <- e
, TFun t_par _ <- t
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
| otherwise = ((free, (e, t)), acc)
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ)
ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
skipLambdas f (free, (ae, t)) = case ae of
AAbs par ae1 -> do
ae1' <- skipLambdas f ae1
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
-- Lift lambda into let and bind free variables
AAbs parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
pure $ foldl applyVars sc freeList
where
freeList = Set.toList free
vars = zip names $ getVars typ
names = snoc parm freeList
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
where
(t_var, t_return) = applyVarType t
applyVarType :: Type -> (Type, Type)
applyVarType typ = (t1, foldr ($) t2 foralls)
where
(t1, t2) = case typ' of
TFun t1 t2 -> (t1, t2)
_ -> error "Not a function!"
(foralls, typ') = skipForalls [] typ
skipForalls acc = \case
TAll tvar t -> skipForalls (snoc (TAll tvar) acc) t
t -> (acc, t)
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
else (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
flattenLambdas = go . (, [])
where
go ((e, t), acc) = case e of
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
where t_var = head $ getVars t
_ -> ((e, t), acc)
getVars :: Type -> [Type]
getVars = fst . partitionType
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)

View file

@ -1,66 +1,114 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module Main where
import Codegen.Codegen (generateCode)
import Data.Bool (bool)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Monomorphizer.Monomorphizer (monomorphize)
import Control.Monad (when)
import Data.Bool (bool)
import Data.List.Extra (isSuffixOf)
import Compiler (compile)
import Renamer.Renamer (rename)
import Data.Maybe (fromJust, isNothing)
import GHC.IO.Handle.Text (hPutStrLn)
import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder),
OptDescr (Option), getOpt,
usageInfo)
import System.Directory (createDirectory, doesPathExist,
getDirectoryContents,
removeDirectoryRecursive,
setCurrentDirectory)
import System.Environment (getArgs)
import System.Exit (ExitCode, exitFailure,
exitSuccess)
import System.Exit (ExitCode (ExitFailure),
exitFailure, exitSuccess,
exitWith)
import System.IO (stderr)
import System.Process.Extra (readCreateProcess, shell,
spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (typecheck)
import Codegen.Codegen (generateCode)
import Compiler (compile)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize)
import Renamer.Renamer (rename)
import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
main :: IO ()
main =
getArgs >>= \case
[] -> putStrLn "Required file path missing"
["-d", s] -> do
when (".crf" `isSuffixOf` s) (main' True s)
putStrLn $ "File '" ++ s ++ "' is not a churf file"
[s] -> do
when (".crf" `isSuffixOf` s) (main' False s)
putStrLn $ "File '" ++ s ++ "' is not a churf file"
xs -> putStrLn $ "Can't process: " ++ unwords xs
main = getArgs >>= parseArgs >>= uncurry main'
main' :: Bool -> String -> IO ()
main' debug s = do
parseArgs :: [String] -> IO (Options, String)
parseArgs argv = case getOpt RequireOrder flags argv of
(os, f:_, [])
| opts.help || isNothing opts.typechecker -> do
hPutStrLn stderr (usageInfo header flags)
exitSuccess
| otherwise -> pure (opts, f)
where
opts = foldr ($) initOpts os
(_, _, errs) -> do
hPutStrLn stderr (concat errs ++ usageInfo header flags)
exitWith (ExitFailure 1)
where
header = "Usage: language [--help] [-d|--debug] [-t|type-checker bi/hm] FILE \n"
flags :: [OptDescr (Options -> Options)]
flags =
[ Option ['d'] ["debug"] (NoArg enableDebug) "Print debug messages."
, Option ['t'] ["type-checker"] (ReqArg chooseTypechecker "bi/hm") "Choose type checker. Possible options are bi and hm"
, Option [] ["help"] (NoArg enableHelp) "Print this help message"
]
initOpts :: Options
initOpts = Options { help = False
, debug = False
, typechecker = Nothing
}
enableHelp :: Options -> Options
enableHelp opts = opts { help = True }
enableDebug :: Options -> Options
enableDebug opts = opts { debug = True }
chooseTypechecker :: String -> Options -> Options
chooseTypechecker s options = options { typechecker = tc }
where
tc = case s of
"hm" -> pure Hm
"bi" -> pure Bi
_ -> Nothing
data Options = Options
{ help :: Bool
, debug :: Bool
, typechecker :: Maybe TypeChecker
}
main' :: Options -> String -> IO ()
main' opts s = do
file <- readFile s
printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file
bool (printToErr $ printTree parsed) (printToErr $ show parsed) debug
bool (printToErr $ printTree parsed) (printToErr $ show parsed) opts.debug
printToErr "\n-- Renamer --"
renamed <- fromRenamerErr . rename $ parsed
bool (printToErr $ printTree renamed) (printToErr $ show renamed) debug
bool (printToErr $ printTree renamed) (printToErr $ show renamed) opts.debug
printToErr "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) debug
typechecked <- fromTypeCheckerErr $ typecheck (fromJust opts.typechecker) renamed
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) opts.debug
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
-- printToErr "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted
--
--printToErr "\n -- Compiler --"
printToErr "\n -- Compiler --"
generatedCode <- fromCompilerErr $ generateCode (monomorphize typechecked)
--putStrLn generatedCode

View file

@ -1,12 +1,13 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Monomorphizer.Monomorphizer (monomorphize) where
import Data.Coerce (coerce)
import Data.Coerce (coerce)
import Monomorphizer.MonomorphizerIr qualified as M
import TypeChecker.TypeCheckerIr qualified as T
import qualified Monomorphizer.MonomorphizerIr as M
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ident (..))
monomorphize :: T.Program -> M.Program
monomorphize (T.Program ds) = M.Program $ monoDefs ds
@ -16,40 +17,40 @@ monoDefs = map monoDef
monoDef :: T.Def -> M.Def
monoDef (T.DBind bind) = M.DBind $ monoBind bind
monoDef (T.DData d) = M.DData $ monoData d
--monoDef (T.DData d) = M.DData $ monoData d
monoBind :: T.Bind -> M.Bind
monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t)
monoData :: T.Data -> M.Data
monoData (T.Data (T.Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs)
--monoData :: T.Data -> M.Data
--monoData (T.Data (Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs)
monoConstructor :: T.Constructor -> M.Constructor
monoConstructor (T.Constructor (T.Ident i) t) = M.Constructor (M.Ident i) (monoType t)
monoConstructor :: T.Inj -> M.Inj
monoConstructor (T.Inj (Ident i) t) = M.Inj (M.Ident i) (monoType t)
monoExpr :: T.Exp -> M.Exp
monoExpr = \case
T.EId (T.Ident i) -> M.EId (M.Ident i)
T.ELit lit -> M.ELit $ monoLit lit
T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt)
T.EVar (Ident i) -> M.EVar (M.Ident i)
T.ELit lit -> M.ELit $ monoLit lit
T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt)
T.EApp expt1 expt2 -> M.EApp (monoexpt expt1) (monoexpt expt2)
T.EAdd expt1 expt2 -> M.EAdd (monoexpt expt1) (monoexpt expt2)
T.EAbs _i _expt -> error "BUG"
T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs)
T.EAbs _i _expt -> error "BUG"
T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs)
monoAbsType :: T.Type -> M.Type
monoAbsType (T.TLit u) = M.TLit (coerce u)
monoAbsType (T.TVar _v) = M.TLit "Int"
monoAbsType (T.TLit u) = M.TLit (coerce u)
monoAbsType (T.TVar _v) = M.TLit "Int"
monoAbsType (T.TAll _v _t) = error "NOT ALL TYPES"
monoAbsType (T.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2)
monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES"
monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES"
monoType :: T.Type -> M.Type
monoType (T.TAll _ t) = monoType t
monoType (T.TAll _ t) = monoType t
monoType (T.TVar (T.MkTVar i)) = M.TLit "Int"
monoType (T.TLit (T.Ident i)) = M.TLit (M.Ident i)
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
monoType (T.TData (T.Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t))
monoType (T.TLit (Ident i)) = M.TLit (M.Ident i)
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
monoType (T.TData (Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t))
monoexpt :: T.ExpT -> M.ExpT
monoexpt (e, t) = (monoExpr e, monoType t)
@ -58,19 +59,19 @@ monoId :: T.Id -> M.Id
monoId (n, t) = (coerce n, monoType t)
monoLit :: T.Lit -> M.Lit
monoLit (T.LInt i) = M.LInt i
monoLit (T.LInt i) = M.LInt i
monoLit (T.LChar c) = M.LChar c
monoInjs :: [T.Branch] -> [M.Branch]
monoInjs = map monoInj
monoInj :: T.Branch -> M.Branch
monoInj (T.Branch (init, t) expt) = M.Branch (monoInit init, monoType t) (monoexpt expt)
monoInj (T.Branch (patt, t) expt) = M.Branch (monoPattern patt, monoType t) (monoexpt expt)
monoInit :: T.Pattern -> M.Pattern
monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t)
monoInit (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t)
monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps)
monoPattern :: T.Pattern -> M.Pattern
monoPattern (T.PVar (id, t)) = M.PVar (id, monoType t)
monoPattern (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t)
monoPattern (T.PInj id ps) = M.PInj (coerce id) (map monoPattern ps)
-- DO NOT DO THIS FOR REAL THOUGH
monoInit (T.PEnum (T.Ident i)) = M.PInj (M.Ident i) []
monoInit T.PCatch = M.PCatch
monoPattern (T.PEnum (Ident i)) = M.PInj (M.Ident i) []
monoPattern T.PCatch = M.PCatch

View file

@ -11,14 +11,14 @@ newtype Program = Program [Def]
data Def = DBind Bind | DData Data
deriving (Show, Ord, Eq)
data Data = Data Type [Constructor]
data Data = Data Type [Inj]
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
deriving (Show, Ord, Eq)
data Exp
= EId Ident
= EVar Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
@ -35,12 +35,12 @@ data Branch = Branch (Pattern, Type) ExpT
type ExpT = (Exp, Type)
data Constructor = Constructor Ident Type
data Inj = Inj Ident Type
deriving (Show, Ord, Eq)
data Lit
= LInt Integer
| LChar Char
| LChar Character
deriving (Show, Ord, Eq)
data Type = TLit Ident | TFun Type Type

View file

@ -1,131 +1,124 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM)
import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (
MonadState,
StateT,
evalStateT,
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT)
import Control.Monad.State (MonadState, State, evalState, gets,
mapAndUnzipM, modify)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
-- | Rename all variables and local binds
rename :: Program -> Either String Program
rename :: Program -> Err Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (TData cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
renameConstr new_types (Constructor name typ) =
Constructor name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
data Cxt = Cxt { var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
type Names = Map String String
renameDefs :: [Def] -> Err [Def]
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
where
initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNamesL initNames vars
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name vars' rhs'
DData (Data typ injs) -> do
tvars <- collectTVars [] typ
tvars' <- mapM nextNameTVar tvars
let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs'
where
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar:tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ show typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar | Just tvar' <- lookup tvar new_names
-> TVar tvar'
| otherwise
-> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | Just tvar' <- lookup tvar new_names
-> TAll tvar' $ substitute' t
| otherwise
-> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n)
ELit lit -> pure (old_names, ELit lit)
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet bind e -> do
(new_names, bind') <- renameBind old_names bind
(new_names', e') <- renameExp new_names e
pure (new_names', ELet bind' e')
EAbs par e -> do
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e')
ELet (Bind name vars rhs) e -> do
(new_names, name') <- newNameL old_names name
(new_names', vars') <- newNamesL new_names vars
(new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
(new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
@ -137,26 +130,23 @@ renameExp old_names = \case
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
(new_names, xs') <- mapAndUnzipM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do
(new_names, init') <- renamePattern ns init
renameBranch ns (Branch patt e) = do
(new_names, patt') <- renamePattern ns patt
(new_names', e') <- renameExp new_names e
return (new_names', Branch init' e')
return (new_names', Branch patt' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of
renamePattern ns p = case p of
PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps
return (ns_new, PInj cs ps)
rest -> return (ns, rest)
(ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps')
PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
@ -167,44 +157,57 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute :: TVar -- α
-> TVar -- α_n
-> Type -- A
-> Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
TLit _ -> typ
TVar tvar | tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL
-- | Create a new name and add it to name environment.
newNameL :: Names -> LIdent -> Rn (Names, LIdent)
newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
-- | Create a new name and add it to name environment.
newNameU :: Names -> UIdent -> Rn (Names, UIdent)
newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
makeName :: String -> Rn String
makeName prefix = do
i <- gets var_counter
let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt { var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar
nextNameTVar (MkTVar (LIdent s))= do
i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter}
pure tvar

206
src/Renamer/RenamerOld.hs Normal file
View file

@ -0,0 +1,206 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM)
import Control.Monad.Except (ExceptT, MonadError, runExceptT,
throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (MonadState, StateT, evalStateT, gets,
modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (TData cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Inj -> Inj
renameConstr new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet bind e -> do
(new_names, bind') <- renameBind old_names bind
(new_names', e') <- renameExp new_names e
pure (new_names', ELet bind' e')
EAbs par e -> do
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do
(new_names, init') <- renamePattern ns init
(new_names', e') <- renameExp new_names e
return (new_names', Branch init' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of
PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps
return (ns_new, PInj cs ps)
rest -> return (ns, rest)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -0,0 +1,73 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.RemoveTEVar where
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Arrow (Arrow (second))
import Control.Monad.Error (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Tuple.Extra (secondM)
import Grammar.Abs
import Grammar.ErrM (Err)
import qualified TypeChecker.TypeCheckerIr as T
class RemoveTEVar a b where
rmTEVar :: a -> Err b
instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where
rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs
instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where
rmTEVar = \case
T.DBind bind -> T.DBind <$> rmTEVar bind
T.DData dat -> T.DData <$> rmTEVar dat
instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where
rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs)
instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where
rmTEVar exp = case exp of
T.EVar name -> pure $ T.EVar name
T.EInj name -> pure $ T.EInj name
T.ELit lit -> pure $ T.ELit lit
T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e)
T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
T.EAdd e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
T.EAbs name e -> T.EAbs name <$> rmTEVar e
T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches)
instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where
rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e)
instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where
rmTEVar = \case
T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t
T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t
T.PCatch -> pure T.PCatch
T.PEnum name -> pure $ T.PEnum name
T.PInj name ps -> T.PInj name <$> rmTEVar ps
instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where
rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs)
instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where
rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ
instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where
rmTEVar = secondM rmTEVar
instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where
rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ)
instance RemoveTEVar a b => RemoveTEVar [a] [b] where
rmTEVar = mapM rmTEVar
instance RemoveTEVar Type T.Type where
rmTEVar = \case
TLit lit -> pure $ T.TLit (coerce lit)
TVar tvar -> pure $ T.TVar tvar
TData name typs -> T.TData (coerce name) <$> rmTEVar typs
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
TAll tvar t -> T.TAll tvar <$> rmTEVar t
TEVar _ -> throwError "NewType TEVar!"

View file

@ -0,0 +1,858 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM,
zipWithM_)
import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (intercalate)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing)
import Data.Sequence (Seq (..))
import qualified Data.Sequence as S
import Data.Tuple.Extra (second, secondM)
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.ErrM
import Grammar.Print (printTree)
import Prelude hiding (exp, id)
import qualified TypeChecker.TypeCheckerIr as T
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α
| EnvTEVar TEVar -- ^ Existential unsolved type variable. ά
| EnvTEVarSolved TEVar Type -- ^ Existential solved type variable. ά = τ
| EnvMark TEVar -- ^ Scoping Marker. ▶ ά
deriving (Eq, Show)
type Env = Seq EnvElem
-- | Ordered context
-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A
data Cxt = Cxt
{ env :: Env -- ^ Local scope context Γ
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K
} deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String)
typecheck :: Program -> Err (T.Program' Type)
typecheck (Program defs) = do
datatypes <- mapM typecheckDataType [ d | DData d <- defs ]
let initCxt = Cxt
{ env = mempty
, sig = Map.fromList [ (name, t)
| DSig' name t <- defs
]
, binds = Map.fromList [ (name, foldr EAbs rhs vars)
| DBind' name vars rhs <- defs
]
, next_tevar = 0
, data_injs = Map.fromList [ (name, typ)
| Data _ injs <- datatypes
, Inj name typ <- injs
]
}
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt;
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds'
where
binds = [ b | DBind b <- defs ]
coerceData = map (\(Data t injs) -> T.Data t $ map
(\(Inj name typ) -> T.Inj (coerce name) typ) injs)
typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do
bind' <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
---------------------------
-- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ
Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) (coerce vars') (rhs', t))
where
vars' = zip vars $ getVars t
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------
-- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ
Nothing -> do
(e, t) <- infer $ foldr EAbs rhs vars
t' <- applyEnv t
e' <- applyEnvExp e
let rhs' = skipLambdas (length vars) e'
vars' = zip vars $ getVars t'
pure (T.Bind (coerce name, t') (coerce vars') (rhs', t'))
env <- gets env
unless (isComplete env) err
putEnv Empty
pure bind'
where
err = throwError $ unlines
[ "Type inference failed: " ++ printTree (Bind name vars rhs)
, "Did you forget to add type annotation to a polymorphic function?"
]
typecheckDataType :: Data -> Err Data
typecheckDataType (Data typ injs) = do
(name, tvars) <- go [] typ
injs' <- mapM (\i -> typecheckInj i name tvars) injs
pure (Data typ injs')
where
go tvars = \case
TAll tvar t -> go (tvar:tvars) t
TData name typs
| Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars'
-> pure (name, tvars')
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj
typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ
= throwError "Unbound type variables"
| TData name' typs <- getReturn inj_typ
, name' == name
, Right tvars' <- mapM toTVar typs
, tvars' == tvars
= pure (Inj inj_name $ foldr TAll inj_typ tvars)
| otherwise
= throwError $ unwords
["Bad type constructor: ", show name
, "\nExpected: ", ppT . TData name $ map TVar tvars
, "\nActual: ", ppT $ getReturn inj_typ
]
where
boundTVars :: [TVar] -> Type -> Bool
boundTVars tvars' = \case
TAll tvar t -> boundTVars (tvar:tvars') t
TFun t1 t2 -> on (&&) (boundTVars tvars') t1 t2
TVar tvar -> elem tvar tvars'
TData _ typs -> all (boundTVars tvars) typs
TLit _ -> True
TEVar _ -> error "TEVar in data type declaration"
---------------------------------------------------------------------------
-- * Subtyping rules
---------------------------------------------------------------------------
-- | Γ ⊢ A <: B ⊣ Δ
-- Under input context Γ, type A is a subtype of B, with output context ∆
subtype :: Type -> Type -> Tc ()
subtype t1 t2 = case (t1, t2) of
(TLit lit1, TLit lit2) | lit1 == lit2 -> pure ()
-- -------------------- <:Var
-- Γ[α] ⊢ α <: α ⊣ Γ[α]
(TVar tvar1, TVar tvar2) | tvar1 == tvar2 -> pure ()
-- -------------------- <:Exvar
-- Γ[ά] ⊢ ά <: ά ⊣ Γ[ά]
(TEVar tevar1, TEVar tevar2) | tevar1 == tevar2 -> pure ()
-- Γ ⊢ B₁ <: A₁ ⊣ Θ Θ ⊢ [Θ]A₂ <: [Θ]B₂ ⊣ Δ
-- ----------------------------------------- <:→
-- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ
(TFun a1 a2, TFun b1 b2) -> do
subtype b1 a1
a2' <- applyEnv a2
b2' <- applyEnv b2
subtype a2' b2'
-- Γ, α ⊢ A <: B ⊣ Δ,α
-- --------------------- <:∀R
-- Γ ⊢ A <: ∀α. B ⊣ Δ
(a, TAll tvar b) -> do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
subtype a b
dropTrailing env_tvar
-- Γ,▶ ά,ά ⊢ [ά/α]A <: B ⊣ Δ,▶ ά,Θ
-- ------------------------------- <:∀L
-- Γ ⊢ ∀α.A <: B ⊣ Δ
(TAll tvar a, b) -> do
tevar <- fresh
let env_marker = EnvMark tevar
env_tevar = EnvTEVar tevar
insertEnv env_marker
insertEnv env_tevar
let a' = substitute tvar tevar a
subtype a' b
dropTrailing env_marker
-- ά ∉ FV(A) Γ[ά] ⊢ ά :=< A ⊣ Δ
-- ------------------------------ <:instantiateL
-- Γ[ά] ⊢ ά <: A ⊣ Δ
(TEVar tevar, typ) | notElem tevar $ frees typ -> instantiateL tevar typ
-- ά ∉ FV(A) Γ[ά] ⊢ A =:< ά ⊣ Δ
-- ------------------------------ <:instantiateR
-- Γ[ά] ⊢ A <: ά ⊣ Δ
(typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar
(TData name1 typs1, TData name2 typs2)
-- D₁ = D₂
-- ----------------
-- Γ ⊢ D₁ () <: D₂ ()
| name1 == name2
, [] <- typs1
, [] <- typs2
-> pure ()
-- Γ ⊢ ά₁ <: έ₁ ⊣ Θ₁
-- ...
-- D₁ = D₂ Θₙ₋₁ ⊢ [Θₙ₋₁]άₙ <: [Θₙ₋₁]έₙ ⊣ Δ
-- -------------------------------------------
-- Γ ⊢ D (ά₁ ‥ άₙ) <: D (έ₁ ‥ έₙ) ⊣ Δ
| name1 == name2
, t1:t1s <- typs1
, t2:t2s <- typs2
-> do
subtype t1 t2
zipWithM_ go t1s t2s
where
go t1' t2' = do
t1'' <- applyEnv t1'
t2'' <- applyEnv t2'
subtype t1'' t2''
_ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"]
---------------------------------------------------------------------------
-- * Instantiation rules
---------------------------------------------------------------------------
-- | Γ ⊢ ά :=< A ⊣ Δ
-- Under input context Γ, instantiate ά such that ά <: A, with output context ∆
instantiateL :: TEVar -> Type -> Tc ()
instantiateL tevar typ = gets env >>= go
where
go env
-- Γ ⊢ τ
-- ----------------------------- InstLSolve
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
| isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
| TEVar tevar' <- typ = instReach tevar tevar'
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ =:< ά₁ ⊣ Θ Θ ⊢ ά₂ :=< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstLArr
-- Γ[ά] ⊢ ά :=< A₁ → A₂ ⊣ Δ
| TFun t1 t2 <- typ = do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar2
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateR t1 tevar1
instantiateL tevar2 =<< applyEnv t2
-- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ'
-- ------------------------- InstLAIIR
-- Γ[ά] ⊢ ά :=< ∀ε.Ε ⊣ Δ
| TAll tvar t <- typ = do
instantiateL tevar t
let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l
| otherwise = error $ "Trying to instantiateL: " ++ ppT (TEVar tevar)
++ " <: " ++ ppT typ
-- | Γ ⊢ A =:< ά ⊣ Δ
-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆
instantiateR :: Type -> TEVar -> Tc ()
instantiateR typ tevar = gets env >>= go
where
go env
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
| isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
| TEVar tevar' <- typ = instReach tevar tevar'
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstRArr
-- Γ[ά] ⊢ ά =:< A₁ → A₂ ⊣ Δ
| TFun t1 t2 <- typ = do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar2
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateL tevar1 t1
t2' <- applyEnv t2
instantiateR t2' tevar2
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
-- ---------------------------------- InstRAIIL
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
| TAll tvar t <- typ = do
tevar' <- fresh
insertEnv $ EnvMark tevar'
insertEnv $ EnvTVar tvar
let t' = substitute tvar tevar' t
instantiateR t' tevar
let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l
| otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: "
++ ppT (TEVar tevar)
-- ----------------------------- InstLReach
-- Γ[ά][έ] ⊢ ά :=< έ ⊣ Γ[ά][έ=ά]
--
-- ----------------------------- InstRReach
-- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά]
instReach :: TEVar -> TEVar -> Tc ()
instReach tevar tevar' = do
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar') . env)
let env_solved = EnvTEVarSolved tevar' $ TEVar tevar
putEnv $ (env_l :|> env_solved) <> env_r
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check exp typ
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
-- Γ ⊢ e ↑ ∀α.A ⊣ Δ
| TAll tvar t <- typ = do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
exp' <- check exp t
(env_l, _) <- gets (splitOn env_tvar . env)
putEnv env_l
pure exp'
-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ
-- --------------------------- →I
-- Γ ⊢ λx.e ↑ A → B ⊣ Δ
| EAbs name e <- exp
, TFun t1 t2 <- typ = do
let env_id = EnvVar name t1
insertEnv env_id
e' <- check e t2
(env_l, _) <- gets (splitOn env_id . env)
putEnv env_l
pure (T.EAbs (coerce name) e', typ)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption
where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
-- -------------------------------------- Sub
-- Γ ⊢ e ↑ B ⊣ Δ
subsumption = do
(exp', t) <- infer exp
exp'' <- applyEnvExp exp'
t' <- applyEnv t
typ' <- applyEnv typ
subtype t' typ'
pure (exp'', t')
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer = \case
ELit lit -> pure (T.ELit lit, inferLit lit)
-- (x : A) ∈ Γ
-- ------------- Var
-- Γ ⊢ x ↓ A ⊣ Γ
EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t
Nothing -> do
e <- maybeToRightM
("Unbound variable " ++ show name)
=<< lookupBind name
snd <$> infer e
pure (T.EVar (coerce name), t)
EInj name -> do
t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name
pure (T.EInj $ coerce name, t)
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- Anno
-- Γ ⊢ (e : A) ↓ A ⊣ Δ
EAnn e t -> do
_ <- gets $ (`wellFormed` t) . env
(e', _) <- check e t
pure (e', t)
-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
EApp e1 e2 -> do
(e1', t) <- infer e1
t' <- applyEnv t
e1'' <- applyEnvExp e1'
(e2', t'') <- apply t' e2
pure (T.EApp (e1'', t) e2', t'')
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ
EAbs name e -> do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVar tevar2
let env_id = EnvVar name (TEVar tevar1)
insertEnv env_id
e' <- check e $ TEVar tevar2
dropTrailing env_id
let t_exp = on TFun TEVar tevar1 tevar2
pure (T.EAbs (coerce name) e', t_exp)
-- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ
-- -------------------------------------------- LetI
-- Γ ⊢ let x=e in e' ↑ C ⊣ Δ
ELet (Bind name [] rhs) e -> do -- TODO vars
(rhs', t_rhs) <- infer rhs
let env_id = EnvVar name t_rhs
insertEnv env_id
(e', t) <- infer e
(env_l, _) <- gets (splitOn env_id . env)
putEnv env_l
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int
-- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> do
cxt <- get
let t = TLit "Int"
e1' <- check e1 t
put cxt
e2' <- check e2 t
pure (T.EAdd e1' e2', t)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
apply :: Type -> Exp -> Tc (T.ExpT' Type, Type)
apply typ exp = case typ of
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ
TAll tvar t -> do
tevar <- fresh
insertEnv $ EnvTEVar tevar
let t' = substitute tvar tevar t
apply t' exp
-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ
-- ------------------------------- άApp
-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ
TEVar tevar -> do
tevar1 <- fresh
tevar2 <- fresh
let env_tevar1 = EnvTEVar tevar1
env_tevar2 = EnvTEVar tevar2
t_fun = on TFun TEVar tevar1 tevar2
env_tevar_solved = EnvTEVarSolved tevar t_fun
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env)
putEnv $
(env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r
expT' <- check exp $ TEVar tevar1
pure (expT', TEVar tevar2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
TFun t1 t2 -> do
expt' <- check exp t1
pure (expt', t2)
_ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
checkBranch (Branch patt exp) t_patt t_exp = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
t_exp' <- applyEnv t_exp
(exp, t_exp) <- check exp t_exp'
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
PVar x -> do
insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, dummy), dummy) -- TODO
PCatch -> pure (T.PCatch, dummy) -- TODO
PLit lit | inferLit lit == t_patt -> let
t = inferLit lit
in
pure (T.PLit (lit, t), t)
| otherwise -> throwError "Literal in pattern have wrong type"
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
pure (T.PEnum (coerce name), dummy) -- TODO
PInj name ps -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
let (t_ps, t_return) = partitionTypeWithForall t
unless (length ps == length t_ps) $
throwError "Wrong number of variables"
subtype t_return t_patt
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
let ps'' = map fst ps' -- TODO
pure (T.PInj (coerce name) ps'', dummy)
---------------------------------------------------------------------------
-- * Auxiliary
---------------------------------------------------------------------------
frees :: Type -> [TEVar]
frees = \case
TLit _ -> []
TVar _ -> []
TEVar tevar -> [tevar]
TFun t1 t2 -> on (++) frees t1 t2
TAll _ t -> frees t
TData _ typs -> concatMap frees typs
-- | [ά/α]A
substitute :: TVar -- α
-> TEVar -- ά
-> Type -- A
-> Type -- [ά/α]A
substitute tvar tevar typ = case typ of
TLit _ -> typ
TVar tvar' | tvar' == tvar -> TEVar tevar
| otherwise -> typ
TEVar _ -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar' t -> TAll tvar' (substitute' t)
TData name typs -> TData name $ map substitute' typs
where
substitute' = substitute tvar tevar
-- | Γ,x,Γ' → (Γ, Γ')
splitOn :: EnvElem -> Env -> (Env, Env)
splitOn x env = second (S.drop 1) $ S.breakl (==x) env
-- | Drop frontmost elements until and including element @x@.
dropTrailing :: EnvElem -> Tc ()
dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyEnvExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do
id <- applyEnvId id
vars' <- mapM applyEnvId vars
rhs' <- applyEnvExpT rhs
exp' <- applyEnvExpT exp
pure $ T.ELet (T.Bind id vars' rhs') exp'
T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2)
T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2)
T.EAbs name e -> T.EAbs name <$> applyEnvExpT e
T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e)
(mapM applyEnvBranch branches)
_ -> pure exp
where
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvId = secondM applyEnv
applyEnvBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
e' <- applyEnvExpT e
pure $ T.Branch pt e'
applyEnvPattern = \case
T.PVar id -> T.PVar <$> applyEnvId id
T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t
T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps
p -> pure p
applyEnv :: Type -> Tc Type
applyEnv t = gets $ (`applyEnv'` t) . env
-- | [Γ]A. Applies context to type until fully applied.
applyEnv' :: Env -> Type -> Type
applyEnv' cxt typ | typ == typ' = typ'
| otherwise = applyEnv' cxt typ'
where
typ' = case typ of
TLit _ -> typ
TData name typs -> TData name $ map (applyEnv' cxt) typs
-- [Γ]α = α
TVar _ -> typ
-- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ
-- [Γ[ά]]ά = [Γ[ά]]ά
TEVar tevar -> fromMaybe typ $ findSolved tevar cxt
-- [Γ](A → B) = [Γ]A → [Γ]B
TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyEnv' cxt t
findSolved :: TEVar -> Env -> Maybe Type
findSolved _ Empty = Nothing
findSolved tevar (xs :|> x) = case x of
EnvTEVarSolved tevar' t | tevar == tevar' -> Just t
_ -> findSolved tevar xs
-- | Γ ⊢ A
-- Under context Γ, type A is well-formed
wellFormed :: Env -> Type -> Err ()
wellFormed env = \case
TLit _ -> pure ()
-- -------- UvarWF
-- Γ[α] ⊢ α
TVar tvar -> unless (EnvTVar tvar `elem` env) $
throwError ("Unbound type variable: " ++ show tvar)
-- Γ ⊢ A Γ ⊢ B
-- ------------- ArrowWF
-- Γ ⊢ A → B
TFun t1 t2 -> do { wellFormed env t1; wellFormed env t2 }
-- Γ,α ⊢ A
-- -------- ForallWF
-- Γ ⊢ ∀α.A
TAll tvar t -> wellFormed (env :|> EnvTVar tvar) t
TEVar tevar
-- ---------- EvarWF
-- Γ[ά] ⊢ ά
| EnvTEVar tevar `elem` env -> pure ()
-- ---------- SolvedEvarWF
-- Γ[ά=τ] ⊢ ά
| Just _ <- findSolved tevar env -> pure ()
| otherwise -> throwError ("Can't find type: " ++ show tevar)
TData _ typs -> mapM_ (wellFormed env) typs
isMono :: Type -> Bool
isMono = \case
TAll{} -> False
TFun t1 t2 -> on (&&) isMono t1 t2
TData _ typs -> all isMono typs
TVar _ -> True
TEVar _ -> True
TLit _ -> True
inferLit :: Lit -> Type
inferLit = \case
LInt _ -> TLit "Int"
LChar _ -> TLit "Char"
fresh :: Tc TEVar
fresh = do
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar)
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar
getVars :: Type -> [Type]
getVars = fst . partitionType
getReturn :: Type -> Type
getReturn = snd . partitionType
-- | Partion type into variable types and return type.
--
-- ∀a.∀b. a → (∀c. c → c) → b
-- ([a, ∀c. c → c], b)
--
-- Unsure if foralls should be added to the return type or not.
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return')
where
t_vars' = map (\t -> foldr applyForall t foralls) t_vars
t_return' = foldr applyForall t_return foralls
applyForall fa t | usesTVar tvar t = fa t
| otherwise = t
where TAll tvar _ = fa t
(t_vars, t_return) = go [] typ'
(foralls, typ') = skipForalls typ
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
usesTVar :: TVar -> Type -> Bool
usesTVar tvar = \case
TLit _ -> False
TVar tvar' | tvar' == tvar -> True
| otherwise -> False
TFun t1 t2 -> on (||) usesTVar' t1 t2
TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar"
| otherwise -> usesTVar' t
TData _ typs -> any usesTVar' typs
_ -> error "Impossible"
where
usesTVar' = usesTVar tvar
skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type
skipLambdas i exp
| i == 0 = exp
| T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e
| otherwise = error "Number of expected lambdas doesn't match expression"
isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar
where
unSolvedTEVar = \case
EnvTEVar _ -> True
_ -> False
toTVar :: Type -> Err TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> throwError "Not a type variable"
insertEnv :: EnvElem -> Tc ()
insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig)
lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env)
where
findId Empty = Nothing
findId (ys :|> y) = case y of
EnvVar x' t | x==x' -> Just t
_ -> findId ys
lookupInj :: UIdent -> Tc (Maybe Type)
lookupInj x = gets (Map.lookup x . data_injs)
putEnv :: Env -> Tc ()
putEnv = modifyEnv . const
modifyEnv :: (Env -> Env) -> Tc ()
modifyEnv f =
modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env }
pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ)
dummy = TLit "Int"
---------------------------------------------------------------------------
-- * Debug
---------------------------------------------------------------------------
traceEnv s = do
env <- gets env
trace (s ++ " " ++ show env) pure ()
traceD s x = trace (s ++ " " ++ show x) pure ()
traceT s x = trace (s ++ " " ++ ppT x) pure ()
traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure ()
ppT = \case
TLit (UIdent s) -> s
TVar (MkTVar (LIdent s)) -> "α_" ++ s
TFun t1 t2 -> ppT t1 ++ "" ++ ppT t2
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
++ " )"
ppEnvElem = \case
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s
EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t
EnvMark (MkTEVar (LIdent s)) -> "" ++ "ά_" ++ s
ppEnv = \case
Empty -> "·"
(xs :|> x) -> ppEnv xs ++ " (" ++ ppEnvElem x ++ ")"

View file

@ -1,36 +1,31 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Subst)
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
@ -78,7 +73,7 @@ checkData d = do
retType :: Type -> Type
retType (TFun _ t2) = retType t2
retType a = a
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
@ -105,7 +100,7 @@ preRun (x : xs) = case x of
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
@ -152,9 +147,9 @@ typeEq _ _ = False
skolem :: T.Type -> T.Type
skolem (T.TVar (T.MkTVar a)) = T.TLit a
skolem (T.TAll x t) = T.TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
skolem t = t
skolem (T.TAll x t) = T.TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
skolem t = t
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2
@ -169,8 +164,8 @@ isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer T.ExpT
inferExp e = do
@ -183,7 +178,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -200,15 +195,15 @@ class NewType a b where
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
TAll b t -> T.TAll (toNew b) (toNew t)
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
@ -422,12 +417,12 @@ generalize :: Map T.Ident T.Type -> T.Type -> T.Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> T.Type -> T.Type
go [] t = t
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
@ -477,10 +472,10 @@ instance SubstType T.Type where
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where
@ -513,10 +508,10 @@ instance SubstType T.Pattern where
apply :: Subst -> T.Pattern -> T.Pattern
apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType a => SubstType [a] where
apply s = map (apply s)
@ -555,7 +550,7 @@ fresh = do
next :: Char -> Char
next 'z' = 'a'
next a = succ a
next a = succ a
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
@ -608,10 +603,10 @@ inferBranch (Branch pat expr) = do
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
inferPattern = \case
@ -659,14 +654,14 @@ inferPattern = \case
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType a = [a]
typeLength :: T.Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b
typeLength _ = 1
typeLength _ = 1
litType :: Lit -> T.Type
litType (LInt _) = int
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
@ -681,8 +676,8 @@ partitionType = go []
go acc 0 t = (acc, t)
go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp =
@ -695,3 +690,4 @@ unzip4 =
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
)
([], [], [], [])

View file

@ -1,245 +1,135 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr (
module TypeChecker.TypeCheckerIr,
) where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Char (isDigit)
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Data.Set (Set)
import Data.String qualified
import Grammar.Print
import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
module TypeChecker.TypeCheckerIr
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
newtype Ctx = Ctx {vars :: Map Ident Type}
deriving (Show)
import Data.String (IsString)
import Grammar.Abs (Character (..), Lit (..), TVar (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type
, takenTypeVars :: Set Ident
}
deriving (Show)
newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Data = Data Ident [Constructor]
deriving (Show, Eq, Ord, Read)
data Constructor = Constructor Ident Type
deriving (Show, Eq, Ord, Read)
newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read)
data Def' t = DBind (Bind' t)
| DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Type
= TLit Ident
| TVar TVar
| TData Ident [Type]
| TFun Type Type
| TAll TVar Type
| TData Ident [Type]
deriving (Show, Eq, Ord, Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
| ECase ExpT [Branch]
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
type ExpT = (Exp, Type)
data Branch = Branch (Pattern, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Pattern = PVar Id | PLit (Lit, Type) | PInj Ident [Pattern] | PCatch | PEnum Ident
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type)
data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, Data.String.IsString)
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
data Lit = LInt Integer | LChar Char
deriving (Show, Eq, Ord, Read)
data Pattern' t
= PVar (Id' t) -- TODO should be Ident
| PLit (Lit, t) -- TODO should be Lit
| PCatch
| PEnum Ident
| PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Bind = Bind Id [Id] ExpT
data Exp' t
= EVar Ident
| EInj Ident
| ELit Lit
| ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Ident where
prt _ (Ident str) = doc . showString $ str
prt i (Ident s) = prt i s
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Data where
prt i = \case
Data type_ constructors -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 constructors, doc (showString "}")])
instance Print Constructor where
prt i = \case
Constructor uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
instance Print [Constructor] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where
instance Print t => Print (Program' t) where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind (name, t) args rhs) =
prPrec i 0 $
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString "\n"
, prt 0 name
, prtIdPs 0 args
, doc $ showString "="
, prt 0 rhs
]
instance Print t => Print (Bind' t) where
prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD
[ prtSig sig
, prt 0 name
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtSig :: Print t => Id' t -> Doc
prtSig (name, t) = concatD [ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
instance Print t => Print (ExpT' t) where
prt i (e, t) = concatD [ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
prtId :: Int -> Id -> Doc
prtId i (name, t) =
prPrec i 0 $
concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) =
prPrec i 0 $
concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
prtIdPs :: Print t => Int -> [Id' t] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prt 0 n]
ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
ELet bs e ->
prPrec i 3 $
concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
]
EApp e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1
, prt 3 e2
]
EAdd e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs n e ->
prPrec i 0 $
concatD
[ doc $ showString "λ"
, prt 0 n
, doc $ showString "."
, prt 0 e
]
ECase exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
]
)
instance Print t => Print (Id' t) where
prt i (name, t) = concatD [ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print ExpT where
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Branch where
prt i = \case
Branch (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print Pattern where
prt i = \case
PVar lident -> prPrec i 0 (concatD [prtId 0 lident])
PLit (lit, typ) -> prPrec i 0 (concatD [doc $ showString "(", prt 0 lit, doc $ showString ",", prt 0 typ, doc $ showString ")"])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 0 patterns])
PCatch -> prPrec i 0 (concatD [doc (showString "_")])
PEnum p -> prt i p
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print TVar where
prt i (MkTVar id) = prt i id
instance Print Type where
prt i = \case
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar@(MkTVar (Ident iden)) ->
if all isDigit iden
then prPrec i 2 (concatD [prt 0 $ TVar (MkTVar (Ident ("a" <> iden)))])
else prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TData ident types -> prPrec i 1 (concatD [prt 0 ident, prt 0 types])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt n -> prPrec i 0 (concatD [prt 0 n])
LChar c -> prPrec i 0 (concatD [prt 0 c])
instance Print t => Print (Exp' t) where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EInj name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 b
, doc $ showString "in"
, prt 0 e
]
EApp e1 e2 -> prPrec i 2 $ concatD
[ prt 2 e1
, prt 3 e2
]
EAdd e1 e2 -> prPrec i 1 $ concatD
[ prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs v e -> prPrec i 0 $ concatD
[ doc $ showString "\\"

View file

@ -0,0 +1,232 @@
{-# LANGUAGE OverloadedStrings #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module TestTypeCheckerBidir (testTypeCheckerBidir) where
import Test.Hspec
import Control.Monad ((<=<))
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Par (myLexer, pProgram)
import Renamer.Renamer (rename)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_id
tc_double
tc_add_lam
tc_const
tc_simple_rank2
tc_rank2
tc_identity
tc_pair
tc_tree
tc_mono_case
tc_pol_case
tc_id = specify "Basic identity function polymorphism" $ run
[ "id : forall a. a -> a;"
, "id x = x;"
, "main = id 4;"
] `shouldSatisfy` ok
tc_double = specify "Addition inference" $ run
["double x = x + x;"] `shouldSatisfy` ok
tc_add_lam = specify "Addition lambda inference" $ run
["four = (\\x. x + x) 2;"] `shouldSatisfy` ok
tc_const = specify "Basic polymorphism with multiple type variables" $ run
[ "const : forall a. forall b. a -> b -> a;"
, "const x y = x;"
, "main = const 'a' 65;"
] `shouldSatisfy` ok
tc_simple_rank2 = specify "Simple rank two polymorphism" $ run
[ "id : forall a. a -> a;"
, "id x = x;"
, "f : forall a. a -> (forall b. b -> b) -> a;"
, "f x g = g x;"
, "main = f 4 id;"
] `shouldSatisfy` ok
tc_rank2 = specify "Rank two polymorphism is ok" $ run
[ "const : forall a. forall b. a -> b -> a;"
, "const x y = x;"
, "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int;"
, "rank2 x f y = f x + f y;"
, "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h';"
] `shouldSatisfy` ok
tc_identity = describe "(∀b. b → b) should only accept the identity function" $ do
specify "identityᵢₙₜ is rejected" $ run (fs ++ id_int) `shouldNotSatisfy` ok
specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok
where
fs =
[ "f : forall a. a -> (forall b. b -> b) -> a;"
, "f x g = g x;"
, "id : forall a. a -> a;"
, "id x = x;"
, "id_int : Int -> Int;"
, "id_int x = x;"
]
id =
[ "main : Int;"
, "main = f 4 id;"
]
id_int =
[ "main : Int;"
, "main = f 4 id_int;"
]
tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
specify "Wrong arguments are rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. forall b. Pair (a b) where {"
, " Pair : a -> b -> Pair (a b)"
, "};"
, "main : Pair (Int Char);"
]
wrong = ["main = Pair 'a' 65;"]
correct = ["main = Pair 65 'a';"]
tc_tree = describe "Tree. Recursive data type" $ do
specify "Wrong tree is rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. Tree (a) where {"
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
, " Leaf : a -> Tree (a)"
, "};"
]
wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3);"]
correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3);"]
tc_mono_case = describe "Monomorphic pattern matching" $ do
specify "First wrong case expression rejected"
$ run wrong1 `shouldNotSatisfy` ok
specify "Second wrong case expression rejected"
$ run wrong2 `shouldNotSatisfy` ok
specify "Third wrong case expression rejected"
$ run wrong3 `shouldNotSatisfy` ok
specify "First correct case expression accepted"
$ run correct1 `shouldSatisfy` ok
specify "Second correct case expression accepted"
$ run correct2 `shouldSatisfy` ok
where
wrong1 =
[ "simple : Int -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => 1;"
, "};"
]
wrong2 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 1 => 1;"
, "};"
]
wrong3 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => '1';"
, "};"
]
correct1 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " 'T' => 1;"
, "};"
]
correct2 =
[ "simple : Char -> Int;"
, "simple c = case c of {"
, " 'F' => 0;"
, " _ => 1;"
, "};"
]
tc_pol_case = describe "Polymophic pattern matching" $ do
specify "First wrong case expression rejected"
$ run (fs ++ wrong1) `shouldNotSatisfy` ok
specify "Second wrong case expression rejected"
$ run (fs ++ wrong2) `shouldNotSatisfy` ok
specify "Third wrong case expression rejected"
$ run (fs ++ wrong3) `shouldNotSatisfy` ok
specify "First correct case expression accepted"
$ run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted"
$ run (fs ++ correct2) `shouldSatisfy` ok
where
fs =
[ "data forall a. List (a) where {"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "};"
]
wrong1 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " Cons 6 xs => 1 + length xs;"
, "};"
]
wrong2 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Cons => 0;"
, " Cons x xs => 1 + length xs;"
, "};"
]
wrong3 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " 0 => 0;"
, " Cons x xs => 1 + length xs;"
, "};"
]
correct1 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " Cons x xs => 1 + length xs;"
, " Cons x (Cons y Nil) => 2;"
, "};"
]
correct2 =
[ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " non_empty => 1;"
, "};"
]
run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines
ok = \case
Ok _ -> True
Bad _ -> False

113
tests/TestTypeCheckerHm.hs Normal file
View file

@ -0,0 +1,113 @@
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE QualifiedDo #-}
module TestTypeCheckerHm (testTypeCheckerHm) where
import Control.Monad ((<=<))
import qualified DoStrings as D
import Grammar.Par (myLexer, pProgram)
import Prelude (Bool (..), Either (..), IO, fmap,
not, ($), (.))
import Test.Hspec
-- import Test.QuickCheck
import TypeChecker.TypeCheckerHm (typecheck)
testTypeCheckerHm = describe "Hillner Milner type checker test" $ do
ok1
ok2
bad1
bad2
-- bad3
ok1 =
specify "Basic polymorphism with multiple type variables" $
run
( D.do
const
"main = const 'a' 65 ;"
)
`shouldSatisfy` ok
ok2 =
specify "Head with a correct signature is accepted" $
run
( D.do
list
headSig
head
)
`shouldSatisfy` ok
bad1 =
specify "Infinite type unification should not succeed" $
run
( D.do
"main = \\x. x x ;"
)
`shouldSatisfy` bad
bad2 =
specify "Pattern matching using different types should not succeed" $
run
( D.do
list
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
`shouldSatisfy` bad
bad3 =
specify "Using a concrete function on a skolem variable should not succeed" $
run
( D.do
bool
_not
"f : a -> Bool () ;"
" f x = not x ;"
)
`shouldSatisfy` bad
run = typecheck <=< pProgram . myLexer
ok (Right _) = True
ok (Left _) = False
bad = not . ok
-- FUNCTIONS
const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
list = D.do
"data List (a) where"
" {"
" Nil : List (a)"
" Cons : a -> List (a) -> List (a)"
" };"
headSig = D.do
"head : List (a) -> a ;"
head = D.do
"head xs = "
" case xs of {"
" Cons x xs => x ;"
" };"
bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
"};"
_not = D.do
"not : Bool () -> Bool () ;"
"not x = case x of {"
" True => False ;"
" False => True ;"
"};"

View file

@ -1,6 +1,6 @@
module DoStrings where
import Prelude hiding ((>>), (>>=))
import Prelude hiding ((>>), (>>=))
(>>) :: String -> String -> String
(>>) str1 str2 = str1 ++ "\n" ++ str2

10
tests/Tests.hs Normal file
View file

@ -0,0 +1,10 @@
module Main where
import Test.Hspec
import TestTypeCheckerBidir (testTypeCheckerBidir)
import TestTypeCheckerHm (testTypeCheckerHm)
main = hspec $ do
testTypeCheckerBidir
testTypeCheckerHm