Add bidirectional type checker, lambda lifter.
This commit is contained in:
parent
2fa30faa87
commit
ac3f222753
22 changed files with 2440 additions and 577 deletions
96
Grammar.cf
96
Grammar.cf
|
|
@ -3,94 +3,94 @@
|
|||
-- * PROGRAM
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
Program. Program ::= [Def] ;
|
||||
Program. Program ::= [Def];
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * TOP-LEVEL
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
DBind. Def ::= Bind ;
|
||||
DSig. Def ::= Sig ;
|
||||
DData. Def ::= Data ;
|
||||
DBind. Def ::= Bind;
|
||||
DSig. Def ::= Sig;
|
||||
DData. Def ::= Data;
|
||||
|
||||
Sig. Sig ::= LIdent ":" Type ;
|
||||
|
||||
Bind. Bind ::= LIdent [LIdent] "=" Exp ;
|
||||
Sig. Sig ::= LIdent ":" Type;
|
||||
Bind. Bind ::= LIdent [LIdent] "=" Exp;
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * TYPES
|
||||
-- * Types
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
TLit. Type2 ::= UIdent ;
|
||||
TVar. Type2 ::= TVar ;
|
||||
TAll. Type1 ::= "forall" TVar "." Type ;
|
||||
TData. Type1 ::= UIdent "(" [Type] ")" ;
|
||||
internal TEVar. Type1 ::= TEVar ;
|
||||
TFun. Type ::= Type1 "->" Type ;
|
||||
TLit. Type1 ::= UIdent; -- τ
|
||||
TVar. Type1 ::= TVar; -- α
|
||||
internal TEVar. Type1 ::= TEVar; -- ά
|
||||
TData. Type1 ::= UIdent "(" [Type] ")"; -- D ()
|
||||
TFun. Type ::= Type1 "->" Type; -- A → A
|
||||
TAll. Type ::= "forall" TVar "." Type; -- ∀α. A
|
||||
|
||||
MkTVar. TVar ::= LIdent ;
|
||||
internal MkTEVar. TEVar ::= LIdent ;
|
||||
MkTVar. TVar ::= LIdent;
|
||||
internal MkTEVar. TEVar ::= LIdent;
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * DATA TYPES
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
Constructor. Constructor ::= UIdent ":" Type ;
|
||||
Data. Data ::= "data" Type "where" "{" [Inj] "}" ;
|
||||
|
||||
Data. Data ::= "data" Type "where" "{" [Constructor] "}" ;
|
||||
Inj. Inj ::= UIdent ":" Type ;
|
||||
separator nonempty Inj " " ;
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * EXPRESSIONS
|
||||
-- * Expressions
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
EAnn. Exp4 ::= "(" Exp ":" Type ")" ;
|
||||
EVar. Exp3 ::= LIdent ;
|
||||
EInj. Exp3 ::= UIdent ;
|
||||
ELit. Exp3 ::= Lit ;
|
||||
EApp. Exp2 ::= Exp2 Exp3 ;
|
||||
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
|
||||
ELet. Exp ::= "let" Bind "in" Exp ;
|
||||
EAbs. Exp ::= "\\" LIdent "." Exp ;
|
||||
ECase. Exp ::= "case" Exp "of" "{" [Branch] "}";
|
||||
EAnn. Exp4 ::= "(" Exp ":" Type ")";
|
||||
EVar. Exp3 ::= LIdent;
|
||||
EInj. Exp3 ::= UIdent;
|
||||
ELit. Exp3 ::= Lit;
|
||||
EApp. Exp2 ::= Exp2 Exp3;
|
||||
EAdd. Exp1 ::= Exp1 "+" Exp2;
|
||||
ELet. Exp ::= "let" Bind "in" Exp;
|
||||
EAbs. Exp ::= "\\" LIdent "." Exp;
|
||||
ECase. Exp ::= "case" Exp "of" "{" [Branch] "}";
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * LITERALS
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
LInt. Lit ::= Integer ;
|
||||
LChar. Lit ::= Char ;
|
||||
LInt. Lit ::= Integer;
|
||||
LChar. Lit ::= Character;
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * CASE
|
||||
-- * PATTERN MATCHING
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
Branch. Branch ::= Pattern "=>" Exp ;
|
||||
|
||||
PVar. Pattern1 ::= LIdent ;
|
||||
PLit. Pattern1 ::= Lit ;
|
||||
PCatch. Pattern1 ::= "_" ;
|
||||
PEnum. Pattern1 ::= UIdent ;
|
||||
PInj. Pattern ::= UIdent [Pattern1] ;
|
||||
PVar. Pattern1 ::= LIdent;
|
||||
PLit. Pattern1 ::= Lit;
|
||||
PCatch. Pattern1 ::= "_";
|
||||
PEnum. Pattern1 ::= UIdent;
|
||||
PInj. Pattern ::= UIdent [Pattern1];
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- * AUX
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
terminator Def ";" ;
|
||||
separator nonempty Constructor "" ;
|
||||
separator Type " " ;
|
||||
separator nonempty Pattern1 " " ;
|
||||
terminator Def ";";
|
||||
terminator Branch ";" ;
|
||||
separator Ident " ";
|
||||
separator LIdent " ";
|
||||
separator TVar " " ;
|
||||
|
||||
coercions Exp 4 ;
|
||||
coercions Type 2 ;
|
||||
coercions Pattern 1 ;
|
||||
separator LIdent "";
|
||||
separator Type " ";
|
||||
separator TVar " ";
|
||||
separator nonempty Pattern1 " ";
|
||||
|
||||
coercions Pattern 1;
|
||||
coercions Exp 4;
|
||||
coercions Type 1 ;
|
||||
|
||||
token Character '\''(char)'\'' ;
|
||||
token UIdent (upper (letter | digit | '_')*) ;
|
||||
token LIdent (lower (letter | digit | '_')*) ;
|
||||
|
||||
comment "--" ;
|
||||
comment "{-" "-}" ;
|
||||
comment "--";
|
||||
comment "{-" "-}";
|
||||
|
|
|
|||
|
|
@ -31,13 +31,18 @@ executable language
|
|||
Grammar.Skel
|
||||
Grammar.ErrM
|
||||
Auxiliary
|
||||
Renamer.Renamer
|
||||
TypeChecker.TypeChecker
|
||||
TypeChecker.TypeCheckerHm
|
||||
TypeChecker.TypeCheckerBidir
|
||||
TypeChecker.TypeCheckerIr
|
||||
TypeChecker.RemoveTEVar
|
||||
LambdaLifter
|
||||
Monomorphizer.Monomorphizer
|
||||
Monomorphizer.MonomorphizerIr
|
||||
Renamer.Renamer
|
||||
Codegen.Codegen
|
||||
Codegen.LlvmIr
|
||||
Compiler
|
||||
|
||||
hs-source-dirs: src
|
||||
|
||||
|
|
@ -60,6 +65,9 @@ Test-suite language-testsuite
|
|||
main-is: Tests.hs
|
||||
|
||||
other-modules:
|
||||
TestTypeCheckerBidir
|
||||
TestTypeCheckerHm
|
||||
|
||||
Grammar.Abs
|
||||
Grammar.Lex
|
||||
Grammar.Par
|
||||
|
|
@ -67,9 +75,11 @@ Test-suite language-testsuite
|
|||
Grammar.Skel
|
||||
Grammar.ErrM
|
||||
Auxiliary
|
||||
TypeChecker.TypeChecker
|
||||
TypeChecker.TypeCheckerIr
|
||||
Renamer.Renamer
|
||||
TypeChecker.TypeCheckerHm
|
||||
TypeChecker.TypeCheckerBidir
|
||||
TypeChecker.RemoveTEVar
|
||||
TypeChecker.TypeCheckerIr
|
||||
Compiler
|
||||
|
||||
hs-source-dirs: src, tests, tests/TypecheckingHM
|
||||
|
|
@ -87,3 +97,4 @@ Test-suite language-testsuite
|
|||
, bytestring
|
||||
|
||||
default-language: GHC2021
|
||||
|
||||
|
|
|
|||
11
sample-programs/basic-0
Normal file
11
sample-programs/basic-0
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
data forall a. List (a) where {
|
||||
Nil : List (a)
|
||||
Cons : a -> List (a) -> List (a)
|
||||
};
|
||||
|
||||
length : forall c. List (c) -> Int;
|
||||
length = \list. case list of {
|
||||
Nil => 0;
|
||||
Cons x xs => 1 + length xs;
|
||||
Cons x (Cons y Nil) => 2;
|
||||
};
|
||||
|
|
@ -3,3 +3,4 @@ add x = \y. x+y;
|
|||
|
||||
main : Int ;
|
||||
main = (\z. z+z) ((add 4) 6) ;
|
||||
|
||||
|
|
|
|||
121
spec.txt
Normal file
121
spec.txt
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
---------------------------------------------------------------------------
|
||||
-- * Parser
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
data Program = Program [Def]
|
||||
|
||||
data Def = DSig Ident Type | DBind Bind
|
||||
|
||||
data Bind = Bind Ident [Ident] Exp
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
| ELit Lit
|
||||
| EAnn Exp Type
|
||||
| ELet Ident Exp Exp
|
||||
| EApp Exp Exp
|
||||
| EAdd Exp Exp
|
||||
| EAbs Ident Exp
|
||||
|
||||
data Lit = LInt Integer
|
||||
| LChar Character
|
||||
|
||||
data Type
|
||||
= TLit Ident -- τ
|
||||
| TVar TVar -- α
|
||||
| TFun Type Type -- A → A
|
||||
| TAll TVar Type -- ∀α. A
|
||||
| TEVar TEVar -- ά (internal)
|
||||
|
||||
data TVar = MkTVar Ident
|
||||
data TEVar = MkTEVar Ident
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Type checker
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- • Def and DSig are removed in favor on just Bind
|
||||
-- • Typed expressions
|
||||
-- • TEVar is removed (NOT IMPLEMENTED)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
| ELit Lit
|
||||
| ELet Bind ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
| EAbs Ident ExpT
|
||||
|
||||
type Id = (Ident, Type)
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
|
||||
data Lit = LInt Integer
|
||||
| LChar Character
|
||||
|
||||
data Type
|
||||
= TLit Ident -- τ
|
||||
| TVar TVar -- α
|
||||
| TFun Type Type -- A → A
|
||||
| TAll TVar Type -- ∀α. A
|
||||
|
||||
data TVar = MkTVar Ident
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Lambda lifter
|
||||
---------------------------------------------------------------------------
|
||||
-- • EAbs are removed (NOT IMPLEMENTED)
|
||||
-- • ELet only allow constant expressions (NOT IMPLEMENTED)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
| ELit Lit
|
||||
| ELet Ident ExpT ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
|
||||
type Id = (Ident, Type)
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Lit = LInt Integer
|
||||
| LChar Character
|
||||
|
||||
data Type
|
||||
= TLit Ident -- τ
|
||||
| TVar TVar -- α
|
||||
| TFun Type Type -- A → A
|
||||
| TAll TVar Type -- ∀α. A
|
||||
|
||||
data TVar = MkTVar Ident
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Monomorpher
|
||||
---------------------------------------------------------------------------
|
||||
-- • Polymorphic types are removed (NOT IMPLEMENTED)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
| ELit Lit
|
||||
| ELet Ident ExpT ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
|
||||
type Id = (Ident, Type)
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Lit = LInt Integer
|
||||
| LChar Character
|
||||
|
||||
data Type = Type Ident
|
||||
|
|
@ -3,6 +3,7 @@ module Auxiliary (module Auxiliary) where
|
|||
import Control.Monad.Error.Class (liftEither)
|
||||
import Control.Monad.Except (MonadError)
|
||||
import Data.Either.Combinators (maybeToRight)
|
||||
import TypeChecker.TypeCheckerIr (Type (TFun))
|
||||
|
||||
snoc :: a -> [a] -> [a]
|
||||
snoc x xs = xs ++ [x]
|
||||
|
|
@ -19,3 +20,4 @@ mapAccumM f = go
|
|||
(acc', x') <- f acc x
|
||||
(acc'', xs') <- go acc' xs
|
||||
pure (acc'', x':xs')
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import Data.Tuple.Extra (dupe, first, second)
|
|||
import Debug.Trace (trace)
|
||||
import qualified Grammar.Abs as GA
|
||||
import Grammar.ErrM (Err)
|
||||
import Monomorphizer.MonomorphizerIr (Ident (..))
|
||||
import Monomorphizer.MonomorphizerIr as MIR
|
||||
|
||||
-- | The record used as the code generator state
|
||||
|
|
@ -57,8 +58,13 @@ getVarCount :: CompilerState Integer
|
|||
getVarCount = gets variableCount
|
||||
|
||||
-- | Increases the variable count and returns it from the CodeGenerator state
|
||||
<<<<<<< HEAD
|
||||
getNewVar :: CompilerState GA.Ident
|
||||
getNewVar = GA.Ident . show <$> (increaseVarCount >> getVarCount)
|
||||
=======
|
||||
getNewVar :: CompilerState Ident
|
||||
getNewVar = (Ident . show) <$> (increaseVarCount >> getVarCount)
|
||||
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
|
||||
|
||||
-- | Increses the label count and returns a label from the CodeGenerator state
|
||||
getNewLabel :: CompilerState Integer
|
||||
|
|
@ -76,10 +82,25 @@ getFunctions bs = Map.fromList $ go bs
|
|||
go (MIR.DBind (MIR.Bind id args _) : xs) =
|
||||
(id, FunctionInfo{numArgs = length args, arguments = args})
|
||||
: go xs
|
||||
<<<<<<< HEAD
|
||||
go (_ : xs) = go xs
|
||||
=======
|
||||
go (MIR.DData (MIR.Data n cons) : xs) =
|
||||
do map
|
||||
( \(Inj id xs) ->
|
||||
( (coerce id, MIR.TLit (extractTypeName n))
|
||||
, FunctionInfo
|
||||
{ numArgs = undefined -- TODO
|
||||
, arguments = createArgs (snd <$> undefined) -- TODO
|
||||
}
|
||||
)
|
||||
)
|
||||
cons
|
||||
<> go xs
|
||||
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
|
||||
|
||||
createArgs :: [MIR.Type] -> [Id]
|
||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
|
||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
|
||||
|
||||
{- | Produces a map of functions infos from a list of binds,
|
||||
which contains useful data for code generation.
|
||||
|
|
@ -89,6 +110,7 @@ getConstructors bs = Map.fromList $ go bs
|
|||
where
|
||||
go [] = []
|
||||
go (MIR.DData (MIR.Data t cons) : xs) =
|
||||
<<<<<<< HEAD
|
||||
fst
|
||||
( foldl
|
||||
( \(acc, i) (Constructor id xs) ->
|
||||
|
|
@ -96,6 +118,17 @@ getConstructors bs = Map.fromList $ go bs
|
|||
, ConstructorInfo
|
||||
{ numArgsCI = length (init . flattenType $ xs)
|
||||
, argumentsCI = createArgs (init . flattenType $ xs)
|
||||
=======
|
||||
do
|
||||
let (Ident n) = extractTypeName t
|
||||
fst
|
||||
( foldl
|
||||
( \(acc, i) (Inj (Ident id) xs) ->
|
||||
( ( (Ident (n <> "_" <> id), MIR.TLit (coerce n))
|
||||
, ConstructorInfo
|
||||
{ numArgsCI = undefined -- TODO
|
||||
, argumentsCI = createArgs (snd <$> undefined) -- TODO
|
||||
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
|
||||
, numCI = i
|
||||
, returnTypeCI = t --last . flattenType $ xs
|
||||
}
|
||||
|
|
@ -133,30 +166,30 @@ test :: Integer -> Program
|
|||
test v =
|
||||
Program
|
||||
[ DataType
|
||||
(GA.Ident "Craig")
|
||||
[ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")]
|
||||
, Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")]
|
||||
(Ident "Craig")
|
||||
[ Constructor (Ident "Bob") [MIR.Type (Ident "_Int")]
|
||||
, Constructor (Ident "Betty") [MIR.Type (Ident "_Int")]
|
||||
]
|
||||
, DataType
|
||||
(GA.Ident "Alice")
|
||||
[ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- ,
|
||||
-- (GA.Ident "Alice", [TInt, TInt])
|
||||
(Ident "Alice")
|
||||
[ Constructor (Ident "Eve") [MIR.Type (Ident "_Int")] -- ,
|
||||
-- (Ident "Alice", [TInt, TInt])
|
||||
]
|
||||
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig"))
|
||||
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) []
|
||||
-- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92)
|
||||
, Bind (Ident "fibonacci", MIR.Type (Ident "_Int")) [(Ident "x", MIR.Type (Ident "_Int"))] (EVar ("x", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig"))
|
||||
, Bind (Ident "main", MIR.Type (Ident "_Int")) []
|
||||
-- (EApp (MIR.Type (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))-- (EInt 92)
|
||||
$
|
||||
eCaseInt
|
||||
(EApp (MIR.TLit (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.TLit (GA.Ident "Craig")), MIR.TLit (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))
|
||||
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int"))
|
||||
(EApp (MIR.TLit (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.TLit (Ident "Craig")), MIR.TLit (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))
|
||||
[ injectionCons "Craig_Bob" "Craig" [CIdent (Ident "x")] (EVar (Ident "x", MIR.Type (Ident "_Int")), MIR.Type (Ident "_Int"))
|
||||
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
|
||||
, Injection (CIdent (GA.Ident "z")) (int 3)
|
||||
, Injection (CIdent (Ident "z")) (int 3)
|
||||
, -- , injectionInt 5 (int 6)
|
||||
injectionCatchAll (int 10)
|
||||
]
|
||||
]
|
||||
where
|
||||
injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs)
|
||||
injectionCons x y xs = Injection (CCons (Ident x, MIR.Type (Ident y)) xs)
|
||||
injectionInt x = Injection (CLit (LInt x))
|
||||
injectionCatchAll = Injection CatchAll
|
||||
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
|
||||
|
|
@ -206,7 +239,7 @@ compileScs [] = do
|
|||
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
|
||||
|
||||
enumerateOneM_
|
||||
( \i (GA.Ident arg_n, arg_t) -> do
|
||||
( \i (Ident arg_n, arg_t) -> do
|
||||
let arg_t' = type2LlvmType arg_t
|
||||
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
|
||||
elemPtr <- getNewVar
|
||||
|
|
@ -222,7 +255,7 @@ compileScs [] = do
|
|||
I32
|
||||
(VInteger i)
|
||||
)
|
||||
emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr
|
||||
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||
)
|
||||
(argumentsCI ci)
|
||||
|
||||
|
|
@ -255,8 +288,13 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
|||
let biggestVariant = 7 + maximum (sum . (\(Constructor _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||
emit $ LIR.Type (Ident outer_id) [I8, Array biggestVariant I8]
|
||||
mapM_
|
||||
<<<<<<< HEAD
|
||||
( \(Constructor inner_id fi) -> do
|
||||
emit $ LIR.Type inner_id (I8 : variantTypes fi)
|
||||
=======
|
||||
( \(Inj (Ident inner_id) fi) -> do
|
||||
emit $ LIR.Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType (snd <$> undefined)) -- TODO
|
||||
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
|
||||
)
|
||||
ts
|
||||
compileScs xs
|
||||
|
|
@ -282,17 +320,17 @@ mainContent var =
|
|||
-- " %4 = load i72, ptr %3\n" <>
|
||||
-- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n"
|
||||
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
|
||||
, -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2")
|
||||
-- , Label (GA.Ident "b_1")
|
||||
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
|
||||
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
|
||||
-- , Label (Ident "b_1")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
|
||||
-- , Br (GA.Ident "end")
|
||||
-- , Label (GA.Ident "b_2")
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "b_2")
|
||||
-- , UnsafeRaw
|
||||
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
|
||||
-- , Br (GA.Ident "end")
|
||||
-- , Label (GA.Ident "end")
|
||||
-- , Br (Ident "end")
|
||||
-- , Label (Ident "end")
|
||||
Ret I64 (VInteger 0)
|
||||
]
|
||||
|
||||
|
|
@ -310,7 +348,7 @@ compileExp :: ExpT -> CompilerState ()
|
|||
compileExp (MIR.ELit lit,t) = emitLit lit
|
||||
compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2
|
||||
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
|
||||
compileExp (MIR.EId name,t) = emitIdent name
|
||||
compileExp (MIR.EVar name,t) = emitIdent name
|
||||
compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2
|
||||
-- compileExp (EAbs t ti e) = emitAbs t ti e
|
||||
compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e)
|
||||
|
|
@ -328,7 +366,7 @@ emitECased t e cases = do
|
|||
let rt = type2LlvmType (snd e)
|
||||
vs <- exprToValue e
|
||||
lbl <- getNewLabel
|
||||
let label = GA.Ident $ "escape_" <> show lbl
|
||||
let label = Ident $ "escape_" <> show lbl
|
||||
stackPtr <- getNewVar
|
||||
emit $ SetVariable stackPtr (Alloca ty)
|
||||
mapM_ (emitCases rt ty label stackPtr vs) cs
|
||||
|
|
@ -341,13 +379,13 @@ emitECased t e cases = do
|
|||
res <- getNewVar
|
||||
emit $ SetVariable res (Load ty Ptr stackPtr)
|
||||
where
|
||||
emitCases :: LLVMType -> LLVMType -> GA.Ident -> GA.Ident -> LLVMValue -> Branch -> CompilerState ()
|
||||
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, t) exp) = do
|
||||
cons <- gets constructors
|
||||
let r = fromJust $ Map.lookup consId cons
|
||||
|
||||
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
|
||||
consVal <- getNewVar
|
||||
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
||||
|
|
@ -397,8 +435,8 @@ emitECased t e cases = do
|
|||
(MIR.LInt i, _) -> VInteger i
|
||||
(MIR.LChar i, _) -> VChar i
|
||||
ns <- getNewVar
|
||||
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
emit $ SetVariable ns (Icmp LLEq ty vs i')
|
||||
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
|
||||
emit $ Label lbl_succPos
|
||||
|
|
@ -444,8 +482,13 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
|
|||
appEmitter e1 e2 stack = do
|
||||
let newStack = e2 : stack
|
||||
case e1 of
|
||||
<<<<<<< HEAD
|
||||
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
|
||||
(MIR.EId name, t) -> do
|
||||
=======
|
||||
(MIR.EApp e1' e2', t) -> appEmitter e1' e2' newStack
|
||||
(MIR.EVar name, t) -> do
|
||||
>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.)
|
||||
args <- traverse exprToValue newStack
|
||||
vs <- getNewVar
|
||||
funcs <- gets functions
|
||||
|
|
@ -462,7 +505,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
|
|||
emit $ SetVariable vs call
|
||||
x -> error $ "The unspeakable happened: " <> show x
|
||||
|
||||
emitIdent :: GA.Ident -> CompilerState ()
|
||||
emitIdent :: Ident -> CompilerState ()
|
||||
emitIdent id = do
|
||||
-- !!this should never happen!!
|
||||
emit $ Comment "This should not have happened!"
|
||||
|
|
@ -477,14 +520,14 @@ emitLit i = do
|
|||
(MIR.LChar i'') -> (VChar i'', I8)
|
||||
varCount <- getNewVar
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0))
|
||||
emit $ SetVariable (Ident (show varCount)) (Add t i' (VInteger 0))
|
||||
|
||||
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
||||
emitAdd t e1 e2 = do
|
||||
v1 <- exprToValue e1
|
||||
v2 <- exprToValue e2
|
||||
v <- getNewVar
|
||||
emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
|
||||
|
||||
emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
||||
emitSub t e1 e2 = do
|
||||
|
|
@ -498,7 +541,7 @@ exprToValue = \case
|
|||
(MIR.ELit i, t) -> pure $ case i of
|
||||
(MIR.LInt i) -> VInteger i
|
||||
(MIR.LChar i) -> VChar i
|
||||
(MIR.EId name, t) -> do
|
||||
(MIR.EVar name, t) -> do
|
||||
funcs <- gets functions
|
||||
case Map.lookup (name, t) funcs of
|
||||
Just fi -> do
|
||||
|
|
@ -515,7 +558,7 @@ exprToValue = \case
|
|||
e -> do
|
||||
compileExp e
|
||||
v <- getVarCount
|
||||
pure $ VIdent (GA.Ident $ show v) (getType e)
|
||||
pure $ VIdent (Ident $ show v) (getType e)
|
||||
|
||||
type2LlvmType :: MIR.Type -> LLVMType
|
||||
type2LlvmType (MIR.TLit id@(Ident name)) = case name of
|
||||
|
|
@ -558,3 +601,4 @@ typeByteSize (CustomType _) = 8
|
|||
|
||||
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
|
||||
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1
|
||||
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ module Codegen.LlvmIr (
|
|||
ToIr(..)
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import Grammar.Abs (Ident (..))
|
||||
import Data.List (intercalate)
|
||||
import Grammar.Abs (Character)
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
|
||||
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show
|
||||
instance ToIr CallingConvention where
|
||||
|
|
@ -87,7 +88,7 @@ instance ToIr Visibility where
|
|||
-- or a string contstant
|
||||
data LLVMValue
|
||||
= VInteger Integer
|
||||
| VChar Char
|
||||
| VChar Character
|
||||
| VIdent Ident LLVMType
|
||||
| VConstant String
|
||||
| VFunction Ident Visibility LLVMType
|
||||
|
|
|
|||
242
src/LambdaLifter.hs
Normal file
242
src/LambdaLifter.hs
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
|
||||
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad.State (MonadState (get, put), State,
|
||||
evalState)
|
||||
import Data.List (partition)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Prelude hiding (exp)
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
|
||||
-- | Lift lambdas and let expression into supercombinators.
|
||||
-- Three phases:
|
||||
-- @freeVars@ annotates all the free variables.
|
||||
-- @abstract@ converts lambdas into let expressions.
|
||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
lambdaLift :: Program -> Program
|
||||
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
|
||||
where
|
||||
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
|
||||
(binds, datatypes) = partition isBind defs
|
||||
isBind = \case
|
||||
DBind _ -> True
|
||||
_ -> False
|
||||
|
||||
-- | Annotate free variables
|
||||
freeVars :: [Bind] -> AnnBinds
|
||||
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
|
||||
| Bind n xs e <- binds
|
||||
]
|
||||
|
||||
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
|
||||
freeVarsExp localVars (exp, t) = case exp of
|
||||
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
||||
| otherwise -> (mempty, (AVar n, t))
|
||||
|
||||
ELit lit -> (mempty, (ALit lit, t))
|
||||
|
||||
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
|
||||
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
|
||||
where
|
||||
e1' = freeVarsExp localVars e1
|
||||
e2' = freeVarsExp localVars e2
|
||||
|
||||
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
|
||||
where
|
||||
e' = freeVarsExp (Set.insert par localVars) e
|
||||
|
||||
-- Sum free variables present in bind and the expression
|
||||
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
|
||||
where
|
||||
binders_frees = Set.delete name $ freeVarsOf rhs'
|
||||
e_free = Set.delete name $ freeVarsOf e'
|
||||
|
||||
rhs' = freeVarsExp e_localVars rhs
|
||||
new_bind = ABind (name, t_bind) parms rhs'
|
||||
|
||||
e' = freeVarsExp e_localVars e
|
||||
e_localVars = Set.insert name localVars
|
||||
|
||||
|
||||
freeVarsOf :: AnnExpT -> Set Ident
|
||||
freeVarsOf = fst
|
||||
|
||||
-- AST annotated with free variables
|
||||
type AnnBinds = [(Id, [Id], AnnExpT)]
|
||||
|
||||
type AnnExpT = (Set Ident, AnnExpT')
|
||||
|
||||
data ABind = ABind Id [Id] AnnExpT deriving Show
|
||||
|
||||
type AnnExpT' = (AnnExp, Type)
|
||||
|
||||
data AnnExp = AVar Ident
|
||||
| AInj Ident
|
||||
| ALit Lit
|
||||
| ALet ABind AnnExpT
|
||||
| AApp AnnExpT AnnExpT
|
||||
| AAdd AnnExpT AnnExpT
|
||||
| AAbs Ident AnnExpT
|
||||
deriving Show
|
||||
|
||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
||||
abstract :: AnnBinds -> [Bind]
|
||||
abstract prog = evalState (mapM go prog) 0
|
||||
where
|
||||
go :: (Id, [Id], AnnExpT) -> State Int Bind
|
||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
||||
where
|
||||
(rhs', parms1) = flattenLambdasAnn rhs
|
||||
|
||||
|
||||
-- | Flatten nested lambdas and collect the parameters
|
||||
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
||||
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
|
||||
flattenLambdasAnn ae = go (ae, [])
|
||||
where
|
||||
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
|
||||
go ((free, (e, t)), acc)
|
||||
| AAbs par (free1, e1) <- e
|
||||
, TFun t_par _ <- t
|
||||
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
|
||||
| otherwise = ((free, (e, t)), acc)
|
||||
|
||||
abstractExp :: AnnExpT -> State Int ExpT
|
||||
abstractExp (free, (exp, typ)) = case exp of
|
||||
AVar n -> pure (EVar n, typ)
|
||||
ALit lit -> pure (ELit lit, typ)
|
||||
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
||||
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
||||
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
|
||||
where
|
||||
go (ABind name parms rhs) = do
|
||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
||||
pure $ Bind name (parms ++ parms1) rhs'
|
||||
|
||||
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
|
||||
skipLambdas f (free, (ae, t)) = case ae of
|
||||
AAbs par ae1 -> do
|
||||
ae1' <- skipLambdas f ae1
|
||||
pure (EAbs par ae1', t)
|
||||
_ -> f (free, (ae, t))
|
||||
|
||||
-- Lift lambda into let and bind free variables
|
||||
AAbs parm e -> do
|
||||
i <- nextNumber
|
||||
rhs <- abstractExp e
|
||||
|
||||
let sc_name = Ident ("sc_" ++ show i)
|
||||
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
|
||||
pure $ foldl applyVars sc freeList
|
||||
|
||||
where
|
||||
freeList = Set.toList free
|
||||
vars = zip names $ getVars typ
|
||||
names = snoc parm freeList
|
||||
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
|
||||
where
|
||||
(t_var, t_return) = applyVarType t
|
||||
|
||||
applyVarType :: Type -> (Type, Type)
|
||||
applyVarType typ = (t1, foldr ($) t2 foralls)
|
||||
|
||||
where
|
||||
(t1, t2) = case typ' of
|
||||
TFun t1 t2 -> (t1, t2)
|
||||
_ -> error "Not a function!"
|
||||
|
||||
(foralls, typ') = skipForalls [] typ
|
||||
|
||||
|
||||
skipForalls acc = \case
|
||||
TAll tvar t -> skipForalls (snoc (TAll tvar) acc) t
|
||||
t -> (acc, t)
|
||||
|
||||
nextNumber :: State Int Int
|
||||
nextNumber = do
|
||||
i <- get
|
||||
put $ succ i
|
||||
pure i
|
||||
|
||||
-- | Collects supercombinators by lifting non-constant let expressions
|
||||
collectScs :: [Bind] -> [Bind]
|
||||
collectScs = concatMap collectFromRhs
|
||||
where
|
||||
collectFromRhs (Bind name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in Bind name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
||||
collectScsExp expT@(exp, typ) = case exp of
|
||||
EVar _ -> ([], expT)
|
||||
ELit _ -> ([], expT)
|
||||
|
||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAbs par e -> (scs, (EAbs par e', typ))
|
||||
where
|
||||
(scs, e') = collectScsExp e
|
||||
|
||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
--
|
||||
ELet (Bind name parms rhs) e -> if null parms
|
||||
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
||||
else (bind : rhs_scs ++ et_scs, et')
|
||||
where
|
||||
bind = Bind name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
|
||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
flattenLambdas :: ExpT -> (ExpT, [Id])
|
||||
flattenLambdas = go . (, [])
|
||||
where
|
||||
go ((e, t), acc) = case e of
|
||||
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
|
||||
where t_var = head $ getVars t
|
||||
_ -> ((e, t), acc)
|
||||
|
||||
getVars :: Type -> [Type]
|
||||
getVars = fst . partitionType
|
||||
|
||||
partitionType :: Type -> ([Type], Type)
|
||||
partitionType = go [] . skipForalls'
|
||||
where
|
||||
|
||||
go acc t = case t of
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> (acc, t)
|
||||
|
||||
skipForalls' :: Type -> Type
|
||||
skipForalls' = snd . skipForalls
|
||||
|
||||
skipForalls :: Type -> ([Type -> Type], Type)
|
||||
skipForalls = go []
|
||||
where
|
||||
go acc typ = case typ of
|
||||
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
||||
_ -> (acc, typ)
|
||||
118
src/Main.hs
118
src/Main.hs
|
|
@ -1,66 +1,114 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Codegen.Codegen (generateCode)
|
||||
import Data.Bool (bool)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
|
||||
import Monomorphizer.Monomorphizer (monomorphize)
|
||||
|
||||
import Control.Monad (when)
|
||||
import Data.Bool (bool)
|
||||
import Data.List.Extra (isSuffixOf)
|
||||
|
||||
import Compiler (compile)
|
||||
import Renamer.Renamer (rename)
|
||||
import Data.Maybe (fromJust, isNothing)
|
||||
import GHC.IO.Handle.Text (hPutStrLn)
|
||||
import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
|
||||
ArgOrder (RequireOrder),
|
||||
OptDescr (Option), getOpt,
|
||||
usageInfo)
|
||||
import System.Directory (createDirectory, doesPathExist,
|
||||
getDirectoryContents,
|
||||
removeDirectoryRecursive,
|
||||
setCurrentDirectory)
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (ExitCode, exitFailure,
|
||||
exitSuccess)
|
||||
import System.Exit (ExitCode (ExitFailure),
|
||||
exitFailure, exitSuccess,
|
||||
exitWith)
|
||||
import System.IO (stderr)
|
||||
import System.Process.Extra (readCreateProcess, shell,
|
||||
spawnCommand, waitForProcess)
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
|
||||
|
||||
import Codegen.Codegen (generateCode)
|
||||
import Compiler (compile)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Grammar.Print (printTree)
|
||||
import LambdaLifter (lambdaLift)
|
||||
import Monomorphizer.Monomorphizer (monomorphize)
|
||||
import Renamer.Renamer (rename)
|
||||
import System.Process (spawnCommand, waitForProcess)
|
||||
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
|
||||
|
||||
main :: IO ()
|
||||
main =
|
||||
getArgs >>= \case
|
||||
[] -> putStrLn "Required file path missing"
|
||||
["-d", s] -> do
|
||||
when (".crf" `isSuffixOf` s) (main' True s)
|
||||
putStrLn $ "File '" ++ s ++ "' is not a churf file"
|
||||
[s] -> do
|
||||
when (".crf" `isSuffixOf` s) (main' False s)
|
||||
putStrLn $ "File '" ++ s ++ "' is not a churf file"
|
||||
xs -> putStrLn $ "Can't process: " ++ unwords xs
|
||||
main = getArgs >>= parseArgs >>= uncurry main'
|
||||
|
||||
main' :: Bool -> String -> IO ()
|
||||
main' debug s = do
|
||||
parseArgs :: [String] -> IO (Options, String)
|
||||
parseArgs argv = case getOpt RequireOrder flags argv of
|
||||
(os, f:_, [])
|
||||
| opts.help || isNothing opts.typechecker -> do
|
||||
hPutStrLn stderr (usageInfo header flags)
|
||||
exitSuccess
|
||||
| otherwise -> pure (opts, f)
|
||||
where
|
||||
opts = foldr ($) initOpts os
|
||||
(_, _, errs) -> do
|
||||
hPutStrLn stderr (concat errs ++ usageInfo header flags)
|
||||
exitWith (ExitFailure 1)
|
||||
where
|
||||
header = "Usage: language [--help] [-d|--debug] [-t|type-checker bi/hm] FILE \n"
|
||||
|
||||
flags :: [OptDescr (Options -> Options)]
|
||||
flags =
|
||||
[ Option ['d'] ["debug"] (NoArg enableDebug) "Print debug messages."
|
||||
, Option ['t'] ["type-checker"] (ReqArg chooseTypechecker "bi/hm") "Choose type checker. Possible options are bi and hm"
|
||||
, Option [] ["help"] (NoArg enableHelp) "Print this help message"
|
||||
]
|
||||
|
||||
initOpts :: Options
|
||||
initOpts = Options { help = False
|
||||
, debug = False
|
||||
, typechecker = Nothing
|
||||
}
|
||||
|
||||
enableHelp :: Options -> Options
|
||||
enableHelp opts = opts { help = True }
|
||||
|
||||
enableDebug :: Options -> Options
|
||||
enableDebug opts = opts { debug = True }
|
||||
|
||||
chooseTypechecker :: String -> Options -> Options
|
||||
chooseTypechecker s options = options { typechecker = tc }
|
||||
where
|
||||
tc = case s of
|
||||
"hm" -> pure Hm
|
||||
"bi" -> pure Bi
|
||||
_ -> Nothing
|
||||
|
||||
data Options = Options
|
||||
{ help :: Bool
|
||||
, debug :: Bool
|
||||
, typechecker :: Maybe TypeChecker
|
||||
}
|
||||
|
||||
main' :: Options -> String -> IO ()
|
||||
main' opts s = do
|
||||
file <- readFile s
|
||||
|
||||
printToErr "-- Parse Tree -- "
|
||||
parsed <- fromSyntaxErr . pProgram $ myLexer file
|
||||
bool (printToErr $ printTree parsed) (printToErr $ show parsed) debug
|
||||
bool (printToErr $ printTree parsed) (printToErr $ show parsed) opts.debug
|
||||
|
||||
printToErr "\n-- Renamer --"
|
||||
renamed <- fromRenamerErr . rename $ parsed
|
||||
bool (printToErr $ printTree renamed) (printToErr $ show renamed) debug
|
||||
bool (printToErr $ printTree renamed) (printToErr $ show renamed) opts.debug
|
||||
|
||||
printToErr "\n-- TypeChecker --"
|
||||
typechecked <- fromTypeCheckerErr $ typecheck renamed
|
||||
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) debug
|
||||
typechecked <- fromTypeCheckerErr $ typecheck (fromJust opts.typechecker) renamed
|
||||
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) opts.debug
|
||||
|
||||
printToErr "\n-- Lambda Lifter --"
|
||||
let lifted = lambdaLift typechecked
|
||||
printToErr $ printTree lifted
|
||||
|
||||
-- printToErr "\n-- Lambda Lifter --"
|
||||
-- let lifted = lambdaLift typechecked
|
||||
-- printToErr $ printTree lifted
|
||||
--
|
||||
--printToErr "\n -- Compiler --"
|
||||
printToErr "\n -- Compiler --"
|
||||
generatedCode <- fromCompilerErr $ generateCode (monomorphize typechecked)
|
||||
--putStrLn generatedCode
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Monomorphizer.Monomorphizer (monomorphize) where
|
||||
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Coerce (coerce)
|
||||
|
||||
import Monomorphizer.MonomorphizerIr qualified as M
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import qualified Monomorphizer.MonomorphizerIr as M
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
|
||||
monomorphize :: T.Program -> M.Program
|
||||
monomorphize (T.Program ds) = M.Program $ monoDefs ds
|
||||
|
|
@ -16,40 +17,40 @@ monoDefs = map monoDef
|
|||
|
||||
monoDef :: T.Def -> M.Def
|
||||
monoDef (T.DBind bind) = M.DBind $ monoBind bind
|
||||
monoDef (T.DData d) = M.DData $ monoData d
|
||||
--monoDef (T.DData d) = M.DData $ monoData d
|
||||
|
||||
monoBind :: T.Bind -> M.Bind
|
||||
monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t)
|
||||
|
||||
monoData :: T.Data -> M.Data
|
||||
monoData (T.Data (T.Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs)
|
||||
--monoData :: T.Data -> M.Data
|
||||
--monoData (T.Data (Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs)
|
||||
|
||||
monoConstructor :: T.Constructor -> M.Constructor
|
||||
monoConstructor (T.Constructor (T.Ident i) t) = M.Constructor (M.Ident i) (monoType t)
|
||||
monoConstructor :: T.Inj -> M.Inj
|
||||
monoConstructor (T.Inj (Ident i) t) = M.Inj (M.Ident i) (monoType t)
|
||||
|
||||
monoExpr :: T.Exp -> M.Exp
|
||||
monoExpr = \case
|
||||
T.EId (T.Ident i) -> M.EId (M.Ident i)
|
||||
T.ELit lit -> M.ELit $ monoLit lit
|
||||
T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt)
|
||||
T.EVar (Ident i) -> M.EVar (M.Ident i)
|
||||
T.ELit lit -> M.ELit $ monoLit lit
|
||||
T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt)
|
||||
T.EApp expt1 expt2 -> M.EApp (monoexpt expt1) (monoexpt expt2)
|
||||
T.EAdd expt1 expt2 -> M.EAdd (monoexpt expt1) (monoexpt expt2)
|
||||
T.EAbs _i _expt -> error "BUG"
|
||||
T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs)
|
||||
T.EAbs _i _expt -> error "BUG"
|
||||
T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs)
|
||||
|
||||
monoAbsType :: T.Type -> M.Type
|
||||
monoAbsType (T.TLit u) = M.TLit (coerce u)
|
||||
monoAbsType (T.TVar _v) = M.TLit "Int"
|
||||
monoAbsType (T.TLit u) = M.TLit (coerce u)
|
||||
monoAbsType (T.TVar _v) = M.TLit "Int"
|
||||
monoAbsType (T.TAll _v _t) = error "NOT ALL TYPES"
|
||||
monoAbsType (T.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2)
|
||||
monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES"
|
||||
monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES"
|
||||
|
||||
monoType :: T.Type -> M.Type
|
||||
monoType (T.TAll _ t) = monoType t
|
||||
monoType (T.TAll _ t) = monoType t
|
||||
monoType (T.TVar (T.MkTVar i)) = M.TLit "Int"
|
||||
monoType (T.TLit (T.Ident i)) = M.TLit (M.Ident i)
|
||||
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
|
||||
monoType (T.TData (T.Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t))
|
||||
monoType (T.TLit (Ident i)) = M.TLit (M.Ident i)
|
||||
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
|
||||
monoType (T.TData (Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t))
|
||||
|
||||
monoexpt :: T.ExpT -> M.ExpT
|
||||
monoexpt (e, t) = (monoExpr e, monoType t)
|
||||
|
|
@ -58,19 +59,19 @@ monoId :: T.Id -> M.Id
|
|||
monoId (n, t) = (coerce n, monoType t)
|
||||
|
||||
monoLit :: T.Lit -> M.Lit
|
||||
monoLit (T.LInt i) = M.LInt i
|
||||
monoLit (T.LInt i) = M.LInt i
|
||||
monoLit (T.LChar c) = M.LChar c
|
||||
|
||||
monoInjs :: [T.Branch] -> [M.Branch]
|
||||
monoInjs = map monoInj
|
||||
|
||||
monoInj :: T.Branch -> M.Branch
|
||||
monoInj (T.Branch (init, t) expt) = M.Branch (monoInit init, monoType t) (monoexpt expt)
|
||||
monoInj (T.Branch (patt, t) expt) = M.Branch (monoPattern patt, monoType t) (monoexpt expt)
|
||||
|
||||
monoInit :: T.Pattern -> M.Pattern
|
||||
monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t)
|
||||
monoInit (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t)
|
||||
monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps)
|
||||
monoPattern :: T.Pattern -> M.Pattern
|
||||
monoPattern (T.PVar (id, t)) = M.PVar (id, monoType t)
|
||||
monoPattern (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t)
|
||||
monoPattern (T.PInj id ps) = M.PInj (coerce id) (map monoPattern ps)
|
||||
-- DO NOT DO THIS FOR REAL THOUGH
|
||||
monoInit (T.PEnum (T.Ident i)) = M.PInj (M.Ident i) []
|
||||
monoInit T.PCatch = M.PCatch
|
||||
monoPattern (T.PEnum (Ident i)) = M.PInj (M.Ident i) []
|
||||
monoPattern T.PCatch = M.PCatch
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ newtype Program = Program [Def]
|
|||
data Def = DBind Bind | DData Data
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Data = Data Type [Constructor]
|
||||
data Data = Data Type [Inj]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
= EVar Ident
|
||||
| ELit Lit
|
||||
| ELet Bind ExpT
|
||||
| EApp ExpT ExpT
|
||||
|
|
@ -35,12 +35,12 @@ data Branch = Branch (Pattern, Type) ExpT
|
|||
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Constructor = Constructor Ident Type
|
||||
data Inj = Inj Ident Type
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Lit
|
||||
= LInt Integer
|
||||
| LChar Char
|
||||
| LChar Character
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Type = TLit Ident | TFun Type Type
|
||||
|
|
|
|||
|
|
@ -1,131 +1,124 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use mapAndUnzipM" #-}
|
||||
|
||||
module Renamer.Renamer (rename) where
|
||||
|
||||
import Auxiliary (mapAccumM)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad (foldM)
|
||||
import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError)
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.State (
|
||||
MonadState,
|
||||
StateT,
|
||||
evalStateT,
|
||||
gets,
|
||||
modify,
|
||||
)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as Map
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Tuple.Extra (dupe)
|
||||
import Grammar.Abs
|
||||
import Auxiliary (mapAccumM)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad.Except (ExceptT, MonadError (throwError),
|
||||
runExceptT)
|
||||
import Control.Monad.State (MonadState, State, evalState, gets,
|
||||
mapAndUnzipM, modify)
|
||||
import Data.Function (on)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Tuple.Extra (dupe, second)
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
|
||||
|
||||
-- | Rename all variables and local binds
|
||||
rename :: Program -> Either String Program
|
||||
rename :: Program -> Err Program
|
||||
rename (Program defs) = Program <$> renameDefs defs
|
||||
|
||||
renameDefs :: [Def] -> Either String [Def]
|
||||
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
|
||||
where
|
||||
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
|
||||
|
||||
renameDef :: Def -> Rn Def
|
||||
renameDef = \case
|
||||
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
|
||||
DBind bind -> DBind . snd <$> renameBind initNames bind
|
||||
DData (Data (TData cname types) constrs) -> do
|
||||
tvars_ <- tvars
|
||||
tvars' <- mapM nextNameTVar tvars_
|
||||
let tvars_lt = zip tvars_ tvars'
|
||||
typ' = map (substituteTVar tvars_lt) types
|
||||
constrs' = map (renameConstr tvars_lt) constrs
|
||||
pure . DData $ Data (TData cname typ') constrs'
|
||||
where
|
||||
tvars = concat <$> mapM (collectTVars []) types
|
||||
collectTVars :: [TVar] -> Type -> Rn [TVar]
|
||||
collectTVars tvars = \case
|
||||
TAll tvar t -> collectTVars (tvar : tvars) t
|
||||
TData _ _ -> return tvars
|
||||
-- Should be monad error
|
||||
TVar v -> return [v]
|
||||
_ -> throwError ("Bad data type definition: " ++ show types)
|
||||
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
|
||||
|
||||
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
|
||||
renameConstr new_types (Constructor name typ) =
|
||||
Constructor name $ substituteTVar new_types typ
|
||||
|
||||
renameBind :: Names -> Bind -> Rn (Names, Bind)
|
||||
renameBind old_names (Bind name vars rhs) = do
|
||||
(new_names, vars') <- newNames old_names (coerce vars)
|
||||
(newer_names, rhs') <- renameExp new_names rhs
|
||||
pure (newer_names, Bind name (coerce vars') rhs')
|
||||
|
||||
substituteTVar :: [(TVar, TVar)] -> Type -> Type
|
||||
substituteTVar new_names typ = case typ of
|
||||
TLit _ -> typ
|
||||
TVar tvar
|
||||
| Just tvar' <- lookup tvar new_names ->
|
||||
TVar tvar'
|
||||
| otherwise ->
|
||||
typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar t
|
||||
| Just tvar' <- lookup tvar new_names ->
|
||||
TAll tvar' $ substitute' t
|
||||
| otherwise ->
|
||||
TAll tvar $ substitute' t
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error ("Impossible " ++ show typ)
|
||||
where
|
||||
substitute' = substituteTVar new_names
|
||||
|
||||
initCxt :: Cxt
|
||||
initCxt = Cxt 0 0
|
||||
|
||||
data Cxt = Cxt
|
||||
{ var_counter :: Int
|
||||
, tvar_counter :: Int
|
||||
}
|
||||
data Cxt = Cxt { var_counter :: Int
|
||||
, tvar_counter :: Int
|
||||
}
|
||||
|
||||
|
||||
-- | Rename monad. State holds the number of renamed names.
|
||||
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
|
||||
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
|
||||
newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a }
|
||||
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
|
||||
|
||||
-- | Maps old to new name
|
||||
type Names = Map LIdent LIdent
|
||||
type Names = Map String String
|
||||
|
||||
renameDefs :: [Def] -> Err [Def]
|
||||
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
|
||||
where
|
||||
initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs]
|
||||
|
||||
renameDef :: Def -> Rn Def
|
||||
renameDef = \case
|
||||
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
|
||||
DBind (Bind name vars rhs) -> do
|
||||
(new_names, vars') <- newNamesL initNames vars
|
||||
rhs' <- snd <$> renameExp new_names rhs
|
||||
pure . DBind $ Bind name vars' rhs'
|
||||
DData (Data typ injs) -> do
|
||||
tvars <- collectTVars [] typ
|
||||
tvars' <- mapM nextNameTVar tvars
|
||||
let tvars_lt = zip tvars tvars'
|
||||
typ' = substituteTVar tvars_lt typ
|
||||
injs' = map (renameInj tvars_lt) injs
|
||||
pure . DData $ Data typ' injs'
|
||||
where
|
||||
collectTVars tvars = \case
|
||||
TAll tvar t -> collectTVars (tvar:tvars) t
|
||||
TData _ _ -> pure tvars
|
||||
_ -> throwError ("Bad data type definition: " ++ show typ)
|
||||
|
||||
renameInj :: [(TVar, TVar)] -> Inj -> Inj
|
||||
renameInj new_types (Inj name typ) =
|
||||
Inj name $ substituteTVar new_types typ
|
||||
|
||||
|
||||
substituteTVar :: [(TVar, TVar)] -> Type -> Type
|
||||
substituteTVar new_names typ = case typ of
|
||||
TLit _ -> typ
|
||||
|
||||
TVar tvar | Just tvar' <- lookup tvar new_names
|
||||
-> TVar tvar'
|
||||
| otherwise
|
||||
-> typ
|
||||
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
|
||||
TAll tvar t | Just tvar' <- lookup tvar new_names
|
||||
-> TAll tvar' $ substitute' t
|
||||
| otherwise
|
||||
-> TAll tvar $ substitute' t
|
||||
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error ("Impossible " ++ show typ)
|
||||
where
|
||||
substitute' = substituteTVar new_names
|
||||
|
||||
|
||||
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
||||
renameExp old_names = \case
|
||||
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
|
||||
EInj n -> pure (old_names, EInj n)
|
||||
ELit lit -> pure (old_names, ELit lit)
|
||||
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
|
||||
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
|
||||
|
||||
ELit lit -> pure (old_names, ELit lit)
|
||||
|
||||
EApp e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EApp e1' e2')
|
||||
|
||||
EAdd e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EAdd e1' e2')
|
||||
|
||||
-- TODO fix shadowing
|
||||
ELet bind e -> do
|
||||
(new_names, bind') <- renameBind old_names bind
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', ELet bind' e')
|
||||
EAbs par e -> do
|
||||
(new_names, par') <- newName old_names (coerce par)
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', EAbs (coerce par') e')
|
||||
ELet (Bind name vars rhs) e -> do
|
||||
(new_names, name') <- newNameL old_names name
|
||||
(new_names', vars') <- newNamesL new_names vars
|
||||
(new_names'', rhs') <- renameExp new_names' rhs
|
||||
(new_names''', e') <- renameExp new_names'' e
|
||||
pure (new_names''', ELet (Bind name' vars' rhs') e')
|
||||
|
||||
EAbs par e -> do
|
||||
(new_names, par') <- newNameL old_names par
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', EAbs par' e')
|
||||
|
||||
EAnn e t -> do
|
||||
(new_names, e') <- renameExp old_names e
|
||||
t' <- renameTVars t
|
||||
|
|
@ -137,26 +130,23 @@ renameExp old_names = \case
|
|||
|
||||
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
|
||||
renameBranches ns xs = do
|
||||
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
|
||||
(new_names, xs') <- mapAndUnzipM (renameBranch ns) xs
|
||||
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||
|
||||
renameBranch :: Names -> Branch -> Rn (Names, Branch)
|
||||
renameBranch ns (Branch init e) = do
|
||||
(new_names, init') <- renamePattern ns init
|
||||
renameBranch ns (Branch patt e) = do
|
||||
(new_names, patt') <- renamePattern ns patt
|
||||
(new_names', e') <- renameExp new_names e
|
||||
return (new_names', Branch init' e')
|
||||
return (new_names', Branch patt' e')
|
||||
|
||||
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
|
||||
renamePattern ns i = case i of
|
||||
renamePattern ns p = case p of
|
||||
PInj cs ps -> do
|
||||
(ns_new, ps) <- renamePatterns ns ps
|
||||
return (ns_new, PInj cs ps)
|
||||
rest -> return (ns, rest)
|
||||
(ns_new, ps') <- mapAccumM renamePattern ns ps
|
||||
return (ns_new, PInj cs ps')
|
||||
PVar name -> second PVar <$> newNameL ns name
|
||||
_ -> return (ns, p)
|
||||
|
||||
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
|
||||
renamePatterns ns xs = do
|
||||
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
|
||||
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||
|
||||
renameTVars :: Type -> Rn Type
|
||||
renameTVars typ = case typ of
|
||||
|
|
@ -167,44 +157,57 @@ renameTVars typ = case typ of
|
|||
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
|
||||
_ -> pure typ
|
||||
|
||||
substitute ::
|
||||
TVar -> -- α
|
||||
TVar -> -- α_n
|
||||
Type -> -- A
|
||||
Type -- [α_n/α]A
|
||||
substitute :: TVar -- α
|
||||
-> TVar -- α_n
|
||||
-> Type -- A
|
||||
-> Type -- [α_n/α]A
|
||||
substitute tvar1 tvar2 typ = case typ of
|
||||
TLit _ -> typ
|
||||
TVar tvar'
|
||||
| tvar' == tvar1 -> TVar tvar2
|
||||
| otherwise -> typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar t -> TAll tvar $ substitute' t
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error "Impossible"
|
||||
TLit _ -> typ
|
||||
TVar tvar | tvar == tvar1 -> TVar tvar2
|
||||
| otherwise -> typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t
|
||||
| otherwise -> TAll tvar $ substitute' t
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error "Impossible"
|
||||
where
|
||||
substitute' = substitute tvar1 tvar2
|
||||
|
||||
-- | Create a new name and add it to name environment.
|
||||
newName :: Names -> LIdent -> Rn (Names, LIdent)
|
||||
newName env old_name = do
|
||||
new_name <- makeName old_name
|
||||
pure (Map.insert old_name new_name env, new_name)
|
||||
|
||||
|
||||
-- | Create multiple names and add them to the name environment
|
||||
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
|
||||
newNames = mapAccumM newName
|
||||
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
|
||||
newNamesL = mapAccumM newNameL
|
||||
|
||||
-- | Create a new name and add it to name environment.
|
||||
newNameL :: Names -> LIdent -> Rn (Names, LIdent)
|
||||
newNameL env (LIdent old_name) = do
|
||||
new_name <- makeName old_name
|
||||
pure (Map.insert old_name new_name env, LIdent new_name)
|
||||
|
||||
|
||||
-- | Create multiple names and add them to the name environment
|
||||
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
|
||||
newNamesU = mapAccumM newNameU
|
||||
|
||||
-- | Create a new name and add it to name environment.
|
||||
newNameU :: Names -> UIdent -> Rn (Names, UIdent)
|
||||
newNameU env (UIdent old_name) = do
|
||||
new_name <- makeName old_name
|
||||
pure (Map.insert old_name new_name env, UIdent new_name)
|
||||
|
||||
|
||||
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
||||
makeName :: LIdent -> Rn LIdent
|
||||
makeName (LIdent prefix) = do
|
||||
i <- gets var_counter
|
||||
let name = LIdent $ prefix ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
|
||||
pure name
|
||||
makeName :: String -> Rn String
|
||||
makeName prefix = do
|
||||
i <- gets var_counter
|
||||
let name = prefix ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt { var_counter = succ cxt.var_counter}
|
||||
pure name
|
||||
|
||||
nextNameTVar :: TVar -> Rn TVar
|
||||
nextNameTVar (MkTVar (LIdent s)) = do
|
||||
i <- gets tvar_counter
|
||||
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
|
||||
pure tvar
|
||||
nextNameTVar (MkTVar (LIdent s))= do
|
||||
i <- gets tvar_counter
|
||||
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter}
|
||||
pure tvar
|
||||
|
|
|
|||
206
src/Renamer/RenamerOld.hs
Normal file
206
src/Renamer/RenamerOld.hs
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use mapAndUnzipM" #-}
|
||||
|
||||
module Renamer.Renamer (rename) where
|
||||
|
||||
import Auxiliary (mapAccumM)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad (foldM)
|
||||
import Control.Monad.Except (ExceptT, MonadError, runExceptT,
|
||||
throwError)
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.State (MonadState, StateT, evalStateT, gets,
|
||||
modify)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Tuple.Extra (dupe)
|
||||
import Grammar.Abs
|
||||
|
||||
-- | Rename all variables and local binds
|
||||
rename :: Program -> Either String Program
|
||||
rename (Program defs) = Program <$> renameDefs defs
|
||||
|
||||
renameDefs :: [Def] -> Either String [Def]
|
||||
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
|
||||
where
|
||||
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
|
||||
|
||||
renameDef :: Def -> Rn Def
|
||||
renameDef = \case
|
||||
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
|
||||
DBind bind -> DBind . snd <$> renameBind initNames bind
|
||||
DData (Data (TData cname types) constrs) -> do
|
||||
tvars_ <- tvars
|
||||
tvars' <- mapM nextNameTVar tvars_
|
||||
let tvars_lt = zip tvars_ tvars'
|
||||
typ' = map (substituteTVar tvars_lt) types
|
||||
constrs' = map (renameConstr tvars_lt) constrs
|
||||
pure . DData $ Data (TData cname typ') constrs'
|
||||
where
|
||||
tvars = concat <$> mapM (collectTVars []) types
|
||||
collectTVars :: [TVar] -> Type -> Rn [TVar]
|
||||
collectTVars tvars = \case
|
||||
TAll tvar t -> collectTVars (tvar : tvars) t
|
||||
TData _ _ -> return tvars
|
||||
-- Should be monad error
|
||||
TVar v -> return [v]
|
||||
_ -> throwError ("Bad data type definition: " ++ show types)
|
||||
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
|
||||
|
||||
renameConstr :: [(TVar, TVar)] -> Inj -> Inj
|
||||
renameConstr new_types (Inj name typ) =
|
||||
Inj name $ substituteTVar new_types typ
|
||||
|
||||
renameBind :: Names -> Bind -> Rn (Names, Bind)
|
||||
renameBind old_names (Bind name vars rhs) = do
|
||||
(new_names, vars') <- newNames old_names (coerce vars)
|
||||
(newer_names, rhs') <- renameExp new_names rhs
|
||||
pure (newer_names, Bind name (coerce vars') rhs')
|
||||
|
||||
substituteTVar :: [(TVar, TVar)] -> Type -> Type
|
||||
substituteTVar new_names typ = case typ of
|
||||
TLit _ -> typ
|
||||
TVar tvar
|
||||
| Just tvar' <- lookup tvar new_names ->
|
||||
TVar tvar'
|
||||
| otherwise ->
|
||||
typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar t
|
||||
| Just tvar' <- lookup tvar new_names ->
|
||||
TAll tvar' $ substitute' t
|
||||
| otherwise ->
|
||||
TAll tvar $ substitute' t
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error ("Impossible " ++ show typ)
|
||||
where
|
||||
substitute' = substituteTVar new_names
|
||||
|
||||
initCxt :: Cxt
|
||||
initCxt = Cxt 0 0
|
||||
|
||||
data Cxt = Cxt
|
||||
{ var_counter :: Int
|
||||
, tvar_counter :: Int
|
||||
}
|
||||
|
||||
-- | Rename monad. State holds the number of renamed names.
|
||||
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
|
||||
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
|
||||
|
||||
-- | Maps old to new name
|
||||
type Names = Map LIdent LIdent
|
||||
|
||||
renameExp :: Names -> Exp -> Rn (Names, Exp)
|
||||
renameExp old_names = \case
|
||||
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
|
||||
EInj n -> pure (old_names, EInj n)
|
||||
ELit lit -> pure (old_names, ELit lit)
|
||||
EApp e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EApp e1' e2')
|
||||
EAdd e1 e2 -> do
|
||||
(env1, e1') <- renameExp old_names e1
|
||||
(env2, e2') <- renameExp old_names e2
|
||||
pure (Map.union env1 env2, EAdd e1' e2')
|
||||
|
||||
-- TODO fix shadowing
|
||||
ELet bind e -> do
|
||||
(new_names, bind') <- renameBind old_names bind
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', ELet bind' e')
|
||||
EAbs par e -> do
|
||||
(new_names, par') <- newName old_names (coerce par)
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', EAbs (coerce par') e')
|
||||
EAnn e t -> do
|
||||
(new_names, e') <- renameExp old_names e
|
||||
t' <- renameTVars t
|
||||
pure (new_names, EAnn e' t')
|
||||
ECase e injs -> do
|
||||
(new_names, e') <- renameExp old_names e
|
||||
(new_names', injs') <- renameBranches new_names injs
|
||||
pure (new_names', ECase e' injs')
|
||||
|
||||
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
|
||||
renameBranches ns xs = do
|
||||
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
|
||||
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||
|
||||
renameBranch :: Names -> Branch -> Rn (Names, Branch)
|
||||
renameBranch ns (Branch init e) = do
|
||||
(new_names, init') <- renamePattern ns init
|
||||
(new_names', e') <- renameExp new_names e
|
||||
return (new_names', Branch init' e')
|
||||
|
||||
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
|
||||
renamePattern ns i = case i of
|
||||
PInj cs ps -> do
|
||||
(ns_new, ps) <- renamePatterns ns ps
|
||||
return (ns_new, PInj cs ps)
|
||||
rest -> return (ns, rest)
|
||||
|
||||
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
|
||||
renamePatterns ns xs = do
|
||||
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
|
||||
if null new_names then return (mempty, xs') else return (head new_names, xs')
|
||||
|
||||
renameTVars :: Type -> Rn Type
|
||||
renameTVars typ = case typ of
|
||||
TAll tvar t -> do
|
||||
tvar' <- nextNameTVar tvar
|
||||
t' <- renameTVars $ substitute tvar tvar' t
|
||||
pure $ TAll tvar' t'
|
||||
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
|
||||
_ -> pure typ
|
||||
|
||||
substitute ::
|
||||
TVar -> -- α
|
||||
TVar -> -- α_n
|
||||
Type -> -- A
|
||||
Type -- [α_n/α]A
|
||||
substitute tvar1 tvar2 typ = case typ of
|
||||
TLit _ -> typ
|
||||
TVar tvar'
|
||||
| tvar' == tvar1 -> TVar tvar2
|
||||
| otherwise -> typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar t -> TAll tvar $ substitute' t
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
_ -> error "Impossible"
|
||||
where
|
||||
substitute' = substitute tvar1 tvar2
|
||||
|
||||
-- | Create a new name and add it to name environment.
|
||||
newName :: Names -> LIdent -> Rn (Names, LIdent)
|
||||
newName env old_name = do
|
||||
new_name <- makeName old_name
|
||||
pure (Map.insert old_name new_name env, new_name)
|
||||
|
||||
-- | Create multiple names and add them to the name environment
|
||||
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
|
||||
newNames = mapAccumM newName
|
||||
|
||||
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
|
||||
makeName :: LIdent -> Rn LIdent
|
||||
makeName (LIdent prefix) = do
|
||||
i <- gets var_counter
|
||||
let name = LIdent $ prefix ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
|
||||
pure name
|
||||
|
||||
nextNameTVar :: TVar -> Rn TVar
|
||||
nextNameTVar (MkTVar (LIdent s)) = do
|
||||
i <- gets tvar_counter
|
||||
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
|
||||
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
|
||||
pure tvar
|
||||
73
src/TypeChecker/RemoveTEVar.hs
Normal file
73
src/TypeChecker/RemoveTEVar.hs
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.RemoveTEVar where
|
||||
|
||||
import Control.Applicative (Applicative (liftA2), liftA3)
|
||||
import Control.Arrow (Arrow (second))
|
||||
import Control.Monad.Error (MonadError (throwError))
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
||||
class RemoveTEVar a b where
|
||||
rmTEVar :: a -> Err b
|
||||
|
||||
instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where
|
||||
rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs
|
||||
|
||||
instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where
|
||||
rmTEVar = \case
|
||||
T.DBind bind -> T.DBind <$> rmTEVar bind
|
||||
T.DData dat -> T.DData <$> rmTEVar dat
|
||||
|
||||
instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where
|
||||
rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs)
|
||||
|
||||
instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where
|
||||
rmTEVar exp = case exp of
|
||||
T.EVar name -> pure $ T.EVar name
|
||||
T.EInj name -> pure $ T.EInj name
|
||||
T.ELit lit -> pure $ T.ELit lit
|
||||
T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e)
|
||||
T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
|
||||
T.EAdd e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
|
||||
T.EAbs name e -> T.EAbs name <$> rmTEVar e
|
||||
T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches)
|
||||
|
||||
instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where
|
||||
rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e)
|
||||
|
||||
instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where
|
||||
rmTEVar = \case
|
||||
T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t
|
||||
T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t
|
||||
T.PCatch -> pure T.PCatch
|
||||
T.PEnum name -> pure $ T.PEnum name
|
||||
T.PInj name ps -> T.PInj name <$> rmTEVar ps
|
||||
|
||||
instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where
|
||||
rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs)
|
||||
|
||||
instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where
|
||||
rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ
|
||||
|
||||
instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where
|
||||
rmTEVar = secondM rmTEVar
|
||||
|
||||
instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where
|
||||
rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ)
|
||||
|
||||
instance RemoveTEVar a b => RemoveTEVar [a] [b] where
|
||||
rmTEVar = mapM rmTEVar
|
||||
|
||||
instance RemoveTEVar Type T.Type where
|
||||
rmTEVar = \case
|
||||
TLit lit -> pure $ T.TLit (coerce lit)
|
||||
TVar tvar -> pure $ T.TVar tvar
|
||||
TData name typs -> T.TData (coerce name) <$> rmTEVar typs
|
||||
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
|
||||
TAll tvar t -> T.TAll tvar <$> rmTEVar t
|
||||
TEVar _ -> throwError "NewType TEVar!"
|
||||
858
src/TypeChecker/TypeCheckerBidir.hs
Normal file
858
src/TypeChecker/TypeCheckerBidir.hs
Normal file
|
|
@ -0,0 +1,858 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
|
||||
|
||||
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
|
||||
|
||||
import Auxiliary (maybeToRightM, snoc)
|
||||
import Control.Applicative (Alternative, Applicative (liftA2),
|
||||
(<|>))
|
||||
import Control.Monad.Except (ExceptT, MonadError (throwError),
|
||||
runExceptT, unless, zipWithM,
|
||||
zipWithM_)
|
||||
import Control.Monad.State (MonadState (get, put), State,
|
||||
evalState, gets, modify)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (intercalate)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (fromMaybe, isNothing)
|
||||
import Data.Sequence (Seq (..))
|
||||
import qualified Data.Sequence as S
|
||||
import Data.Tuple.Extra (second, secondM)
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM
|
||||
import Grammar.Print (printTree)
|
||||
import Prelude hiding (exp, id)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
||||
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
|
||||
-- https://doi.org/10.1145/2500365.2500582
|
||||
|
||||
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
|
||||
| EnvTVar TVar -- ^ Universal type variable. α
|
||||
| EnvTEVar TEVar -- ^ Existential unsolved type variable. ά
|
||||
| EnvTEVarSolved TEVar Type -- ^ Existential solved type variable. ά = τ
|
||||
| EnvMark TEVar -- ^ Scoping Marker. ▶ ά
|
||||
deriving (Eq, Show)
|
||||
|
||||
type Env = Seq EnvElem
|
||||
|
||||
-- | Ordered context
|
||||
-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A
|
||||
data Cxt = Cxt
|
||||
{ env :: Env -- ^ Local scope context Γ
|
||||
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
|
||||
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
|
||||
, next_tevar :: Int -- ^ Counter to distinguish ά
|
||||
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K
|
||||
} deriving (Show, Eq)
|
||||
|
||||
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
|
||||
deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String)
|
||||
|
||||
typecheck :: Program -> Err (T.Program' Type)
|
||||
typecheck (Program defs) = do
|
||||
datatypes <- mapM typecheckDataType [ d | DData d <- defs ]
|
||||
|
||||
|
||||
let initCxt = Cxt
|
||||
{ env = mempty
|
||||
, sig = Map.fromList [ (name, t)
|
||||
| DSig' name t <- defs
|
||||
]
|
||||
, binds = Map.fromList [ (name, foldr EAbs rhs vars)
|
||||
| DBind' name vars rhs <- defs
|
||||
]
|
||||
, next_tevar = 0
|
||||
, data_injs = Map.fromList [ (name, typ)
|
||||
| Data _ injs <- datatypes
|
||||
, Inj name typ <- injs
|
||||
]
|
||||
}
|
||||
|
||||
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt;
|
||||
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds'
|
||||
where
|
||||
binds = [ b | DBind b <- defs ]
|
||||
coerceData = map (\(Data t injs) -> T.Data t $ map
|
||||
(\(Inj name typ) -> T.Inj (coerce name) typ) injs)
|
||||
|
||||
|
||||
typecheckBind :: Bind -> Tc (T.Bind' Type)
|
||||
typecheckBind (Bind name vars rhs) = do
|
||||
bind' <- lookupSig name >>= \case
|
||||
-- TODO These Judgment aren't accurate
|
||||
-- (f:A → B) ∈ Γ
|
||||
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
|
||||
---------------------------
|
||||
-- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ
|
||||
Just t -> do
|
||||
(rhs', _) <- check (foldr EAbs rhs vars) t
|
||||
pure (T.Bind (coerce name, t) (coerce vars') (rhs', t))
|
||||
where
|
||||
vars' = zip vars $ getVars t
|
||||
|
||||
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
|
||||
-- ------------------------------
|
||||
-- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ
|
||||
Nothing -> do
|
||||
(e, t) <- infer $ foldr EAbs rhs vars
|
||||
t' <- applyEnv t
|
||||
e' <- applyEnvExp e
|
||||
let rhs' = skipLambdas (length vars) e'
|
||||
vars' = zip vars $ getVars t'
|
||||
pure (T.Bind (coerce name, t') (coerce vars') (rhs', t'))
|
||||
env <- gets env
|
||||
unless (isComplete env) err
|
||||
putEnv Empty
|
||||
pure bind'
|
||||
where
|
||||
err = throwError $ unlines
|
||||
[ "Type inference failed: " ++ printTree (Bind name vars rhs)
|
||||
, "Did you forget to add type annotation to a polymorphic function?"
|
||||
]
|
||||
|
||||
typecheckDataType :: Data -> Err Data
|
||||
typecheckDataType (Data typ injs) = do
|
||||
(name, tvars) <- go [] typ
|
||||
injs' <- mapM (\i -> typecheckInj i name tvars) injs
|
||||
pure (Data typ injs')
|
||||
where
|
||||
go tvars = \case
|
||||
TAll tvar t -> go (tvar:tvars) t
|
||||
TData name typs
|
||||
| Right tvars' <- mapM toTVar typs
|
||||
, all (`elem` tvars) tvars'
|
||||
-> pure (name, tvars')
|
||||
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
|
||||
|
||||
typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj
|
||||
typecheckInj (Inj inj_name inj_typ) name tvars
|
||||
| not $ boundTVars tvars inj_typ
|
||||
= throwError "Unbound type variables"
|
||||
| TData name' typs <- getReturn inj_typ
|
||||
, name' == name
|
||||
, Right tvars' <- mapM toTVar typs
|
||||
, tvars' == tvars
|
||||
= pure (Inj inj_name $ foldr TAll inj_typ tvars)
|
||||
| otherwise
|
||||
= throwError $ unwords
|
||||
["Bad type constructor: ", show name
|
||||
, "\nExpected: ", ppT . TData name $ map TVar tvars
|
||||
, "\nActual: ", ppT $ getReturn inj_typ
|
||||
]
|
||||
where
|
||||
boundTVars :: [TVar] -> Type -> Bool
|
||||
boundTVars tvars' = \case
|
||||
TAll tvar t -> boundTVars (tvar:tvars') t
|
||||
TFun t1 t2 -> on (&&) (boundTVars tvars') t1 t2
|
||||
TVar tvar -> elem tvar tvars'
|
||||
TData _ typs -> all (boundTVars tvars) typs
|
||||
TLit _ -> True
|
||||
TEVar _ -> error "TEVar in data type declaration"
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Subtyping rules
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- | Γ ⊢ A <: B ⊣ Δ
|
||||
-- Under input context Γ, type A is a subtype of B, with output context ∆
|
||||
subtype :: Type -> Type -> Tc ()
|
||||
subtype t1 t2 = case (t1, t2) of
|
||||
|
||||
(TLit lit1, TLit lit2) | lit1 == lit2 -> pure ()
|
||||
|
||||
-- -------------------- <:Var
|
||||
-- Γ[α] ⊢ α <: α ⊣ Γ[α]
|
||||
(TVar tvar1, TVar tvar2) | tvar1 == tvar2 -> pure ()
|
||||
|
||||
-- -------------------- <:Exvar
|
||||
-- Γ[ά] ⊢ ά <: ά ⊣ Γ[ά]
|
||||
(TEVar tevar1, TEVar tevar2) | tevar1 == tevar2 -> pure ()
|
||||
|
||||
-- Γ ⊢ B₁ <: A₁ ⊣ Θ Θ ⊢ [Θ]A₂ <: [Θ]B₂ ⊣ Δ
|
||||
-- ----------------------------------------- <:→
|
||||
-- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ
|
||||
(TFun a1 a2, TFun b1 b2) -> do
|
||||
subtype b1 a1
|
||||
a2' <- applyEnv a2
|
||||
b2' <- applyEnv b2
|
||||
subtype a2' b2'
|
||||
|
||||
-- Γ, α ⊢ A <: B ⊣ Δ,α,Θ
|
||||
-- --------------------- <:∀R
|
||||
-- Γ ⊢ A <: ∀α. B ⊣ Δ
|
||||
(a, TAll tvar b) -> do
|
||||
let env_tvar = EnvTVar tvar
|
||||
insertEnv env_tvar
|
||||
subtype a b
|
||||
dropTrailing env_tvar
|
||||
|
||||
-- Γ,▶ ά,ά ⊢ [ά/α]A <: B ⊣ Δ,▶ ά,Θ
|
||||
-- ------------------------------- <:∀L
|
||||
-- Γ ⊢ ∀α.A <: B ⊣ Δ
|
||||
(TAll tvar a, b) -> do
|
||||
tevar <- fresh
|
||||
let env_marker = EnvMark tevar
|
||||
env_tevar = EnvTEVar tevar
|
||||
insertEnv env_marker
|
||||
insertEnv env_tevar
|
||||
let a' = substitute tvar tevar a
|
||||
subtype a' b
|
||||
dropTrailing env_marker
|
||||
|
||||
-- ά ∉ FV(A) Γ[ά] ⊢ ά :=< A ⊣ Δ
|
||||
-- ------------------------------ <:instantiateL
|
||||
-- Γ[ά] ⊢ ά <: A ⊣ Δ
|
||||
(TEVar tevar, typ) | notElem tevar $ frees typ -> instantiateL tevar typ
|
||||
|
||||
-- ά ∉ FV(A) Γ[ά] ⊢ A =:< ά ⊣ Δ
|
||||
-- ------------------------------ <:instantiateR
|
||||
-- Γ[ά] ⊢ A <: ά ⊣ Δ
|
||||
(typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar
|
||||
|
||||
(TData name1 typs1, TData name2 typs2)
|
||||
|
||||
-- D₁ = D₂
|
||||
-- ----------------
|
||||
-- Γ ⊢ D₁ () <: D₂ ()
|
||||
| name1 == name2
|
||||
, [] <- typs1
|
||||
, [] <- typs2
|
||||
-> pure ()
|
||||
|
||||
-- Γ ⊢ ά₁ <: έ₁ ⊣ Θ₁
|
||||
-- ...
|
||||
-- D₁ = D₂ Θₙ₋₁ ⊢ [Θₙ₋₁]άₙ <: [Θₙ₋₁]έₙ ⊣ Δ
|
||||
-- -------------------------------------------
|
||||
-- Γ ⊢ D (ά₁ ‥ άₙ) <: D (έ₁ ‥ έₙ) ⊣ Δ
|
||||
| name1 == name2
|
||||
, t1:t1s <- typs1
|
||||
, t2:t2s <- typs2
|
||||
-> do
|
||||
subtype t1 t2
|
||||
zipWithM_ go t1s t2s
|
||||
where
|
||||
go t1' t2' = do
|
||||
t1'' <- applyEnv t1'
|
||||
t2'' <- applyEnv t2'
|
||||
subtype t1'' t2''
|
||||
|
||||
_ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"]
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Instantiation rules
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- | Γ ⊢ ά :=< A ⊣ Δ
|
||||
-- Under input context Γ, instantiate ά such that ά <: A, with output context ∆
|
||||
instantiateL :: TEVar -> Type -> Tc ()
|
||||
instantiateL tevar typ = gets env >>= go
|
||||
where
|
||||
go env
|
||||
|
||||
-- Γ ⊢ τ
|
||||
-- ----------------------------- InstLSolve
|
||||
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
|
||||
| isMono typ
|
||||
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
|
||||
, Right _ <- wellFormed env_l typ
|
||||
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
|
||||
|
||||
| TEVar tevar' <- typ = instReach tevar tevar'
|
||||
|
||||
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ =:< ά₁ ⊣ Θ Θ ⊢ ά₂ :=< [Θ]A₂ ⊣ Δ
|
||||
-- ------------------------------------------------------- InstLArr
|
||||
-- Γ[ά] ⊢ ά :=< A₁ → A₂ ⊣ Δ
|
||||
| TFun t1 t2 <- typ = do
|
||||
tevar1 <- fresh
|
||||
tevar2 <- fresh
|
||||
insertEnv $ EnvTEVar tevar2
|
||||
insertEnv $ EnvTEVar tevar1
|
||||
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
|
||||
instantiateR t1 tevar1
|
||||
instantiateL tevar2 =<< applyEnv t2
|
||||
|
||||
-- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ'
|
||||
-- ------------------------- InstLAIIR
|
||||
-- Γ[ά] ⊢ ά :=< ∀ε.Ε ⊣ Δ
|
||||
| TAll tvar t <- typ = do
|
||||
instantiateL tevar t
|
||||
let (env_l, _) = splitOn (EnvTVar tvar) env
|
||||
putEnv env_l
|
||||
|
||||
| otherwise = error $ "Trying to instantiateL: " ++ ppT (TEVar tevar)
|
||||
++ " <: " ++ ppT typ
|
||||
|
||||
-- | Γ ⊢ A =:< ά ⊣ Δ
|
||||
-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆
|
||||
instantiateR :: Type -> TEVar -> Tc ()
|
||||
instantiateR typ tevar = gets env >>= go
|
||||
where
|
||||
go env
|
||||
|
||||
-- Γ ⊢ τ
|
||||
-- ----------------------------- InstRSolve
|
||||
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
|
||||
| isMono typ
|
||||
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
|
||||
, Right _ <- wellFormed env_l typ
|
||||
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
|
||||
|
||||
| TEVar tevar' <- typ = instReach tevar tevar'
|
||||
|
||||
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
|
||||
-- ------------------------------------------------------- InstRArr
|
||||
-- Γ[ά] ⊢ ά =:< A₁ → A₂ ⊣ Δ
|
||||
| TFun t1 t2 <- typ = do
|
||||
tevar1 <- fresh
|
||||
tevar2 <- fresh
|
||||
insertEnv $ EnvTEVar tevar2
|
||||
insertEnv $ EnvTEVar tevar1
|
||||
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
|
||||
instantiateL tevar1 t1
|
||||
t2' <- applyEnv t2
|
||||
instantiateR t2' tevar2
|
||||
|
||||
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
|
||||
-- ---------------------------------- InstRAIIL
|
||||
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
|
||||
| TAll tvar t <- typ = do
|
||||
tevar' <- fresh
|
||||
insertEnv $ EnvMark tevar'
|
||||
insertEnv $ EnvTVar tvar
|
||||
let t' = substitute tvar tevar' t
|
||||
instantiateR t' tevar
|
||||
let (env_l, _) = splitOn (EnvTVar tvar) env
|
||||
putEnv env_l
|
||||
|
||||
|
||||
| otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: "
|
||||
++ ppT (TEVar tevar)
|
||||
|
||||
|
||||
-- ----------------------------- InstLReach
|
||||
-- Γ[ά][έ] ⊢ ά :=< έ ⊣ Γ[ά][έ=ά]
|
||||
--
|
||||
-- ----------------------------- InstRReach
|
||||
-- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά]
|
||||
instReach :: TEVar -> TEVar -> Tc ()
|
||||
instReach tevar tevar' = do
|
||||
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar') . env)
|
||||
let env_solved = EnvTEVarSolved tevar' $ TEVar tevar
|
||||
putEnv $ (env_l :|> env_solved) <> env_r
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Typing rules
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- | Γ ⊢ e ↑ A ⊣ Δ
|
||||
-- Under input context Γ, e checks against input type A, with output context ∆
|
||||
check :: Exp -> Type -> Tc (T.ExpT' Type)
|
||||
check exp typ
|
||||
|
||||
-- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ
|
||||
-- ------------------- ∀I
|
||||
-- Γ ⊢ e ↑ ∀α.A ⊣ Δ
|
||||
| TAll tvar t <- typ = do
|
||||
let env_tvar = EnvTVar tvar
|
||||
insertEnv env_tvar
|
||||
exp' <- check exp t
|
||||
(env_l, _) <- gets (splitOn env_tvar . env)
|
||||
putEnv env_l
|
||||
pure exp'
|
||||
|
||||
-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ
|
||||
-- --------------------------- →I
|
||||
-- Γ ⊢ λx.e ↑ A → B ⊣ Δ
|
||||
| EAbs name e <- exp
|
||||
, TFun t1 t2 <- typ = do
|
||||
let env_id = EnvVar name t1
|
||||
insertEnv env_id
|
||||
e' <- check e t2
|
||||
(env_l, _) <- gets (splitOn env_id . env)
|
||||
putEnv env_l
|
||||
pure (T.EAbs (coerce name) e', typ)
|
||||
|
||||
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
|
||||
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
|
||||
-- ---------------------------------------
|
||||
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
|
||||
| ECase scrut branches <- exp = do
|
||||
(scrut', t_scrut) <- infer scrut
|
||||
t_scrut' <- applyEnv t_scrut
|
||||
typ' <- applyEnv typ
|
||||
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
|
||||
pure (T.ECase (scrut', t_scrut') branches', typ')
|
||||
|
||||
| otherwise = subsumption
|
||||
where
|
||||
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
|
||||
-- -------------------------------------- Sub
|
||||
-- Γ ⊢ e ↑ B ⊣ Δ
|
||||
subsumption = do
|
||||
(exp', t) <- infer exp
|
||||
exp'' <- applyEnvExp exp'
|
||||
t' <- applyEnv t
|
||||
typ' <- applyEnv typ
|
||||
subtype t' typ'
|
||||
pure (exp'', t')
|
||||
|
||||
-- | Γ ⊢ e ↓ A ⊣ Δ
|
||||
-- Under input context Γ, e infers output type A, with output context ∆
|
||||
infer :: Exp -> Tc (T.ExpT' Type)
|
||||
infer = \case
|
||||
|
||||
ELit lit -> pure (T.ELit lit, inferLit lit)
|
||||
|
||||
-- (x : A) ∈ Γ
|
||||
-- ------------- Var
|
||||
-- Γ ⊢ x ↓ A ⊣ Γ
|
||||
EVar name -> do
|
||||
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
|
||||
Just t -> pure t
|
||||
Nothing -> do
|
||||
e <- maybeToRightM
|
||||
("Unbound variable " ++ show name)
|
||||
=<< lookupBind name
|
||||
snd <$> infer e
|
||||
pure (T.EVar (coerce name), t)
|
||||
|
||||
EInj name -> do
|
||||
t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name
|
||||
pure (T.EInj $ coerce name, t)
|
||||
|
||||
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
|
||||
-- --------------------- Anno
|
||||
-- Γ ⊢ (e : A) ↓ A ⊣ Δ
|
||||
EAnn e t -> do
|
||||
_ <- gets $ (`wellFormed` t) . env
|
||||
(e', _) <- check e t
|
||||
pure (e', t)
|
||||
|
||||
-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ
|
||||
-- ----------------------------------- →E
|
||||
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
|
||||
EApp e1 e2 -> do
|
||||
(e1', t) <- infer e1
|
||||
t' <- applyEnv t
|
||||
e1'' <- applyEnvExp e1'
|
||||
(e2', t'') <- apply t' e2
|
||||
pure (T.EApp (e1'', t) e2', t'')
|
||||
|
||||
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
|
||||
-- ------------------------------- →I
|
||||
-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ
|
||||
EAbs name e -> do
|
||||
tevar1 <- fresh
|
||||
tevar2 <- fresh
|
||||
insertEnv $ EnvTEVar tevar1
|
||||
insertEnv $ EnvTEVar tevar2
|
||||
let env_id = EnvVar name (TEVar tevar1)
|
||||
insertEnv env_id
|
||||
e' <- check e $ TEVar tevar2
|
||||
dropTrailing env_id
|
||||
let t_exp = on TFun TEVar tevar1 tevar2
|
||||
pure (T.EAbs (coerce name) e', t_exp)
|
||||
|
||||
|
||||
-- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ
|
||||
-- -------------------------------------------- LetI
|
||||
-- Γ ⊢ let x=e in e' ↑ C ⊣ Δ
|
||||
ELet (Bind name [] rhs) e -> do -- TODO vars
|
||||
(rhs', t_rhs) <- infer rhs
|
||||
let env_id = EnvVar name t_rhs
|
||||
insertEnv env_id
|
||||
(e', t) <- infer e
|
||||
(env_l, _) <- gets (splitOn env_id . env)
|
||||
putEnv env_l
|
||||
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
|
||||
|
||||
-- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int
|
||||
-- --------------------------- +I
|
||||
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
|
||||
EAdd e1 e2 -> do
|
||||
cxt <- get
|
||||
let t = TLit "Int"
|
||||
e1' <- check e1 t
|
||||
put cxt
|
||||
e2' <- check e2 t
|
||||
pure (T.EAdd e1' e2', t)
|
||||
|
||||
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
|
||||
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
|
||||
-- Instantiate existential type variables until there is an arrow type.
|
||||
apply :: Type -> Exp -> Tc (T.ExpT' Type, Type)
|
||||
apply typ exp = case typ of
|
||||
|
||||
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
|
||||
-- ------------------------ ∀App
|
||||
-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ
|
||||
TAll tvar t -> do
|
||||
tevar <- fresh
|
||||
insertEnv $ EnvTEVar tevar
|
||||
let t' = substitute tvar tevar t
|
||||
apply t' exp
|
||||
|
||||
-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ
|
||||
-- ------------------------------- άApp
|
||||
-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ
|
||||
TEVar tevar -> do
|
||||
tevar1 <- fresh
|
||||
tevar2 <- fresh
|
||||
let env_tevar1 = EnvTEVar tevar1
|
||||
env_tevar2 = EnvTEVar tevar2
|
||||
t_fun = on TFun TEVar tevar1 tevar2
|
||||
env_tevar_solved = EnvTEVarSolved tevar t_fun
|
||||
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env)
|
||||
putEnv $
|
||||
(env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r
|
||||
expT' <- check exp $ TEVar tevar1
|
||||
pure (expT', TEVar tevar2)
|
||||
|
||||
-- Γ ⊢ e ↑ A ⊣ Δ
|
||||
-- --------------------- →App
|
||||
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
|
||||
TFun t1 t2 -> do
|
||||
expt' <- check exp t1
|
||||
pure (expt', t2)
|
||||
|
||||
_ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp)
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Pattern matching
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
|
||||
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
|
||||
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
|
||||
checkBranch (Branch patt exp) t_patt t_exp = do
|
||||
env_marker <- EnvMark <$> fresh
|
||||
insertEnv env_marker
|
||||
patt' <- checkPattern patt t_patt
|
||||
t_exp' <- applyEnv t_exp
|
||||
(exp, t_exp) <- check exp t_exp'
|
||||
(env_l, _) <- gets (splitOn env_marker . env)
|
||||
putEnv env_l
|
||||
pure (T.Branch patt' (exp, t_exp))
|
||||
|
||||
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
||||
checkPattern patt t_patt = case patt of
|
||||
PVar x -> do
|
||||
insertEnv $ EnvVar x t_patt
|
||||
pure (T.PVar (coerce x, dummy), dummy) -- TODO
|
||||
PCatch -> pure (T.PCatch, dummy) -- TODO
|
||||
PLit lit | inferLit lit == t_patt -> let
|
||||
t = inferLit lit
|
||||
in
|
||||
pure (T.PLit (lit, t), t)
|
||||
| otherwise -> throwError "Literal in pattern have wrong type"
|
||||
|
||||
PEnum name -> do
|
||||
t <- maybeToRightM ("Unknown constructor " ++ show name)
|
||||
=<< lookupInj name
|
||||
subtype t t_patt
|
||||
pure (T.PEnum (coerce name), dummy) -- TODO
|
||||
|
||||
PInj name ps -> do
|
||||
t <- maybeToRightM ("Unknown constructor " ++ show name)
|
||||
=<< lookupInj name
|
||||
let (t_ps, t_return) = partitionTypeWithForall t
|
||||
unless (length ps == length t_ps) $
|
||||
throwError "Wrong number of variables"
|
||||
subtype t_return t_patt
|
||||
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
|
||||
let ps'' = map fst ps' -- TODO
|
||||
pure (T.PInj (coerce name) ps'', dummy)
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Auxiliary
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
frees :: Type -> [TEVar]
|
||||
frees = \case
|
||||
TLit _ -> []
|
||||
TVar _ -> []
|
||||
TEVar tevar -> [tevar]
|
||||
TFun t1 t2 -> on (++) frees t1 t2
|
||||
TAll _ t -> frees t
|
||||
TData _ typs -> concatMap frees typs
|
||||
|
||||
-- | [ά/α]A
|
||||
substitute :: TVar -- α
|
||||
-> TEVar -- ά
|
||||
-> Type -- A
|
||||
-> Type -- [ά/α]A
|
||||
substitute tvar tevar typ = case typ of
|
||||
TLit _ -> typ
|
||||
TVar tvar' | tvar' == tvar -> TEVar tevar
|
||||
| otherwise -> typ
|
||||
TEVar _ -> typ
|
||||
TFun t1 t2 -> on TFun substitute' t1 t2
|
||||
TAll tvar' t -> TAll tvar' (substitute' t)
|
||||
TData name typs -> TData name $ map substitute' typs
|
||||
where
|
||||
substitute' = substitute tvar tevar
|
||||
|
||||
-- | Γ,x,Γ' → (Γ, Γ')
|
||||
splitOn :: EnvElem -> Env -> (Env, Env)
|
||||
splitOn x env = second (S.drop 1) $ S.breakl (==x) env
|
||||
|
||||
-- | Drop frontmost elements until and including element @x@.
|
||||
dropTrailing :: EnvElem -> Tc ()
|
||||
dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
|
||||
|
||||
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
|
||||
applyEnvExp exp = case exp of
|
||||
T.ELet (T.Bind id vars rhs) exp -> do
|
||||
id <- applyEnvId id
|
||||
vars' <- mapM applyEnvId vars
|
||||
rhs' <- applyEnvExpT rhs
|
||||
exp' <- applyEnvExpT exp
|
||||
pure $ T.ELet (T.Bind id vars' rhs') exp'
|
||||
T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2)
|
||||
T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2)
|
||||
T.EAbs name e -> T.EAbs name <$> applyEnvExpT e
|
||||
T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e)
|
||||
(mapM applyEnvBranch branches)
|
||||
_ -> pure exp
|
||||
where
|
||||
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
|
||||
applyEnvId = secondM applyEnv
|
||||
applyEnvBranch (T.Branch (p, t) e) = do
|
||||
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
|
||||
e' <- applyEnvExpT e
|
||||
pure $ T.Branch pt e'
|
||||
applyEnvPattern = \case
|
||||
T.PVar id -> T.PVar <$> applyEnvId id
|
||||
T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t
|
||||
T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps
|
||||
p -> pure p
|
||||
|
||||
applyEnv :: Type -> Tc Type
|
||||
applyEnv t = gets $ (`applyEnv'` t) . env
|
||||
|
||||
-- | [Γ]A. Applies context to type until fully applied.
|
||||
applyEnv' :: Env -> Type -> Type
|
||||
applyEnv' cxt typ | typ == typ' = typ'
|
||||
| otherwise = applyEnv' cxt typ'
|
||||
where
|
||||
typ' = case typ of
|
||||
TLit _ -> typ
|
||||
TData name typs -> TData name $ map (applyEnv' cxt) typs
|
||||
-- [Γ]α = α
|
||||
TVar _ -> typ
|
||||
-- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ
|
||||
-- [Γ[ά]]ά = [Γ[ά]]ά
|
||||
TEVar tevar -> fromMaybe typ $ findSolved tevar cxt
|
||||
-- [Γ](A → B) = [Γ]A → [Γ]B
|
||||
TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2
|
||||
-- [Γ](∀α. A) = (∀α. [Γ]A)
|
||||
TAll tvar t -> TAll tvar $ applyEnv' cxt t
|
||||
|
||||
findSolved :: TEVar -> Env -> Maybe Type
|
||||
findSolved _ Empty = Nothing
|
||||
findSolved tevar (xs :|> x) = case x of
|
||||
EnvTEVarSolved tevar' t | tevar == tevar' -> Just t
|
||||
_ -> findSolved tevar xs
|
||||
|
||||
-- | Γ ⊢ A
|
||||
-- Under context Γ, type A is well-formed
|
||||
wellFormed :: Env -> Type -> Err ()
|
||||
wellFormed env = \case
|
||||
TLit _ -> pure ()
|
||||
|
||||
-- -------- UvarWF
|
||||
-- Γ[α] ⊢ α
|
||||
TVar tvar -> unless (EnvTVar tvar `elem` env) $
|
||||
throwError ("Unbound type variable: " ++ show tvar)
|
||||
-- Γ ⊢ A Γ ⊢ B
|
||||
-- ------------- ArrowWF
|
||||
-- Γ ⊢ A → B
|
||||
TFun t1 t2 -> do { wellFormed env t1; wellFormed env t2 }
|
||||
|
||||
-- Γ,α ⊢ A
|
||||
-- -------- ForallWF
|
||||
-- Γ ⊢ ∀α.A
|
||||
TAll tvar t -> wellFormed (env :|> EnvTVar tvar) t
|
||||
|
||||
TEVar tevar
|
||||
-- ---------- EvarWF
|
||||
-- Γ[ά] ⊢ ά
|
||||
| EnvTEVar tevar `elem` env -> pure ()
|
||||
|
||||
-- ---------- SolvedEvarWF
|
||||
-- Γ[ά=τ] ⊢ ά
|
||||
| Just _ <- findSolved tevar env -> pure ()
|
||||
| otherwise -> throwError ("Can't find type: " ++ show tevar)
|
||||
|
||||
TData _ typs -> mapM_ (wellFormed env) typs
|
||||
|
||||
isMono :: Type -> Bool
|
||||
isMono = \case
|
||||
TAll{} -> False
|
||||
TFun t1 t2 -> on (&&) isMono t1 t2
|
||||
TData _ typs -> all isMono typs
|
||||
TVar _ -> True
|
||||
TEVar _ -> True
|
||||
TLit _ -> True
|
||||
|
||||
inferLit :: Lit -> Type
|
||||
inferLit = \case
|
||||
LInt _ -> TLit "Int"
|
||||
LChar _ -> TLit "Char"
|
||||
|
||||
fresh :: Tc TEVar
|
||||
fresh = do
|
||||
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar)
|
||||
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
|
||||
pure tevar
|
||||
|
||||
getVars :: Type -> [Type]
|
||||
getVars = fst . partitionType
|
||||
|
||||
getReturn :: Type -> Type
|
||||
getReturn = snd . partitionType
|
||||
|
||||
-- | Partion type into variable types and return type.
|
||||
--
|
||||
-- ∀a.∀b. a → (∀c. c → c) → b
|
||||
-- ([a, ∀c. c → c], b)
|
||||
--
|
||||
-- Unsure if foralls should be added to the return type or not.
|
||||
partitionType :: Type -> ([Type], Type)
|
||||
partitionType = go [] . skipForalls'
|
||||
where
|
||||
|
||||
go acc t = case t of
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> (acc, t)
|
||||
|
||||
skipForalls' :: Type -> Type
|
||||
skipForalls' = snd . skipForalls
|
||||
|
||||
skipForalls :: Type -> ([Type -> Type], Type)
|
||||
skipForalls = go []
|
||||
where
|
||||
go acc typ = case typ of
|
||||
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
||||
_ -> (acc, typ)
|
||||
|
||||
partitionTypeWithForall :: Type -> ([Type], Type)
|
||||
partitionTypeWithForall typ = (t_vars', t_return')
|
||||
where
|
||||
t_vars' = map (\t -> foldr applyForall t foralls) t_vars
|
||||
t_return' = foldr applyForall t_return foralls
|
||||
|
||||
applyForall fa t | usesTVar tvar t = fa t
|
||||
| otherwise = t
|
||||
where TAll tvar _ = fa t
|
||||
|
||||
(t_vars, t_return) = go [] typ'
|
||||
(foralls, typ') = skipForalls typ
|
||||
|
||||
|
||||
go acc t = case t of
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> (acc, t)
|
||||
|
||||
usesTVar :: TVar -> Type -> Bool
|
||||
usesTVar tvar = \case
|
||||
TLit _ -> False
|
||||
TVar tvar' | tvar' == tvar -> True
|
||||
| otherwise -> False
|
||||
TFun t1 t2 -> on (||) usesTVar' t1 t2
|
||||
TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar"
|
||||
| otherwise -> usesTVar' t
|
||||
TData _ typs -> any usesTVar' typs
|
||||
_ -> error "Impossible"
|
||||
where
|
||||
usesTVar' = usesTVar tvar
|
||||
|
||||
skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type
|
||||
skipLambdas i exp
|
||||
| i == 0 = exp
|
||||
| T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e
|
||||
| otherwise = error "Number of expected lambdas doesn't match expression"
|
||||
|
||||
isComplete :: Env -> Bool
|
||||
isComplete = isNothing . S.findIndexL unSolvedTEVar
|
||||
where
|
||||
unSolvedTEVar = \case
|
||||
EnvTEVar _ -> True
|
||||
_ -> False
|
||||
|
||||
toTVar :: Type -> Err TVar
|
||||
toTVar = \case
|
||||
TVar tvar -> pure tvar
|
||||
_ -> throwError "Not a type variable"
|
||||
|
||||
insertEnv :: EnvElem -> Tc ()
|
||||
insertEnv x = modifyEnv (:|> x)
|
||||
|
||||
lookupBind :: LIdent -> Tc (Maybe Exp)
|
||||
lookupBind x = gets (Map.lookup x . binds)
|
||||
|
||||
lookupSig :: LIdent -> Tc (Maybe Type)
|
||||
lookupSig x = gets (Map.lookup x . sig)
|
||||
|
||||
lookupEnv :: LIdent -> Tc (Maybe Type)
|
||||
lookupEnv x = gets (findId . env)
|
||||
where
|
||||
findId Empty = Nothing
|
||||
findId (ys :|> y) = case y of
|
||||
EnvVar x' t | x==x' -> Just t
|
||||
_ -> findId ys
|
||||
|
||||
lookupInj :: UIdent -> Tc (Maybe Type)
|
||||
lookupInj x = gets (Map.lookup x . data_injs)
|
||||
|
||||
putEnv :: Env -> Tc ()
|
||||
putEnv = modifyEnv . const
|
||||
|
||||
modifyEnv :: (Env -> Env) -> Tc ()
|
||||
modifyEnv f =
|
||||
modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env }
|
||||
|
||||
pattern DBind' name vars exp = DBind (Bind name vars exp)
|
||||
pattern DSig' name typ = DSig (Sig name typ)
|
||||
|
||||
dummy = TLit "Int"
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Debug
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
traceEnv s = do
|
||||
env <- gets env
|
||||
trace (s ++ " " ++ show env) pure ()
|
||||
|
||||
traceD s x = trace (s ++ " " ++ show x) pure ()
|
||||
|
||||
traceT s x = trace (s ++ " " ++ ppT x) pure ()
|
||||
|
||||
traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure ()
|
||||
|
||||
ppT = \case
|
||||
TLit (UIdent s) -> s
|
||||
TVar (MkTVar (LIdent s)) -> "α_" ++ s
|
||||
TFun t1 t2 -> ppT t1 ++ "→" ++ ppT t2
|
||||
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
|
||||
TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
|
||||
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
|
||||
++ " )"
|
||||
ppEnvElem = \case
|
||||
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
|
||||
EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s
|
||||
EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
|
||||
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t
|
||||
EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "ά_" ++ s
|
||||
|
||||
ppEnv = \case
|
||||
Empty -> "·"
|
||||
(xs :|> x) -> ppEnv xs ++ " (" ++ ppEnvElem x ++ ")"
|
||||
|
|
@ -1,36 +1,31 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
import Auxiliary
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr (
|
||||
Ctx (..),
|
||||
Env (..),
|
||||
Error,
|
||||
Infer,
|
||||
Subst,
|
||||
)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import Auxiliary
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
|
||||
Subst)
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty
|
||||
|
|
@ -78,7 +73,7 @@ checkData d = do
|
|||
|
||||
retType :: Type -> Type
|
||||
retType (TFun _ t2) = retType t2
|
||||
retType a = a
|
||||
retType a = a
|
||||
|
||||
checkPrg :: Program -> Infer T.Program
|
||||
checkPrg (Program bs) = do
|
||||
|
|
@ -105,7 +100,7 @@ preRun (x : xs) = case x of
|
|||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
Just _ -> preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def]
|
||||
|
|
@ -152,9 +147,9 @@ typeEq _ _ = False
|
|||
|
||||
skolem :: T.Type -> T.Type
|
||||
skolem (T.TVar (T.MkTVar a)) = T.TLit a
|
||||
skolem (T.TAll x t) = T.TAll x (skolem t)
|
||||
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
|
||||
skolem t = t
|
||||
skolem (T.TAll x t) = T.TAll x (skolem t)
|
||||
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
|
||||
skolem t = t
|
||||
|
||||
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
|
||||
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2
|
||||
|
|
@ -169,8 +164,8 @@ isMoreSpecificOrEq a b = a == b
|
|||
|
||||
isPoly :: Type -> Bool
|
||||
isPoly (TAll _ _) = True
|
||||
isPoly (TVar _) = True
|
||||
isPoly _ = False
|
||||
isPoly (TVar _) = True
|
||||
isPoly _ = False
|
||||
|
||||
inferExp :: Exp -> Infer T.ExpT
|
||||
inferExp e = do
|
||||
|
|
@ -183,7 +178,7 @@ class CollectTVars a where
|
|||
|
||||
instance CollectTVars Exp where
|
||||
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
|
||||
collectTypeVars _ = S.empty
|
||||
collectTypeVars _ = S.empty
|
||||
|
||||
instance CollectTVars Type where
|
||||
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
||||
|
|
@ -200,15 +195,15 @@ class NewType a b where
|
|||
|
||||
instance NewType Type T.Type where
|
||||
toNew = \case
|
||||
TLit i -> T.TLit $ coerce i
|
||||
TVar v -> T.TVar $ toNew v
|
||||
TLit i -> T.TLit $ coerce i
|
||||
TVar v -> T.TVar $ toNew v
|
||||
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
|
||||
TAll b t -> T.TAll (toNew b) (toNew t)
|
||||
TAll b t -> T.TAll (toNew b) (toNew t)
|
||||
TData i ts -> T.TData (coerce i) (map toNew ts)
|
||||
TEVar _ -> error "Should not exist after typechecker"
|
||||
TEVar _ -> error "Should not exist after typechecker"
|
||||
|
||||
instance NewType Lit T.Lit where
|
||||
toNew (LInt i) = T.LInt i
|
||||
toNew (LInt i) = T.LInt i
|
||||
toNew (LChar i) = T.LChar i
|
||||
|
||||
instance NewType Data T.Data where
|
||||
|
|
@ -422,12 +417,12 @@ generalize :: Map T.Ident T.Type -> T.Type -> T.Type
|
|||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||
where
|
||||
go :: [T.Ident] -> T.Type -> T.Type
|
||||
go [] t = t
|
||||
go [] t = t
|
||||
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
|
||||
removeForalls :: T.Type -> T.Type
|
||||
removeForalls (T.TAll _ t) = removeForalls t
|
||||
removeForalls (T.TAll _ t) = removeForalls t
|
||||
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
|
||||
removeForalls t = t
|
||||
removeForalls t = t
|
||||
|
||||
{- | Instantiate a polymorphic type. The free type variables are substituted
|
||||
with fresh ones.
|
||||
|
|
@ -477,10 +472,10 @@ instance SubstType T.Type where
|
|||
T.TLit a -> T.TLit a
|
||||
T.TVar (T.MkTVar a) -> case M.lookup a sub of
|
||||
Nothing -> T.TVar (T.MkTVar $ coerce a)
|
||||
Just t -> t
|
||||
Just t -> t
|
||||
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
|
||||
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
|
||||
Just _ -> apply sub t
|
||||
Just _ -> apply sub t
|
||||
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
|
||||
T.TData name a -> T.TData name (map (apply sub) a)
|
||||
instance FreeVars (Map T.Ident T.Type) where
|
||||
|
|
@ -513,10 +508,10 @@ instance SubstType T.Pattern where
|
|||
apply :: Subst -> T.Pattern -> T.Pattern
|
||||
apply s = \case
|
||||
T.PVar (iden, t) -> T.PVar (iden, apply s t)
|
||||
T.PLit (lit, t) -> T.PLit (lit, apply s t)
|
||||
T.PInj i ps -> T.PInj i $ apply s ps
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
T.PLit (lit, t) -> T.PLit (lit, apply s t)
|
||||
T.PInj i ps -> T.PInj i $ apply s ps
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
|
||||
instance SubstType a => SubstType [a] where
|
||||
apply s = map (apply s)
|
||||
|
|
@ -555,7 +550,7 @@ fresh = do
|
|||
|
||||
next :: Char -> Char
|
||||
next 'z' = 'a'
|
||||
next a = succ a
|
||||
next a = succ a
|
||||
|
||||
-- | Run the monadic action with an additional binding
|
||||
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
|
||||
|
|
@ -608,10 +603,10 @@ inferBranch (Branch pat expr) = do
|
|||
withPattern :: T.Pattern -> Infer a -> Infer a
|
||||
withPattern p ma = case p of
|
||||
T.PVar (x, t) -> withBinding x t ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
|
||||
inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
|
||||
inferPattern = \case
|
||||
|
|
@ -659,14 +654,14 @@ inferPattern = \case
|
|||
|
||||
flattenType :: T.Type -> [T.Type]
|
||||
flattenType (T.TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
flattenType a = [a]
|
||||
|
||||
typeLength :: T.Type -> Int
|
||||
typeLength (T.TFun a b) = typeLength a + typeLength b
|
||||
typeLength _ = 1
|
||||
typeLength _ = 1
|
||||
|
||||
litType :: Lit -> T.Type
|
||||
litType (LInt _) = int
|
||||
litType (LInt _) = int
|
||||
litType (LChar _) = char
|
||||
|
||||
int = T.TLit "Int"
|
||||
|
|
@ -681,8 +676,8 @@ partitionType = go []
|
|||
go acc 0 t = (acc, t)
|
||||
go acc i t = case t of
|
||||
TAll tvar t' -> second (TAll tvar) $ go acc i t'
|
||||
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
|
||||
_ -> error "Number of parameters and type doesn't match"
|
||||
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
|
||||
_ -> error "Number of parameters and type doesn't match"
|
||||
|
||||
exprErr :: Infer a -> Exp -> Infer a
|
||||
exprErr ma exp =
|
||||
|
|
@ -695,3 +690,4 @@ unzip4 =
|
|||
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
|
||||
)
|
||||
([], [], [], [])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,245 +1,135 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module TypeChecker.TypeCheckerIr (
|
||||
module TypeChecker.TypeCheckerIr,
|
||||
) where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Char (isDigit)
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map (Map)
|
||||
import Data.Set (Set)
|
||||
import Data.String qualified
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import Prelude qualified as C (Eq, Ord, Read, Show)
|
||||
module TypeChecker.TypeCheckerIr
|
||||
( module Grammar.Abs
|
||||
, module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
|
||||
newtype Ctx = Ctx {vars :: Map Ident Type}
|
||||
deriving (Show)
|
||||
import Data.String (IsString)
|
||||
import Grammar.Abs (Character (..), Lit (..), TVar (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
data Env = Env
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map Ident (Maybe Type)
|
||||
, constructors :: Map Ident Type
|
||||
, takenTypeVars :: Set Ident
|
||||
}
|
||||
deriving (Show)
|
||||
newtype Program' t = Program [Def' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
type Error = String
|
||||
type Subst = Map Ident Type
|
||||
|
||||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Data = Data Ident [Constructor]
|
||||
deriving (Show, Eq, Ord, Read)
|
||||
|
||||
data Constructor = Constructor Ident Type
|
||||
deriving (Show, Eq, Ord, Read)
|
||||
|
||||
newtype TVar = MkTVar Ident
|
||||
deriving (Show, Eq, Ord, Read)
|
||||
data Def' t = DBind (Bind' t)
|
||||
| DData (Data' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Type
|
||||
= TLit Ident
|
||||
| TVar TVar
|
||||
| TData Ident [Type]
|
||||
| TFun Type Type
|
||||
| TAll TVar Type
|
||||
| TData Ident [Type]
|
||||
deriving (Show, Eq, Ord, Read)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Exp
|
||||
= EId Ident
|
||||
| ELit Lit
|
||||
| ELet Bind ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
| EAbs Ident ExpT
|
||||
| ECase ExpT [Branch]
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
data Data' t = Data t [Inj' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Branch = Branch (Pattern, Type) ExpT
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Pattern = PVar Id | PLit (Lit, Type) | PInj Ident [Pattern] | PCatch | PEnum Ident
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Def = DBind Bind | DData Data
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
type Id = (Ident, Type)
|
||||
data Inj' t = Inj Ident t
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
newtype Ident = Ident String
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Data.String.IsString)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
|
||||
|
||||
data Lit = LInt Integer | LChar Char
|
||||
deriving (Show, Eq, Ord, Read)
|
||||
data Pattern' t
|
||||
= PVar (Id' t) -- TODO should be Ident
|
||||
| PLit (Lit, t) -- TODO should be Lit
|
||||
| PCatch
|
||||
| PEnum Ident
|
||||
| PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
data Exp' t
|
||||
= EVar Ident
|
||||
| EInj Ident
|
||||
| ELit Lit
|
||||
| ELet (Bind' t) (ExpT' t)
|
||||
| EApp (ExpT' t) (ExpT' t)
|
||||
| EAdd (ExpT' t) (ExpT' t)
|
||||
| EAbs Ident (ExpT' t)
|
||||
| ECase (ExpT' t) [Branch' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
type Id' t = (Ident, t)
|
||||
type ExpT' t = (Exp' t, t)
|
||||
|
||||
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
instance Print Ident where
|
||||
prt _ (Ident str) = doc . showString $ str
|
||||
prt i (Ident s) = prt i s
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
|
||||
|
||||
instance Print Data where
|
||||
prt i = \case
|
||||
Data type_ constructors -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 constructors, doc (showString "}")])
|
||||
|
||||
instance Print Constructor where
|
||||
prt i = \case
|
||||
Constructor uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
|
||||
|
||||
instance Print [Constructor] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i (DBind bind) = prt i bind
|
||||
prt i (DData d) = prt i d
|
||||
|
||||
instance Print Program where
|
||||
instance Print t => Print (Program' t) where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
instance Print Bind where
|
||||
prt i (Bind (name, t) args rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString "\n"
|
||||
, prt 0 name
|
||||
, prtIdPs 0 args
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
instance Print t => Print (Bind' t) where
|
||||
prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD
|
||||
[ prtSig sig
|
||||
, prt 0 name
|
||||
, prtIdPs 0 parms
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
|
||||
prtSig :: Print t => Id' t -> Doc
|
||||
prtSig (name, t) = concatD [ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ";"
|
||||
]
|
||||
|
||||
prtIdPs :: Int -> [Id] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
||||
instance Print t => Print (ExpT' t) where
|
||||
prt i (e, t) = concatD [ doc $ showString "("
|
||||
, prt i e
|
||||
, doc $ showString ","
|
||||
, prt i t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
prtId :: Int -> Id -> Doc
|
||||
prtId i (name, t) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
instance Print t => Print [Bind' t] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
prtIdP :: Int -> Id -> Doc
|
||||
prtIdP i (name, t) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
prtIdPs :: Print t => Int -> [Id' t] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prt i)
|
||||
|
||||
instance Print Exp where
|
||||
prt i = \case
|
||||
EId n -> prPrec i 3 $ concatD [prt 0 n]
|
||||
ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
|
||||
ELet bs e ->
|
||||
prPrec i 3 $
|
||||
concatD
|
||||
[ doc $ showString "let"
|
||||
, prt 0 bs
|
||||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
]
|
||||
EApp e1 e2 ->
|
||||
prPrec i 2 $
|
||||
concatD
|
||||
[ prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
EAdd e1 e2 ->
|
||||
prPrec i 1 $
|
||||
concatD
|
||||
[ doc $ showString "@"
|
||||
, prt 1 e1
|
||||
, doc $ showString "+"
|
||||
, prt 2 e2
|
||||
]
|
||||
EAbs n e ->
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc $ showString "λ"
|
||||
, prt 0 n
|
||||
, doc $ showString "."
|
||||
, prt 0 e
|
||||
]
|
||||
ECase exp injs ->
|
||||
prPrec
|
||||
i
|
||||
0
|
||||
( concatD
|
||||
[ doc (showString "case")
|
||||
, prt 0 exp
|
||||
, doc (showString "of")
|
||||
, doc (showString "{")
|
||||
, prt 0 injs
|
||||
, doc (showString "}")
|
||||
, doc (showString ":")
|
||||
]
|
||||
)
|
||||
instance Print t => Print (Id' t) where
|
||||
prt i (name, t) = concatD [ doc $ showString "("
|
||||
, prt i name
|
||||
, doc $ showString ","
|
||||
, prt i t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
instance Print ExpT where
|
||||
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
|
||||
|
||||
instance Print Branch where
|
||||
prt i = \case
|
||||
Branch (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print Pattern where
|
||||
prt i = \case
|
||||
PVar lident -> prPrec i 0 (concatD [prtId 0 lident])
|
||||
PLit (lit, typ) -> prPrec i 0 (concatD [doc $ showString "(", prt 0 lit, doc $ showString ",", prt 0 typ, doc $ showString ")"])
|
||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 0 patterns])
|
||||
PCatch -> prPrec i 0 (concatD [doc (showString "_")])
|
||||
PEnum p -> prt i p
|
||||
|
||||
instance Print [Branch] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print TVar where
|
||||
prt i (MkTVar id) = prt i id
|
||||
|
||||
instance Print Type where
|
||||
prt i = \case
|
||||
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
|
||||
TVar tvar@(MkTVar (Ident iden)) ->
|
||||
if all isDigit iden
|
||||
then prPrec i 2 (concatD [prt 0 $ TVar (MkTVar (Ident ("a" <> iden)))])
|
||||
else prPrec i 2 (concatD [prt 0 tvar])
|
||||
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
|
||||
TData ident types -> prPrec i 1 (concatD [prt 0 ident, prt 0 types])
|
||||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||
|
||||
instance Print Lit where
|
||||
prt i = \case
|
||||
LInt n -> prPrec i 0 (concatD [prt 0 n])
|
||||
LChar c -> prPrec i 0 (concatD [prt 0 c])
|
||||
instance Print t => Print (Exp' t) where
|
||||
prt i = \case
|
||||
EVar name -> prPrec i 3 $ prt 0 name
|
||||
EInj name -> prPrec i 3 $ prt 0 name
|
||||
ELit lit -> prPrec i 3 $ prt 0 lit
|
||||
ELet b e -> prPrec i 3 $ concatD
|
||||
[ doc $ showString "let"
|
||||
, prt 0 b
|
||||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
]
|
||||
EApp e1 e2 -> prPrec i 2 $ concatD
|
||||
[ prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
EAdd e1 e2 -> prPrec i 1 $ concatD
|
||||
[ prt 1 e1
|
||||
, doc $ showString "+"
|
||||
, prt 2 e2
|
||||
]
|
||||
EAbs v e -> prPrec i 0 $ concatD
|
||||
[ doc $ showString "\\"
|
||||
|
|
|
|||
232
tests/TestTypeCheckerBidir.hs
Normal file
232
tests/TestTypeCheckerBidir.hs
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# HLINT ignore "Use camelCase" #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module TestTypeCheckerBidir (testTypeCheckerBidir) where
|
||||
|
||||
import Test.Hspec
|
||||
|
||||
import Control.Monad ((<=<))
|
||||
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Renamer.Renamer (rename)
|
||||
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
|
||||
import TypeChecker.TypeCheckerBidir (typecheck)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
||||
|
||||
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
|
||||
tc_id
|
||||
tc_double
|
||||
tc_add_lam
|
||||
tc_const
|
||||
tc_simple_rank2
|
||||
tc_rank2
|
||||
tc_identity
|
||||
tc_pair
|
||||
tc_tree
|
||||
tc_mono_case
|
||||
tc_pol_case
|
||||
|
||||
tc_id = specify "Basic identity function polymorphism" $ run
|
||||
[ "id : forall a. a -> a;"
|
||||
, "id x = x;"
|
||||
, "main = id 4;"
|
||||
] `shouldSatisfy` ok
|
||||
|
||||
tc_double = specify "Addition inference" $ run
|
||||
["double x = x + x;"] `shouldSatisfy` ok
|
||||
|
||||
|
||||
tc_add_lam = specify "Addition lambda inference" $ run
|
||||
["four = (\\x. x + x) 2;"] `shouldSatisfy` ok
|
||||
|
||||
|
||||
tc_const = specify "Basic polymorphism with multiple type variables" $ run
|
||||
[ "const : forall a. forall b. a -> b -> a;"
|
||||
, "const x y = x;"
|
||||
, "main = const 'a' 65;"
|
||||
] `shouldSatisfy` ok
|
||||
|
||||
tc_simple_rank2 = specify "Simple rank two polymorphism" $ run
|
||||
[ "id : forall a. a -> a;"
|
||||
, "id x = x;"
|
||||
|
||||
, "f : forall a. a -> (forall b. b -> b) -> a;"
|
||||
, "f x g = g x;"
|
||||
|
||||
, "main = f 4 id;"
|
||||
] `shouldSatisfy` ok
|
||||
|
||||
tc_rank2 = specify "Rank two polymorphism is ok" $ run
|
||||
[ "const : forall a. forall b. a -> b -> a;"
|
||||
, "const x y = x;"
|
||||
|
||||
, "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int;"
|
||||
, "rank2 x f y = f x + f y;"
|
||||
|
||||
, "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h';"
|
||||
] `shouldSatisfy` ok
|
||||
|
||||
tc_identity = describe "(∀b. b → b) should only accept the identity function" $ do
|
||||
specify "identityᵢₙₜ is rejected" $ run (fs ++ id_int) `shouldNotSatisfy` ok
|
||||
specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok
|
||||
where
|
||||
fs =
|
||||
[ "f : forall a. a -> (forall b. b -> b) -> a;"
|
||||
, "f x g = g x;"
|
||||
|
||||
, "id : forall a. a -> a;"
|
||||
, "id x = x;"
|
||||
|
||||
, "id_int : Int -> Int;"
|
||||
, "id_int x = x;"
|
||||
]
|
||||
id =
|
||||
[ "main : Int;"
|
||||
, "main = f 4 id;"
|
||||
]
|
||||
id_int =
|
||||
[ "main : Int;"
|
||||
, "main = f 4 id_int;"
|
||||
]
|
||||
|
||||
tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
|
||||
specify "Wrong arguments are rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
|
||||
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
|
||||
where
|
||||
fs =
|
||||
[ "data forall a. forall b. Pair (a b) where {"
|
||||
, " Pair : a -> b -> Pair (a b)"
|
||||
, "};"
|
||||
|
||||
, "main : Pair (Int Char);"
|
||||
]
|
||||
wrong = ["main = Pair 'a' 65;"]
|
||||
correct = ["main = Pair 65 'a';"]
|
||||
|
||||
tc_tree = describe "Tree. Recursive data type" $ do
|
||||
specify "Wrong tree is rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok
|
||||
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
|
||||
where
|
||||
fs =
|
||||
[ "data forall a. Tree (a) where {"
|
||||
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
|
||||
, " Leaf : a -> Tree (a)"
|
||||
, "};"
|
||||
]
|
||||
wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3);"]
|
||||
correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3);"]
|
||||
|
||||
tc_mono_case = describe "Monomorphic pattern matching" $ do
|
||||
specify "First wrong case expression rejected"
|
||||
$ run wrong1 `shouldNotSatisfy` ok
|
||||
specify "Second wrong case expression rejected"
|
||||
$ run wrong2 `shouldNotSatisfy` ok
|
||||
specify "Third wrong case expression rejected"
|
||||
$ run wrong3 `shouldNotSatisfy` ok
|
||||
specify "First correct case expression accepted"
|
||||
$ run correct1 `shouldSatisfy` ok
|
||||
specify "Second correct case expression accepted"
|
||||
$ run correct2 `shouldSatisfy` ok
|
||||
|
||||
where
|
||||
wrong1 =
|
||||
[ "simple : Int -> Int;"
|
||||
, "simple c = case c of {"
|
||||
, " 'F' => 0;"
|
||||
, " 'T' => 1;"
|
||||
, "};"
|
||||
]
|
||||
wrong2 =
|
||||
[ "simple : Char -> Int;"
|
||||
, "simple c = case c of {"
|
||||
, " 'F' => 0;"
|
||||
, " 1 => 1;"
|
||||
, "};"
|
||||
]
|
||||
wrong3 =
|
||||
[ "simple : Char -> Int;"
|
||||
, "simple c = case c of {"
|
||||
, " 'F' => 0;"
|
||||
, " 'T' => '1';"
|
||||
, "};"
|
||||
]
|
||||
correct1 =
|
||||
[ "simple : Char -> Int;"
|
||||
, "simple c = case c of {"
|
||||
, " 'F' => 0;"
|
||||
, " 'T' => 1;"
|
||||
, "};"
|
||||
]
|
||||
correct2 =
|
||||
[ "simple : Char -> Int;"
|
||||
, "simple c = case c of {"
|
||||
, " 'F' => 0;"
|
||||
, " _ => 1;"
|
||||
, "};"
|
||||
]
|
||||
|
||||
tc_pol_case = describe "Polymophic pattern matching" $ do
|
||||
specify "First wrong case expression rejected"
|
||||
$ run (fs ++ wrong1) `shouldNotSatisfy` ok
|
||||
specify "Second wrong case expression rejected"
|
||||
$ run (fs ++ wrong2) `shouldNotSatisfy` ok
|
||||
specify "Third wrong case expression rejected"
|
||||
$ run (fs ++ wrong3) `shouldNotSatisfy` ok
|
||||
specify "First correct case expression accepted"
|
||||
$ run (fs ++ correct1) `shouldSatisfy` ok
|
||||
specify "Second correct case expression accepted"
|
||||
$ run (fs ++ correct2) `shouldSatisfy` ok
|
||||
where
|
||||
fs =
|
||||
[ "data forall a. List (a) where {"
|
||||
, " Nil : List (a)"
|
||||
, " Cons : a -> List (a) -> List (a)"
|
||||
, "};"
|
||||
]
|
||||
wrong1 =
|
||||
[ "length : forall c. List (c) -> Int;"
|
||||
, "length = \\list. case list of {"
|
||||
, " Nil => 0;"
|
||||
, " Cons 6 xs => 1 + length xs;"
|
||||
, "};"
|
||||
]
|
||||
wrong2 =
|
||||
[ "length : forall c. List (c) -> Int;"
|
||||
, "length = \\list. case list of {"
|
||||
, " Cons => 0;"
|
||||
, " Cons x xs => 1 + length xs;"
|
||||
, "};"
|
||||
]
|
||||
wrong3 =
|
||||
[ "length : forall c. List (c) -> Int;"
|
||||
, "length = \\list. case list of {"
|
||||
, " 0 => 0;"
|
||||
, " Cons x xs => 1 + length xs;"
|
||||
, "};"
|
||||
]
|
||||
correct1 =
|
||||
[ "length : forall c. List (c) -> Int;"
|
||||
, "length = \\list. case list of {"
|
||||
, " Nil => 0;"
|
||||
, " Cons x xs => 1 + length xs;"
|
||||
, " Cons x (Cons y Nil) => 2;"
|
||||
, "};"
|
||||
]
|
||||
correct2 =
|
||||
[ "length : forall c. List (c) -> Int;"
|
||||
, "length = \\list. case list of {"
|
||||
, " Nil => 0;"
|
||||
, " non_empty => 1;"
|
||||
, "};"
|
||||
]
|
||||
|
||||
run :: [String] -> Err T.Program
|
||||
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines
|
||||
|
||||
ok = \case
|
||||
Ok _ -> True
|
||||
Bad _ -> False
|
||||
113
tests/TestTypeCheckerHm.hs
Normal file
113
tests/TestTypeCheckerHm.hs
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
{-# LANGUAGE NoImplicitPrelude #-}
|
||||
{-# LANGUAGE QualifiedDo #-}
|
||||
|
||||
module TestTypeCheckerHm (testTypeCheckerHm) where
|
||||
|
||||
import Control.Monad ((<=<))
|
||||
import qualified DoStrings as D
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Prelude (Bool (..), Either (..), IO, fmap,
|
||||
not, ($), (.))
|
||||
import Test.Hspec
|
||||
|
||||
-- import Test.QuickCheck
|
||||
import TypeChecker.TypeCheckerHm (typecheck)
|
||||
|
||||
|
||||
|
||||
testTypeCheckerHm = describe "Hillner Milner type checker test" $ do
|
||||
ok1
|
||||
ok2
|
||||
bad1
|
||||
bad2
|
||||
-- bad3
|
||||
|
||||
|
||||
ok1 =
|
||||
specify "Basic polymorphism with multiple type variables" $
|
||||
run
|
||||
( D.do
|
||||
const
|
||||
"main = const 'a' 65 ;"
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
ok2 =
|
||||
specify "Head with a correct signature is accepted" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
headSig
|
||||
head
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
|
||||
bad1 =
|
||||
specify "Infinite type unification should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
"main = \\x. x x ;"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
|
||||
bad2 =
|
||||
specify "Pattern matching using different types should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
"bad xs = case xs of {"
|
||||
" 1 => 0 ;"
|
||||
" Nil => 0 ;"
|
||||
"};"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
|
||||
bad3 =
|
||||
specify "Using a concrete function on a skolem variable should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
bool
|
||||
_not
|
||||
"f : a -> Bool () ;"
|
||||
" f x = not x ;"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
|
||||
run = typecheck <=< pProgram . myLexer
|
||||
|
||||
ok (Right _) = True
|
||||
ok (Left _) = False
|
||||
|
||||
bad = not . ok
|
||||
|
||||
-- FUNCTIONS
|
||||
|
||||
const = D.do
|
||||
"const : a -> b -> a ;"
|
||||
"const x y = x ;"
|
||||
list = D.do
|
||||
"data List (a) where"
|
||||
" {"
|
||||
" Nil : List (a)"
|
||||
" Cons : a -> List (a) -> List (a)"
|
||||
" };"
|
||||
|
||||
headSig = D.do
|
||||
"head : List (a) -> a ;"
|
||||
head = D.do
|
||||
"head xs = "
|
||||
" case xs of {"
|
||||
" Cons x xs => x ;"
|
||||
" };"
|
||||
|
||||
bool = D.do
|
||||
"data Bool () where {"
|
||||
" True : Bool ()"
|
||||
" False : Bool ()"
|
||||
"};"
|
||||
|
||||
_not = D.do
|
||||
"not : Bool () -> Bool () ;"
|
||||
"not x = case x of {"
|
||||
" True => False ;"
|
||||
" False => True ;"
|
||||
"};"
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
module DoStrings where
|
||||
|
||||
import Prelude hiding ((>>), (>>=))
|
||||
import Prelude hiding ((>>), (>>=))
|
||||
|
||||
(>>) :: String -> String -> String
|
||||
(>>) str1 str2 = str1 ++ "\n" ++ str2
|
||||
|
|
|
|||
10
tests/Tests.hs
Normal file
10
tests/Tests.hs
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
module Main where
|
||||
|
||||
import Test.Hspec
|
||||
import TestTypeCheckerBidir (testTypeCheckerBidir)
|
||||
import TestTypeCheckerHm (testTypeCheckerHm)
|
||||
|
||||
main = hspec $ do
|
||||
testTypeCheckerBidir
|
||||
testTypeCheckerHm
|
||||
Loading…
Add table
Add a link
Reference in a new issue