diff --git a/.gitignore b/.gitignore index 5112877..8d1bad3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist-newstyle *.x *.bak src/Grammar -/language +language +llvm.ll diff --git a/Grammar.cf b/Grammar.cf index 410d11d..0b4785f 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -1,20 +1,25 @@ - - Program. Program ::= [Bind]; EId. Exp3 ::= Ident; EInt. Exp3 ::= Integer; -ELet. Exp3 ::= "let" [Bind] "in" Exp; +EAnn. Exp3 ::= "(" Exp ":" Type ")"; +ELet. Exp3 ::= "let" Bind "in" Exp; EApp. Exp2 ::= Exp2 Exp3; EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident "." Exp; +EAbs. Exp ::= "\\" Ident ":" Type "." Exp; + +Bind. Bind ::= Ident ":" Type ";" + Ident [Ident] "=" Exp; -Bind. Bind ::= Ident [Ident] "=" Exp; separator Bind ";"; -separator Ident " "; +separator Ident ""; coercions Exp 3; +TInt. Type1 ::= "Int" ; +TPol. Type1 ::= Ident ; +TFun. Type ::= Type1 "->" Type ; +coercions Type 1 ; + comment "--"; comment "{-" "-}"; - diff --git a/cabal.project.local b/cabal.project.local new file mode 100644 index 0000000..0432756 --- /dev/null +++ b/cabal.project.local @@ -0,0 +1,2 @@ +ignore-project: False +tests: True diff --git a/language.cabal b/language.cabal index 0577abe..8b958a5 100644 --- a/language.cabal +++ b/language.cabal @@ -1,4 +1,4 @@ -cabal-version: 3.0 +cabal-version: 3.4 name: language @@ -12,18 +12,19 @@ build-type: Simple extra-doc-files: CHANGELOG.md + extra-source-files: Grammar.cf common warnings - ghc-options: -Wall + ghc-options: -W executable language import: warnings main-is: Main.hs - + other-modules: Grammar.Abs Grammar.Lex @@ -33,8 +34,12 @@ executable language Grammar.ErrM LambdaLifter Auxiliary - Interpreter - + Renamer + TypeChecker + TypeCheckerIr +-- Interpreter + Compiler + LlvmIr hs-source-dirs: src build-depends: @@ -44,5 +49,4 @@ executable language , either , array , extra - default-language: GHC2021 diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index f109950..f0cdcc4 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,2 +1,21 @@ -f = \x. x+1; +-- tripplemagic : Int -> Int -> Int -> Int; +-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z; +-- main : Int; +-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 +-- answer: 22 + +-- apply : (Int -> Int) -> Int -> Int; +-- apply f x = f x; +-- main : Int; +-- main = apply (\x : Int . x + 5) 5 +-- answer: 10 + +apply : (Int -> Int -> Int) -> Int -> Int -> Int; +apply f x y = f x y; +krimp: Int -> Int -> Int; +krimp x y = x + y; +main : Int; +main = apply (krimp) 2 3; +-- answer: 5 + diff --git a/sample-programs/basic-2 b/sample-programs/basic-2 deleted file mode 100644 index 4b8ead0..0000000 --- a/sample-programs/basic-2 +++ /dev/null @@ -1,4 +0,0 @@ -add x = \y. x+y; - -main = (\z. z+z) ((add 4) 6); - diff --git a/sample-programs/basic-3 b/sample-programs/basic-3 deleted file mode 100644 index 9443439..0000000 --- a/sample-programs/basic-3 +++ /dev/null @@ -1,2 +0,0 @@ - -main = (\x. x+x+3) ((\x. x) 2) diff --git a/sample-programs/basic-4 b/sample-programs/basic-4 deleted file mode 100644 index 1de7a8c..0000000 --- a/sample-programs/basic-4 +++ /dev/null @@ -1,2 +0,0 @@ - -f x = let g = (\y. y+1) in g (g x) diff --git a/sample-programs/basic-5 b/sample-programs/basic-5 deleted file mode 100644 index 9984ddd..0000000 --- a/sample-programs/basic-5 +++ /dev/null @@ -1,9 +0,0 @@ -id x = x; - -add x y = x + y; - -double n = n + n; - -apply f x = \y. f x y; - -main = apply (id add) ((\x. x + 1) 1) (double 3); diff --git a/sample-programs/basic-6 b/sample-programs/basic-6 deleted file mode 100644 index 511ae10..0000000 --- a/sample-programs/basic-6 +++ /dev/null @@ -1,3 +0,0 @@ - - -f = \x.\y. x+y diff --git a/sample-programs/basic-7 b/sample-programs/basic-7 deleted file mode 100644 index b3769b9..0000000 --- a/sample-programs/basic-7 +++ /dev/null @@ -1,5 +0,0 @@ -add x y = x + y; - -apply f x = f x; - -main = apply (add 4) 6; diff --git a/sample-programs/basic-8 b/sample-programs/basic-8 deleted file mode 100644 index 59abdac..0000000 --- a/sample-programs/basic-8 +++ /dev/null @@ -1,2 +0,0 @@ - -f x = let double = \y. y+y in (\x. x+y) 4; diff --git a/sample-programs/basic-9 b/sample-programs/basic-9 deleted file mode 100644 index ba9ebdc..0000000 --- a/sample-programs/basic-9 +++ /dev/null @@ -1,4 +0,0 @@ - - - -main = (\f.\x.\y. f x + f y) (\x. x+x) ((\x. x+1) ((\x. x+3) 2)) 4 diff --git a/shell.nix b/shell.nix index 84d3c04..0af8c7b 100644 --- a/shell.nix +++ b/shell.nix @@ -1,5 +1,5 @@ let - pkgs = import (fetchTarball https://github.com/NixOS/nixpkgs/archive/8c619a1f3cedd16ea172146e30645e703d21bfc1.tar.gz) { }; # pin the channel to ensure reproducibility! + pkgs = import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/747927516efcb5e31ba03b7ff32f61f6d47e7d87.zip") { }; # pin the channel to ensure reproducibility! in pkgs.haskellPackages.developPackage { root = ./.; diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index 2de36a7..735d804 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -1,4 +1,4 @@ - +{-# LANGUAGE LambdaCase #-} module Auxiliary (module Auxiliary) where import Control.Monad.Error.Class (liftEither) import Control.Monad.Except (MonadError) @@ -9,3 +9,13 @@ snoc x xs = xs ++ [x] maybeToRightM :: MonadError l m => l -> Maybe r -> m r maybeToRightM err = liftEither . maybeToRight err + +mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b]) +mapAccumM f = go + where + go acc = \case + [] -> pure (acc, []) + x:xs -> do + (acc', x') <- f acc x + (acc'', xs') <- go acc' xs + pure (acc'', x':xs') diff --git a/src/Compiler.hs b/src/Compiler.hs index e69de29..fd6b6bc 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -0,0 +1,266 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Compiler (compile) where + +import Auxiliary (snoc) +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (dupe, first, second) +import Grammar.ErrM (Err) +import LlvmIr (LLVMIr (..), LLVMType (..), + LLVMValue (..), Visibility (..), + llvmIrToString) +import TypeChecker (partitionType) +import TypeCheckerIr + +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo + , variableCount :: Integer + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t { instructions = snoc l $ instructions t } + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState Integer +getNewVar = increaseVarCount >> getVarCount + +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. +getFunctions :: [Bind] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ map go bs + where + go (Bind id args _) = + (id, FunctionInfo { numArgs=length args, arguments=args }) + + + +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , variableCount = 0 + } + +-- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to lli +compile :: Program -> Err String +compile (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen + +compileScs :: [Bind] -> CompilerState () +compileScs [] = pure () +compileScs (Bind (name, t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define (type2LlvmType t_return) name args' + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs + where + t_return = snd $ partitionType (length args) t + +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (Ident "end") + -- , Label (Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (Ident "end") + -- , Label (Ident "end") + Ret I64 (VInteger 0) + ] + +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] + +compileExp :: Exp -> CompilerState () +compileExp = \case + EInt i -> emitInt i + EAdd t e1 e2 -> emitAdd t e1 e2 + EId (name, _) -> emitIdent name + EApp t e1 e2 -> emitApp t e1 e2 + EAbs t ti e -> emitAbs t ti e + ELet bind e -> emitLet bind e + +--- aux functions --- +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e + +emitLet :: Bind -> Exp -> CompilerState () +emitLet b e = emit . Comment $ concat [ "ELet (" + , show b + , " = " + , show e + , ") is not implemented!" + ] + +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call (type2LlvmType t) visibility name args' + emit $ SetVariable (Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x + +emitIdent :: Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitInt :: Integer -> CompilerState () +emitInt i = do + -- !!this should never happen!! + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) + +-- emitMul :: Exp -> Exp -> CompilerState () +-- emitMul e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Mul I64 v1 v2 + +-- emitMod :: Exp -> Exp -> CompilerState () +-- emitMod e1 e2 = do +-- -- `let m a b = rem (abs $ b + a) b` +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- vadd <- gets variableCount +-- emit $ SetVariable $ Ident $ show vadd +-- emit $ Add I64 v1 v2 +-- +-- increaseVarCount +-- vabs <- gets variableCount +-- emit $ SetVariable $ Ident $ show vabs +-- emit $ Call I64 (Ident "llvm.abs.i64") +-- [ (I64, VIdent (Ident $ show vadd)) +-- , (I1, VInteger 1) +-- ] +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 + +-- emitDiv :: Exp -> Exp -> CompilerState () +-- emitDiv e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Div I64 v1 v2 + +-- emitSub :: Exp -> Exp -> CompilerState () +-- emitSub e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Sub I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + EInt i -> pure $ VInteger i + + EId id@(name, t) -> do + funcs <- gets functions + case Map.lookup id funcs of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ SetVariable (Ident $ show vc) + (Call (type2LlvmType t) Global name []) + pure $ VIdent (Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (Ident $ show v) (getType e) + +type2LlvmType :: Type -> LLVMType +type2LlvmType = \case + TInt -> I64 + TFun t xs -> do + let (t', xs') = function2LLVMType xs [type2LlvmType t] + Function t' xs' + t -> CustomType $ Ident ("\"" ++ show t ++ "\"") + where + function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) + function2LLVMType x s = (type2LlvmType x, s) + +getType :: Exp -> LLVMType +getType (EInt _) = I64 +getType (EAdd t _ _) = type2LlvmType t +getType (EId (_, t)) = type2LlvmType t +getType (EApp t _ _) = type2LlvmType t +getType (EAbs t _ _) = type2LlvmType t +getType (ELet _ e) = getType e + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VFunction _ _ t) = t diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3503a7c..37d46a7 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -35,7 +35,6 @@ initCxt scs = expandLambdas :: Bind -> Bind expandLambdas (Bind name parms rhs) = Bind name [] $ foldr EAbs rhs parms - findMain :: [Bind] -> Err Exp findMain [] = throwError "No main!" findMain (sc:scs) = case sc of diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 3d9595d..015e7f3 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -5,21 +5,22 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where import Auxiliary (snoc) -import Data.Foldable.Extra (notNull) -import Data.List (mapAccumL, mapAccumR, partition) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe, mapMaybe) -import Data.Set (Set, (\\)) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.State (MonadState (get, put), State, evalState) +import Data.Set (Set) import qualified Data.Set as Set -import Data.Tuple.Extra (uncurry3) -import Grammar.Abs import Prelude hiding (exp) +import Renamer +import TypeCheckerIr -- | Lift lambdas and let expression into supercombinators. +-- Three phases: +-- @freeVars@ annotatss all the free variables. +-- @abstract@ converts lambdas into let expressions. +-- @collectScs@ moves every non-constant let expression to a top-level function. lambdaLift :: Program -> Program -lambdaLift = collectScs . rename . abstract . freeVars +lambdaLift = collectScs . abstract . freeVars -- | Annotate free variables @@ -28,242 +29,162 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) | Bind n xs e <- ds ] -freeVarsExp :: Set Ident -> Exp -> AnnExp +freeVarsExp :: Set Id -> Exp -> AnnExp freeVarsExp localVars = \case + EId n | Set.member n localVars -> (Set.singleton n, AId n) + | otherwise -> (mempty, AId n) - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) + EInt i -> (mempty, AInt i) - EInt i -> (mempty, AInt i) + EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 - EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') + where + e1' = freeVarsExp localVars e1 + e2' = freeVarsExp localVars e2 - EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') + where + e' = freeVarsExp (Set.insert par localVars) e - EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e') - where - e' = freeVarsExp (Set.insert par localVars) e + -- Sum free variables present in bind and the expression + ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') + where + binders_frees = Set.delete name $ freeVarsOf rhs' + e_free = Set.delete name $ freeVarsOf e' - -- Sum free variables present in binders and the expression - ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') - where - binders_frees = rhss_frees \\ names_set - e_free = freeVarsOf e' \\ names_set + rhs' = freeVarsExp e_localVars rhs + new_bind = ABind name parms rhs' - rhss_frees = foldr1 Set.union (map freeVarsOf rhss') - names_set = Set.fromList names - - (names, parms, rhss) = fromBinders binders - rhss' = map (freeVarsExp e_localVars) rhss - e_localVars = Set.union localVars names_set - - binders' = zipWith3 ABind names parms rhss' - e' = freeVarsExp e_localVars e + e' = freeVarsExp e_localVars e + e_localVars = Set.insert name localVars -freeVarsOf :: AnnExp -> Set Ident +freeVarsOf :: AnnExp -> Set Id freeVarsOf = fst -fromBinders :: [Bind] -> ([Ident], [[Ident]], [Exp]) -fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ] - -- AST annotated with free variables -type AnnProgram = [(Ident, [Ident], AnnExp)] +type AnnProgram = [(Id, [Id], AnnExp)] -type AnnExp = (Set Ident, AnnExp') +type AnnExp = (Set Id, AnnExp') -data ABind = ABind Ident [Ident] AnnExp deriving Show +data ABind = ABind Id [Id] AnnExp deriving Show -data AnnExp' = AId Ident +data AnnExp' = AId Id | AInt Integer - | AApp AnnExp AnnExp - | AAdd AnnExp AnnExp - | AAbs Ident AnnExp - | ALet [ABind] AnnExp + | ALet ABind AnnExp + | AApp Type AnnExp AnnExp + | AAdd Type AnnExp AnnExp + | AAbs Type Id AnnExp deriving Show - -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program -abstract prog = Program $ map go prog +abstract prog = Program $ evalState (mapM go prog) 0 where - go :: (Ident, [Ident], AnnExp) -> Bind - go (name, pars, rhs@(_, e)) = - case e of - AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2 - where - (e2, pars2) = flattenLambdasAnn e1 - _ -> Bind name pars $ abstractExp rhs + go :: (Id, [Id], AnnExp) -> State Int Bind + go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + where + (rhs', parms1) = flattenLambdasAnn rhs -- | Flatten nested lambdas and collect the parameters -- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident]) +flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) flattenLambdasAnn ae = go (ae, []) where - go :: (AnnExp, [Ident]) -> (AnnExp, [Ident]) + go :: (AnnExp, [Id]) -> (AnnExp, [Id]) go ((free, e), acc) = - case e of - AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) + case e of + AAbs _ par (free1, e1) -> + go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) - -abstractExp :: AnnExp -> Exp +abstractExp :: AnnExp -> State Int Exp abstractExp (free, exp) = case exp of - AId n -> EId n - AInt i -> EInt i - AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2) - AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) - ALet bs e -> ELet (map go bs) $ abstractExp e - where - go (ABind name parms rhs) = - let - (rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs - in - Bind name (parms ++ parms1) rhs' - - skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp - skipLambdas f (free, ae) = case ae of - AAbs name ae1 -> EAbs name $ skipLambdas f ae1 - _ -> f (free, ae) - - -- Lift lambda into let and bind free variables - AAbs par e -> foldl EApp sc $ map EId freeList - where - freeList = Set.toList free - sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc" - --- | Rename all supercombinators and variables -rename :: Program -> Program -rename (Program ds) = Program $ map (uncurry3 Bind) tuples - where - tuples = snd (mapAccumL renameSc 0 ds) - renameSc i (Bind n xs e) = (i2, (n, xs', e')) + AId n -> pure $ EId n + AInt i -> pure $ EInt i + AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) + AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) + ALet b e -> liftA2 ELet (go b) (abstractExp e) where - (i1, xs', env) = newNames i xs - (i2, e') = renameExp env i1 e + go (ABind name parms rhs) = do + (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs + pure $ Bind name (parms ++ parms1) rhs' -renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) -renameExp env i = \case + skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp + skipLambdas f (free, ae) = case ae of + AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 + _ -> f (free, ae) - EId n -> (i, EId . fromMaybe n $ Map.lookup n env) + -- Lift lambda into let and bind free variables + AAbs t parm e -> do + i <- nextNumber + rhs <- abstractExp e - EInt i1 -> (i, EInt i1) + let sc_name = Ident ("sc_" ++ show i) + sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) - EApp e1 e2 -> (i2, EApp e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - EAdd e1 e2 -> (i2, EAdd e1' e2') - where - (i1, e1') = renameExp env i e1 - (i2, e2') = renameExp env i1 e2 - - ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e') - where - (i1, e') = renameExp e_env i e - (names, pars, rhss) = fromBinders bs - (i2, ns', env') = newNames i1 (names ++ concat pars) - pars' = (map . map) renamePar pars - e_env = Map.union env' env - (i3, es') = mapAccumL (renameExp e_env) i2 rhss - - renamePar p = case Map.lookup p env' of - Just p' -> p' - Nothing -> error ("Can't find name for " ++ show p) + pure $ foldl (EApp TInt) sc $ map EId freeList + where + freeList = Set.toList free + parms = snoc parm freeList - EAbs par e -> (i2, EAbs par' e') - where - (i1, par', env') = newName par - (i2, e') = renameExp (Map.union env' env ) i1 e +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i - -newName :: Ident -> (Int, Ident, Map Ident Ident) -newName old_name = (i, head names, env) - where (i, names, env) = newNames 1 [old_name] - -newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) -newNames i old_names = (i', new_names, env) - where - (i', new_names) = getNames i old_names - env = Map.fromList $ zip old_names new_names - -getNames :: Int -> [Ident] -> (Int, [Ident]) -getNames i ns = (i + length ss, zipWith makeName ss [i..]) - where - ss = map (\(Ident s) -> s) ns - -makeName :: String -> Int -> Ident -makeName prefix i = Ident (prefix ++ "_" ++ show i) - - --- | Collects supercombinators by lifting appropriate let expressions +-- | Collects supercombinators by lifting non-constant let expressions collectScs :: Program -> Program collectScs (Program scs) = Program $ concatMap collectFromRhs scs where collectFromRhs (Bind name parms rhs) = - let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) + EId n -> ([], EId n) + EInt i -> ([], EInt i) - EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAdd e1 e2 -> (scs1 ++ scs2, EAdd e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 + EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') + where + (scs1, e1') = collectScsExp e1 + (scs2, e2') = collectScsExp e2 - EAbs x e -> (scs, EAbs x e') - where - (scs, e') = collectScsExp e + EAbs t par e -> (scs, EAbs t par e') + where + (scs, e') = collectScsExp e - -- Collect supercombinators from binds, the rhss, and the expression. - -- - -- > f = let - -- > sc = rhs - -- > sc1 = rhs1 - -- > ... - -- > in e - -- - ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') - where - binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in - Bind n (parms ++ parms1) rhs' - | Bind n parms rhs <- scs' - ] - (rhss_scs, binds') = mapAccumL collectScsRhs [] binds - (e_scs, e') = collectScsExp e + -- Collect supercombinators from bind, the rhss, and the expression. + -- + -- > f = let sc x y = rhs in e + -- + ELet (Bind name parms rhs) e -> if null parms + then ( rhs_scs ++ e_scs, ELet bind e') + else (bind : rhs_scs ++ e_scs, e') + where + bind = Bind name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (e_scs, e') = collectScsExp e - (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' - - collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') - where - (rhs_scs, rhs') = collectScsExp rhs -- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: Exp -> (Exp, [Ident]) -flattenLambdas e = go (e, []) +flattenLambdas :: Exp -> (Exp, [Id]) +flattenLambdas = go . (, []) where go (e, acc) = case e of - EAbs par e1 -> go (e1, snoc par acc) - _ -> (e, acc) - -mkEAbs :: [Bind] -> Exp -> Exp -mkEAbs [] e = e -mkEAbs bs e = ELet bs e + EAbs _ par e1 -> go (e1, snoc par acc) + _ -> (e, acc) diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs new file mode 100644 index 0000000..d340ddc --- /dev/null +++ b/src/LlvmIr.hs @@ -0,0 +1,204 @@ +{-# LANGUAGE LambdaCase #-} + +module LlvmIr ( + LLVMType (..), + LLVMIr (..), + llvmIrToString, + LLVMValue (..), + LLVMComp (..), + Visibility (..), +) where + +import Data.List (intercalate) +import TypeCheckerIr + +-- | A datatype which represents some basic LLVM types +data LLVMType + = I1 + | I8 + | I32 + | I64 + | Ptr + | Ref LLVMType + | Function LLVMType [LLVMType] + | Array Int LLVMType + | CustomType Ident + +instance Show LLVMType where + show :: LLVMType -> String + show = \case + I1 -> "i1" + I8 -> "i8" + I32 -> "i32" + I64 -> "i64" + Ptr -> "ptr" + Ref ty -> show ty <> "*" + Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" + Array n ty -> concat ["[", show n, " x ", show ty, "]"] + CustomType (Ident ty) -> ty + +data LLVMComp + = LLEq + | LLNe + | LLUgt + | LLUge + | LLUlt + | LLUle + | LLSgt + | LLSge + | LLSlt + | LLSle +instance Show LLVMComp where + show :: LLVMComp -> String + show = \case + LLEq -> "eq" + LLNe -> "ne" + LLUgt -> "ugt" + LLUge -> "uge" + LLUlt -> "ult" + LLUle -> "ule" + LLSgt -> "sgt" + LLSge -> "sge" + LLSlt -> "slt" + LLSle -> "sle" + +data Visibility = Local | Global +instance Show Visibility where + show :: Visibility -> String + show Local = "%" + show Global = "@" + +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant +data LLVMValue + = VInteger Integer + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType + +instance Show LLVMValue where + show :: LLVMValue -> String + show v = case v of + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> show vis <> n + VConstant s -> "c" <> show s + +type Params = [(Ident, LLVMType)] +type Args = [(LLVMType, LLVMValue)] + +-- | A datatype which represents different instructions in LLVM +data LLVMIr + = Define LLVMType Ident Params + | DefineEnd + | Declare LLVMType Ident Params + | SetVariable Ident LLVMIr + | Variable Ident + | Add LLVMType LLVMValue LLVMValue + | Sub LLVMType LLVMValue LLVMValue + | Div LLVMType LLVMValue LLVMValue + | Mul LLVMType LLVMValue LLVMValue + | Srem LLVMType LLVMValue LLVMValue + | Icmp LLVMComp LLVMType LLVMValue LLVMValue + | Br Ident + | BrCond LLVMValue Ident Ident + | Label Ident + | Call LLVMType Visibility Ident Args + | Alloca LLVMType + | Store LLVMType Ident LLVMType Ident + | Bitcast LLVMType Ident LLVMType + | Ret LLVMType LLVMValue + | Comment String + | UnsafeRaw String -- This should generally be avoided, and proper + -- instructions should be used in its place + deriving (Show) + +-- | Converts a list of LLVMIr instructions to a string +llvmIrToString :: [LLVMIr] -> String +llvmIrToString = go 0 + where + go :: Int -> [LLVMIr] -> String + go _ [] = mempty + go i (x : xs) = do + let (i', n) = case x of + Define{} -> (i + 1, 0) + DefineEnd -> (i - 1, 0) + _ -> (i, i) + insToString n x <> go i' xs + +-- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +insToString :: Int -> LLVMIr -> String +insToString i l = + replicate i '\t' <> case l of + (Define t (Ident i) params) -> + concat + [ "define ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call t vis (Ident i) arg) -> + concat + [ "call ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 (Ident id1) t2 (Ident id2)) -> + concat + [ "store ", show t1, " %", id1 + , ", ", show t2 , " %", id2, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" + (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , "label_", s1, ", ", "label %", "label_", s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id diff --git a/src/Main.hs b/src/Main.hs index 41379fc..1831428 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,52 +1,97 @@ {-# LANGUAGE LambdaCase #-} + module Main where -import Grammar.ErrM (Err) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) -import Interpreter (interpret) -import LambdaLifter (abstract, freeVars, lambdaLift) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) +import Compiler (compile) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) + +-- import Interpreter (interpret) +import LambdaLifter (lambdaLift) +import Renamer (rename) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import TypeChecker (typecheck) main :: IO () -main = getArgs >>= \case - [] -> print "Required file path missing" - (s:_) -> main' s +main = + getArgs >>= \case + [] -> print "Required file path missing" + (s : _) -> main' s main' :: String -> IO () main' s = do - file <- readFile s + file <- readFile s - putStrLn "\n-- parse" - parsed <- fromSyntaxErr . pProgram $ myLexer file - putStrLn $ printTree parsed + printToErr "-- Parse Tree -- " + parsed <- fromSyntaxErr . pProgram $ myLexer file + printToErr $ printTree parsed - putStrLn "\n-- Lambda Lifter" - let lifted = lambdaLift parsed - putStrLn $ printTree lifted + printToErr "\n-- Renamer --" + let renamed = rename parsed + printToErr $ printTree renamed - -- interpred <- fromInterpreterErr $ interpret lifted - -- putStrLn "\n-- interpret" - -- print interpred + printToErr "\n-- TypeChecker --" + typechecked <- fromTypeCheckerErr $ typecheck renamed + printToErr $ printTree typechecked - exitSuccess + printToErr "\n-- Lambda Lifter --" + let lifted = lambdaLift typechecked + printToErr $ printTree lifted + printToErr "\n -- Printing compiler output to stdout --" + compiled <- fromCompilerErr $ compile lifted + putStrLn compiled + writeFile "llvm.ll" compiled + + -- interpred <- fromInterpreterErr $ interpret lifted + -- putStrLn "\n-- interpret" + -- print interpred + + exitSuccess + +printToErr :: String -> IO () +printToErr = hPutStrLn stderr + +fromCompilerErr :: Err a -> IO a +fromCompilerErr = + either + ( \err -> do + putStrLn "\nCOMPILER ERROR" + putStrLn err + exitFailure + ) + pure fromSyntaxErr :: Err a -> IO a -fromSyntaxErr = either - (\err -> do - putStrLn "\nSYNTAX ERROR" - putStrLn err - exitFailure) - pure +fromSyntaxErr = + either + ( \err -> do + putStrLn "\nSYNTAX ERROR" + putStrLn err + exitFailure + ) + pure + +fromTypeCheckerErr :: Err a -> IO a +fromTypeCheckerErr = + either + ( \err -> do + putStrLn "\nTYPECHECKER ERROR" + putStrLn err + exitFailure + ) + pure fromInterpreterErr :: Err a -> IO a -fromInterpreterErr = either - (\err -> do - putStrLn "\nINTERPRETER ERROR" - putStrLn err - exitFailure) - pure - - +fromInterpreterErr = + either + ( \err -> do + putStrLn "\nINTERPRETER ERROR" + putStrLn err + exitFailure + ) + pure diff --git a/src/Renamer.hs b/src/Renamer.hs new file mode 100644 index 0000000..b284e92 --- /dev/null +++ b/src/Renamer.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE LambdaCase #-} + +module Renamer (module Renamer) where + +import Auxiliary (mapAccumM) +import Control.Monad.State (MonadState, State, evalState, gets, + modify) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe) +import Grammar.Abs + + +-- | Rename all variables and local binds +rename :: Program -> Program +rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 + where + initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs + renameSc :: Names -> Bind -> Rn Bind + renameSc old_names (Bind name t _ parms rhs) = do + (new_names, parms') <- newNames old_names parms + rhs' <- snd <$> renameExp new_names rhs + pure $ Bind name t name parms' rhs' + + +-- | Rename monad. State holds the number of renamed names. +newtype Rn a = Rn { runRn :: State Int a } + deriving (Functor, Applicative, Monad, MonadState Int) + +-- | Maps old to new name +type Names = Map Ident Ident + +renameLocalBind :: Names -> Bind -> Rn (Names, Bind) +renameLocalBind old_names (Bind name t _ parms rhs) = do + (new_names, name') <- newName old_names name + (new_names', parms') <- newNames new_names parms + (new_names'', rhs') <- renameExp new_names' rhs + pure (new_names'', Bind name' t name' parms' rhs') + +renameExp :: Names -> Exp -> Rn (Names, Exp) +renameExp old_names = \case + EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) + + EInt i1 -> pure (old_names, EInt i1) + + EApp e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EApp e1' e2') + + EAdd e1 e2 -> do + (env1, e1') <- renameExp old_names e1 + (env2, e2') <- renameExp old_names e2 + pure (Map.union env1 env2, EAdd e1' e2') + + ELet b e -> do + (new_names, b) <- renameLocalBind old_names b + (new_names', e') <- renameExp new_names e + pure (new_names', ELet b e') + + EAbs par t e -> do + (new_names, par') <- newName old_names par + (new_names', e') <- renameExp new_names e + pure (new_names', EAbs par' t e') + + EAnn e t -> do + (new_names, e') <- renameExp old_names e + pure (new_names, EAnn e' t) + +-- | Create a new name and add it to name environment. +newName :: Names -> Ident -> Rn (Names, Ident) +newName env old_name = do + new_name <- makeName old_name + pure (Map.insert old_name new_name env, new_name) + +-- | Create multiple names and add them to the name environment +newNames :: Names -> [Ident] -> Rn (Names, [Ident]) +newNames = mapAccumM newName + +-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. +makeName :: Ident -> Rn Ident +makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ + diff --git a/src/TypeChecker.hs b/src/TypeChecker.hs index e69de29..1e44888 100644 --- a/src/TypeChecker.hs +++ b/src/TypeChecker.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} + +module TypeChecker (typecheck, partitionType) where + +import Auxiliary (maybeToRightM, snoc) +import Control.Monad.Except (throwError, unless) +import Data.Map (Map) +import qualified Data.Map as Map +import Grammar.Abs +import Grammar.ErrM (Err) +import Grammar.Print (Print (prt), concatD, doc, printTree, + render) +import Prelude hiding (exp, id) +import qualified TypeCheckerIr as T + +-- NOTE: this type checker is poorly tested + +-- TODO +-- Coercion +-- Type inference + +data Cxt = Cxt + { env :: Map Ident Type -- ^ Local scope signature + , sig :: Map Ident Type -- ^ Top-level signatures + } + +initCxt :: [Bind] -> Cxt +initCxt sc = Cxt { env = mempty + , sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc + } + +typecheck :: Program -> Err T.Program +typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc + +-- | Check if infered rhs type matches type signature. +checkBind :: Cxt -> Bind -> Err T.Bind +checkBind cxt b = + case expandLambdas b of + Bind name t _ parms rhs -> do + (rhs', t_rhs) <- infer cxt rhs + unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs + pure $ T.Bind (name, t) (zip parms ts_parms) rhs' + where + ts_parms = fst $ partitionType (length parms) t + +-- | @ f x y = rhs ⇒ f = \x.\y. rhs @ +expandLambdas :: Bind -> Bind +expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs' + where + rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms + ts_parms = fst $ partitionType (length parms) t + +-- | Infer type of expression. +infer :: Cxt -> Exp -> Err (T.Exp, Type) +infer cxt = \case + EId x -> + case lookupEnv x cxt of + Nothing -> + case lookupSig x cxt of + Nothing -> throwError ("Unbound variable:" ++ printTree x) + Just t -> pure (T.EId (x, t), t) + Just t -> pure (T.EId (x, t), t) + + EInt i -> pure (T.EInt i, T.TInt) + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure (T.EApp t2 e' e1', t2) + _ -> do + throwError ("Not a function: " ++ show e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure (T.EAdd T.TInt e' e1', T.TInt) + + EAbs x t e -> do + (e', t1) <- infer (insertEnv x t cxt) e + let t_abs = TFun t t1 + pure (T.EAbs t_abs (x, t) e', t_abs) + + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + (e', t) <- infer cxt' e + pure (T.ELet b' e', t) + + EAnn e t -> do + (e', t1) <- infer cxt e + unless (typeEq t t1) $ + throwError "Inferred type and type annotation doesn't match" + pure (e', t1) + +-- | Check infered type matches the supplied type. +check :: Cxt -> Exp -> Type -> Err T.Exp +check cxt exp typ = case exp of + + EId x -> do + t <- case lookupEnv x cxt of + Nothing -> maybeToRightM + ("Unbound variable:" ++ printTree x) + (lookupSig x cxt) + Just t -> pure t + unless (typeEq t typ) . throwError $ typeErr x typ t + pure $ T.EId (x, t) + + EInt i -> do + unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ + pure $ T.EInt i + + EApp e e1 -> do + (e', t) <- infer cxt e + case t of + TFun t1 t2 -> do + e1' <- check cxt e1 t1 + pure $ T.EApp t2 e' e1' + _ -> throwError ("Not a function 2: " ++ printTree e) + + EAdd e e1 -> do + e' <- check cxt e T.TInt + e1' <- check cxt e1 T.TInt + pure $ T.EAdd T.TInt e' e1' + + EAbs x t e -> do + (e', t_e) <- infer (insertEnv x t cxt) e + let t1 = TFun t t_e + unless (typeEq t1 typ) $ throwError "Wrong lamda type!" + pure $ T.EAbs t1 (x, t) e' + + ELet b e -> do + let cxt' = insertBind b cxt + b' <- checkBind cxt' b + e' <- check cxt' e typ + pure $ T.ELet b' e' + + EAnn e t -> do + unless (typeEq t typ) $ + throwError "Inferred type and type annotation doesn't match" + check cxt e t + +-- | Check if types are equivalent. Doesn't handle coercion or polymorphism. +typeEq :: Type -> Type -> Bool +typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 +typeEq t t1 = t == t1 + +-- | Partion type into types of parameters and return type. +partitionType :: Int -- Number of parameters to apply + -> Type + -> ([Type], Type) +partitionType = go [] + where + go acc 0 t = (acc, t) + go acc i t = case t of + TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2 + _ -> error "Number of parameters and type doesn't match" + +insertBind :: Bind -> Cxt -> Cxt +insertBind (Bind n t _ _ _) = insertEnv n t + +lookupEnv :: Ident -> Cxt -> Maybe Type +lookupEnv x = Map.lookup x . env + +insertEnv :: Ident -> Type -> Cxt -> Cxt +insertEnv x t cxt = cxt { env = Map.insert x t cxt.env } + +lookupSig :: Ident -> Cxt -> Maybe Type +lookupSig x = Map.lookup x . sig + +typeErr :: Print a => a -> Type -> Type -> String +typeErr p expected actual = render $ concatD + [ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n" + , doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n" + , doc $ showString "Actual: " , prt 0 actual + ] diff --git a/src/TypeCheckerIr.hs b/src/TypeCheckerIr.hs new file mode 100644 index 0000000..f6e3ec6 --- /dev/null +++ b/src/TypeCheckerIr.hs @@ -0,0 +1,100 @@ +{-# LANGUAGE LambdaCase #-} + +module TypeCheckerIr + ( module Grammar.Abs + , module TypeCheckerIr + ) where + +import Grammar.Abs (Ident (..), Type (..)) +import Grammar.Print +import Prelude +import qualified Prelude as C (Eq, Ord, Read, Show) + +newtype Program = Program [Bind] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Exp + = EId Id + | EInt Integer + | ELet Bind Exp + | EApp Type Exp Exp + | EAdd Type Exp Exp + | EAbs Type Id Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id = (Ident, Type) + +data Bind = Bind Id [Id] Exp + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Program where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print Bind where + prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD + [ prtId 0 name + , doc $ showString ";" + , prt 0 n + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Int -> [Id] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) + +prtId :: Int -> Id -> Doc +prtId i (name, t) = prPrec i 0 $ concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + ] + +prtIdP :: Int -> Id -> Doc +prtIdP i (name, t) = prPrec i 0 $ concatD + [ doc $ showString "(" + , prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ")" + ] + + +instance Print Exp where + prt i = \case + EId n -> prPrec i 3 $ concatD [prtIdP 0 n] + EInt i1 -> prPrec i 3 $ concatD [prt 0 i1] + ELet bs e -> prPrec i 3 $ concatD + [ doc $ showString "let" + , prt 0 bs + , doc $ showString "in" + , prt 0 e + ] + EApp t e1 e2 -> prPrec i 2 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 2 e1 + , prt 3 e2 + ] + EAdd t e1 e2 -> prPrec i 1 $ concatD + [ doc $ showString "@" + , prt 0 t + , prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs t n e -> prPrec i 0 $ concatD + [ doc $ showString "@" + , prt 0 t + , doc $ showString "\\" + , prtIdP 0 n + , doc $ showString "." + , prt 0 e + ] + +