From 2f45f39435f207bfb5eb3a922ac33e86792a548e Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Mon, 27 Feb 2023 11:12:05 +0100 Subject: [PATCH] Incorporated most of main, as well as started on quickcheck --- Grammar.cf | 44 +++- language.cabal | 68 +++---- src/Codegen/Codegen.hs | 277 +++++++++++++++++++++++++ src/Codegen/LlvmIr.hs | 204 +++++++++++++++++++ src/Compiler.hs | 0 src/Interpreter.hs | 78 -------- src/LambdaLifter/LambdaLifter.hs | 192 ++++++++++++++++++ src/Main.hs | 109 ++++++---- src/Renamer/Renamer.hs | 154 +++++++------- src/Renamer/RenamerIr.hs | 32 --- src/Renamer/RenamerM.hs | 83 -------- src/TypeChecker/AlgoW.hs | 238 ---------------------- src/TypeChecker/HM.hs | 181 ----------------- src/TypeChecker/HMIr.hs | 110 ---------- src/TypeChecker/TypeChecker.hs | 333 ++++++++++++++++++++----------- src/TypeChecker/TypeCheckerIr.hs | 157 +++++++++------ test_program | 5 +- tests/Main.hs | 21 -- tests/Tests.hs | 56 ++++++ 19 files changed, 1252 insertions(+), 1090 deletions(-) create mode 100644 src/Codegen/Codegen.hs create mode 100644 src/Codegen/LlvmIr.hs delete mode 100644 src/Compiler.hs delete mode 100644 src/Interpreter.hs create mode 100644 src/LambdaLifter/LambdaLifter.hs delete mode 100644 src/Renamer/RenamerIr.hs delete mode 100644 src/Renamer/RenamerM.hs delete mode 100644 src/TypeChecker/AlgoW.hs delete mode 100644 src/TypeChecker/HM.hs delete mode 100644 src/TypeChecker/HMIr.hs delete mode 100644 tests/Main.hs create mode 100644 tests/Tests.hs diff --git a/Grammar.cf b/Grammar.cf index 5406ac8..6870367 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,27 +1,54 @@ -Program. Program ::= [Bind] ; +Program. Program ::= [Def] ; +DBind. Def ::= Bind ; +DData. Def ::= Data ; +terminator Def ";" ; Bind. Bind ::= Ident ":" Type ";" Ident [Ident] "=" Exp ; EAnn. Exp5 ::= "(" Exp ":" Type ")" ; EId. Exp4 ::= Ident ; -EInt. Exp4 ::= Integer ; +ELit. Exp4 ::= Literal ; EApp. Exp3 ::= Exp3 Exp4 ; EAdd. Exp1 ::= Exp1 "+" Exp2 ; -ELet. Exp ::= "let" Ident "=" Exp "in" Exp ; +ELet. Exp ::= "let" Ident "=" Exp "in" Exp ; EAbs. Exp ::= "\\" Ident "." Exp ; +ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; -TMono. Type1 ::= "Mono" Ident ; -TPol. Type1 ::= "Poly" Ident ; +LInt. Literal ::= Integer ; + +Inj. Inj ::= Init "=>" Exp ; +terminator Inj ";" ; + +InitLit. Init ::= Literal ; +InitConstr. Init ::= Ident [Match] ; +InitCatch. Init ::= "_" ; + +LMatch. Match ::= Literal ; +IMatch. Match ::= Ident ; +InitMatch. Match ::= Ident Match ; +separator Match " " ; + +TMono. Type1 ::= "_" Ident ; +TPol. Type1 ::= "'" Ident ; TArr. Type ::= Type1 "->" Type ; +separator Type " " ; + +-- shift/reduce problem here +Data. Data ::= "data" Ident [Type] "where" ";" + [Constructor]; + +terminator Constructor ";" ; + +Constructor. Constructor ::= Ident ":" Type ; -- This doesn't seem to work so we'll have to live with ugly keywords for now --- token Upper (upper (letter | digit | '_')*) ; --- token Lower (lower (letter | digit | '_')*) ; +-- token Poly upper (letter | digit | '_')* ; +-- token Mono lower (letter | digit | '_')* ; -separator Bind ";" ; +terminator Bind ";" ; separator Ident " "; coercions Type 1 ; @@ -29,3 +56,4 @@ coercions Exp 5 ; comment "--" ; comment "{-" "-}" ; + diff --git a/language.cabal b/language.cabal index f803c1b..eb58aa0 100644 --- a/language.cabal +++ b/language.cabal @@ -16,7 +16,7 @@ extra-source-files: Grammar.cf common warnings - ghc-options: -Wdefault + ghc-options: -W executable language import: warnings @@ -31,15 +31,12 @@ executable language Grammar.Skel Grammar.ErrM Auxiliary - -- TypeChecker.TypeChecker - -- TypeChecker.TypeCheckerIr - -- TypeChecker.Unification - TypeChecker.HM - TypeChecker.AlgoW - TypeChecker.HMIr - Renamer.RenamerM - -- Renamer.Renamer - -- Renamer.RenamerIr + TypeChecker.TypeChecker + TypeChecker.TypeCheckerIr + Renamer.Renamer + LambdaLifter.LambdaLifter + Codegen.Codegen + Codegen.LlvmIr hs-source-dirs: src @@ -50,34 +47,35 @@ executable language , either , extra , array + , QuickCheck default-language: GHC2021 -test-suite test - hs-source-dirs: tests, src - main-is: Main.hs - type: exitcode-stdio-1.0 +Test-suite language-testsuite + type: exitcode-stdio-1.0 + main-is: Tests.hs - other-modules: - Grammar.Abs - Grammar.Lex - Grammar.Par - Grammar.Print - Grammar.Skel - Grammar.ErrM - Auxiliary - Renamer.RenamerM - TypeChecker.AlgoW - TypeChecker.HM - TypeChecker.HMIr + other-modules: + Grammar.Abs + Grammar.Lex + Grammar.Par + Grammar.Print + Grammar.Skel + Grammar.ErrM + Auxiliary + TypeChecker.TypeChecker + TypeChecker.TypeCheckerIr + Renamer.Renamer - build-depends: - base >=4.16 - , mtl - , containers - , either - , array - , extra - , hspec + hs-source-dirs: src, tests - default-language: GHC2021 + build-depends: + base >=4.16 + , mtl + , containers + , either + , extra + , array + , QuickCheck + + default-language: GHC2021 diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs new file mode 100644 index 0000000..76a1f02 --- /dev/null +++ b/src/Codegen/Codegen.hs @@ -0,0 +1,277 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Codegen.Codegen (compile) where + +import Auxiliary (snoc) +import Codegen.LlvmIr (LLVMIr (..), LLVMType (..), + LLVMValue (..), Visibility (..), + llvmIrToString) +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (dupe, first, second) +import Grammar.ErrM (Err) +import TypeChecker.TypeChecker +import TypeChecker.TypeCheckerIr + +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo + , variableCount :: Integer + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t { instructions = snoc l $ instructions t } + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState Integer +getNewVar = increaseVarCount >> getVarCount + +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. +getFunctions :: [Bind] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ map go bs + where + go (Bind id args _) = + (id, FunctionInfo { numArgs=length args, arguments=args }) + + + +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , variableCount = 0 + } + +-- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to lli +compile :: Program -> Err String +compile (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen + +compileScs :: [Bind] -> CompilerState () +compileScs [] = pure () +compileScs (Bind (name, t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define (type2LlvmType t_return) name args' + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs + where + t_return = snd $ partitionType (length args) t + +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (Ident "end") + -- , Label (Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (Ident "end") + -- , Label (Ident "end") + Ret I64 (VInteger 0) + ] + +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] + +compileExp :: Exp -> CompilerState () +compileExp = \case + ELit _ (LInt i) -> emitInt i + EAdd t e1 e2 -> emitAdd t e1 e2 + EId (name, _) -> emitIdent name + EApp t e1 e2 -> emitApp t e1 e2 + EAbs t ti e -> emitAbs t ti e + ELet bind e -> emitLet bind e + +--- aux functions --- +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e + +emitLet :: Bind -> Exp -> CompilerState () +emitLet b e = emit . Comment $ concat [ "ELet (" + , show b + , " = " + , show e + , ") is not implemented!" + ] + +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call (type2LlvmType t) visibility name args' + emit $ SetVariable (Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x + +emitIdent :: Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitInt :: Integer -> CompilerState () +emitInt i = do + -- !!this should never happen!! + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) + +-- emitMul :: Exp -> Exp -> CompilerState () +-- emitMul e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Mul I64 v1 v2 + +-- emitMod :: Exp -> Exp -> CompilerState () +-- emitMod e1 e2 = do +-- -- `let m a b = rem (abs $ b + a) b` +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- vadd <- gets variableCount +-- emit $ SetVariable $ Ident $ show vadd +-- emit $ Add I64 v1 v2 +-- +-- increaseVarCount +-- vabs <- gets variableCount +-- emit $ SetVariable $ Ident $ show vabs +-- emit $ Call I64 (Ident "llvm.abs.i64") +-- [ (I64, VIdent (Ident $ show vadd)) +-- , (I1, VInteger 1) +-- ] +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 + +-- emitDiv :: Exp -> Exp -> CompilerState () +-- emitDiv e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Div I64 v1 v2 + +-- emitSub :: Exp -> Exp -> CompilerState () +-- emitSub e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Sub I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + ELit _ (LInt i) -> pure $ VInteger i + + EId id@(name, t) -> do + funcs <- gets functions + case Map.lookup id funcs of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ SetVariable (Ident $ show vc) + (Call (type2LlvmType t) Global name []) + pure $ VIdent (Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (Ident $ show v) (getType e) + +type2LlvmType :: Type -> LLVMType +type2LlvmType = \case + (TMono "Int") -> I64 + TArr t xs -> do + let (t', xs') = function2LLVMType xs [type2LlvmType t] + Function t' xs' + t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"") + where + function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s) + function2LLVMType x s = (type2LlvmType x, s) + +getType :: Exp -> LLVMType +getType (ELit _ (LInt _)) = I64 +getType (EAdd t _ _) = type2LlvmType t +getType (EId (_, t)) = type2LlvmType t +getType (EApp t _ _) = type2LlvmType t +getType (EAbs t _ _) = type2LlvmType t +getType (ELet _ e) = getType e + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VFunction _ _ t) = t + +-- | Partion type into types of parameters and return type. +partitionType :: Int -- Number of parameters to apply + -> Type + -> ([Type], Type) +partitionType = go [] + where + go acc 0 t = (acc, t) + go acc i t = case t of + TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs new file mode 100644 index 0000000..aa6de54 --- /dev/null +++ b/src/Codegen/LlvmIr.hs @@ -0,0 +1,204 @@ +{-# LANGUAGE LambdaCase #-} + +module Codegen.LlvmIr ( + LLVMType (..), + LLVMIr (..), + llvmIrToString, + LLVMValue (..), + LLVMComp (..), + Visibility (..), +) where + +import Data.List (intercalate) +import TypeChecker.TypeCheckerIr + +-- | A datatype which represents some basic LLVM types +data LLVMType + = I1 + | I8 + | I32 + | I64 + | Ptr + | Ref LLVMType + | Function LLVMType [LLVMType] + | Array Int LLVMType + | CustomType Ident + +instance Show LLVMType where + show :: LLVMType -> String + show = \case + I1 -> "i1" + I8 -> "i8" + I32 -> "i32" + I64 -> "i64" + Ptr -> "ptr" + Ref ty -> show ty <> "*" + Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" + Array n ty -> concat ["[", show n, " x ", show ty, "]"] + CustomType (Ident ty) -> ty + +data LLVMComp + = LLEq + | LLNe + | LLUgt + | LLUge + | LLUlt + | LLUle + | LLSgt + | LLSge + | LLSlt + | LLSle +instance Show LLVMComp where + show :: LLVMComp -> String + show = \case + LLEq -> "eq" + LLNe -> "ne" + LLUgt -> "ugt" + LLUge -> "uge" + LLUlt -> "ult" + LLUle -> "ule" + LLSgt -> "sgt" + LLSge -> "sge" + LLSlt -> "slt" + LLSle -> "sle" + +data Visibility = Local | Global +instance Show Visibility where + show :: Visibility -> String + show Local = "%" + show Global = "@" + +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant +data LLVMValue + = VInteger Integer + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType + +instance Show LLVMValue where + show :: LLVMValue -> String + show v = case v of + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> show vis <> n + VConstant s -> "c" <> show s + +type Params = [(Ident, LLVMType)] +type Args = [(LLVMType, LLVMValue)] + +-- | A datatype which represents different instructions in LLVM +data LLVMIr + = Define LLVMType Ident Params + | DefineEnd + | Declare LLVMType Ident Params + | SetVariable Ident LLVMIr + | Variable Ident + | Add LLVMType LLVMValue LLVMValue + | Sub LLVMType LLVMValue LLVMValue + | Div LLVMType LLVMValue LLVMValue + | Mul LLVMType LLVMValue LLVMValue + | Srem LLVMType LLVMValue LLVMValue + | Icmp LLVMComp LLVMType LLVMValue LLVMValue + | Br Ident + | BrCond LLVMValue Ident Ident + | Label Ident + | Call LLVMType Visibility Ident Args + | Alloca LLVMType + | Store LLVMType Ident LLVMType Ident + | Bitcast LLVMType Ident LLVMType + | Ret LLVMType LLVMValue + | Comment String + | UnsafeRaw String -- This should generally be avoided, and proper + -- instructions should be used in its place + deriving (Show) + +-- | Converts a list of LLVMIr instructions to a string +llvmIrToString :: [LLVMIr] -> String +llvmIrToString = go 0 + where + go :: Int -> [LLVMIr] -> String + go _ [] = mempty + go i (x : xs) = do + let (i', n) = case x of + Define{} -> (i + 1, 0) + DefineEnd -> (i - 1, 0) + _ -> (i, i) + insToString n x <> go i' xs + +-- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +insToString :: Int -> LLVMIr -> String +insToString i l = + replicate i '\t' <> case l of + (Define t (Ident i) params) -> + concat + [ "define ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call t vis (Ident i) arg) -> + concat + [ "call ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 (Ident id1) t2 (Ident id2)) -> + concat + [ "store ", show t1, " %", id1 + , ", ", show t2 , " %", id2, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" + (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , "label_", s1, ", ", "label %", "label_", s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id diff --git a/src/Compiler.hs b/src/Compiler.hs deleted file mode 100644 index e69de29..0000000 diff --git a/src/Interpreter.hs b/src/Interpreter.hs deleted file mode 100644 index 378c95b..0000000 --- a/src/Interpreter.hs +++ /dev/null @@ -1,78 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -module Interpreter where - -import Control.Applicative (Applicative) -import Control.Monad.Except (Except, MonadError (throwError), - liftEither) -import Data.Either.Combinators (maybeToRight) -import Data.Map (Map) -import qualified Data.Map as Map -import Grammar.Abs -import Grammar.Print (printTree) - -interpret :: Program -> Except String Integer -interpret (Program e) = - eval mempty e >>= \case - VClosure {} -> throwError "main evaluated to a function" - VInt i -> pure i - - -data Val = VInt Integer - | VClosure Cxt Ident Exp - -type Cxt = Map Ident Val - -eval :: Cxt -> Exp -> Except String Val -eval cxt = \case - - - -- ------------ x ∈ γ - -- γ ⊢ x ⇓ γ(x) - - EId x -> - maybeToRightM - ("Unbound variable:" ++ printTree x) - $ Map.lookup x cxt - - -- --------- - -- γ ⊢ i ⇓ i - - EInt i -> pure $ VInt i - - -- γ ⊢ e ⇓ let δ in λx. f - -- γ ⊢ e₁ ⇓ v - -- δ,x=v ⊢ f ⇓ v₁ - -- ------------------------------ - -- γ ⊢ e e₁ ⇓ v₁ - - EApp e e1 -> - eval cxt e >>= \case - VInt _ -> throwError "Not a function" - VClosure delta x f -> do - v <- eval cxt e1 - eval (Map.insert x v delta) f - - -- - -- ----------------------------- - -- γ ⊢ λx. f ⇓ let γ in λx. f - - EAbs x e -> pure $ VClosure cxt x e - - - -- γ ⊢ e ⇓ v - -- γ ⊢ e₁ ⇓ v₁ - -- ------------------ - -- γ ⊢ e e₁ ⇓ v + v₁ - - EAdd e e1 -> do - v <- eval cxt e - v1 <- eval cxt e1 - case (v, v1) of - (VInt i, VInt i1) -> pure $ VInt (i + i1) - _ -> throwError "Can't add a function" - - - -maybeToRightM :: MonadError l m => l -> Maybe r -> m r -maybeToRightM err = liftEither . maybeToRight err - diff --git a/src/LambdaLifter/LambdaLifter.hs b/src/LambdaLifter/LambdaLifter.hs new file mode 100644 index 0000000..a617159 --- /dev/null +++ b/src/LambdaLifter/LambdaLifter.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + + +module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where + +import Auxiliary (snoc) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, + evalState) +import Data.Set (Set) +import qualified Data.Set as Set +import Prelude hiding (exp) +import Renamer.Renamer +import TypeChecker.TypeCheckerIr + + +-- | Lift lambdas and let expression into supercombinators. +-- Three phases: +-- @freeVars@ annotatss all the free variables. +-- @abstract@ converts lambdas into let expressions. +-- @collectScs@ moves every non-constant let expression to a top-level function. +lambdaLift :: Program -> Program +lambdaLift = collectScs . abstract . freeVars + + +-- | Annotate free variables +freeVars :: Program -> AnnProgram +freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) + | Bind n xs e <- ds + ] + +freeVarsExp :: Set Id -> Exp -> AnnExp +freeVarsExp localVars = \case + EId n | Set.member n localVars -> (Set.singleton n, AId n) + | otherwise -> (mempty, AId n) + + ELit _ (LInt i) -> (mempty, AInt i) + + EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 + + EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') + where + e' = freeVarsExp (Set.insert par localVars) e + + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') + where + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' + + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' + + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars + + +freeVarsOf :: AnnExp -> Set Id +freeVarsOf = fst + +-- AST annotated with free variables +type AnnProgram = [(Id, [Id], AnnExp)] + +type AnnExp = (Set Id, AnnExp') + +data ABind = ABind Id [Id] AnnExp deriving Show + +data AnnExp' = AId Id + | AInt Integer + | ALet ABind AnnExp + | AApp Type AnnExp AnnExp + | AAdd Type AnnExp AnnExp + | AAbs Type Id AnnExp + deriving Show +-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +-- Free variables are @v₁ v₂ .. vₙ@ are bound. +abstract :: AnnProgram -> Program +abstract prog = Program $ evalState (mapM go prog) 0 + where + go :: (Id, [Id], AnnExp) -> State Int Bind + go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + where + (rhs', parms1) = flattenLambdasAnn rhs + + +-- | Flatten nested lambdas and collect the parameters +-- @\x.\y.\z. ae → (ae, [x,y,z])@ +flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) +flattenLambdasAnn ae = go (ae, []) + where + go :: (AnnExp, [Id]) -> (AnnExp, [Id]) + go ((free, e), acc) = + case e of + AAbs _ par (free1, e1) -> + go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) + +abstractExp :: AnnExp -> State Int Exp +abstractExp (free, exp) = case exp of + AId n -> pure $ EId n + AInt i -> pure $ ELit (TMono "Int") (LInt i) + AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) + AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) + ALet b e -> liftA2 ELet (go b) (abstractExp e) + where + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' + + skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp + skipLambdas f (free, ae) = case ae of + AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 + _ -> f (free, ae) + + -- Lift lambda into let and bind free variables + AAbs t parm e -> do + i <- nextNumber + rhs <- abstractExp e + + let sc_name = Ident ("sc_" ++ show i) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) + + pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList + where + freeList = Set.toList free + parms = snoc parm freeList + + +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i + +-- | Collects supercombinators by lifting non-constant let expressions +collectScs :: Program -> Program +collectScs (Program scs) = Program $ concatMap collectFromRhs scs + where + collectFromRhs (Bind name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs + + +collectScsExp :: Exp -> ([Bind], Exp) +collectScsExp = \case + EId n -> ([], EId n) + ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i)) + + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 + + EAbs t par e -> (scs, EAbs t par e') + where + (scs, e') = collectScsExp e + + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ e_scs, ELet bind e') + else (bind : rhs_scs ++ e_scs, e') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (e_scs, e') = collectScsExp e + + +-- @\x.\y.\z. e → (e, [x,y,z])@ +flattenLambdas :: Exp -> (Exp, [Id]) +flattenLambdas = go . (, []) + where + go (e, acc) = case e of + EAbs _ par e1 -> go (e1, snoc par acc) + _ -> (e, acc) + diff --git a/src/Main.hs b/src/Main.hs index 58811fe..3a7bde4 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,42 +2,81 @@ module Main where -import Grammar.Par (myLexer, pProgram) --- import TypeChecker.TypeChecker (typecheck) +import Codegen.Codegen (compile) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) -import Grammar.Print (printTree) -import Renamer.RenamerM (rename) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) -import TypeChecker.AlgoW (typecheck) +import LambdaLifter.LambdaLifter (lambdaLift) +import Renamer.Renamer (rename) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import TypeChecker.TypeChecker (typecheck) main :: IO () -main = getArgs >>= \case +main = + getArgs >>= \case [] -> print "Required file path missing" - (x : _) -> do - file <- readFile x - case pProgram (myLexer file) of - Left err -> do - putStrLn "SYNTAX ERROR" - putStrLn err - exitFailure - Right prg -> do - putStrLn "" - putStrLn " ----- PARSER ----- " - putStrLn "" - putStrLn . printTree $ prg - case typecheck (rename prg) of - Left err -> do - putStrLn "TYPECHECK ERROR" - print err - exitFailure - Right prg -> do - putStrLn "" - putStrLn " ----- RAW ----- " - putStrLn "" - print prg - putStrLn "" - putStrLn " ----- TYPECHECKER ----- " - putStrLn "" - putStrLn $ printTree prg - exitSuccess + (s : _) -> main' s + +main' :: String -> IO () +main' s = do + file <- readFile s + + printToErr "-- Parse Tree -- " + parsed <- fromSyntaxErr . pProgram $ myLexer file + printToErr $ printTree parsed + + printToErr "\n-- Renamer --" + let renamed = rename parsed + printToErr $ printTree renamed + + printToErr "\n-- TypeChecker --" + typechecked <- fromTypeCheckerErr $ typecheck renamed + printToErr $ printTree typechecked + + printToErr "\n-- Lambda Lifter --" + let lifted = lambdaLift typechecked + printToErr $ printTree lifted + + printToErr "\n -- Printing compiler output to stdout --" + compiled <- fromCompilerErr $ compile lifted + putStrLn compiled + writeFile "llvm.ll" compiled + + exitSuccess + +printToErr :: String -> IO () +printToErr = hPutStrLn stderr + +fromCompilerErr :: Err a -> IO a +fromCompilerErr = + either + ( \err -> do + putStrLn "\nCOMPILER ERROR" + putStrLn err + exitFailure + ) + pure + +fromSyntaxErr :: Err a -> IO a +fromSyntaxErr = + either + ( \err -> do + putStrLn "\nSYNTAX ERROR" + putStrLn err + exitFailure + ) + pure + +fromTypeCheckerErr :: Err a -> IO a +fromTypeCheckerErr = + either + ( \err -> do + putStrLn "\nTYPECHECKER ERROR" + putStrLn err + exitFailure + ) + pure diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index c8b857e..1ea892c 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -1,101 +1,91 @@ -{-# LANGUAGE LambdaCase, OverloadedRecordDot, OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} -module Renamer.Renamer (rename) where +module Renamer.Renamer where -import Renamer.RenamerIr -import Control.Monad.State -import Control.Monad.Except -import Control.Monad.Reader -import Data.Functor.Identity (Identity, runIdentity) -import Data.Set (Set) -import qualified Data.Set as S -import Data.Map (Map) -import qualified Data.Map as M +import Auxiliary (mapAccumM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.List (foldl') +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) +import Grammar.Abs -import Renamer.RenamerIr -import qualified Grammar.Abs as Old -type Rename = StateT Ctx (ExceptT Error Identity) +-- | Rename all variables and local binds +rename :: Program -> Program +rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 + where + -- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + initNames = Map.fromList $ foldl' saveIfBind [] bs + saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc + saveIfBind acc _ = acc + renameSc :: Names -> Def -> Rn Def + renameSc old_names (DBind (Bind name t _ parms rhs)) = do + (new_names, parms') <- newNames old_names parms + rhs' <- snd <$> renameExp new_names rhs + pure . DBind $ Bind name t name parms' rhs' + renameSc _ def = pure def -data Ctx = Ctx { count :: Integer - , sig :: Set Ident - , env :: Map Ident Integer} +-- -run :: Rename a -> Either Error a -run = runIdentity . runExceptT . flip evalStateT initCtx +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn { runRn :: State Int a } + deriving (Functor, Applicative, Monad, MonadState Int) -initCtx :: Ctx -initCtx = Ctx { count = 0 - , sig = mempty - , env = mempty } +-- | Maps old to new name +type Names = Map Ident Ident -rename :: Old.Program -> Either Error RProgram -rename = run . renamePrg +renameLocalBind :: Names -> Bind -> Rn (Names, Bind) +renameLocalBind old_names (Bind name t _ parms rhs) = do + (new_names, name') <- newName old_names name + (new_names', parms') <- newNames new_names parms + (new_names'', rhs') <- renameExp new_names' rhs + pure (new_names'', Bind name' t name' parms' rhs') -renamePrg :: Old.Program -> Rename RProgram -renamePrg (Old.Program xs) = do - xs' <- mapM renameBind xs - return $ RProgram xs' +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) -renameBind :: Old.Bind -> Rename RBind -renameBind (Old.Bind n t i args e) = do - insertSig i - e' <- renameExp (makeLambda (reverse args) e) - return $ RBind i e' - where - makeLambda :: [Ident] -> Old.Exp -> Old.Exp - makeLambda [] e = e - makeLambda (x:xs) e = makeLambda xs (Old.EAbs x e) + ELit (LInt i1) -> pure (old_names, ELit (LInt i1)) -renameExp :: Old.Exp -> Rename RExp -renameExp = \case + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') - Old.EId i -> do - st <- get - case M.lookup i st.env of - Just n -> return $ RId i - Nothing -> case S.member i st.sig of - True -> return $ RId i - False -> throwError $ UnboundVar (show i) + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') - Old.EInt c -> return $ RInt c + ELet i e1 e2 -> do + (new_names, e1') <- renameExp old_names e1 + (new_names', e2') <- renameExp new_names e2 + pure (new_names', ELet i e1' e2') - Old.EAnn e t -> flip RAnn t <$> renameExp e + EAbs par e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' e') - Old.EApp e1 e2 -> RApp <$> renameExp e1 <*> renameExp e2 + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) - Old.EAdd e1 e2 -> RAdd <$> renameExp e1 <*> renameExp e2 + ECase _ _ -> error "ECase NOT IMPLEMENTED YET" - -- Convert let-expressions to lambdas - Old.ELet i e1 e2 -> renameExp (Old.EApp (Old.EAbs i e2) e1) +-- | Create a new name and add it to name environment. +newName :: Names -> Ident -> Rn (Names, Ident) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) - Old.EAbs i e -> do - n <- cnt - ctx <- get - insertEnv i n - re <- renameExp e - return $ RAbs n i re +-- | Create multiple names and add them to the name environment +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) +newNames = mapAccumM newName --- | Get current count and increase it by one -cnt :: Rename Integer -cnt = do - st <- get - put (Ctx { count = succ st.count - , sig = st.sig - , env = st.env }) - return st.count - -insertEnv :: Ident -> Integer -> Rename () -insertEnv i n = do - c <- get - put ( Ctx { env = M.insert i n c.env , sig = c.sig , count = c.count} ) - -insertSig :: Ident -> Rename () -insertSig i = do - c <- get - put ( Ctx { sig = S.insert i c.sig , env = c.env , count = c.count } ) - -data Error = UnboundVar String - -instance Show Error where - show (UnboundVar str) = "Unbound variable: " <> str +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ diff --git a/src/Renamer/RenamerIr.hs b/src/Renamer/RenamerIr.hs deleted file mode 100644 index 77e2f1f..0000000 --- a/src/Renamer/RenamerIr.hs +++ /dev/null @@ -1,32 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module Renamer.RenamerIr ( - RExp (..), - RBind (..), - RProgram (..), - Ident (..), - Type (..), -) where - -import Grammar.Abs ( - Bind (..), - Ident (..), - Program (..), - Type (..), - ) -import Grammar.Print - -data RProgram = RProgram [RBind] - deriving (Eq, Show, Read, Ord) - -data RBind = RBind Ident RExp - deriving (Eq, Show, Read, Ord) - -data RExp - = RAnn RExp Type - | RId Ident - | RInt Integer - | RApp RExp RExp - | RAdd RExp RExp - | RAbs Integer Ident RExp - deriving (Eq, Ord, Show, Read) diff --git a/src/Renamer/RenamerM.hs b/src/Renamer/RenamerM.hs deleted file mode 100644 index 5fb1fa2..0000000 --- a/src/Renamer/RenamerM.hs +++ /dev/null @@ -1,83 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module Renamer.RenamerM where - -import Auxiliary (mapAccumM) -import Control.Monad.State (MonadState, State, evalState, gets, - modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) -import Data.Tuple.Extra (dupe) -import Grammar.Abs - - --- | Rename all variables and local binds -rename :: Program -> Program -rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 - where - initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs - renameSc :: Names -> Bind -> Rn Bind - renameSc old_names (Bind name t _ parms rhs) = do - (new_names, parms') <- newNames old_names parms - rhs' <- snd <$> renameExp new_names rhs - pure $ Bind name t name parms' rhs' - - --- | Rename monad. State holds the number of renamed names. -newtype Rn a = Rn { runRn :: State Int a } - deriving (Functor, Applicative, Monad, MonadState Int) - --- | Maps old to new name -type Names = Map Ident Ident - -renameLocalBind :: Names -> Bind -> Rn (Names, Bind) -renameLocalBind old_names (Bind name t _ parms rhs) = do - (new_names, name') <- newName old_names name - (new_names', parms') <- newNames new_names parms - (new_names'', rhs') <- renameExp new_names' rhs - pure (new_names'', Bind name' t name' parms' rhs') - -renameExp :: Names -> Exp -> Rn (Names, Exp) -renameExp old_names = \case - EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) - - EInt i1 -> pure (old_names, EInt i1) - - EApp e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EApp e1' e2') - - EAdd e1 e2 -> do - (env1, e1') <- renameExp old_names e1 - (env2, e2') <- renameExp old_names e2 - pure (Map.union env1 env2, EAdd e1' e2') - - ELet i e1 e2 -> do - (new_names, e1') <- renameExp old_names e1 - (new_names', e2') <- renameExp new_names e2 - pure (new_names', ELet i e1' e2') - - EAbs par e -> do - (new_names, par') <- newName old_names par - (new_names', e') <- renameExp new_names e - pure (new_names', EAbs par' e') - - EAnn e t -> do - (new_names, e') <- renameExp old_names e - pure (new_names, EAnn e' t) - --- | Create a new name and add it to name environment. -newName :: Names -> Ident -> Rn (Names, Ident) -newName env old_name = do - new_name <- makeName old_name - pure (Map.insert old_name new_name env, new_name) - --- | Create multiple names and add them to the name environment -newNames :: Names -> [Ident] -> Rn (Names, [Ident]) -newNames = mapAccumM newName - --- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -makeName :: Ident -> Rn Ident -makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ diff --git a/src/TypeChecker/AlgoW.hs b/src/TypeChecker/AlgoW.hs deleted file mode 100644 index de931d1..0000000 --- a/src/TypeChecker/AlgoW.hs +++ /dev/null @@ -1,238 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} -{-# HLINT ignore "Use traverse_" #-} - -module TypeChecker.AlgoW where - -import Control.Monad.Except -import Control.Monad.Reader -import Control.Monad.State -import Data.Bifunctor (bimap, second) -import Data.Functor.Identity (Identity, runIdentity) -import Data.List (foldl', intersect) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (fromMaybe) -import Data.Set (Set) -import qualified Data.Set as S - -import Grammar.Abs -import Grammar.Print (Print, printTree) -import qualified TypeChecker.HMIr as T - --- | A data type representing type variables -data Poly = Forall [Ident] Type - deriving Show - -newtype Ctx = Ctx { vars :: Map Ident Poly } - -data Env = Env { count :: Int - , sigs :: Map Ident Type - } - -type Error = String -type Subst = Map Ident Type - -type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) - -initCtx = Ctx mempty -initEnv = Env 0 mempty - -runPretty :: Print a => Infer a -> Either Error String -runPretty = fmap printTree . run - -run :: Infer a -> Either Error a -run = runC initEnv initCtx - -runC :: Env -> Ctx -> Infer a -> Either Error a -runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e - -typecheck :: Program -> Either Error T.Program -typecheck = run . checkPrg - -checkPrg :: Program -> Infer T.Program -checkPrg (Program bs) = do - traverse (\(Bind n t _ _ _) -> insertSig n t) bs - bs' <- mapM checkBind bs - return $ T.Program bs' - -checkBind :: Bind -> Infer T.Bind -checkBind (Bind n t _ args e) = do - (t', e') <- inferExp $ makeLambda e (reverse args) - s <- unify t t' - let t'' = apply s t - return $ T.Bind (t'',n) [] e' - where - makeLambda :: Exp -> [Ident] -> Exp - makeLambda = foldl (flip EAbs) - -inferExp :: Exp -> Infer (Type, T.Exp) -inferExp e = do - (s, t, e') <- w e - let subbed = apply s t - return (subbed, replace subbed e') - -replace :: Type -> T.Exp -> T.Exp -replace t = \case - T.EInt t' e -> T.EInt t e - T.EId t' i -> T.EId t i - T.EAbs t' name e -> T.EAbs t name e - T.EApp t' e1 e2 -> T.EApp t e1 e2 - T.EAdd t' e1 e2 -> T.EAdd t e1 e2 - T.ELet t' name e1 e2 -> T.ELet t name e1 e2 - -w :: Exp -> Infer (Subst, Type, T.Exp) -w = \case - EAnn e t -> do - (s1, t', e') <- w e - applySt s1 $ do - s2 <- unify (apply s1 t) t' - return (s2 `compose` s1, t, e') - EInt n -> return (nullSubst, TMono "Int", T.EInt (TMono "Int") n) - EId i -> do - var <- asks vars - case M.lookup i var of - Nothing -> throwError $ "Unbound variable: " ++ show i - Just t -> inst t >>= \x -> return (nullSubst, x, T.EId x i) - EAbs name e -> do - fr <- fresh - withBinding name (Forall [] fr) $ do - (s1, t', e') <- w e - let newArr = TArr (apply s1 fr) t' - return (s1, newArr, T.EAbs newArr name e') - EAdd e0 e1 -> do - (s1, t0, e0') <- w e0 - applySt s1 $ do - (s2, t1, e1') <- w e1 - applySt s2 $ do - s3 <- unify (subst s2 t0) (TMono "Int") - s4 <- unify (subst s3 t1) (TMono "Int") - return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') - EApp e0 e1 -> do - fr <- fresh - (s1, t0, e0') <- w e0 - applySt s1 $ do - (s2, t1, e1') <- w e1 - applySt s2 $ do - s3 <- unify (subst s2 t0) (TArr t1 fr) - let t = apply s3 fr - return (s3 `compose` s2 `compose` s1, t, T.EApp t e0' e1') - ELet name e0 e1 -> do - (s1, t1, e0') <- w e0 - env <- asks vars - let t' = generalize (apply s1 env) t1 - withBinding name t' $ do - (s2, t2, e1') <- w e1 - return (s2 `compose` s1, t2, T.ELet t2 name e0' e1' ) - --- | Unify two types producing a new substitution (constraint) -unify :: Type -> Type -> Infer Subst -unify t0 t1 = case (t0, t1) of - (TArr a b, TArr c d) -> do - s1 <- unify a c - s2 <- unify (subst s1 b) (subst s1 c) - return $ s1 `compose` s2 - (TPol a, b) -> occurs a b - (a, TPol b) -> occurs b a - (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" - (a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] - --- | Check if a type is contained in another type. --- I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal -occurs :: Ident -> Type -> Infer Subst -occurs i (TPol a) = return nullSubst -occurs i t = if S.member i (free t) - then throwError "Occurs check failed" - else return $ M.singleton i t - --- | Generalize a type over all free variables in the substitution set -generalize :: Map Ident Poly -> Type -> Poly -generalize env t = Forall (S.toList $ free t S.\\ free env) t - --- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. -inst :: Poly -> Infer Type -inst (Forall xs t) = do - xs' <- mapM (const fresh) xs - let s = M.fromList $ zip xs xs' - return $ apply s t - -compose :: Subst -> Subst -> Subst -compose m1 m2 = M.map (subst m1) m2 `M.union` m1 - --- | A class representing free variables functions -class FreeVars t where - -- | Get all free variables from t - free :: t -> Set Ident - -- | Apply a substitution to t - apply :: Subst -> t -> t - -instance FreeVars Type where - free :: Type -> Set Ident - free (TPol a) = S.singleton a - free (TMono _) = mempty - free (TArr a b) = free a `S.union` free b - apply :: Subst -> Type -> Type - apply sub t = do - case t of - TMono a -> TMono a - TPol a -> case M.lookup a sub of - Nothing -> TPol a - Just t -> t - TArr a b -> TArr (apply sub a) (apply sub b) - -instance FreeVars Poly where - free :: Poly -> Set Ident - free (Forall xs t) = free t S.\\ S.fromList xs - apply :: Subst -> Poly -> Poly - apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) - -instance FreeVars (Map Ident Poly) where - free :: Map Ident Poly -> Set Ident - free m = foldl' S.union S.empty (map free $ M.elems m) - apply :: Subst -> Map Ident Poly -> Map Ident Poly - apply s = M.map (apply s) - -applySt :: Subst -> Infer a -> Infer a -applySt s = local (\st -> st { vars = apply s (vars st) }) - --- | Represents the empty substition set -nullSubst :: Subst -nullSubst = M.empty - --- | Substitute type variables with their mappings from the substitution set. -subst :: Subst -> Type -> Type -subst m t = do - case t of - TPol a -> fromMaybe t (M.lookup a m) - TMono a -> TMono a - TArr a b -> TArr (subst m a) (subst m b) - --- | Generate a new fresh variable and increment the state counter -fresh :: Infer Type -fresh = do - n <- gets count - modify (\st -> st { count = n + 1 }) - return . TPol . Ident $ "t" ++ show n - --- | Run the monadic action with an additional binding -withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a -withBinding i p = local (\st -> st { vars = M.insert i p (vars st) }) - --- | Insert a function signature into the environment -insertSig :: Ident -> Type -> Infer () -insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) }) - --- | Lookup a variable in the context -lookupVar :: Ident -> Infer Poly -lookupVar i = do - m <- asks vars - case M.lookup i m of - Just t -> return t - Nothing -> throwError $ "Unbound variable: " ++ show i - -lett = let (Right (t,e)) = run $ inferExp $ ELet "x" (EAdd (EInt 5) (EInt 5)) (EAdd (EId "x") (EId "x")) - in t == TMono "Int" - -letty = let (Right (t,e)) = run $ inferExp $ ELet "f" (EAbs "x" (EId "x")) (EApp (EId "f") (EInt 3)) - in e diff --git a/src/TypeChecker/HM.hs b/src/TypeChecker/HM.hs deleted file mode 100644 index 7b33cbe..0000000 --- a/src/TypeChecker/HM.hs +++ /dev/null @@ -1,181 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} -{-# HLINT ignore "Use traverse_" #-} -{-# LANGUAGE FlexibleInstances #-} - -module TypeChecker.HM where - -import Control.Monad.Except -import Control.Monad.State -import Data.Bifunctor (bimap, second) -import Data.Functor.Identity (Identity, runIdentity) -import Data.Map (Map) -import qualified Data.Map as M - -import Grammar.Abs -import Grammar.Print -import qualified TypeChecker.HMIr as T - -type Infer = StateT Ctx (ExceptT String Identity) -type Error = String - -data Ctx = Ctx { constr :: Map Type Type - , vars :: Map Ident Type - , sigs :: Map Ident Type - , frsh :: Char } - deriving Show - -runC :: Ctx -> Infer a -> Either String (a, Ctx) -runC c = runIdentity . runExceptT . flip runStateT c - -run :: Infer a -> Either String a -run = runIdentity . runExceptT . flip evalStateT initC - -initC :: Ctx -initC = Ctx M.empty M.empty M.empty 'a' - -typecheck :: Program -> Either Error T.Program -typecheck = run . inferPrg - -inferPrg :: Program -> Infer T.Program -inferPrg (Program bs) = do - traverse (\(Bind n t _ _ _) -> insertSig n t) bs - bs' <- mapM inferBind bs - return $ T.Program bs' - -inferBind :: Bind -> Infer T.Bind -inferBind (Bind i t _ params rhs) = do - (t',e') <- inferExp (makeLambda rhs (reverse params)) - when (t /= t') (throwError . unwords $ [ "Signature of function" - , show i - , "with type:" - , show t - , "does not match inferred type" - , show t' - , "of expression:" - , show e']) - return $ T.Bind (t,i) [] e' - -makeLambda :: Exp -> [Ident] -> Exp -makeLambda = foldl (flip EAbs) - -inferExp :: Exp -> Infer (Type, T.Exp) -inferExp e = do - (t, e') <- inferExp' e - t'' <- solveConstraints t - return (t'', replaceType t'' e') - - where - inferExp' :: Exp -> Infer (Type, T.Exp) - inferExp' = \case - EAnn e t -> do - (t',e') <- inferExp' e - t'' <- solveConstraints t' - when (t'' /= t) (throwError "Annotated type and inferred type don't match") - return (t', e') - EInt i -> return (int, T.EInt int i) - EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i - EAdd e1 e2 -> do - insertSig "+" (TArr int (TArr int int)) - inferExp' (EApp (EApp (EId "+") e1) e2) - EApp e1 e2 -> do - (t1, e1') <- inferExp' e1 - (t2, e2') <- inferExp' e2 - fr <- fresh - addConstraint t1 (TArr t2 fr) - return (fr, T.EApp fr e1' e2') - EAbs name e -> do - fr <- fresh - insertVar name fr - (ret_t,e') <- inferExp' e - t <- solveConstraints (TArr fr ret_t) - return (t, T.EAbs t name e') - ELet name e1 e2 -> error "Let expression not implemented yet" - -replaceType :: Type -> T.Exp -> T.Exp -replaceType t = \case - T.EInt _ i -> T.EInt t i - T.EId _ i -> T.EId t i - T.EAdd _ e1 e2 -> T.EAdd t e1 e2 - T.EApp _ e1 e2 -> T.EApp t e1 e2 - T.EAbs _ name e -> T.EAbs t name e - T.ELet _ name e1 e2 -> T.ELet t name e1 e2 - -isInt :: Type -> Bool -isInt (TMono "Int") = True -isInt _ = False - -lookupVar :: Ident -> Infer Type -lookupVar i = do - st <- get - case M.lookup i (vars st) of - Just t -> return t - Nothing -> case M.lookup i (sigs st) of - Just t -> return t - Nothing -> throwError $ "Unbound variable or function" ++ printTree i - -insertVar :: Ident -> Type -> Infer () -insertVar s t = modify ( \st -> st { vars = M.insert s t (vars st) } ) - -insertSig :: Ident -> Type -> Infer () -insertSig s t = modify ( \st -> st { sigs = M.insert s t (sigs st) } ) - --- | Generate a new fresh variable and increment the state -fresh :: Infer Type -fresh = do - chr <- gets frsh - modify (\st -> st { frsh = succ chr }) - return $ TPol (Ident [chr]) - --- | Adds a constraint to the constraint set. --- i.e: a = int -> b --- b = int --- thus when solving constraints it must be the case that --- a = int -> int -addConstraint :: Type -> Type -> Infer () -addConstraint t1 t2 = do - modify (\st -> st { constr = M.insert t1 t2 (constr st) }) - --- | Given a type, solve the constraints and figure out the type that should be assigned to it. -solveConstraints :: Type -> Infer Type -solveConstraints t = do - c <- gets constr - v <- gets vars - xs <- solveAll (M.toList c) - modify (\st -> st { constr = M.fromList xs }) - return $ subst t xs - --- | Substitute -subst :: Type -> [(Type, Type)] -> Type -subst t [] = t -subst (TArr t1 t2) (x:xs) = subst (TArr (replace x t1) (replace x t2)) xs -subst t (x:xs) = subst (replace x t) xs - --- | Given a set of constraints run the replacement on all of them, producing a new set of --- replacements. --- https://youtu.be/trmq3wYcUxU - good video for explanation -solveAll :: [(Type, Type)] -> Infer [(Type, Type)] -solveAll [] = return [] -solveAll (x:xs) = case x of - (TArr t1 t2, TArr t3 t4) -> solveAll $ (t1,t3) : (t2,t4) : xs - (TArr t1 t2, b) -> fmap ((b, TArr t1 t2) :) $ solveAll $ solve (b, TArr t1 t2) xs - (a, TArr t1 t2) -> fmap ((a, TArr t1 t2) :) $ solveAll $ solve (a, TArr t1 t2) xs - (TMono a, TPol b) -> fmap ((TPol b, TMono a) :) $ solveAll $ solve (TPol b, TMono a) xs - (TPol a, TMono b) -> fmap ((TPol a, TMono b) :) $ solveAll $ solve (TPol a, TMono b) xs - (TPol a, TPol b) -> fmap ((TPol a, TPol b) :) $ solveAll $ solve (TPol a, TPol b) xs - (TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types" - -solve :: (Type, Type) -> [(Type, Type)] -> [(Type, Type)] -solve x = map (both (replace x)) - --- | Given a constraint (type, type) and a type, if the constraint matches the input --- replace with the constrained type -replace :: (Type, Type) -> Type -> Type -replace a (TArr t1 t2) = TArr (replace a t1) (replace a t2) -replace (a,b) c = if a==c then b else c - -both :: (a -> b) -> (a,a) -> (b,b) -both f = bimap f f - -int = TMono "Int" diff --git a/src/TypeChecker/HMIr.hs b/src/TypeChecker/HMIr.hs deleted file mode 100644 index 0a6085c..0000000 --- a/src/TypeChecker/HMIr.hs +++ /dev/null @@ -1,110 +0,0 @@ -{-# LANGUAGE LambdaCase #-} - -module TypeChecker.HMIr - ( module Grammar.Abs - , module TypeChecker.HMIr - ) where - -import Grammar.Abs (Ident (..), Type (..)) -import Grammar.Print -import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) - -newtype Program = Program [Bind] - deriving (C.Eq, C.Ord, C.Show, C.Read) - -data Exp - = EId Type Ident - | EInt Type Integer - | ELet Type Ident Exp Exp - | EApp Type Exp Exp - | EAdd Type Exp Exp - | EAbs Type Ident Exp - deriving (C.Eq, C.Ord, C.Read) - -instance Show Exp where - show (EId t (Ident i)) = i ++ " : " ++ show t - show (EInt _ i) = show i - show (ELet t i e1 e2) = "let " ++ show t ++ " = " ++ show e1 ++ " in " ++ show e2 - show (EApp t e1 e2) = show e1 ++ " " ++ show e2 ++ " : " ++ show t - show (EAdd _ e1 e2) = show e1 ++ " + " ++ show e2 - show (EAbs t (Ident i) e) = "\\ " ++ i ++ ". " ++ show e ++ " : " ++ show t - -type Id = (Type, Ident) - -data Bind = Bind Id [Id] Exp - deriving (C.Eq, C.Ord, C.Show, C.Read) - -instance Print Program where - prt i (Program sc) = prPrec i 0 $ prt 0 sc - -instance Print Bind where - prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD - [ prtId 0 name - , doc $ showString ";" - , prt 0 n - , prtIdPs 0 parms - , doc $ showString "=" - , prt 0 rhs - ] - -instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] - prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] - -prtIdPs :: Int -> [Id] -> Doc -prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) - -prtId :: Int -> Id -> Doc -prtId i (name, t) = prPrec i 0 $ concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - ] - -prtIdP :: Int -> Id -> Doc -prtIdP i (name, t) = prPrec i 0 $ concatD - [ doc $ showString "(" - , prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" - ] - - -instance Print Exp where - prt i = \case - EId _ n -> prPrec i 3 $ concatD [prt 0 n] - EInt _ i1 -> prPrec i 3 $ concatD [prt 0 i1] - ELet _ name e1 e2 -> prPrec i 3 $ concatD - [ doc $ showString "let" - , prt 0 name - , prt 0 e1 - , doc $ showString "in" - , prt 0 e2 - ] - EApp t e1 e2 -> prPrec i 2 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 2 e1 - , prt 3 e2 - ] - EAdd t e1 e2 -> prPrec i 1 $ concatD - [ doc $ showString "@" - , prt 0 t - , prt 1 e1 - , doc $ showString "+" - , prt 2 e2 - ] - EAbs t n e -> prPrec i 0 $ concatD - [ doc $ showString "@" - , prt 0 t - , doc $ showString "\\" - , prt 0 n - , doc $ showString "." - , prt 0 e - ] - - - diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 99a1e17..0d9ace9 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -1,153 +1,250 @@ --- {-# LANGUAGE LambdaCase #-} --- {-# LANGUAGE OverloadedRecordDot #-} --- {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# HLINT ignore "Use traverse_" #-} +{-# OPTIONS_GHC -Wno-overlapping-patterns #-} module TypeChecker.TypeChecker where --- import Control.Monad (void) --- import Control.Monad.Except (ExceptT, runExceptT, throwError) --- import Control.Monad.State (StateT) --- import qualified Control.Monad.State as St --- import Data.Functor.Identity (Identity, runIdentity) --- import Data.Map (Map) --- import qualified Data.Map as M +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Functor.Identity (Identity, runIdentity) +import Data.List (foldl') +import Data.Map (Map) +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S --- import TypeChecker.TypeCheckerIr +import Grammar.Abs +import Grammar.Print (printTree) +import qualified TypeChecker.TypeCheckerIr as T --- data Ctx = Ctx --- { vars :: Map Integer Type --- , sigs :: Map Ident Type --- , nextFresh :: Int --- } --- deriving (Show) +-- | A data type representing type variables +data Poly = Forall [Ident] Type + deriving Show --- -- Perhaps swap over to reader monad instead for vars and sigs. --- type Infer = StateT Ctx (ExceptT Error Identity) +newtype Ctx = Ctx { vars :: Map Ident Poly } --- {- +data Env = Env { count :: Int + , sigs :: Map Ident Type + } --- The type checker will assume we first rename all variables to unique name, as to not --- have to care about scoping. It significantly improves the quality of life of the --- programmer. +type Error = String +type Subst = Map Ident Type --- TODOs: --- Add skolemization variables. i.e --- { \x. 3 : forall a. a -> a } --- should not type check +type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) --- Generalize. Not really sure what that means though +initCtx = Ctx mempty +initEnv = Env 0 mempty --- -} +runPretty :: Exp -> Either Error String +runPretty = fmap (printTree . fst). run . inferExp --- typecheck :: RProgram -> Either Error TProgram --- typecheck = todo +run :: Infer a -> Either Error a +run = runC initEnv initCtx --- run :: Infer a -> Either Error a --- run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0) +runC :: Env -> Ctx -> Infer a -> Either Error a +runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e --- -- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary --- -- { \x. \y. x + y } will have the type { a -> b -> Int } --- inferExp :: RExp -> Infer Type --- inferExp = \case +typecheck :: Program -> Either Error T.Program +typecheck = run . checkPrg --- RAnn expr typ -> do --- t <- inferExp expr --- void $ t =:= typ --- return t +checkPrg :: Program -> Infer T.Program +checkPrg (Program bs) = do + let bs' = getBinds bs + traverse (\(Bind n t _ _ _) -> insertSig n t) bs' + bs' <- mapM checkBind bs' + return $ T.Program bs' + where + getBinds :: [Def] -> [Bind] + getBinds = map toBind . filter isBind + isBind :: Def -> Bool + isBind (DBind _) = True + isBind _ = True + toBind :: Def -> Bind + toBind (DBind bind) = bind + toBind _ = error "Can't convert DData to Bind" --- RBound num name -> lookupVars num +checkBind :: Bind -> Infer T.Bind +checkBind (Bind n t _ args e) = do + (t', e') <- inferExp $ makeLambda e (reverse args) + s <- unify t t' + let t'' = apply s t + unless (t `typeEq` t'') (throwError $ unwords ["Top level signature", printTree t, "does not match body with type:", printTree t'']) + return $ T.Bind (n, t) [] e' + where + makeLambda :: Exp -> [Ident] -> Exp + makeLambda = foldl (flip EAbs) --- RFree name -> lookupSigs name +typeEq :: Type -> Type -> Bool +typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' +typeEq (TMono a) (TMono b) = a == b +typeEq (TPol _) (TPol _) = True +typeEq _ _ = False --- RConst (CInt i) -> return $ TMono "Int" +inferExp :: Exp -> Infer (Type, T.Exp) +inferExp e = do + (s, t, e') <- w e + let subbed = apply s t + return (subbed, replace subbed e') --- RConst (CStr str) -> return $ TMono "Str" +replace :: Type -> T.Exp -> T.Exp +replace t = \case + T.ELit _ e -> T.ELit t e + T.EId (n, _) -> T.EId (n, t) + T.EAbs _ name e -> T.EAbs t name e + T.EApp _ e1 e2 -> T.EApp t e1 e2 + T.EAdd _ e1 e2 -> T.EAdd t e1 e2 + T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2 --- RAdd expr1 expr2 -> do --- let int = TMono "Int" --- typ1 <- check expr1 int --- typ2 <- check expr2 int --- return int +w :: Exp -> Infer (Subst, Type, T.Exp) +w = \case --- RApp expr1 expr2 -> do --- fn_t <- inferExp expr1 --- arg_t <- inferExp expr2 --- res <- fresh --- new_t <- fn_t =:= TArrow arg_t res --- return res + EAnn e t -> do + (s1, t', e') <- w e + applySt s1 $ do + s2 <- unify (apply s1 t) t' + return (s2 `compose` s1, t, e') --- RAbs num name expr -> do --- arg <- fresh --- insertVars num arg --- typ <- inferExp expr --- return $ TArrow arg typ + ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) --- check :: RExp -> Type -> Infer () --- check e t = do --- t' <- inferExp e --- t =:= t' --- return () + ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a --- fresh :: Infer Type --- fresh = do --- var <- St.gets nextFresh --- St.modify (\st -> st {nextFresh = succ var}) --- return (TPoly $ Ident (show var)) + EId i -> do + var <- asks vars + case M.lookup i var of + Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x)) + Nothing -> do + sig <- gets sigs + case M.lookup i sig of + Nothing -> throwError $ "Unbound variable: " ++ show i + Just t -> return (nullSubst, t, T.EId (i, t)) --- -- | Unify two types. --- (=:=) :: Type -> Type -> Infer Type --- (=:=) (TPoly _) b = return b --- (=:=) a (TPoly _) = return a --- (=:=) (TMono a) (TMono b) | a == b = return (TMono a) --- (=:=) (TArrow a b) (TArrow c d) = do --- t1 <- a =:= c --- t2 <- b =:= d --- return $ TArrow t1 t2 --- (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b]) + EAbs name e -> do + fr <- fresh + withBinding name (Forall [] fr) $ do + (s1, t', e') <- w e + let varType = apply s1 fr + let newArr = TArr varType t' + return (s1, newArr, T.EAbs newArr (name, varType) e') --- lookupVars :: Integer -> Infer Type --- lookupVars i = do --- st <- St.gets vars --- case M.lookup i st of --- Just t -> return t --- Nothing -> throwError $ UnboundVar "lookupVars" + EAdd e0 e1 -> do + (s1, t0, e0') <- w e0 + applySt s1 $ do + (s2, t1, e1') <- w e1 + applySt s2 $ do + s3 <- unify (apply s2 t0) (TMono "Int") + s4 <- unify (apply s3 t1) (TMono "Int") + return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') --- insertVars :: Integer -> Type -> Infer () --- insertVars i t = do --- st <- St.get --- St.put (st {vars = M.insert i t st.vars}) + EApp e0 e1 -> do + fr <- fresh + (s0, t0, e0') <- w e0 + applySt s0 $ do + (s1, t1, e1') <- w e1 + -- applySt s1 $ do + s2 <- unify (apply s1 t0) (TArr t1 fr) + let t = apply s2 fr + return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') --- lookupSigs :: Ident -> Infer Type --- lookupSigs i = do --- st <- St.gets sigs --- case M.lookup i st of --- Just t -> return t --- Nothing -> throwError $ UnboundVar "lookupSigs" + ELet name e0 e1 -> do + (s1, t1, e0') <- w e0 + env <- asks vars + let t' = generalize (apply s1 env) t1 + withBinding name t' $ do + (s2, t2, e1') <- w e1 + return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' ) --- insertSigs :: Ident -> Type -> Infer () --- insertSigs i t = do --- st <- St.get --- St.put (st {sigs = M.insert i t st.sigs}) + ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b --- {-# WARNING todo "TODO IN CODE" #-} --- todo :: a --- todo = error "TODO in code" +-- | Unify two types producing a new substitution (constraint) +unify :: Type -> Type -> Infer Subst +unify t0 t1 = case (t0, t1) of + (TArr a b, TArr c d) -> do + s1 <- unify a c + s2 <- unify (apply s1 b) (apply s1 d) + return $ s1 `compose` s2 + (TPol a, b) -> occurs a b + (a, TPol b) -> occurs b a + (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" + (a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] --- data Error --- = TypeMismatch String --- | NotNumber String --- | FunctionTypeMismatch String --- | NotFunction String --- | UnboundVar String --- | AnnotatedMismatch String --- | Default String --- deriving (Show) +-- | Check if a type is contained in another type. +-- I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal +occurs :: Ident -> Type -> Infer Subst +occurs _ (TPol _) = return nullSubst +occurs i t = if S.member i (free t) + then throwError "Occurs check failed" + else return $ M.singleton i t +-- | Generalize a type over all free variables in the substitution set +generalize :: Map Ident Poly -> Type -> Poly +generalize env t = Forall (S.toList $ free t S.\\ free env) t --- {- +-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. +inst :: Poly -> Infer Type +inst (Forall xs t) = do + xs' <- mapM (const fresh) xs + let s = M.fromList $ zip xs xs' + return $ apply s t --- The procedure inst(σ) specializes the polytype --- σ by copying the term and replacing the bound type variables --- consistently by new monotype variables. +-- | Compose two substitution sets +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 --- -} +-- | A class representing free variables functions +class FreeVars t where + -- | Get all free variables from t + free :: t -> Set Ident + -- | Apply a substitution to t + apply :: Subst -> t -> t + +instance FreeVars Type where + free :: Type -> Set Ident + free (TPol a) = S.singleton a + free (TMono _) = mempty + free (TArr a b) = free a `S.union` free b + apply :: Subst -> Type -> Type + apply sub t = do + case t of + TMono a -> TMono a + TPol a -> case M.lookup a sub of + Nothing -> TPol a + Just t -> t + TArr a b -> TArr (apply sub a) (apply sub b) + +instance FreeVars Poly where + free :: Poly -> Set Ident + free (Forall xs t) = free t S.\\ S.fromList xs + apply :: Subst -> Poly -> Poly + apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) + +instance FreeVars (Map Ident Poly) where + free :: Map Ident Poly -> Set Ident + free m = foldl' S.union S.empty (map free $ M.elems m) + apply :: Subst -> Map Ident Poly -> Map Ident Poly + apply s = M.map (apply s) + +-- | Apply substitutions to the environment. +applySt :: Subst -> Infer a -> Infer a +applySt s = local (\st -> st { vars = apply s (vars st) }) + +-- | Represents the empty substition set +nullSubst :: Subst +nullSubst = M.empty + +-- | Generate a new fresh variable and increment the state counter +fresh :: Infer Type +fresh = do + n <- gets count + modify (\st -> st { count = n + 1 }) + return . TPol . Ident $ "t" ++ show n + +-- | Run the monadic action with an additional binding +withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a +withBinding i p = local (\st -> st { vars = M.insert i p (vars st) }) + +-- | Insert a function signature into the environment +insertSig :: Ident -> Type -> Infer () +insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) }) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index c08d981..c85ebcc 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,74 +1,99 @@ --- {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} -module TypeChecker.TypeCheckerIr --( --- TProgram (..), --- TBind (..), --- TExp (..), --- RProgram (..), --- RBind (..), --- RExp (..), --- Type (..), --- Const (..), --- Ident (..), --- ) where +module TypeChecker.TypeCheckerIr + ( module Grammar.Abs + , module TypeChecker.TypeCheckerIr + ) where --- import Grammar.Print --- import Renamer.RenamerIr +import Grammar.Abs (Ident (..), Literal (..), Type (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) --- newtype TProgram = TProgram [TBind] --- deriving (Eq, Show, Read, Ord) +newtype Program = Program [Bind] + deriving (C.Eq, C.Ord, C.Show, C.Read) --- data TBind = TBind Ident Type TExp --- deriving (Eq, Show, Read, Ord) +data Exp + = EId Id + | ELit Type Literal + | ELet Bind Exp + | EApp Type Exp Exp + | EAdd Type Exp Exp + | EAbs Type Id Exp + deriving (C.Eq, C.Ord, C.Read, C.Show) --- data TExp --- = TAnn TExp Type --- | TBound Integer Ident Type --- | TFree Ident Type --- | TConst Const Type --- | TApp TExp TExp Type --- | TAdd TExp TExp Type --- | TAbs Integer Ident TExp Type --- deriving (Eq, Ord, Show, Read) +type Id = (Ident, Type) + +data Bind = Bind Id [Id] Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print Bind where + prt i (Bind (t, name) parms rhs) = prPrec i 0 $ concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) + +prtId :: Int -> Id -> Doc +prtId i (name, t) = prPrec i 0 $ concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] + +prtIdP :: Int -> Id -> Doc +prtIdP i (name, t) = prPrec i 0 $ concatD + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + + +instance Print Exp where + prt i = \case + EId n -> prPrec i 3 $ concatD [prtId 0 n] + ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1] + ELet bs e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + ] + EApp t e1 e2 -> prPrec i 2 $ concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> prPrec i 1 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs t n e -> prPrec i 0 $ concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtId 0 n + , doc $ showString "." + , prt 0 e + ] --- instance Print TProgram where --- prt i = \case --- TProgram defs -> prPrec i 0 (concatD [prt 0 defs]) --- instance Print TBind where --- prt i = \case --- TBind x t e -> --- prPrec i 0 $ --- concatD --- [ prt 0 x --- , doc (showString ":") --- , prt 0 t --- , doc (showString "=") --- , prt 0 e --- , doc (showString "\n") --- ] --- instance Print TExp where --- prt i = \case --- TAnn e t -> --- prPrec i 2 $ --- concatD --- [ prt 0 e --- , doc (showString ":") --- , prt 1 t --- ] --- TBound _ u t -> prPrec i 3 $ concatD [prt 0 u] --- TFree u t -> prPrec i 3 $ concatD [prt 0 u] --- TConst c _ -> prPrec i 3 (concatD [prt 0 c]) --- TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1] --- TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1] --- TAbs _ u e t -> --- prPrec i 0 $ --- concatD --- [ doc (showString "(") --- , doc (showString "λ") --- , prt 0 u --- , doc (showString ".") --- , prt 0 e --- , doc (showString ")") --- ] diff --git a/test_program b/test_program index 3481a0b..69a2c20 100644 --- a/test_program +++ b/test_program @@ -1,3 +1,2 @@ -fun : Mono Int -> Mono Int ; -fun = let f = \x. x in f 3 ; - +main : _Int ; +main = 3 + 3 ; diff --git a/tests/Main.hs b/tests/Main.hs deleted file mode 100644 index 7432800..0000000 --- a/tests/Main.hs +++ /dev/null @@ -1,21 +0,0 @@ -{-# LANGUAGE OverloadedStrings #-} - -module Main where - -import Grammar.Abs -import System.Exit (exitFailure) -import Test.Hspec -import TypeChecker.AlgoW - -main :: IO () -main = do - print "RUNNING TESTS BROTHER" - exitFailure - -- hspec $ do - -- describe "the algorithm W" $ do - -- it "infers EInt as type Int" $ do - -- fmap fst (run (inferExp (EInt 1))) `shouldBe` Right (TMono "Int") - -- it "throws an exception if a variable is inferred with an empty env" $ do - -- run (inferExp (EId "x")) `shouldBe` Left "Unbound variable: x" - -- it "throws an exception if the annotated type does not match the inferred type" $ do - -- fmap fst (run (inferExp (EAnn (EInt 3) (TPol "a")))) `shouldBe` Right (TMono "bad") diff --git a/tests/Tests.hs b/tests/Tests.hs new file mode 100644 index 0000000..46a9a3f --- /dev/null +++ b/tests/Tests.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# HLINT ignore "Use <$>" #-} + +module Main where + +import Control.Monad.Except +import Grammar.Abs +import Test.QuickCheck +import TypeChecker.TypeChecker +import qualified TypeChecker.TypeCheckerIr as T + +main :: IO () +main = do + quickCheck prop_isInt + quickCheck prop_idAbs_generic + +newtype AbsExp = AE Exp deriving Show +newtype EIntExp = EI Exp deriving Show + +instance Arbitrary EIntExp where + arbitrary = genInt + +instance Arbitrary AbsExp where + arbitrary = genLambda + +getType :: Infer (Type, T.Exp) -> Either Error Type +getType ie = case run ie of + Left err -> Left err + Right (t,e) -> return t + +genInt :: Gen EIntExp +genInt = EI . ELit . LInt <$> arbitrary + +genLambda :: Gen AbsExp +genLambda = do + str <- arbitrary @String + let str' = Ident str + return $ AE $ EAbs str' (EId str') + +prop_idAbs_generic :: AbsExp -> Bool +prop_idAbs_generic (AE e) = case getType (inferExp e) of + Left _ -> False + Right t -> isGenericArr t + +prop_isInt :: EIntExp -> Bool +prop_isInt (EI e) = case getType (inferExp e) of + Left _ -> False + Right t -> t == int + +int :: Type +int = TMono "Int" + +isGenericArr :: Type -> Bool +isGenericArr (TArr (TPol a) (TPol b)) = a == b +isGenericArr _ = False