diff --git a/Grammar.cf b/Grammar.cf index 78dfa65..09d0f2e 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -3,94 +3,94 @@ -- * PROGRAM ------------------------------------------------------------------------------- -Program. Program ::= [Def] ; +Program. Program ::= [Def]; ------------------------------------------------------------------------------- -- * TOP-LEVEL ------------------------------------------------------------------------------- -DBind. Def ::= Bind ; -DSig. Def ::= Sig ; -DData. Def ::= Data ; +DBind. Def ::= Bind; +DSig. Def ::= Sig; +DData. Def ::= Data; -Sig. Sig ::= LIdent ":" Type ; - -Bind. Bind ::= LIdent [LIdent] "=" Exp ; +Sig. Sig ::= LIdent ":" Type; +Bind. Bind ::= LIdent [LIdent] "=" Exp; ------------------------------------------------------------------------------- --- * TYPES +-- * Types ------------------------------------------------------------------------------- - TLit. Type2 ::= UIdent ; - TVar. Type2 ::= TVar ; - TAll. Type1 ::= "forall" TVar "." Type ; - TData. Type1 ::= UIdent "(" [Type] ")" ; -internal TEVar. Type1 ::= TEVar ; - TFun. Type ::= Type1 "->" Type ; + TLit. Type1 ::= UIdent; -- τ + TVar. Type1 ::= TVar; -- α +internal TEVar. Type1 ::= TEVar; -- ά + TData. Type1 ::= UIdent "(" [Type] ")"; -- D () + TFun. Type ::= Type1 "->" Type; -- A → A + TAll. Type ::= "forall" TVar "." Type; -- ∀α. A - MkTVar. TVar ::= LIdent ; -internal MkTEVar. TEVar ::= LIdent ; + MkTVar. TVar ::= LIdent; +internal MkTEVar. TEVar ::= LIdent; ------------------------------------------------------------------------------- -- * DATA TYPES ------------------------------------------------------------------------------- -Constructor. Constructor ::= UIdent ":" Type ; +Data. Data ::= "data" Type "where" "{" [Inj] "}" ; -Data. Data ::= "data" Type "where" "{" [Constructor] "}" ; +Inj. Inj ::= UIdent ":" Type ; +separator nonempty Inj " " ; ------------------------------------------------------------------------------- --- * EXPRESSIONS +-- * Expressions ------------------------------------------------------------------------------- -EAnn. Exp4 ::= "(" Exp ":" Type ")" ; -EVar. Exp3 ::= LIdent ; -EInj. Exp3 ::= UIdent ; -ELit. Exp3 ::= Lit ; -EApp. Exp2 ::= Exp2 Exp3 ; -EAdd. Exp1 ::= Exp1 "+" Exp2 ; -ELet. Exp ::= "let" Bind "in" Exp ; -EAbs. Exp ::= "\\" LIdent "." Exp ; -ECase. Exp ::= "case" Exp "of" "{" [Branch] "}"; +EAnn. Exp4 ::= "(" Exp ":" Type ")"; +EVar. Exp3 ::= LIdent; +EInj. Exp3 ::= UIdent; +ELit. Exp3 ::= Lit; +EApp. Exp2 ::= Exp2 Exp3; +EAdd. Exp1 ::= Exp1 "+" Exp2; +ELet. Exp ::= "let" Bind "in" Exp; +EAbs. Exp ::= "\\" LIdent "." Exp; +ECase. Exp ::= "case" Exp "of" "{" [Branch] "}"; ------------------------------------------------------------------------------- -- * LITERALS ------------------------------------------------------------------------------- -LInt. Lit ::= Integer ; -LChar. Lit ::= Char ; +LInt. Lit ::= Integer; +LChar. Lit ::= Character; ------------------------------------------------------------------------------- --- * CASE +-- * PATTERN MATCHING ------------------------------------------------------------------------------- Branch. Branch ::= Pattern "=>" Exp ; -PVar. Pattern1 ::= LIdent ; -PLit. Pattern1 ::= Lit ; -PCatch. Pattern1 ::= "_" ; -PEnum. Pattern1 ::= UIdent ; -PInj. Pattern ::= UIdent [Pattern1] ; +PVar. Pattern1 ::= LIdent; +PLit. Pattern1 ::= Lit; +PCatch. Pattern1 ::= "_"; +PEnum. Pattern1 ::= UIdent; +PInj. Pattern ::= UIdent [Pattern1]; ------------------------------------------------------------------------------- -- * AUX ------------------------------------------------------------------------------- -terminator Def ";" ; -separator nonempty Constructor "" ; -separator Type " " ; -separator nonempty Pattern1 " " ; +terminator Def ";"; terminator Branch ";" ; -separator Ident " "; -separator LIdent " "; -separator TVar " " ; -coercions Exp 4 ; -coercions Type 2 ; -coercions Pattern 1 ; +separator LIdent ""; +separator Type " "; +separator TVar " "; +separator nonempty Pattern1 " "; +coercions Pattern 1; +coercions Exp 4; +coercions Type 1 ; + +token Character '\''(char)'\'' ; token UIdent (upper (letter | digit | '_')*) ; token LIdent (lower (letter | digit | '_')*) ; -comment "--" ; -comment "{-" "-}" ; +comment "--"; +comment "{-" "-}"; diff --git a/language.cabal b/language.cabal index 9783156..61724ee 100644 --- a/language.cabal +++ b/language.cabal @@ -31,13 +31,18 @@ executable language Grammar.Skel Grammar.ErrM Auxiliary + Renamer.Renamer TypeChecker.TypeChecker + TypeChecker.TypeCheckerHm + TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerIr + TypeChecker.RemoveTEVar + LambdaLifter Monomorphizer.Monomorphizer Monomorphizer.MonomorphizerIr - Renamer.Renamer Codegen.Codegen Codegen.LlvmIr + Compiler hs-source-dirs: src @@ -60,6 +65,9 @@ Test-suite language-testsuite main-is: Tests.hs other-modules: + TestTypeCheckerBidir + TestTypeCheckerHm + Grammar.Abs Grammar.Lex Grammar.Par @@ -67,9 +75,11 @@ Test-suite language-testsuite Grammar.Skel Grammar.ErrM Auxiliary - TypeChecker.TypeChecker - TypeChecker.TypeCheckerIr Renamer.Renamer + TypeChecker.TypeCheckerHm + TypeChecker.TypeCheckerBidir + TypeChecker.RemoveTEVar + TypeChecker.TypeCheckerIr Compiler hs-source-dirs: src, tests, tests/TypecheckingHM @@ -87,3 +97,4 @@ Test-suite language-testsuite , bytestring default-language: GHC2021 + diff --git a/sample-programs/basic-0 b/sample-programs/basic-0 new file mode 100644 index 0000000..4738fb6 --- /dev/null +++ b/sample-programs/basic-0 @@ -0,0 +1,11 @@ +data forall a. List (a) where { + Nil : List (a) + Cons : a -> List (a) -> List (a) +}; + +length : forall c. List (c) -> Int; +length = \list. case list of { + Nil => 0; + Cons x xs => 1 + length xs; + Cons x (Cons y Nil) => 2; +}; diff --git a/sample-programs/basic-2.crf b/sample-programs/basic-2.crf index 2db6128..5ce4da5 100644 --- a/sample-programs/basic-2.crf +++ b/sample-programs/basic-2.crf @@ -3,3 +3,4 @@ add x = \y. x+y; main : Int ; main = (\z. z+z) ((add 4) 6) ; + diff --git a/spec.txt b/spec.txt new file mode 100644 index 0000000..2273846 --- /dev/null +++ b/spec.txt @@ -0,0 +1,121 @@ +--------------------------------------------------------------------------- +-- * Parser +--------------------------------------------------------------------------- + +data Program = Program [Def] + +data Def = DSig Ident Type | DBind Bind + +data Bind = Bind Ident [Ident] Exp + +data Exp + = EId Ident + | ELit Lit + | EAnn Exp Type + | ELet Ident Exp Exp + | EApp Exp Exp + | EAdd Exp Exp + | EAbs Ident Exp + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + | TEVar TEVar -- ά (internal) + +data TVar = MkTVar Ident +data TEVar = MkTEVar Ident + +--------------------------------------------------------------------------- +-- * Type checker +--------------------------------------------------------------------------- + +-- • Def and DSig are removed in favor on just Bind +-- • Typed expressions +-- • TEVar is removed (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Bind ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + | EAbs Ident ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + +data TVar = MkTVar Ident + +--------------------------------------------------------------------------- +-- * Lambda lifter +--------------------------------------------------------------------------- +-- • EAbs are removed (NOT IMPLEMENTED) +-- • ELet only allow constant expressions (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Ident ExpT ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + +data Lit = LInt Integer + | LChar Character + +data Type + = TLit Ident -- τ + | TVar TVar -- α + | TFun Type Type -- A → A + | TAll TVar Type -- ∀α. A + +data TVar = MkTVar Ident + +--------------------------------------------------------------------------- +-- * Monomorpher +--------------------------------------------------------------------------- +-- • Polymorphic types are removed (NOT IMPLEMENTED) + +newtype Program = Program [Bind] + +data Bind = Bind Id [Id] ExpT + +data Exp + = EId Ident + | ELit Lit + | ELet Ident ExpT ExpT + | EApp ExpT ExpT + | EAdd ExpT ExpT + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + +data Lit = LInt Integer + | LChar Character + +data Type = Type Ident diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index 735d804..d27ac24 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -3,6 +3,7 @@ module Auxiliary (module Auxiliary) where import Control.Monad.Error.Class (liftEither) import Control.Monad.Except (MonadError) import Data.Either.Combinators (maybeToRight) +import TypeChecker.TypeCheckerIr (Type (TFun)) snoc :: a -> [a] -> [a] snoc x xs = xs ++ [x] @@ -19,3 +20,4 @@ mapAccumM f = go (acc', x') <- f acc x (acc'', xs') <- go acc' xs pure (acc'', x':xs') + diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 041671d..5e7e37d 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -17,6 +17,7 @@ import Data.Tuple.Extra (dupe, first, second) import Debug.Trace (trace) import qualified Grammar.Abs as GA import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr (Ident (..)) import Monomorphizer.MonomorphizerIr as MIR -- | The record used as the code generator state @@ -57,8 +58,13 @@ getVarCount :: CompilerState Integer getVarCount = gets variableCount -- | Increases the variable count and returns it from the CodeGenerator state +<<<<<<< HEAD getNewVar :: CompilerState GA.Ident getNewVar = GA.Ident . show <$> (increaseVarCount >> getVarCount) +======= +getNewVar :: CompilerState Ident +getNewVar = (Ident . show) <$> (increaseVarCount >> getVarCount) +>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.) -- | Increses the label count and returns a label from the CodeGenerator state getNewLabel :: CompilerState Integer @@ -76,10 +82,25 @@ getFunctions bs = Map.fromList $ go bs go (MIR.DBind (MIR.Bind id args _) : xs) = (id, FunctionInfo{numArgs = length args, arguments = args}) : go xs +<<<<<<< HEAD go (_ : xs) = go xs +======= + go (MIR.DData (MIR.Data n cons) : xs) = + do map + ( \(Inj id xs) -> + ( (coerce id, MIR.TLit (extractTypeName n)) + , FunctionInfo + { numArgs = undefined -- TODO + , arguments = createArgs (snd <$> undefined) -- TODO + } + ) + ) + cons + <> go xs +>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.) createArgs :: [MIR.Type] -> [Id] -createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs +createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs {- | Produces a map of functions infos from a list of binds, which contains useful data for code generation. @@ -89,6 +110,7 @@ getConstructors bs = Map.fromList $ go bs where go [] = [] go (MIR.DData (MIR.Data t cons) : xs) = +<<<<<<< HEAD fst ( foldl ( \(acc, i) (Constructor id xs) -> @@ -96,6 +118,17 @@ getConstructors bs = Map.fromList $ go bs , ConstructorInfo { numArgsCI = length (init . flattenType $ xs) , argumentsCI = createArgs (init . flattenType $ xs) +======= + do + let (Ident n) = extractTypeName t + fst + ( foldl + ( \(acc, i) (Inj (Ident id) xs) -> + ( ( (Ident (n <> "_" <> id), MIR.TLit (coerce n)) + , ConstructorInfo + { numArgsCI = undefined -- TODO + , argumentsCI = createArgs (snd <$> undefined) -- TODO +>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.) , numCI = i , returnTypeCI = t --last . flattenType $ xs } @@ -133,30 +166,30 @@ test :: Integer -> Program test v = Program [ DataType - (GA.Ident "Craig") - [ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")] - , Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")] + (Ident "Craig") + [ Constructor (Ident "Bob") [MIR.Type (Ident "_Int")] + , Constructor (Ident "Betty") [MIR.Type (Ident "_Int")] ] , DataType - (GA.Ident "Alice") - [ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- , - -- (GA.Ident "Alice", [TInt, TInt]) + (Ident "Alice") + [ Constructor (Ident "Eve") [MIR.Type (Ident "_Int")] -- , + -- (Ident "Alice", [TInt, TInt]) ] - , Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) - , Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) [] - -- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92) + , Bind (Ident "fibonacci", MIR.Type (Ident "_Int")) [(Ident "x", MIR.Type (Ident "_Int"))] (EVar ("x", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig")) + , Bind (Ident "main", MIR.Type (Ident "_Int")) [] + -- (EApp (MIR.Type (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.Type (Ident "Craig")), MIR.Type (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig"))-- (EInt 92) $ eCaseInt - (EApp (MIR.TLit (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.TLit (GA.Ident "Craig")), MIR.TLit (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig")) - [ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int")) + (EApp (MIR.TLit (Ident "Craig")) (EVar (Ident "Craig_Bob", MIR.TLit (Ident "Craig")), MIR.TLit (Ident "Craig")) (ELit (LInt v), MIR.Type (Ident "_Int")), MIR.Type (Ident "Craig")) + [ injectionCons "Craig_Bob" "Craig" [CIdent (Ident "x")] (EVar (Ident "x", MIR.Type (Ident "_Int")), MIR.Type (Ident "_Int")) , injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2) - , Injection (CIdent (GA.Ident "z")) (int 3) + , Injection (CIdent (Ident "z")) (int 3) , -- , injectionInt 5 (int 6) injectionCatchAll (int 10) ] ] where - injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs) + injectionCons x y xs = Injection (CCons (Ident x, MIR.Type (Ident y)) xs) injectionInt x = Injection (CLit (LInt x)) injectionCatchAll = Injection CatchAll eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int")) @@ -206,7 +239,7 @@ compileScs [] = do emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) enumerateOneM_ - ( \i (GA.Ident arg_n, arg_t) -> do + ( \i (Ident arg_n, arg_t) -> do let arg_t' = type2LlvmType arg_t emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) elemPtr <- getNewVar @@ -222,7 +255,7 @@ compileScs [] = do I32 (VInteger i) ) - emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr + emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr ) (argumentsCI ci) @@ -255,8 +288,13 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do let biggestVariant = 7 + maximum (sum . (\(Constructor _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) emit $ LIR.Type (Ident outer_id) [I8, Array biggestVariant I8] mapM_ +<<<<<<< HEAD ( \(Constructor inner_id fi) -> do emit $ LIR.Type inner_id (I8 : variantTypes fi) +======= + ( \(Inj (Ident inner_id) fi) -> do + emit $ LIR.Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType (snd <$> undefined)) -- TODO +>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.) ) ts compileScs xs @@ -282,17 +320,17 @@ mainContent var = -- " %4 = load i72, ptr %3\n" <> -- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" - , -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2") - -- , Label (GA.Ident "b_1") + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") -- , UnsafeRaw -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (GA.Ident "end") - -- , Label (GA.Ident "b_2") + -- , Br (Ident "end") + -- , Label (Ident "b_2") -- , UnsafeRaw -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (GA.Ident "end") - -- , Label (GA.Ident "end") + -- , Br (Ident "end") + -- , Label (Ident "end") Ret I64 (VInteger 0) ] @@ -310,7 +348,7 @@ compileExp :: ExpT -> CompilerState () compileExp (MIR.ELit lit,t) = emitLit lit compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2 -- compileExp (ESub t e1 e2) = emitSub t e1 e2 -compileExp (MIR.EId name,t) = emitIdent name +compileExp (MIR.EVar name,t) = emitIdent name compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2 -- compileExp (EAbs t ti e) = emitAbs t ti e compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e) @@ -328,7 +366,7 @@ emitECased t e cases = do let rt = type2LlvmType (snd e) vs <- exprToValue e lbl <- getNewLabel - let label = GA.Ident $ "escape_" <> show lbl + let label = Ident $ "escape_" <> show lbl stackPtr <- getNewVar emit $ SetVariable stackPtr (Alloca ty) mapM_ (emitCases rt ty label stackPtr vs) cs @@ -341,13 +379,13 @@ emitECased t e cases = do res <- getNewVar emit $ SetVariable res (Load ty Ptr stackPtr) where - emitCases :: LLVMType -> LLVMType -> GA.Ident -> GA.Ident -> LLVMValue -> Branch -> CompilerState () + emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState () emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, t) exp) = do cons <- gets constructors let r = fromJust $ Map.lookup consId cons - lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel consVal <- getNewVar emit $ SetVariable consVal (ExtractValue rt vs 0) @@ -397,8 +435,8 @@ emitECased t e cases = do (MIR.LInt i, _) -> VInteger i (MIR.LChar i, _) -> VChar i ns <- getNewVar - lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel emit $ SetVariable ns (Icmp LLEq ty vs i') emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ Label lbl_succPos @@ -444,8 +482,13 @@ emitApp rt e1 e2 = appEmitter e1 e2 [] appEmitter e1 e2 stack = do let newStack = e2 : stack case e1 of +<<<<<<< HEAD (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack (MIR.EId name, t) -> do +======= + (MIR.EApp e1' e2', t) -> appEmitter e1' e2' newStack + (MIR.EVar name, t) -> do +>>>>>>> da28c6d (Add bidirectional type checker, lambda lifter.) args <- traverse exprToValue newStack vs <- getNewVar funcs <- gets functions @@ -462,7 +505,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 [] emit $ SetVariable vs call x -> error $ "The unspeakable happened: " <> show x -emitIdent :: GA.Ident -> CompilerState () +emitIdent :: Ident -> CompilerState () emitIdent id = do -- !!this should never happen!! emit $ Comment "This should not have happened!" @@ -477,14 +520,14 @@ emitLit i = do (MIR.LChar i'') -> (VChar i'', I8) varCount <- getNewVar emit $ Comment "This should not have happened!" - emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0)) + emit $ SetVariable (Ident (show varCount)) (Add t i' (VInteger 0)) emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitAdd t e1 e2 = do v1 <- exprToValue e1 v2 <- exprToValue e2 v <- getNewVar - emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2) + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitSub t e1 e2 = do @@ -498,7 +541,7 @@ exprToValue = \case (MIR.ELit i, t) -> pure $ case i of (MIR.LInt i) -> VInteger i (MIR.LChar i) -> VChar i - (MIR.EId name, t) -> do + (MIR.EVar name, t) -> do funcs <- gets functions case Map.lookup (name, t) funcs of Just fi -> do @@ -515,7 +558,7 @@ exprToValue = \case e -> do compileExp e v <- getVarCount - pure $ VIdent (GA.Ident $ show v) (getType e) + pure $ VIdent (Ident $ show v) (getType e) type2LlvmType :: MIR.Type -> LLVMType type2LlvmType (MIR.TLit id@(Ident name)) = case name of @@ -558,3 +601,4 @@ typeByteSize (CustomType _) = 8 enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 + diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index 3c11ae1..0baf35a 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -11,8 +11,9 @@ module Codegen.LlvmIr ( ToIr(..) ) where -import Data.List (intercalate) -import Grammar.Abs (Ident (..)) +import Data.List (intercalate) +import Grammar.Abs (Character) +import TypeChecker.TypeCheckerIr (Ident (..)) data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show instance ToIr CallingConvention where @@ -87,7 +88,7 @@ instance ToIr Visibility where -- or a string contstant data LLVMValue = VInteger Integer - | VChar Char + | VChar Character | VIdent Ident LLVMType | VConstant String | VFunction Ident Visibility LLVMType diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs new file mode 100644 index 0000000..b85dd8b --- /dev/null +++ b/src/LambdaLifter.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + + +module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where + +import Auxiliary (snoc) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, + evalState) +import Data.List (partition) +import Data.Set (Set) +import qualified Data.Set as Set +import Prelude hiding (exp) +import TypeChecker.TypeCheckerIr + + +-- | Lift lambdas and let expression into supercombinators. +-- Three phases: +-- @freeVars@ annotates all the free variables. +-- @abstract@ converts lambdas into let expressions. +-- @collectScs@ moves every non-constant let expression to a top-level function. +lambdaLift :: Program -> Program +lambdaLift (Program defs) = Program $ datatypes ++ ll binds + where + ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b) + (binds, datatypes) = partition isBind defs + isBind = \case + DBind _ -> True + _ -> False + +-- | Annotate free variables +freeVars :: [Bind] -> AnnBinds +freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e) + | Bind n xs e <- binds + ] + +freeVarsExp :: Set Ident -> ExpT -> AnnExpT +freeVarsExp localVars (exp, t) = case exp of + EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t)) + | otherwise -> (mempty, (AVar n, t)) + + ELit lit -> (mempty, (ALit lit, t)) + + EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t)) + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t)) + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t)) + where + e' = freeVarsExp (Set.insert par localVars) e + + -- Sum free variables present in bind and the expression + ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t)) + where + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' + + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind (name, t_bind) parms rhs' + + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars + + +freeVarsOf :: AnnExpT -> Set Ident +freeVarsOf = fst + +-- AST annotated with free variables +type AnnBinds = [(Id, [Id], AnnExpT)] + +type AnnExpT = (Set Ident, AnnExpT') + +data ABind = ABind Id [Id] AnnExpT deriving Show + +type AnnExpT' = (AnnExp, Type) + +data AnnExp = AVar Ident + | AInj Ident + | ALit Lit + | ALet ABind AnnExpT + | AApp AnnExpT AnnExpT + | AAdd AnnExpT AnnExpT + | AAbs Ident AnnExpT + deriving Show + +-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +-- Free variables are @v₁ v₂ .. vₙ@ are bound. +abstract :: AnnBinds -> [Bind] +abstract prog = evalState (mapM go prog) 0 + where + go :: (Id, [Id], AnnExpT) -> State Int Bind + go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + where + (rhs', parms1) = flattenLambdasAnn rhs + + +-- | Flatten nested lambdas and collect the parameters +-- @\x.\y.\z. ae → (ae, [x,y,z])@ +flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id]) +flattenLambdasAnn ae = go (ae, []) + where + go :: (AnnExpT, [Id]) -> (AnnExpT, [Id]) + go ((free, (e, t)), acc) + | AAbs par (free1, e1) <- e + , TFun t_par _ <- t + = go ((Set.delete par free1, e1), snoc (par, t_par) acc) + | otherwise = ((free, (e, t)), acc) + +abstractExp :: AnnExpT -> State Int ExpT +abstractExp (free, (exp, typ)) = case exp of + AVar n -> pure (EVar n, typ) + ALit lit -> pure (ELit lit, typ) + AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2) + AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2) + ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e) + where + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' + + skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT + skipLambdas f (free, (ae, t)) = case ae of + AAbs par ae1 -> do + ae1' <- skipLambdas f ae1 + pure (EAbs par ae1', t) + _ -> f (free, (ae, t)) + + -- Lift lambda into let and bind free variables + AAbs parm e -> do + i <- nextNumber + rhs <- abstractExp e + + let sc_name = Ident ("sc_" ++ show i) + sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ) + pure $ foldl applyVars sc freeList + + where + freeList = Set.toList free + vars = zip names $ getVars typ + names = snoc parm freeList + applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return) + where + (t_var, t_return) = applyVarType t + +applyVarType :: Type -> (Type, Type) +applyVarType typ = (t1, foldr ($) t2 foralls) + + where + (t1, t2) = case typ' of + TFun t1 t2 -> (t1, t2) + _ -> error "Not a function!" + + (foralls, typ') = skipForalls [] typ + + + skipForalls acc = \case + TAll tvar t -> skipForalls (snoc (TAll tvar) acc) t + t -> (acc, t) + +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i + +-- | Collects supercombinators by lifting non-constant let expressions +collectScs :: [Bind] -> [Bind] +collectScs = concatMap collectFromRhs + where + collectFromRhs (Bind name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs + + +collectScsExp :: ExpT -> ([Bind], ExpT) +collectScsExp expT@(exp, typ) = case exp of + EVar _ -> ([], expT) + ELit _ -> ([], expT) + + EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ)) + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAbs par e -> (scs, (EAbs par e', typ)) + where + (scs, e') = collectScsExp e + + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ et_scs, (ELet bind et', snd et')) + else (bind : rhs_scs ++ et_scs, et') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (et_scs, et') = collectScsExp e + + +-- @\x.\y.\z. e → (e, [x,y,z])@ +flattenLambdas :: ExpT -> (ExpT, [Id]) +flattenLambdas = go . (, []) + where + go ((e, t), acc) = case e of + EAbs name e1 -> go (e1, snoc (name, t_var) acc) + where t_var = head $ getVars t + _ -> ((e, t), acc) + +getVars :: Type -> [Type] +getVars = fst . partitionType + +partitionType :: Type -> ([Type], Type) +partitionType = go [] . skipForalls' + where + + go acc t = case t of + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> (acc, t) + +skipForalls' :: Type -> Type +skipForalls' = snd . skipForalls + +skipForalls :: Type -> ([Type -> Type], Type) +skipForalls = go [] + where + go acc typ = case typ of + TAll tvar t -> go (snoc (TAll tvar) acc) t + _ -> (acc, typ) diff --git a/src/Main.hs b/src/Main.hs index 16f1442..210916d 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,66 +1,114 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} module Main where -import Codegen.Codegen (generateCode) -import Data.Bool (bool) -import GHC.IO.Handle.Text (hPutStrLn) -import Grammar.ErrM (Err) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) - -import Monomorphizer.Monomorphizer (monomorphize) - import Control.Monad (when) +import Data.Bool (bool) import Data.List.Extra (isSuffixOf) - -import Compiler (compile) -import Renamer.Renamer (rename) +import Data.Maybe (fromJust, isNothing) +import GHC.IO.Handle.Text (hPutStrLn) +import System.Console.GetOpt (ArgDescr (NoArg, ReqArg), + ArgOrder (RequireOrder), + OptDescr (Option), getOpt, + usageInfo) import System.Directory (createDirectory, doesPathExist, getDirectoryContents, removeDirectoryRecursive, setCurrentDirectory) import System.Environment (getArgs) -import System.Exit (ExitCode, exitFailure, - exitSuccess) +import System.Exit (ExitCode (ExitFailure), + exitFailure, exitSuccess, + exitWith) import System.IO (stderr) -import System.Process.Extra (readCreateProcess, shell, - spawnCommand, waitForProcess) -import TypeChecker.TypeChecker (typecheck) + + +import Codegen.Codegen (generateCode) +import Compiler (compile) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import LambdaLifter (lambdaLift) +import Monomorphizer.Monomorphizer (monomorphize) +import Renamer.Renamer (rename) +import System.Process (spawnCommand, waitForProcess) +import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck) main :: IO () -main = - getArgs >>= \case - [] -> putStrLn "Required file path missing" - ["-d", s] -> do - when (".crf" `isSuffixOf` s) (main' True s) - putStrLn $ "File '" ++ s ++ "' is not a churf file" - [s] -> do - when (".crf" `isSuffixOf` s) (main' False s) - putStrLn $ "File '" ++ s ++ "' is not a churf file" - xs -> putStrLn $ "Can't process: " ++ unwords xs +main = getArgs >>= parseArgs >>= uncurry main' -main' :: Bool -> String -> IO () -main' debug s = do +parseArgs :: [String] -> IO (Options, String) +parseArgs argv = case getOpt RequireOrder flags argv of + (os, f:_, []) + | opts.help || isNothing opts.typechecker -> do + hPutStrLn stderr (usageInfo header flags) + exitSuccess + | otherwise -> pure (opts, f) + where + opts = foldr ($) initOpts os + (_, _, errs) -> do + hPutStrLn stderr (concat errs ++ usageInfo header flags) + exitWith (ExitFailure 1) + where + header = "Usage: language [--help] [-d|--debug] [-t|type-checker bi/hm] FILE \n" + +flags :: [OptDescr (Options -> Options)] +flags = + [ Option ['d'] ["debug"] (NoArg enableDebug) "Print debug messages." + , Option ['t'] ["type-checker"] (ReqArg chooseTypechecker "bi/hm") "Choose type checker. Possible options are bi and hm" + , Option [] ["help"] (NoArg enableHelp) "Print this help message" + ] + +initOpts :: Options +initOpts = Options { help = False + , debug = False + , typechecker = Nothing + } + +enableHelp :: Options -> Options +enableHelp opts = opts { help = True } + +enableDebug :: Options -> Options +enableDebug opts = opts { debug = True } + +chooseTypechecker :: String -> Options -> Options +chooseTypechecker s options = options { typechecker = tc } + where + tc = case s of + "hm" -> pure Hm + "bi" -> pure Bi + _ -> Nothing + +data Options = Options + { help :: Bool + , debug :: Bool + , typechecker :: Maybe TypeChecker + } + +main' :: Options -> String -> IO () +main' opts s = do file <- readFile s printToErr "-- Parse Tree -- " parsed <- fromSyntaxErr . pProgram $ myLexer file - bool (printToErr $ printTree parsed) (printToErr $ show parsed) debug + bool (printToErr $ printTree parsed) (printToErr $ show parsed) opts.debug printToErr "\n-- Renamer --" renamed <- fromRenamerErr . rename $ parsed - bool (printToErr $ printTree renamed) (printToErr $ show renamed) debug + bool (printToErr $ printTree renamed) (printToErr $ show renamed) opts.debug printToErr "\n-- TypeChecker --" - typechecked <- fromTypeCheckerErr $ typecheck renamed - bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) debug + typechecked <- fromTypeCheckerErr $ typecheck (fromJust opts.typechecker) renamed + bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) opts.debug + + printToErr "\n-- Lambda Lifter --" + let lifted = lambdaLift typechecked + printToErr $ printTree lifted -- printToErr "\n-- Lambda Lifter --" -- let lifted = lambdaLift typechecked -- printToErr $ printTree lifted -- - --printToErr "\n -- Compiler --" + printToErr "\n -- Compiler --" generatedCode <- fromCompilerErr $ generateCode (monomorphize typechecked) --putStrLn generatedCode diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs index 7062b79..5440bab 100644 --- a/src/Monomorphizer/Monomorphizer.hs +++ b/src/Monomorphizer/Monomorphizer.hs @@ -1,12 +1,13 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Monomorphizer.Monomorphizer (monomorphize) where -import Data.Coerce (coerce) +import Data.Coerce (coerce) -import Monomorphizer.MonomorphizerIr qualified as M -import TypeChecker.TypeCheckerIr qualified as T +import qualified Monomorphizer.MonomorphizerIr as M +import qualified TypeChecker.TypeCheckerIr as T +import TypeChecker.TypeCheckerIr (Ident (..)) monomorphize :: T.Program -> M.Program monomorphize (T.Program ds) = M.Program $ monoDefs ds @@ -16,40 +17,40 @@ monoDefs = map monoDef monoDef :: T.Def -> M.Def monoDef (T.DBind bind) = M.DBind $ monoBind bind -monoDef (T.DData d) = M.DData $ monoData d +--monoDef (T.DData d) = M.DData $ monoData d monoBind :: T.Bind -> M.Bind monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t) -monoData :: T.Data -> M.Data -monoData (T.Data (T.Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs) +--monoData :: T.Data -> M.Data +--monoData (T.Data (Ident id) cs) = M.Data (M.TLit (M.Ident id)) (map monoConstructor cs) -monoConstructor :: T.Constructor -> M.Constructor -monoConstructor (T.Constructor (T.Ident i) t) = M.Constructor (M.Ident i) (monoType t) +monoConstructor :: T.Inj -> M.Inj +monoConstructor (T.Inj (Ident i) t) = M.Inj (M.Ident i) (monoType t) monoExpr :: T.Exp -> M.Exp monoExpr = \case - T.EId (T.Ident i) -> M.EId (M.Ident i) - T.ELit lit -> M.ELit $ monoLit lit - T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt) + T.EVar (Ident i) -> M.EVar (M.Ident i) + T.ELit lit -> M.ELit $ monoLit lit + T.ELet bind expt -> M.ELet (monoBind bind) (monoexpt expt) T.EApp expt1 expt2 -> M.EApp (monoexpt expt1) (monoexpt expt2) T.EAdd expt1 expt2 -> M.EAdd (monoexpt expt1) (monoexpt expt2) - T.EAbs _i _expt -> error "BUG" - T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs) + T.EAbs _i _expt -> error "BUG" + T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs) monoAbsType :: T.Type -> M.Type -monoAbsType (T.TLit u) = M.TLit (coerce u) -monoAbsType (T.TVar _v) = M.TLit "Int" +monoAbsType (T.TLit u) = M.TLit (coerce u) +monoAbsType (T.TVar _v) = M.TLit "Int" monoAbsType (T.TAll _v _t) = error "NOT ALL TYPES" monoAbsType (T.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2) -monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES" +monoAbsType (T.TData _ _) = error "NOT INDEXED TYPES" monoType :: T.Type -> M.Type -monoType (T.TAll _ t) = monoType t +monoType (T.TAll _ t) = monoType t monoType (T.TVar (T.MkTVar i)) = M.TLit "Int" -monoType (T.TLit (T.Ident i)) = M.TLit (M.Ident i) -monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) -monoType (T.TData (T.Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t)) +monoType (T.TLit (Ident i)) = M.TLit (M.Ident i) +monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) +monoType (T.TData (Ident n) t) = M.TLit (M.Ident (n ++ concatMap show t)) monoexpt :: T.ExpT -> M.ExpT monoexpt (e, t) = (monoExpr e, monoType t) @@ -58,19 +59,19 @@ monoId :: T.Id -> M.Id monoId (n, t) = (coerce n, monoType t) monoLit :: T.Lit -> M.Lit -monoLit (T.LInt i) = M.LInt i +monoLit (T.LInt i) = M.LInt i monoLit (T.LChar c) = M.LChar c monoInjs :: [T.Branch] -> [M.Branch] monoInjs = map monoInj monoInj :: T.Branch -> M.Branch -monoInj (T.Branch (init, t) expt) = M.Branch (monoInit init, monoType t) (monoexpt expt) +monoInj (T.Branch (patt, t) expt) = M.Branch (monoPattern patt, monoType t) (monoexpt expt) -monoInit :: T.Pattern -> M.Pattern -monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t) -monoInit (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t) -monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps) +monoPattern :: T.Pattern -> M.Pattern +monoPattern (T.PVar (id, t)) = M.PVar (id, monoType t) +monoPattern (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t) +monoPattern (T.PInj id ps) = M.PInj (coerce id) (map monoPattern ps) -- DO NOT DO THIS FOR REAL THOUGH -monoInit (T.PEnum (T.Ident i)) = M.PInj (M.Ident i) [] -monoInit T.PCatch = M.PCatch +monoPattern (T.PEnum (Ident i)) = M.PInj (M.Ident i) [] +monoPattern T.PCatch = M.PCatch diff --git a/src/Monomorphizer/MonomorphizerIr.hs b/src/Monomorphizer/MonomorphizerIr.hs index e0e7383..c80ad65 100644 --- a/src/Monomorphizer/MonomorphizerIr.hs +++ b/src/Monomorphizer/MonomorphizerIr.hs @@ -11,14 +11,14 @@ newtype Program = Program [Def] data Def = DBind Bind | DData Data deriving (Show, Ord, Eq) -data Data = Data Type [Constructor] +data Data = Data Type [Inj] deriving (Show, Ord, Eq) data Bind = Bind Id [Id] ExpT deriving (Show, Ord, Eq) data Exp - = EId Ident + = EVar Ident | ELit Lit | ELet Bind ExpT | EApp ExpT ExpT @@ -35,12 +35,12 @@ data Branch = Branch (Pattern, Type) ExpT type ExpT = (Exp, Type) -data Constructor = Constructor Ident Type +data Inj = Inj Ident Type deriving (Show, Ord, Eq) data Lit = LInt Integer - | LChar Char + | LChar Character deriving (Show, Ord, Eq) data Type = TLit Ident | TFun Type Type diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index 5576793..0a67e22 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -1,131 +1,124 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} - -{-# HLINT ignore "Use mapAndUnzipM" #-} module Renamer.Renamer (rename) where -import Auxiliary (mapAccumM) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad (foldM) -import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) -import Control.Monad.Identity (Identity, runIdentity) -import Control.Monad.State ( - MonadState, - StateT, - evalStateT, - gets, - modify, - ) -import Data.Coerce (coerce) -import Data.Function (on) -import Data.Map (Map) -import Data.Map qualified as Map -import Data.Maybe (fromMaybe) -import Data.Tuple.Extra (dupe) -import Grammar.Abs +import Auxiliary (mapAccumM) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.Except (ExceptT, MonadError (throwError), + runExceptT) +import Control.Monad.State (MonadState, State, evalState, gets, + mapAndUnzipM, modify) +import Data.Function (on) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe, second) +import Grammar.Abs +import Grammar.ErrM (Err) + -- | Rename all variables and local binds -rename :: Program -> Either String Program +rename :: Program -> Err Program rename (Program defs) = Program <$> renameDefs defs -renameDefs :: [Def] -> Either String [Def] -renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt - where - initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs] - - renameDef :: Def -> Rn Def - renameDef = \case - DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ - DBind bind -> DBind . snd <$> renameBind initNames bind - DData (Data (TData cname types) constrs) -> do - tvars_ <- tvars - tvars' <- mapM nextNameTVar tvars_ - let tvars_lt = zip tvars_ tvars' - typ' = map (substituteTVar tvars_lt) types - constrs' = map (renameConstr tvars_lt) constrs - pure . DData $ Data (TData cname typ') constrs' - where - tvars = concat <$> mapM (collectTVars []) types - collectTVars :: [TVar] -> Type -> Rn [TVar] - collectTVars tvars = \case - TAll tvar t -> collectTVars (tvar : tvars) t - TData _ _ -> return tvars - -- Should be monad error - TVar v -> return [v] - _ -> throwError ("Bad data type definition: " ++ show types) - DData (Data types _) -> throwError ("Bad data type definition: " ++ show types) - - renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor - renameConstr new_types (Constructor name typ) = - Constructor name $ substituteTVar new_types typ - -renameBind :: Names -> Bind -> Rn (Names, Bind) -renameBind old_names (Bind name vars rhs) = do - (new_names, vars') <- newNames old_names (coerce vars) - (newer_names, rhs') <- renameExp new_names rhs - pure (newer_names, Bind name (coerce vars') rhs') - -substituteTVar :: [(TVar, TVar)] -> Type -> Type -substituteTVar new_names typ = case typ of - TLit _ -> typ - TVar tvar - | Just tvar' <- lookup tvar new_names -> - TVar tvar' - | otherwise -> - typ - TFun t1 t2 -> on TFun substitute' t1 t2 - TAll tvar t - | Just tvar' <- lookup tvar new_names -> - TAll tvar' $ substitute' t - | otherwise -> - TAll tvar $ substitute' t - TData name typs -> TData name $ map substitute' typs - _ -> error ("Impossible " ++ show typ) - where - substitute' = substituteTVar new_names - initCxt :: Cxt initCxt = Cxt 0 0 -data Cxt = Cxt - { var_counter :: Int - , tvar_counter :: Int - } +data Cxt = Cxt { var_counter :: Int + , tvar_counter :: Int + } + -- | Rename monad. State holds the number of renamed names. -newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a} - deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) +newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a } + deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) -- | Maps old to new name -type Names = Map LIdent LIdent +type Names = Map String String + +renameDefs :: [Def] -> Err [Def] +renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt + where + initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs] + + renameDef :: Def -> Rn Def + renameDef = \case + DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ + DBind (Bind name vars rhs) -> do + (new_names, vars') <- newNamesL initNames vars + rhs' <- snd <$> renameExp new_names rhs + pure . DBind $ Bind name vars' rhs' + DData (Data typ injs) -> do + tvars <- collectTVars [] typ + tvars' <- mapM nextNameTVar tvars + let tvars_lt = zip tvars tvars' + typ' = substituteTVar tvars_lt typ + injs' = map (renameInj tvars_lt) injs + pure . DData $ Data typ' injs' + where + collectTVars tvars = \case + TAll tvar t -> collectTVars (tvar:tvars) t + TData _ _ -> pure tvars + _ -> throwError ("Bad data type definition: " ++ show typ) + + renameInj :: [(TVar, TVar)] -> Inj -> Inj + renameInj new_types (Inj name typ) = + Inj name $ substituteTVar new_types typ + + +substituteTVar :: [(TVar, TVar)] -> Type -> Type +substituteTVar new_names typ = case typ of + TLit _ -> typ + + TVar tvar | Just tvar' <- lookup tvar new_names + -> TVar tvar' + | otherwise + -> typ + + TFun t1 t2 -> on TFun substitute' t1 t2 + + TAll tvar t | Just tvar' <- lookup tvar new_names + -> TAll tvar' $ substitute' t + | otherwise + -> TAll tvar $ substitute' t + + TData name typs -> TData name $ map substitute' typs + _ -> error ("Impossible " ++ show typ) + where + substitute' = substituteTVar new_names + renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp old_names = \case - EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names) - EInj n -> pure (old_names, EInj n) - ELit lit -> pure (old_names, ELit lit) + EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names) + EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names) + + ELit lit -> pure (old_names, ELit lit) + EApp e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, EApp e1' e2') + EAdd e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, EAdd e1' e2') -- TODO fix shadowing - ELet bind e -> do - (new_names, bind') <- renameBind old_names bind - (new_names', e') <- renameExp new_names e - pure (new_names', ELet bind' e') - EAbs par e -> do - (new_names, par') <- newName old_names (coerce par) - (new_names', e') <- renameExp new_names e - pure (new_names', EAbs (coerce par') e') + ELet (Bind name vars rhs) e -> do + (new_names, name') <- newNameL old_names name + (new_names', vars') <- newNamesL new_names vars + (new_names'', rhs') <- renameExp new_names' rhs + (new_names''', e') <- renameExp new_names'' e + pure (new_names''', ELet (Bind name' vars' rhs') e') + + EAbs par e -> do + (new_names, par') <- newNameL old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' e') + EAnn e t -> do (new_names, e') <- renameExp old_names e t' <- renameTVars t @@ -137,26 +130,23 @@ renameExp old_names = \case renameBranches :: Names -> [Branch] -> Rn (Names, [Branch]) renameBranches ns xs = do - (new_names, xs') <- unzip <$> mapM (renameBranch ns) xs + (new_names, xs') <- mapAndUnzipM (renameBranch ns) xs if null new_names then return (mempty, xs') else return (head new_names, xs') renameBranch :: Names -> Branch -> Rn (Names, Branch) -renameBranch ns (Branch init e) = do - (new_names, init') <- renamePattern ns init +renameBranch ns (Branch patt e) = do + (new_names, patt') <- renamePattern ns patt (new_names', e') <- renameExp new_names e - return (new_names', Branch init' e') + return (new_names', Branch patt' e') renamePattern :: Names -> Pattern -> Rn (Names, Pattern) -renamePattern ns i = case i of +renamePattern ns p = case p of PInj cs ps -> do - (ns_new, ps) <- renamePatterns ns ps - return (ns_new, PInj cs ps) - rest -> return (ns, rest) + (ns_new, ps') <- mapAccumM renamePattern ns ps + return (ns_new, PInj cs ps') + PVar name -> second PVar <$> newNameL ns name + _ -> return (ns, p) -renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern]) -renamePatterns ns xs = do - (new_names, xs') <- unzip <$> mapM (renamePattern ns) xs - if null new_names then return (mempty, xs') else return (head new_names, xs') renameTVars :: Type -> Rn Type renameTVars typ = case typ of @@ -167,44 +157,57 @@ renameTVars typ = case typ of TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) _ -> pure typ -substitute :: - TVar -> -- α - TVar -> -- α_n - Type -> -- A - Type -- [α_n/α]A +substitute :: TVar -- α + -> TVar -- α_n + -> Type -- A + -> Type -- [α_n/α]A substitute tvar1 tvar2 typ = case typ of - TLit _ -> typ - TVar tvar' - | tvar' == tvar1 -> TVar tvar2 - | otherwise -> typ - TFun t1 t2 -> on TFun substitute' t1 t2 - TAll tvar t -> TAll tvar $ substitute' t - TData name typs -> TData name $ map substitute' typs - _ -> error "Impossible" + TLit _ -> typ + TVar tvar | tvar == tvar1 -> TVar tvar2 + | otherwise -> typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t + | otherwise -> TAll tvar $ substitute' t + TData name typs -> TData name $ map substitute' typs + _ -> error "Impossible" where substitute' = substitute tvar1 tvar2 --- | Create a new name and add it to name environment. -newName :: Names -> LIdent -> Rn (Names, LIdent) -newName env old_name = do - new_name <- makeName old_name - pure (Map.insert old_name new_name env, new_name) + -- | Create multiple names and add them to the name environment -newNames :: Names -> [LIdent] -> Rn (Names, [LIdent]) -newNames = mapAccumM newName +newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent]) +newNamesL = mapAccumM newNameL + +-- | Create a new name and add it to name environment. +newNameL :: Names -> LIdent -> Rn (Names, LIdent) +newNameL env (LIdent old_name) = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, LIdent new_name) + + +-- | Create multiple names and add them to the name environment +newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent]) +newNamesU = mapAccumM newNameU + +-- | Create a new name and add it to name environment. +newNameU :: Names -> UIdent -> Rn (Names, UIdent) +newNameU env (UIdent old_name) = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, UIdent new_name) + -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -makeName :: LIdent -> Rn LIdent -makeName (LIdent prefix) = do - i <- gets var_counter - let name = LIdent $ prefix ++ "_" ++ show i - modify $ \cxt -> cxt{var_counter = succ cxt.var_counter} - pure name +makeName :: String -> Rn String +makeName prefix = do + i <- gets var_counter + let name = prefix ++ "_" ++ show i + modify $ \cxt -> cxt { var_counter = succ cxt.var_counter} + pure name nextNameTVar :: TVar -> Rn TVar -nextNameTVar (MkTVar (LIdent s)) = do - i <- gets tvar_counter - let tvar = MkTVar $ coerce $ s ++ "_" ++ show i - modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter} - pure tvar +nextNameTVar (MkTVar (LIdent s))= do + i <- gets tvar_counter + let tvar = MkTVar . LIdent $ s ++ "_" ++ show i + modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter} + pure tvar diff --git a/src/Renamer/RenamerOld.hs b/src/Renamer/RenamerOld.hs new file mode 100644 index 0000000..bf21c9f --- /dev/null +++ b/src/Renamer/RenamerOld.hs @@ -0,0 +1,206 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Use mapAndUnzipM" #-} + +module Renamer.Renamer (rename) where + +import Auxiliary (mapAccumM) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad (foldM) +import Control.Monad.Except (ExceptT, MonadError, runExceptT, + throwError) +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.State (MonadState, StateT, evalStateT, gets, + modify) +import Data.Coerce (coerce) +import Data.Function (on) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) +import Grammar.Abs + +-- | Rename all variables and local binds +rename :: Program -> Either String Program +rename (Program defs) = Program <$> renameDefs defs + +renameDefs :: [Def] -> Either String [Def] +renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt + where + initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs] + + renameDef :: Def -> Rn Def + renameDef = \case + DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ + DBind bind -> DBind . snd <$> renameBind initNames bind + DData (Data (TData cname types) constrs) -> do + tvars_ <- tvars + tvars' <- mapM nextNameTVar tvars_ + let tvars_lt = zip tvars_ tvars' + typ' = map (substituteTVar tvars_lt) types + constrs' = map (renameConstr tvars_lt) constrs + pure . DData $ Data (TData cname typ') constrs' + where + tvars = concat <$> mapM (collectTVars []) types + collectTVars :: [TVar] -> Type -> Rn [TVar] + collectTVars tvars = \case + TAll tvar t -> collectTVars (tvar : tvars) t + TData _ _ -> return tvars + -- Should be monad error + TVar v -> return [v] + _ -> throwError ("Bad data type definition: " ++ show types) + DData (Data types _) -> throwError ("Bad data type definition: " ++ show types) + + renameConstr :: [(TVar, TVar)] -> Inj -> Inj + renameConstr new_types (Inj name typ) = + Inj name $ substituteTVar new_types typ + +renameBind :: Names -> Bind -> Rn (Names, Bind) +renameBind old_names (Bind name vars rhs) = do + (new_names, vars') <- newNames old_names (coerce vars) + (newer_names, rhs') <- renameExp new_names rhs + pure (newer_names, Bind name (coerce vars') rhs') + +substituteTVar :: [(TVar, TVar)] -> Type -> Type +substituteTVar new_names typ = case typ of + TLit _ -> typ + TVar tvar + | Just tvar' <- lookup tvar new_names -> + TVar tvar' + | otherwise -> + typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar t + | Just tvar' <- lookup tvar new_names -> + TAll tvar' $ substitute' t + | otherwise -> + TAll tvar $ substitute' t + TData name typs -> TData name $ map substitute' typs + _ -> error ("Impossible " ++ show typ) + where + substitute' = substituteTVar new_names + +initCxt :: Cxt +initCxt = Cxt 0 0 + +data Cxt = Cxt + { var_counter :: Int + , tvar_counter :: Int + } + +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a} + deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) + +-- | Maps old to new name +type Names = Map LIdent LIdent + +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case + EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names) + EInj n -> pure (old_names, EInj n) + ELit lit -> pure (old_names, ELit lit) + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') + + -- TODO fix shadowing + ELet bind e -> do + (new_names, bind') <- renameBind old_names bind + (new_names', e') <- renameExp new_names e + pure (new_names', ELet bind' e') + EAbs par e -> do + (new_names, par') <- newName old_names (coerce par) + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs (coerce par') e') + EAnn e t -> do + (new_names, e') <- renameExp old_names e + t' <- renameTVars t + pure (new_names, EAnn e' t') + ECase e injs -> do + (new_names, e') <- renameExp old_names e + (new_names', injs') <- renameBranches new_names injs + pure (new_names', ECase e' injs') + +renameBranches :: Names -> [Branch] -> Rn (Names, [Branch]) +renameBranches ns xs = do + (new_names, xs') <- unzip <$> mapM (renameBranch ns) xs + if null new_names then return (mempty, xs') else return (head new_names, xs') + +renameBranch :: Names -> Branch -> Rn (Names, Branch) +renameBranch ns (Branch init e) = do + (new_names, init') <- renamePattern ns init + (new_names', e') <- renameExp new_names e + return (new_names', Branch init' e') + +renamePattern :: Names -> Pattern -> Rn (Names, Pattern) +renamePattern ns i = case i of + PInj cs ps -> do + (ns_new, ps) <- renamePatterns ns ps + return (ns_new, PInj cs ps) + rest -> return (ns, rest) + +renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern]) +renamePatterns ns xs = do + (new_names, xs') <- unzip <$> mapM (renamePattern ns) xs + if null new_names then return (mempty, xs') else return (head new_names, xs') + +renameTVars :: Type -> Rn Type +renameTVars typ = case typ of + TAll tvar t -> do + tvar' <- nextNameTVar tvar + t' <- renameTVars $ substitute tvar tvar' t + pure $ TAll tvar' t' + TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) + _ -> pure typ + +substitute :: + TVar -> -- α + TVar -> -- α_n + Type -> -- A + Type -- [α_n/α]A +substitute tvar1 tvar2 typ = case typ of + TLit _ -> typ + TVar tvar' + | tvar' == tvar1 -> TVar tvar2 + | otherwise -> typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar t -> TAll tvar $ substitute' t + TData name typs -> TData name $ map substitute' typs + _ -> error "Impossible" + where + substitute' = substitute tvar1 tvar2 + +-- | Create a new name and add it to name environment. +newName :: Names -> LIdent -> Rn (Names, LIdent) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) + +-- | Create multiple names and add them to the name environment +newNames :: Names -> [LIdent] -> Rn (Names, [LIdent]) +newNames = mapAccumM newName + +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: LIdent -> Rn LIdent +makeName (LIdent prefix) = do + i <- gets var_counter + let name = LIdent $ prefix ++ "_" ++ show i + modify $ \cxt -> cxt{var_counter = succ cxt.var_counter} + pure name + +nextNameTVar :: TVar -> Rn TVar +nextNameTVar (MkTVar (LIdent s)) = do + i <- gets tvar_counter + let tvar = MkTVar $ coerce $ s ++ "_" ++ show i + modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter} + pure tvar diff --git a/src/TypeChecker/RemoveTEVar.hs b/src/TypeChecker/RemoveTEVar.hs new file mode 100644 index 0000000..b83a134 --- /dev/null +++ b/src/TypeChecker/RemoveTEVar.hs @@ -0,0 +1,73 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeChecker.RemoveTEVar where + +import Control.Applicative (Applicative (liftA2), liftA3) +import Control.Arrow (Arrow (second)) +import Control.Monad.Error (MonadError (throwError)) +import Data.Coerce (coerce) +import Data.Function (on) +import Data.Tuple.Extra (secondM) +import Grammar.Abs +import Grammar.ErrM (Err) +import qualified TypeChecker.TypeCheckerIr as T + +class RemoveTEVar a b where + rmTEVar :: a -> Err b + +instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where + rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs + +instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where + rmTEVar = \case + T.DBind bind -> T.DBind <$> rmTEVar bind + T.DData dat -> T.DData <$> rmTEVar dat + +instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where + rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs) + +instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where + rmTEVar exp = case exp of + T.EVar name -> pure $ T.EVar name + T.EInj name -> pure $ T.EInj name + T.ELit lit -> pure $ T.ELit lit + T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e) + T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2) + T.EAdd e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2) + T.EAbs name e -> T.EAbs name <$> rmTEVar e + T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches) + +instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where + rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e) + +instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where + rmTEVar = \case + T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t + T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t + T.PCatch -> pure T.PCatch + T.PEnum name -> pure $ T.PEnum name + T.PInj name ps -> T.PInj name <$> rmTEVar ps + +instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where + rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs) + +instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where + rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ + +instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where + rmTEVar = secondM rmTEVar + +instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where + rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ) + +instance RemoveTEVar a b => RemoveTEVar [a] [b] where + rmTEVar = mapM rmTEVar + +instance RemoveTEVar Type T.Type where + rmTEVar = \case + TLit lit -> pure $ T.TLit (coerce lit) + TVar tvar -> pure $ T.TVar tvar + TData name typs -> T.TData (coerce name) <$> rmTEVar typs + TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2) + TAll tvar t -> T.TAll tvar <$> rmTEVar t + TEVar _ -> throwError "NewType TEVar!" diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs new file mode 100644 index 0000000..7cb0081 --- /dev/null +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -0,0 +1,858 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# OPTIONS_GHC -Wno-incomplete-patterns #-} + +module TypeChecker.TypeCheckerBidir (typecheck, getVars) where + +import Auxiliary (maybeToRightM, snoc) +import Control.Applicative (Alternative, Applicative (liftA2), + (<|>)) +import Control.Monad.Except (ExceptT, MonadError (throwError), + runExceptT, unless, zipWithM, + zipWithM_) +import Control.Monad.State (MonadState (get, put), State, + evalState, gets, modify) +import Data.Coerce (coerce) +import Data.Function (on) +import Data.List (intercalate) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe, isNothing) +import Data.Sequence (Seq (..)) +import qualified Data.Sequence as S +import Data.Tuple.Extra (second, secondM) +import Debug.Trace (trace) +import Grammar.Abs +import Grammar.ErrM +import Grammar.Print (printTree) +import Prelude hiding (exp, id) +import qualified TypeChecker.TypeCheckerIr as T + +-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) +-- https://doi.org/10.1145/2500365.2500582 + +data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A + | EnvTVar TVar -- ^ Universal type variable. α + | EnvTEVar TEVar -- ^ Existential unsolved type variable. ά + | EnvTEVarSolved TEVar Type -- ^ Existential solved type variable. ά = τ + | EnvMark TEVar -- ^ Scoping Marker. ▶ ά + deriving (Eq, Show) + +type Env = Seq EnvElem + +-- | Ordered context +-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A +data Cxt = Cxt + { env :: Env -- ^ Local scope context Γ + , sig :: Map LIdent Type -- ^ Top-level signatures x : A + , binds :: Map LIdent Exp -- ^ Top-level binds x : e + , next_tevar :: Int -- ^ Counter to distinguish ά + , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K + } deriving (Show, Eq) + +newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } + deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) + +typecheck :: Program -> Err (T.Program' Type) +typecheck (Program defs) = do + datatypes <- mapM typecheckDataType [ d | DData d <- defs ] + + + let initCxt = Cxt + { env = mempty + , sig = Map.fromList [ (name, t) + | DSig' name t <- defs + ] + , binds = Map.fromList [ (name, foldr EAbs rhs vars) + | DBind' name vars rhs <- defs + ] + , next_tevar = 0 + , data_injs = Map.fromList [ (name, typ) + | Data _ injs <- datatypes + , Inj name typ <- injs + ] + } + + binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt; + pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds' + where + binds = [ b | DBind b <- defs ] + coerceData = map (\(Data t injs) -> T.Data t $ map + (\(Inj name typ) -> T.Inj (coerce name) typ) injs) + + +typecheckBind :: Bind -> Tc (T.Bind' Type) +typecheckBind (Bind name vars rhs) = do + bind' <- lookupSig name >>= \case + -- TODO These Judgment aren't accurate + -- (f:A → B) ∈ Γ + -- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ + --------------------------- + -- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ + Just t -> do + (rhs', _) <- check (foldr EAbs rhs vars) t + pure (T.Bind (coerce name, t) (coerce vars') (rhs', t)) + where + vars' = zip vars $ getVars t + + -- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ + -- ------------------------------ + -- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ + Nothing -> do + (e, t) <- infer $ foldr EAbs rhs vars + t' <- applyEnv t + e' <- applyEnvExp e + let rhs' = skipLambdas (length vars) e' + vars' = zip vars $ getVars t' + pure (T.Bind (coerce name, t') (coerce vars') (rhs', t')) + env <- gets env + unless (isComplete env) err + putEnv Empty + pure bind' + where + err = throwError $ unlines + [ "Type inference failed: " ++ printTree (Bind name vars rhs) + , "Did you forget to add type annotation to a polymorphic function?" + ] + +typecheckDataType :: Data -> Err Data +typecheckDataType (Data typ injs) = do + (name, tvars) <- go [] typ + injs' <- mapM (\i -> typecheckInj i name tvars) injs + pure (Data typ injs') + where + go tvars = \case + TAll tvar t -> go (tvar:tvars) t + TData name typs + | Right tvars' <- mapM toTVar typs + , all (`elem` tvars) tvars' + -> pure (name, tvars') + _ -> throwError $ unwords ["Bad data type definition: ", ppT typ] + +typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj +typecheckInj (Inj inj_name inj_typ) name tvars + | not $ boundTVars tvars inj_typ + = throwError "Unbound type variables" + | TData name' typs <- getReturn inj_typ + , name' == name + , Right tvars' <- mapM toTVar typs + , tvars' == tvars + = pure (Inj inj_name $ foldr TAll inj_typ tvars) + | otherwise + = throwError $ unwords + ["Bad type constructor: ", show name + , "\nExpected: ", ppT . TData name $ map TVar tvars + , "\nActual: ", ppT $ getReturn inj_typ + ] + where + boundTVars :: [TVar] -> Type -> Bool + boundTVars tvars' = \case + TAll tvar t -> boundTVars (tvar:tvars') t + TFun t1 t2 -> on (&&) (boundTVars tvars') t1 t2 + TVar tvar -> elem tvar tvars' + TData _ typs -> all (boundTVars tvars) typs + TLit _ -> True + TEVar _ -> error "TEVar in data type declaration" + +--------------------------------------------------------------------------- +-- * Subtyping rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ A <: B ⊣ Δ +-- Under input context Γ, type A is a subtype of B, with output context ∆ +subtype :: Type -> Type -> Tc () +subtype t1 t2 = case (t1, t2) of + + (TLit lit1, TLit lit2) | lit1 == lit2 -> pure () + + -- -------------------- <:Var + -- Γ[α] ⊢ α <: α ⊣ Γ[α] + (TVar tvar1, TVar tvar2) | tvar1 == tvar2 -> pure () + + -- -------------------- <:Exvar + -- Γ[ά] ⊢ ά <: ά ⊣ Γ[ά] + (TEVar tevar1, TEVar tevar2) | tevar1 == tevar2 -> pure () + + -- Γ ⊢ B₁ <: A₁ ⊣ Θ Θ ⊢ [Θ]A₂ <: [Θ]B₂ ⊣ Δ + -- ----------------------------------------- <:→ + -- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ + (TFun a1 a2, TFun b1 b2) -> do + subtype b1 a1 + a2' <- applyEnv a2 + b2' <- applyEnv b2 + subtype a2' b2' + + -- Γ, α ⊢ A <: B ⊣ Δ,α,Θ + -- --------------------- <:∀R + -- Γ ⊢ A <: ∀α. B ⊣ Δ + (a, TAll tvar b) -> do + let env_tvar = EnvTVar tvar + insertEnv env_tvar + subtype a b + dropTrailing env_tvar + + -- Γ,▶ ά,ά ⊢ [ά/α]A <: B ⊣ Δ,▶ ά,Θ + -- ------------------------------- <:∀L + -- Γ ⊢ ∀α.A <: B ⊣ Δ + (TAll tvar a, b) -> do + tevar <- fresh + let env_marker = EnvMark tevar + env_tevar = EnvTEVar tevar + insertEnv env_marker + insertEnv env_tevar + let a' = substitute tvar tevar a + subtype a' b + dropTrailing env_marker + + -- ά ∉ FV(A) Γ[ά] ⊢ ά :=< A ⊣ Δ + -- ------------------------------ <:instantiateL + -- Γ[ά] ⊢ ά <: A ⊣ Δ + (TEVar tevar, typ) | notElem tevar $ frees typ -> instantiateL tevar typ + + -- ά ∉ FV(A) Γ[ά] ⊢ A =:< ά ⊣ Δ + -- ------------------------------ <:instantiateR + -- Γ[ά] ⊢ A <: ά ⊣ Δ + (typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar + + (TData name1 typs1, TData name2 typs2) + + -- D₁ = D₂ + -- ---------------- + -- Γ ⊢ D₁ () <: D₂ () + | name1 == name2 + , [] <- typs1 + , [] <- typs2 + -> pure () + + -- Γ ⊢ ά₁ <: έ₁ ⊣ Θ₁ + -- ... + -- D₁ = D₂ Θₙ₋₁ ⊢ [Θₙ₋₁]άₙ <: [Θₙ₋₁]έₙ ⊣ Δ + -- ------------------------------------------- + -- Γ ⊢ D (ά₁ ‥ άₙ) <: D (έ₁ ‥ έₙ) ⊣ Δ + | name1 == name2 + , t1:t1s <- typs1 + , t2:t2s <- typs2 + -> do + subtype t1 t2 + zipWithM_ go t1s t2s + where + go t1' t2' = do + t1'' <- applyEnv t1' + t2'' <- applyEnv t2' + subtype t1'' t2'' + + _ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"] + +--------------------------------------------------------------------------- +-- * Instantiation rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ ά :=< A ⊣ Δ +-- Under input context Γ, instantiate ά such that ά <: A, with output context ∆ +instantiateL :: TEVar -> Type -> Tc () +instantiateL tevar typ = gets env >>= go + where + go env + + -- Γ ⊢ τ + -- ----------------------------- InstLSolve + -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' + | isMono typ + , (env_l, env_r) <- splitOn (EnvTEVar tevar) env + , Right _ <- wellFormed env_l typ + = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r + + | TEVar tevar' <- typ = instReach tevar tevar' + + -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ =:< ά₁ ⊣ Θ Θ ⊢ ά₂ :=< [Θ]A₂ ⊣ Δ + -- ------------------------------------------------------- InstLArr + -- Γ[ά] ⊢ ά :=< A₁ → A₂ ⊣ Δ + | TFun t1 t2 <- typ = do + tevar1 <- fresh + tevar2 <- fresh + insertEnv $ EnvTEVar tevar2 + insertEnv $ EnvTEVar tevar1 + insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) + instantiateR t1 tevar1 + instantiateL tevar2 =<< applyEnv t2 + + -- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ' + -- ------------------------- InstLAIIR + -- Γ[ά] ⊢ ά :=< ∀ε.Ε ⊣ Δ + | TAll tvar t <- typ = do + instantiateL tevar t + let (env_l, _) = splitOn (EnvTVar tvar) env + putEnv env_l + + | otherwise = error $ "Trying to instantiateL: " ++ ppT (TEVar tevar) + ++ " <: " ++ ppT typ + +-- | Γ ⊢ A =:< ά ⊣ Δ +-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆ +instantiateR :: Type -> TEVar -> Tc () +instantiateR typ tevar = gets env >>= go + where + go env + + -- Γ ⊢ τ + -- ----------------------------- InstRSolve + -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' + | isMono typ + , (env_l, env_r) <- splitOn (EnvTEVar tevar) env + , Right _ <- wellFormed env_l typ + = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r + + | TEVar tevar' <- typ = instReach tevar tevar' + + -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ + -- ------------------------------------------------------- InstRArr + -- Γ[ά] ⊢ ά =:< A₁ → A₂ ⊣ Δ + | TFun t1 t2 <- typ = do + tevar1 <- fresh + tevar2 <- fresh + insertEnv $ EnvTEVar tevar2 + insertEnv $ EnvTEVar tevar1 + insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) + instantiateL tevar1 t1 + t2' <- applyEnv t2 + instantiateR t2' tevar2 + + -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' + -- ---------------------------------- InstRAIIL + -- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ + | TAll tvar t <- typ = do + tevar' <- fresh + insertEnv $ EnvMark tevar' + insertEnv $ EnvTVar tvar + let t' = substitute tvar tevar' t + instantiateR t' tevar + let (env_l, _) = splitOn (EnvTVar tvar) env + putEnv env_l + + + | otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: " + ++ ppT (TEVar tevar) + + +-- ----------------------------- InstLReach +-- Γ[ά][έ] ⊢ ά :=< έ ⊣ Γ[ά][έ=ά] +-- +-- ----------------------------- InstRReach +-- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά] +instReach :: TEVar -> TEVar -> Tc () +instReach tevar tevar' = do + (env_l, env_r) <- gets (splitOn (EnvTEVar tevar') . env) + let env_solved = EnvTEVarSolved tevar' $ TEVar tevar + putEnv $ (env_l :|> env_solved) <> env_r + +--------------------------------------------------------------------------- +-- * Typing rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ e ↑ A ⊣ Δ +-- Under input context Γ, e checks against input type A, with output context ∆ +check :: Exp -> Type -> Tc (T.ExpT' Type) +check exp typ + + -- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ + -- ------------------- ∀I + -- Γ ⊢ e ↑ ∀α.A ⊣ Δ + | TAll tvar t <- typ = do + let env_tvar = EnvTVar tvar + insertEnv env_tvar + exp' <- check exp t + (env_l, _) <- gets (splitOn env_tvar . env) + putEnv env_l + pure exp' + + -- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ + -- --------------------------- →I + -- Γ ⊢ λx.e ↑ A → B ⊣ Δ + | EAbs name e <- exp + , TFun t1 t2 <- typ = do + let env_id = EnvVar name t1 + insertEnv env_id + e' <- check e t2 + (env_l, _) <- gets (splitOn env_id . env) + putEnv env_l + pure (T.EAbs (coerce name) e', typ) + + -- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ + -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO + -- --------------------------------------- + -- Γ ⊢ case e of Π ↑ C ⊣ Δ + | ECase scrut branches <- exp = do + (scrut', t_scrut) <- infer scrut + t_scrut' <- applyEnv t_scrut + typ' <- applyEnv typ + branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches + pure (T.ECase (scrut', t_scrut') branches', typ') + + | otherwise = subsumption + where + -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ + -- -------------------------------------- Sub + -- Γ ⊢ e ↑ B ⊣ Δ + subsumption = do + (exp', t) <- infer exp + exp'' <- applyEnvExp exp' + t' <- applyEnv t + typ' <- applyEnv typ + subtype t' typ' + pure (exp'', t') + +-- | Γ ⊢ e ↓ A ⊣ Δ +-- Under input context Γ, e infers output type A, with output context ∆ +infer :: Exp -> Tc (T.ExpT' Type) +infer = \case + + ELit lit -> pure (T.ELit lit, inferLit lit) + + -- (x : A) ∈ Γ + -- ------------- Var + -- Γ ⊢ x ↓ A ⊣ Γ + EVar name -> do + t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case + Just t -> pure t + Nothing -> do + e <- maybeToRightM + ("Unbound variable " ++ show name) + =<< lookupBind name + snd <$> infer e + pure (T.EVar (coerce name), t) + + EInj name -> do + t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name + pure (T.EInj $ coerce name, t) + + -- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ + -- --------------------- Anno + -- Γ ⊢ (e : A) ↓ A ⊣ Δ + EAnn e t -> do + _ <- gets $ (`wellFormed` t) . env + (e', _) <- check e t + pure (e', t) + + -- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ + -- ----------------------------------- →E + -- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ + EApp e1 e2 -> do + (e1', t) <- infer e1 + t' <- applyEnv t + e1'' <- applyEnvExp e1' + (e2', t'') <- apply t' e2 + pure (T.EApp (e1'', t) e2', t'') + + -- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ + -- ------------------------------- →I + -- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ + EAbs name e -> do + tevar1 <- fresh + tevar2 <- fresh + insertEnv $ EnvTEVar tevar1 + insertEnv $ EnvTEVar tevar2 + let env_id = EnvVar name (TEVar tevar1) + insertEnv env_id + e' <- check e $ TEVar tevar2 + dropTrailing env_id + let t_exp = on TFun TEVar tevar1 tevar2 + pure (T.EAbs (coerce name) e', t_exp) + + + -- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ + -- -------------------------------------------- LetI + -- Γ ⊢ let x=e in e' ↑ C ⊣ Δ + ELet (Bind name [] rhs) e -> do -- TODO vars + (rhs', t_rhs) <- infer rhs + let env_id = EnvVar name t_rhs + insertEnv env_id + (e', t) <- infer e + (env_l, _) <- gets (splitOn env_id . env) + putEnv env_l + pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t) + + -- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int + -- --------------------------- +I + -- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ + EAdd e1 e2 -> do + cxt <- get + let t = TLit "Int" + e1' <- check e1 t + put cxt + e2' <- check e2 t + pure (T.EAdd e1' e2', t) + +-- | Γ ⊢ A • e ⇓ C ⊣ Δ +-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ +-- Instantiate existential type variables until there is an arrow type. +apply :: Type -> Exp -> Tc (T.ExpT' Type, Type) +apply typ exp = case typ of + + -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ + -- ------------------------ ∀App + -- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ + TAll tvar t -> do + tevar <- fresh + insertEnv $ EnvTEVar tevar + let t' = substitute tvar tevar t + apply t' exp + + -- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ + -- ------------------------------- άApp + -- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ + TEVar tevar -> do + tevar1 <- fresh + tevar2 <- fresh + let env_tevar1 = EnvTEVar tevar1 + env_tevar2 = EnvTEVar tevar2 + t_fun = on TFun TEVar tevar1 tevar2 + env_tevar_solved = EnvTEVarSolved tevar t_fun + (env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env) + putEnv $ + (env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r + expT' <- check exp $ TEVar tevar1 + pure (expT', TEVar tevar2) + + -- Γ ⊢ e ↑ A ⊣ Δ + -- --------------------- →App + -- Γ ⊢ A → C • e ⇓ C ⊣ Δ + TFun t1 t2 -> do + expt' <- check exp t1 + pure (expt', t2) + + _ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp) + +--------------------------------------------------------------------------- +-- * Pattern matching +--------------------------------------------------------------------------- + +-- | Γ ⊢ p ⇒ e ∷ A ↑ C +-- Under context Γ, check branch p ⇒ e of type A and bodies of type C +checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type) +checkBranch (Branch patt exp) t_patt t_exp = do + env_marker <- EnvMark <$> fresh + insertEnv env_marker + patt' <- checkPattern patt t_patt + t_exp' <- applyEnv t_exp + (exp, t_exp) <- check exp t_exp' + (env_l, _) <- gets (splitOn env_marker . env) + putEnv env_l + pure (T.Branch patt' (exp, t_exp)) + +checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) +checkPattern patt t_patt = case patt of + PVar x -> do + insertEnv $ EnvVar x t_patt + pure (T.PVar (coerce x, dummy), dummy) -- TODO + PCatch -> pure (T.PCatch, dummy) -- TODO + PLit lit | inferLit lit == t_patt -> let + t = inferLit lit + in + pure (T.PLit (lit, t), t) + | otherwise -> throwError "Literal in pattern have wrong type" + + PEnum name -> do + t <- maybeToRightM ("Unknown constructor " ++ show name) + =<< lookupInj name + subtype t t_patt + pure (T.PEnum (coerce name), dummy) -- TODO + + PInj name ps -> do + t <- maybeToRightM ("Unknown constructor " ++ show name) + =<< lookupInj name + let (t_ps, t_return) = partitionTypeWithForall t + unless (length ps == length t_ps) $ + throwError "Wrong number of variables" + subtype t_return t_patt + ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps + let ps'' = map fst ps' -- TODO + pure (T.PInj (coerce name) ps'', dummy) + +--------------------------------------------------------------------------- +-- * Auxiliary +--------------------------------------------------------------------------- + +frees :: Type -> [TEVar] +frees = \case + TLit _ -> [] + TVar _ -> [] + TEVar tevar -> [tevar] + TFun t1 t2 -> on (++) frees t1 t2 + TAll _ t -> frees t + TData _ typs -> concatMap frees typs + +-- | [ά/α]A +substitute :: TVar -- α + -> TEVar -- ά + -> Type -- A + -> Type -- [ά/α]A +substitute tvar tevar typ = case typ of + TLit _ -> typ + TVar tvar' | tvar' == tvar -> TEVar tevar + | otherwise -> typ + TEVar _ -> typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar' t -> TAll tvar' (substitute' t) + TData name typs -> TData name $ map substitute' typs + where + substitute' = substitute tvar tevar + +-- | Γ,x,Γ' → (Γ, Γ') +splitOn :: EnvElem -> Env -> (Env, Env) +splitOn x env = second (S.drop 1) $ S.breakl (==x) env + +-- | Drop frontmost elements until and including element @x@. +dropTrailing :: EnvElem -> Tc () +dropTrailing x = modifyEnv $ S.takeWhileL (/= x) + +applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type) +applyEnvExp exp = case exp of + T.ELet (T.Bind id vars rhs) exp -> do + id <- applyEnvId id + vars' <- mapM applyEnvId vars + rhs' <- applyEnvExpT rhs + exp' <- applyEnvExpT exp + pure $ T.ELet (T.Bind id vars' rhs') exp' + T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2) + T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2) + T.EAbs name e -> T.EAbs name <$> applyEnvExpT e + T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e) + (mapM applyEnvBranch branches) + _ -> pure exp + where + applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t) + applyEnvId = secondM applyEnv + applyEnvBranch (T.Branch (p, t) e) = do + pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t) + e' <- applyEnvExpT e + pure $ T.Branch pt e' + applyEnvPattern = \case + T.PVar id -> T.PVar <$> applyEnvId id + T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t + T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps + p -> pure p + +applyEnv :: Type -> Tc Type +applyEnv t = gets $ (`applyEnv'` t) . env + +-- | [Γ]A. Applies context to type until fully applied. +applyEnv' :: Env -> Type -> Type +applyEnv' cxt typ | typ == typ' = typ' + | otherwise = applyEnv' cxt typ' + where + typ' = case typ of + TLit _ -> typ + TData name typs -> TData name $ map (applyEnv' cxt) typs + -- [Γ]α = α + TVar _ -> typ + -- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ + -- [Γ[ά]]ά = [Γ[ά]]ά + TEVar tevar -> fromMaybe typ $ findSolved tevar cxt + -- [Γ](A → B) = [Γ]A → [Γ]B + TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2 + -- [Γ](∀α. A) = (∀α. [Γ]A) + TAll tvar t -> TAll tvar $ applyEnv' cxt t + +findSolved :: TEVar -> Env -> Maybe Type +findSolved _ Empty = Nothing +findSolved tevar (xs :|> x) = case x of + EnvTEVarSolved tevar' t | tevar == tevar' -> Just t + _ -> findSolved tevar xs + +-- | Γ ⊢ A +-- Under context Γ, type A is well-formed +wellFormed :: Env -> Type -> Err () +wellFormed env = \case + TLit _ -> pure () + + -- -------- UvarWF + -- Γ[α] ⊢ α + TVar tvar -> unless (EnvTVar tvar `elem` env) $ + throwError ("Unbound type variable: " ++ show tvar) + -- Γ ⊢ A Γ ⊢ B + -- ------------- ArrowWF + -- Γ ⊢ A → B + TFun t1 t2 -> do { wellFormed env t1; wellFormed env t2 } + + -- Γ,α ⊢ A + -- -------- ForallWF + -- Γ ⊢ ∀α.A + TAll tvar t -> wellFormed (env :|> EnvTVar tvar) t + + TEVar tevar + -- ---------- EvarWF + -- Γ[ά] ⊢ ά + | EnvTEVar tevar `elem` env -> pure () + + -- ---------- SolvedEvarWF + -- Γ[ά=τ] ⊢ ά + | Just _ <- findSolved tevar env -> pure () + | otherwise -> throwError ("Can't find type: " ++ show tevar) + + TData _ typs -> mapM_ (wellFormed env) typs + +isMono :: Type -> Bool +isMono = \case + TAll{} -> False + TFun t1 t2 -> on (&&) isMono t1 t2 + TData _ typs -> all isMono typs + TVar _ -> True + TEVar _ -> True + TLit _ -> True + +inferLit :: Lit -> Type +inferLit = \case + LInt _ -> TLit "Int" + LChar _ -> TLit "Char" + +fresh :: Tc TEVar +fresh = do + tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar) + modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } + pure tevar + +getVars :: Type -> [Type] +getVars = fst . partitionType + +getReturn :: Type -> Type +getReturn = snd . partitionType + +-- | Partion type into variable types and return type. +-- +-- ∀a.∀b. a → (∀c. c → c) → b +-- ([a, ∀c. c → c], b) +-- +-- Unsure if foralls should be added to the return type or not. +partitionType :: Type -> ([Type], Type) +partitionType = go [] . skipForalls' + where + + go acc t = case t of + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> (acc, t) + +skipForalls' :: Type -> Type +skipForalls' = snd . skipForalls + +skipForalls :: Type -> ([Type -> Type], Type) +skipForalls = go [] + where + go acc typ = case typ of + TAll tvar t -> go (snoc (TAll tvar) acc) t + _ -> (acc, typ) + +partitionTypeWithForall :: Type -> ([Type], Type) +partitionTypeWithForall typ = (t_vars', t_return') + where + t_vars' = map (\t -> foldr applyForall t foralls) t_vars + t_return' = foldr applyForall t_return foralls + + applyForall fa t | usesTVar tvar t = fa t + | otherwise = t + where TAll tvar _ = fa t + + (t_vars, t_return) = go [] typ' + (foralls, typ') = skipForalls typ + + + go acc t = case t of + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> (acc, t) + +usesTVar :: TVar -> Type -> Bool +usesTVar tvar = \case + TLit _ -> False + TVar tvar' | tvar' == tvar -> True + | otherwise -> False + TFun t1 t2 -> on (||) usesTVar' t1 t2 + TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar" + | otherwise -> usesTVar' t + TData _ typs -> any usesTVar' typs + _ -> error "Impossible" + where + usesTVar' = usesTVar tvar + +skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type +skipLambdas i exp + | i == 0 = exp + | T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e + | otherwise = error "Number of expected lambdas doesn't match expression" + +isComplete :: Env -> Bool +isComplete = isNothing . S.findIndexL unSolvedTEVar + where + unSolvedTEVar = \case + EnvTEVar _ -> True + _ -> False + +toTVar :: Type -> Err TVar +toTVar = \case + TVar tvar -> pure tvar + _ -> throwError "Not a type variable" + +insertEnv :: EnvElem -> Tc () +insertEnv x = modifyEnv (:|> x) + +lookupBind :: LIdent -> Tc (Maybe Exp) +lookupBind x = gets (Map.lookup x . binds) + +lookupSig :: LIdent -> Tc (Maybe Type) +lookupSig x = gets (Map.lookup x . sig) + +lookupEnv :: LIdent -> Tc (Maybe Type) +lookupEnv x = gets (findId . env) + where + findId Empty = Nothing + findId (ys :|> y) = case y of + EnvVar x' t | x==x' -> Just t + _ -> findId ys + +lookupInj :: UIdent -> Tc (Maybe Type) +lookupInj x = gets (Map.lookup x . data_injs) + +putEnv :: Env -> Tc () +putEnv = modifyEnv . const + +modifyEnv :: (Env -> Env) -> Tc () +modifyEnv f = + modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env } + +pattern DBind' name vars exp = DBind (Bind name vars exp) +pattern DSig' name typ = DSig (Sig name typ) + +dummy = TLit "Int" + +--------------------------------------------------------------------------- +-- * Debug +--------------------------------------------------------------------------- + +traceEnv s = do + env <- gets env + trace (s ++ " " ++ show env) pure () + +traceD s x = trace (s ++ " " ++ show x) pure () + +traceT s x = trace (s ++ " " ++ ppT x) pure () + +traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure () + +ppT = \case + TLit (UIdent s) -> s + TVar (MkTVar (LIdent s)) -> "α_" ++ s + TFun t1 t2 -> ppT t1 ++ "→" ++ ppT t2 + TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t + TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s + TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs) + ++ " )" +ppEnvElem = \case + EnvVar (LIdent s) t -> s ++ ":" ++ ppT t + EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s + EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s + EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t + EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "ά_" ++ s + +ppEnv = \case + Empty -> "·" + (xs :|> x) -> ppEnv xs ++ " (" ++ ppEnvElem x ++ ")" diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index ba07616..adcf033 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,36 +1,31 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeChecker where -import Auxiliary -import Control.Monad.Except -import Control.Monad.Identity (runIdentity) -import Control.Monad.Reader -import Control.Monad.State -import Data.Bifunctor (second) -import Data.Coerce (coerce) -import Data.Foldable (traverse_) -import Data.Function (on) -import Data.List (foldl') -import Data.List.Extra (unsnoc) -import Data.Map (Map) -import Data.Map qualified as M -import Data.Maybe (fromJust) -import Data.Set (Set) -import Data.Set qualified as S -import Debug.Trace (trace) -import Grammar.Abs -import Grammar.Print (printTree) -import TypeChecker.TypeCheckerIr ( - Ctx (..), - Env (..), - Error, - Infer, - Subst, - ) -import TypeChecker.TypeCheckerIr qualified as T +import Auxiliary +import Control.Monad.Except +import Control.Monad.Identity (runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor (second) +import Data.Coerce (coerce) +import Data.Foldable (traverse_) +import Data.Function (on) +import Data.List (foldl') +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import qualified Data.Set as S +import Debug.Trace (trace) +import Grammar.Abs +import Grammar.Print (printTree) +import qualified TypeChecker.TypeCheckerIr as T +import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, + Subst) initCtx = Ctx mempty initEnv = Env 0 'a' mempty mempty mempty @@ -78,7 +73,7 @@ checkData d = do retType :: Type -> Type retType (TFun _ t2) = retType t2 -retType a = a +retType a = a checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do @@ -105,7 +100,7 @@ preRun (x : xs) = case x of s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs + Just _ -> preRun xs DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs checkDef :: [Def] -> Infer [T.Def] @@ -152,9 +147,9 @@ typeEq _ _ = False skolem :: T.Type -> T.Type skolem (T.TVar (T.MkTVar a)) = T.TLit a -skolem (T.TAll x t) = T.TAll x (skolem t) -skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 -skolem t = t +skolem (T.TAll x t) = T.TAll x (skolem t) +skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 +skolem t = t isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 @@ -169,8 +164,8 @@ isMoreSpecificOrEq a b = a == b isPoly :: Type -> Bool isPoly (TAll _ _) = True -isPoly (TVar _) = True -isPoly _ = False +isPoly (TVar _) = True +isPoly _ = False inferExp :: Exp -> Infer T.ExpT inferExp e = do @@ -183,7 +178,7 @@ class CollectTVars a where instance CollectTVars Exp where collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e - collectTypeVars _ = S.empty + collectTypeVars _ = S.empty instance CollectTVars Type where collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) @@ -200,15 +195,15 @@ class NewType a b where instance NewType Type T.Type where toNew = \case - TLit i -> T.TLit $ coerce i - TVar v -> T.TVar $ toNew v + TLit i -> T.TLit $ coerce i + TVar v -> T.TVar $ toNew v TFun t1 t2 -> (T.TFun `on` toNew) t1 t2 - TAll b t -> T.TAll (toNew b) (toNew t) + TAll b t -> T.TAll (toNew b) (toNew t) TData i ts -> T.TData (coerce i) (map toNew ts) - TEVar _ -> error "Should not exist after typechecker" + TEVar _ -> error "Should not exist after typechecker" instance NewType Lit T.Lit where - toNew (LInt i) = T.LInt i + toNew (LInt i) = T.LInt i toNew (LChar i) = T.LChar i instance NewType Data T.Data where @@ -422,12 +417,12 @@ generalize :: Map T.Ident T.Type -> T.Type -> T.Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where go :: [T.Ident] -> T.Type -> T.Type - go [] t = t + go [] t = t go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) removeForalls :: T.Type -> T.Type - removeForalls (T.TAll _ t) = removeForalls t + removeForalls (T.TAll _ t) = removeForalls t removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) - removeForalls t = t + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. @@ -477,10 +472,10 @@ instance SubstType T.Type where T.TLit a -> T.TLit a T.TVar (T.MkTVar a) -> case M.lookup a sub of Nothing -> T.TVar (T.MkTVar $ coerce a) - Just t -> t + Just t -> t T.TAll (T.MkTVar i) t -> case M.lookup i sub of Nothing -> T.TAll (T.MkTVar i) (apply sub t) - Just _ -> apply sub t + Just _ -> apply sub t T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TData name a -> T.TData name (map (apply sub) a) instance FreeVars (Map T.Ident T.Type) where @@ -513,10 +508,10 @@ instance SubstType T.Pattern where apply :: Subst -> T.Pattern -> T.Pattern apply s = \case T.PVar (iden, t) -> T.PVar (iden, apply s t) - T.PLit (lit, t) -> T.PLit (lit, apply s t) - T.PInj i ps -> T.PInj i $ apply s ps - T.PCatch -> T.PCatch - T.PEnum i -> T.PEnum i + T.PLit (lit, t) -> T.PLit (lit, apply s t) + T.PInj i ps -> T.PInj i $ apply s ps + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i instance SubstType a => SubstType [a] where apply s = map (apply s) @@ -555,7 +550,7 @@ fresh = do next :: Char -> Char next 'z' = 'a' -next a = succ a +next a = succ a -- | Run the monadic action with an additional binding withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a @@ -608,10 +603,10 @@ inferBranch (Branch pat expr) = do withPattern :: T.Pattern -> Infer a -> Infer a withPattern p ma = case p of T.PVar (x, t) -> withBinding x t ma - T.PInj _ ps -> foldl' (flip withPattern) ma ps - T.PLit _ -> ma - T.PCatch -> ma - T.PEnum _ -> ma + T.PInj _ ps -> foldl' (flip withPattern) ma ps + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma inferPattern :: Pattern -> Infer (T.Pattern, T.Type) inferPattern = \case @@ -659,14 +654,14 @@ inferPattern = \case flattenType :: T.Type -> [T.Type] flattenType (T.TFun a b) = flattenType a <> flattenType b -flattenType a = [a] +flattenType a = [a] typeLength :: T.Type -> Int typeLength (T.TFun a b) = typeLength a + typeLength b -typeLength _ = 1 +typeLength _ = 1 litType :: Lit -> T.Type -litType (LInt _) = int +litType (LInt _) = int litType (LChar _) = char int = T.TLit "Int" @@ -681,8 +676,8 @@ partitionType = go [] go acc 0 t = (acc, t) go acc i t = case t of TAll tvar t' -> second (TAll tvar) $ go acc i t' - TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" + TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" exprErr :: Infer a -> Exp -> Infer a exprErr ma exp = @@ -695,3 +690,4 @@ unzip4 = (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d]) ) ([], [], [], []) + diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 74dc649..d56c14c 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,245 +1,135 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} -module TypeChecker.TypeCheckerIr ( - module TypeChecker.TypeCheckerIr, -) where -import Control.Monad.Except -import Control.Monad.Reader -import Control.Monad.State -import Data.Char (isDigit) -import Data.Functor.Identity (Identity) -import Data.Map (Map) -import Data.Set (Set) -import Data.String qualified -import Grammar.Print -import Prelude -import Prelude qualified as C (Eq, Ord, Read, Show) +module TypeChecker.TypeCheckerIr + ( module Grammar.Abs + , module TypeChecker.TypeCheckerIr + ) where -newtype Ctx = Ctx {vars :: Map Ident Type} - deriving (Show) +import Data.String (IsString) +import Grammar.Abs (Character (..), Lit (..), TVar (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) -data Env = Env - { count :: Int - , nextChar :: Char - , sigs :: Map Ident (Maybe Type) - , constructors :: Map Ident Type - , takenTypeVars :: Set Ident - } - deriving (Show) +newtype Program' t = Program [Def' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) -type Error = String -type Subst = Map Ident Type - -type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) - -newtype Program = Program [Def] - deriving (C.Eq, C.Ord, C.Show, C.Read) - -data Data = Data Ident [Constructor] - deriving (Show, Eq, Ord, Read) - -data Constructor = Constructor Ident Type - deriving (Show, Eq, Ord, Read) - -newtype TVar = MkTVar Ident - deriving (Show, Eq, Ord, Read) +data Def' t = DBind (Bind' t) + | DData (Data' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) data Type = TLit Ident | TVar TVar + | TData Ident [Type] | TFun Type Type | TAll TVar Type - | TData Ident [Type] - deriving (Show, Eq, Ord, Read) + deriving (C.Eq, C.Ord, C.Show, C.Read) -data Exp - = EId Ident - | ELit Lit - | ELet Bind ExpT - | EApp ExpT ExpT - | EAdd ExpT ExpT - | EAbs Ident ExpT - | ECase ExpT [Branch] - deriving (C.Eq, C.Ord, C.Read, C.Show) +data Data' t = Data t [Inj' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) -type ExpT = (Exp, Type) - -data Branch = Branch (Pattern, Type) ExpT - deriving (C.Eq, C.Ord, C.Read, C.Show) - -data Pattern = PVar Id | PLit (Lit, Type) | PInj Ident [Pattern] | PCatch | PEnum Ident - deriving (C.Eq, C.Ord, C.Show, C.Read) - -data Def = DBind Bind | DData Data - deriving (C.Eq, C.Ord, C.Read, C.Show) - -type Id = (Ident, Type) +data Inj' t = Inj Ident t + deriving (C.Eq, C.Ord, C.Show, C.Read) newtype Ident = Ident String - deriving (C.Eq, C.Ord, C.Show, C.Read, Data.String.IsString) + deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) -data Lit = LInt Integer | LChar Char - deriving (Show, Eq, Ord, Read) +data Pattern' t + = PVar (Id' t) -- TODO should be Ident + | PLit (Lit, t) -- TODO should be Lit + | PCatch + | PEnum Ident + | PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t) + deriving (C.Eq, C.Ord, C.Show, C.Read) -data Bind = Bind Id [Id] ExpT +data Exp' t + = EVar Ident + | EInj Ident + | ELit Lit + | ELet (Bind' t) (ExpT' t) + | EApp (ExpT' t) (ExpT' t) + | EAdd (ExpT' t) (ExpT' t) + | EAbs Ident (ExpT' t) + | ECase (ExpT' t) [Branch' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id' t = (Ident, t) +type ExpT' t = (Exp' t, t) + +data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) deriving (C.Eq, C.Ord, C.Show, C.Read) +data Branch' t = Branch (Pattern' t, t) (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + instance Print Ident where - prt _ (Ident str) = doc . showString $ str + prt i (Ident s) = prt i s -instance Print [Def] where - prt _ [] = concatD [] - prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs] - -instance Print Data where - prt i = \case - Data type_ constructors -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 constructors, doc (showString "}")]) - -instance Print Constructor where - prt i = \case - Constructor uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) - -instance Print [Constructor] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x : xs) = concatD [prt 0 x, prt 0 xs] - -instance Print Def where - prt i (DBind bind) = prt i bind - prt i (DData d) = prt i d - -instance Print Program where +instance Print t => Print (Program' t) where prt i (Program sc) = prPrec i 0 $ prt 0 sc -instance Print Bind where - prt i (Bind (name, t) args rhs) = - prPrec i 0 $ - concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString "\n" - , prt 0 name - , prtIdPs 0 args - , doc $ showString "=" - , prt 0 rhs - ] +instance Print t => Print (Bind' t) where + prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] -instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs] +prtSig :: Print t => Id' t -> Doc +prtSig (name, t) = concatD [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] -prtIdPs :: Int -> [Id] -> Doc -prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) +instance Print t => Print (ExpT' t) where + prt i (e, t) = concatD [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] -prtId :: Int -> Id -> Doc -prtId i (name, t) = - prPrec i 0 $ - concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] +instance Print t => Print [Bind' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] -prtIdP :: Int -> Id -> Doc -prtIdP i (name, t) = - prPrec i 0 $ - concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] +prtIdPs :: Print t => Int -> [Id' t] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prt i) -instance Print Exp where - prt i = \case - EId n -> prPrec i 3 $ concatD [prt 0 n] - ELit lit -> prPrec i 3 $ concatD [prt 0 lit] - ELet bs e -> - prPrec i 3 $ - concatD - [ doc $ showString "let" - , prt 0 bs - , doc $ showString "in" - , prt 0 e - ] - EApp e1 e2 -> - prPrec i 2 $ - concatD - [ prt 2 e1 - , prt 3 e2 - ] - EAdd e1 e2 -> - prPrec i 1 $ - concatD - [ doc $ showString "@" - , prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - EAbs n e -> - prPrec i 0 $ - concatD - [ doc $ showString "λ" - , prt 0 n - , doc $ showString "." - , prt 0 e - ] - ECase exp injs -> - prPrec - i - 0 - ( concatD - [ doc (showString "case") - , prt 0 exp - , doc (showString "of") - , doc (showString "{") - , prt 0 injs - , doc (showString "}") - , doc (showString ":") - ] - ) +instance Print t => Print (Id' t) where + prt i (name, t) = concatD [ doc $ showString "(" + , prt i name + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] -instance Print ExpT where - prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"] - -instance Print Branch where - prt i = \case - Branch (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp]) - -instance Print Pattern where - prt i = \case - PVar lident -> prPrec i 0 (concatD [prtId 0 lident]) - PLit (lit, typ) -> prPrec i 0 (concatD [doc $ showString "(", prt 0 lit, doc $ showString ",", prt 0 typ, doc $ showString ")"]) - PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 0 patterns]) - PCatch -> prPrec i 0 (concatD [doc (showString "_")]) - PEnum p -> prt i p - -instance Print [Branch] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] - -instance Print TVar where - prt i (MkTVar id) = prt i id - -instance Print Type where - prt i = \case - TLit uident -> prPrec i 2 (concatD [prt 0 uident]) - TVar tvar@(MkTVar (Ident iden)) -> - if all isDigit iden - then prPrec i 2 (concatD [prt 0 $ TVar (MkTVar (Ident ("a" <> iden)))]) - else prPrec i 2 (concatD [prt 0 tvar]) - TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) - TData ident types -> prPrec i 1 (concatD [prt 0 ident, prt 0 types]) - TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) - -instance Print Lit where - prt i = \case - LInt n -> prPrec i 0 (concatD [prt 0 n]) - LChar c -> prPrec i 0 (concatD [prt 0 c]) +instance Print t => Print (Exp' t) where + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + EInj name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> prPrec i 2 $ concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> prPrec i 1 $ concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs v e -> prPrec i 0 $ concatD + [ doc $ showString "\\" diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs new file mode 100644 index 0000000..3a20ca6 --- /dev/null +++ b/tests/TestTypeCheckerBidir.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module TestTypeCheckerBidir (testTypeCheckerBidir) where + +import Test.Hspec + +import Control.Monad ((<=<)) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Par (myLexer, pProgram) +import Renamer.Renamer (rename) +import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar)) +import TypeChecker.TypeCheckerBidir (typecheck) +import qualified TypeChecker.TypeCheckerIr as T + + +testTypeCheckerBidir = describe "Bidirectional type checker test" $ do + tc_id + tc_double + tc_add_lam + tc_const + tc_simple_rank2 + tc_rank2 + tc_identity + tc_pair + tc_tree + tc_mono_case + tc_pol_case + +tc_id = specify "Basic identity function polymorphism" $ run + [ "id : forall a. a -> a;" + , "id x = x;" + , "main = id 4;" + ] `shouldSatisfy` ok + +tc_double = specify "Addition inference" $ run + ["double x = x + x;"] `shouldSatisfy` ok + + +tc_add_lam = specify "Addition lambda inference" $ run + ["four = (\\x. x + x) 2;"] `shouldSatisfy` ok + + +tc_const = specify "Basic polymorphism with multiple type variables" $ run + [ "const : forall a. forall b. a -> b -> a;" + , "const x y = x;" + , "main = const 'a' 65;" + ] `shouldSatisfy` ok + +tc_simple_rank2 = specify "Simple rank two polymorphism" $ run + [ "id : forall a. a -> a;" + , "id x = x;" + + , "f : forall a. a -> (forall b. b -> b) -> a;" + , "f x g = g x;" + + , "main = f 4 id;" + ] `shouldSatisfy` ok + +tc_rank2 = specify "Rank two polymorphism is ok" $ run + [ "const : forall a. forall b. a -> b -> a;" + , "const x y = x;" + + , "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int;" + , "rank2 x f y = f x + f y;" + + , "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h';" + ] `shouldSatisfy` ok + +tc_identity = describe "(∀b. b → b) should only accept the identity function" $ do + specify "identityᵢₙₜ is rejected" $ run (fs ++ id_int) `shouldNotSatisfy` ok + specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok + where + fs = + [ "f : forall a. a -> (forall b. b -> b) -> a;" + , "f x g = g x;" + + , "id : forall a. a -> a;" + , "id x = x;" + + , "id_int : Int -> Int;" + , "id_int x = x;" + ] + id = + [ "main : Int;" + , "main = f 4 id;" + ] + id_int = + [ "main : Int;" + , "main = f 4 id_int;" + ] + +tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do + specify "Wrong arguments are rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data forall a. forall b. Pair (a b) where {" + , " Pair : a -> b -> Pair (a b)" + , "};" + + , "main : Pair (Int Char);" + ] + wrong = ["main = Pair 'a' 65;"] + correct = ["main = Pair 65 'a';"] + +tc_tree = describe "Tree. Recursive data type" $ do + specify "Wrong tree is rejected" $ run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data forall a. Tree (a) where {" + , " Node : a -> Tree (a) -> Tree (a) -> Tree (a)" + , " Leaf : a -> Tree (a)" + , "};" + ] + wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3);"] + correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3);"] + +tc_mono_case = describe "Monomorphic pattern matching" $ do + specify "First wrong case expression rejected" + $ run wrong1 `shouldNotSatisfy` ok + specify "Second wrong case expression rejected" + $ run wrong2 `shouldNotSatisfy` ok + specify "Third wrong case expression rejected" + $ run wrong3 `shouldNotSatisfy` ok + specify "First correct case expression accepted" + $ run correct1 `shouldSatisfy` ok + specify "Second correct case expression accepted" + $ run correct2 `shouldSatisfy` ok + + where + wrong1 = + [ "simple : Int -> Int;" + , "simple c = case c of {" + , " 'F' => 0;" + , " 'T' => 1;" + , "};" + ] + wrong2 = + [ "simple : Char -> Int;" + , "simple c = case c of {" + , " 'F' => 0;" + , " 1 => 1;" + , "};" + ] + wrong3 = + [ "simple : Char -> Int;" + , "simple c = case c of {" + , " 'F' => 0;" + , " 'T' => '1';" + , "};" + ] + correct1 = + [ "simple : Char -> Int;" + , "simple c = case c of {" + , " 'F' => 0;" + , " 'T' => 1;" + , "};" + ] + correct2 = + [ "simple : Char -> Int;" + , "simple c = case c of {" + , " 'F' => 0;" + , " _ => 1;" + , "};" + ] + +tc_pol_case = describe "Polymophic pattern matching" $ do + specify "First wrong case expression rejected" + $ run (fs ++ wrong1) `shouldNotSatisfy` ok + specify "Second wrong case expression rejected" + $ run (fs ++ wrong2) `shouldNotSatisfy` ok + specify "Third wrong case expression rejected" + $ run (fs ++ wrong3) `shouldNotSatisfy` ok + specify "First correct case expression accepted" + $ run (fs ++ correct1) `shouldSatisfy` ok + specify "Second correct case expression accepted" + $ run (fs ++ correct2) `shouldSatisfy` ok + where + fs = + [ "data forall a. List (a) where {" + , " Nil : List (a)" + , " Cons : a -> List (a) -> List (a)" + , "};" + ] + wrong1 = + [ "length : forall c. List (c) -> Int;" + , "length = \\list. case list of {" + , " Nil => 0;" + , " Cons 6 xs => 1 + length xs;" + , "};" + ] + wrong2 = + [ "length : forall c. List (c) -> Int;" + , "length = \\list. case list of {" + , " Cons => 0;" + , " Cons x xs => 1 + length xs;" + , "};" + ] + wrong3 = + [ "length : forall c. List (c) -> Int;" + , "length = \\list. case list of {" + , " 0 => 0;" + , " Cons x xs => 1 + length xs;" + , "};" + ] + correct1 = + [ "length : forall c. List (c) -> Int;" + , "length = \\list. case list of {" + , " Nil => 0;" + , " Cons x xs => 1 + length xs;" + , " Cons x (Cons y Nil) => 2;" + , "};" + ] + correct2 = + [ "length : forall c. List (c) -> Int;" + , "length = \\list. case list of {" + , " Nil => 0;" + , " non_empty => 1;" + , "};" + ] + +run :: [String] -> Err T.Program +run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines + +ok = \case + Ok _ -> True + Bad _ -> False diff --git a/tests/TestTypeCheckerHm.hs b/tests/TestTypeCheckerHm.hs new file mode 100644 index 0000000..b666701 --- /dev/null +++ b/tests/TestTypeCheckerHm.hs @@ -0,0 +1,113 @@ +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE QualifiedDo #-} + +module TestTypeCheckerHm (testTypeCheckerHm) where + +import Control.Monad ((<=<)) +import qualified DoStrings as D +import Grammar.Par (myLexer, pProgram) +import Prelude (Bool (..), Either (..), IO, fmap, + not, ($), (.)) +import Test.Hspec + +-- import Test.QuickCheck +import TypeChecker.TypeCheckerHm (typecheck) + + + +testTypeCheckerHm = describe "Hillner Milner type checker test" $ do + ok1 + ok2 + bad1 + bad2 + -- bad3 + + +ok1 = + specify "Basic polymorphism with multiple type variables" $ + run + ( D.do + const + "main = const 'a' 65 ;" + ) + `shouldSatisfy` ok +ok2 = + specify "Head with a correct signature is accepted" $ + run + ( D.do + list + headSig + head + ) + `shouldSatisfy` ok + +bad1 = + specify "Infinite type unification should not succeed" $ + run + ( D.do + "main = \\x. x x ;" + ) + `shouldSatisfy` bad + +bad2 = + specify "Pattern matching using different types should not succeed" $ + run + ( D.do + list + "bad xs = case xs of {" + " 1 => 0 ;" + " Nil => 0 ;" + "};" + ) + `shouldSatisfy` bad + +bad3 = + specify "Using a concrete function on a skolem variable should not succeed" $ + run + ( D.do + bool + _not + "f : a -> Bool () ;" + " f x = not x ;" + ) + `shouldSatisfy` bad + +run = typecheck <=< pProgram . myLexer + +ok (Right _) = True +ok (Left _) = False + +bad = not . ok + +-- FUNCTIONS + +const = D.do + "const : a -> b -> a ;" + "const x y = x ;" +list = D.do + "data List (a) where" + " {" + " Nil : List (a)" + " Cons : a -> List (a) -> List (a)" + " };" + +headSig = D.do + "head : List (a) -> a ;" +head = D.do + "head xs = " + " case xs of {" + " Cons x xs => x ;" + " };" + +bool = D.do + "data Bool () where {" + " True : Bool ()" + " False : Bool ()" + "};" + +_not = D.do + "not : Bool () -> Bool () ;" + "not x = case x of {" + " True => False ;" + " False => True ;" + "};" diff --git a/tests/TestTypeChekerHm.hs/DoStrings.hs b/tests/TestTypeChekerHm.hs/DoStrings.hs index 9c1ec16..dabf5d6 100644 --- a/tests/TestTypeChekerHm.hs/DoStrings.hs +++ b/tests/TestTypeChekerHm.hs/DoStrings.hs @@ -1,6 +1,6 @@ module DoStrings where -import Prelude hiding ((>>), (>>=)) +import Prelude hiding ((>>), (>>=)) (>>) :: String -> String -> String (>>) str1 str2 = str1 ++ "\n" ++ str2 diff --git a/tests/Tests.hs b/tests/Tests.hs new file mode 100644 index 0000000..7bcb0af --- /dev/null +++ b/tests/Tests.hs @@ -0,0 +1,10 @@ + +module Main where + +import Test.Hspec +import TestTypeCheckerBidir (testTypeCheckerBidir) +import TestTypeCheckerHm (testTypeCheckerHm) + +main = hspec $ do + testTypeCheckerBidir + testTypeCheckerHm