diff --git a/Grammar.cf b/Grammar.cf index 6870367..96554bb 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -3,7 +3,7 @@ Program. Program ::= [Def] ; DBind. Def ::= Bind ; DData. Def ::= Data ; -terminator Def ";" ; +separator Def ";" ; Bind. Bind ::= Ident ":" Type ";" Ident [Ident] "=" Exp ; @@ -31,16 +31,19 @@ IMatch. Match ::= Ident ; InitMatch. Match ::= Ident Match ; separator Match " " ; -TMono. Type1 ::= "_" Ident ; -TPol. Type1 ::= "'" Ident ; -TArr. Type ::= Type1 "->" Type ; +TMono. Type1 ::= "_" Ident ; +TPol. Type1 ::= "'" Ident ; +TConstr. Type1 ::= Ident "(" [Type] ")" ; +TArr. Type ::= Type1 "->" Type ; + separator Type " " ; +coercions Type 2 ; -- shift/reduce problem here -Data. Data ::= "data" Ident [Type] "where" ";" +Data. Data ::= "data" Type "where" ";" [Constructor]; -terminator Constructor ";" ; +separator Constructor "," ; Constructor. Constructor ::= Ident ":" Type ; @@ -48,10 +51,9 @@ Constructor. Constructor ::= Ident ":" Type ; -- token Poly upper (letter | digit | '_')* ; -- token Mono lower (letter | digit | '_')* ; -terminator Bind ";" ; +separator Bind ";" ; separator Ident " "; -coercions Type 1 ; coercions Exp 5 ; comment "--" ; diff --git a/language.cabal b/language.cabal index eb58aa0..3556367 100644 --- a/language.cabal +++ b/language.cabal @@ -34,9 +34,9 @@ executable language TypeChecker.TypeChecker TypeChecker.TypeCheckerIr Renamer.Renamer - LambdaLifter.LambdaLifter - Codegen.Codegen - Codegen.LlvmIr + -- LambdaLifter.LambdaLifter + -- Codegen.Codegen + -- Codegen.LlvmIr hs-source-dirs: src diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 76a1f02..fe66b43 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,277 +1,277 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +--{-# LANGUAGE LambdaCase #-} +--{-# LANGUAGE OverloadedStrings #-} -module Codegen.Codegen (compile) where +module Codegen.Codegen where -import Auxiliary (snoc) -import Codegen.LlvmIr (LLVMIr (..), LLVMType (..), - LLVMValue (..), Visibility (..), - llvmIrToString) -import Control.Monad.State (StateT, execStateT, gets, modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (dupe, first, second) -import Grammar.ErrM (Err) -import TypeChecker.TypeChecker -import TypeChecker.TypeCheckerIr +--import Auxiliary (snoc) +--import Codegen.LlvmIr (LLVMIr (..), LLVMType (..), +-- LLVMValue (..), Visibility (..), +-- llvmIrToString) +--import Control.Monad.State (StateT, execStateT, gets, modify) +--import Data.Map (Map) +--import qualified Data.Map as Map +--import Data.Tuple.Extra (dupe, first, second) +--import Grammar.ErrM (Err) +--import TypeChecker.TypeCheckerIr --- | The record used as the code generator state -data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo - , variableCount :: Integer - } +---- | The record used as the code generator state +--data CodeGenerator = CodeGenerator +-- { instructions :: [LLVMIr] +-- , functions :: Map Id FunctionInfo +-- , variableCount :: Integer +-- } --- | A state type synonym -type CompilerState a = StateT CodeGenerator Err a +---- | A state type synonym +--type CompilerState a = StateT CodeGenerator Err a -data FunctionInfo = FunctionInfo - { numArgs :: Int - , arguments :: [Id] - } +--data FunctionInfo = FunctionInfo +-- { numArgs :: Int +-- , arguments :: [Id] +-- } --- | Adds a instruction to the CodeGenerator state -emit :: LLVMIr -> CompilerState () -emit l = modify $ \t -> t { instructions = snoc l $ instructions t } +---- | Adds a instruction to the CodeGenerator state +--emit :: LLVMIr -> CompilerState () +--emit l = modify $ \t -> t { instructions = snoc l $ instructions t } --- | Increases the variable counter in the CodeGenerator state -increaseVarCount :: CompilerState () -increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } +---- | Increases the variable counter in the CodeGenerator state +--increaseVarCount :: CompilerState () +--increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } --- | Returns the variable count from the CodeGenerator state -getVarCount :: CompilerState Integer -getVarCount = gets variableCount +---- | Returns the variable count from the CodeGenerator state +--getVarCount :: CompilerState Integer +--getVarCount = gets variableCount --- | Increases the variable count and returns it from the CodeGenerator state -getNewVar :: CompilerState Integer -getNewVar = increaseVarCount >> getVarCount +---- | Increases the variable count and returns it from the CodeGenerator state +--getNewVar :: CompilerState Integer +--getNewVar = increaseVarCount >> getVarCount --- | Produces a map of functions infos from a list of binds, --- which contains useful data for code generation. -getFunctions :: [Bind] -> Map Id FunctionInfo -getFunctions bs = Map.fromList $ map go bs - where - go (Bind id args _) = - (id, FunctionInfo { numArgs=length args, arguments=args }) +---- | Produces a map of functions infos from a list of binds, +---- which contains useful data for code generation. +--getFunctions :: [Bind] -> Map Id FunctionInfo +--getFunctions bs = Map.fromList $ map go bs +-- where +-- go (Bind id args _) = +-- (id, FunctionInfo { numArgs=length args, arguments=args }) -initCodeGenerator :: [Bind] -> CodeGenerator -initCodeGenerator scs = CodeGenerator { instructions = defaultStart - , functions = getFunctions scs - , variableCount = 0 - } +--initCodeGenerator :: [Bind] -> CodeGenerator +--initCodeGenerator scs = CodeGenerator { instructions = defaultStart +-- , functions = getFunctions scs +-- , variableCount = 0 +-- } --- | Compiles an AST and produces a LLVM Ir string. --- An easy way to actually "compile" this output is to --- Simply pipe it to lli -compile :: Program -> Err String -compile (Program scs) = do - let codegen = initCodeGenerator scs - llvmIrToString . instructions <$> execStateT (compileScs scs) codegen +---- | Compiles an AST and produces a LLVM Ir string. +---- An easy way to actually "compile" this output is to +---- Simply pipe it to lli +--compile :: Program -> Err String +--compile (Program scs) = do +-- let codegen = initCodeGenerator scs +-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen -compileScs :: [Bind] -> CompilerState () -compileScs [] = pure () -compileScs (Bind (name, t) args exp : xs) = do - emit $ UnsafeRaw "\n" - emit . Comment $ show name <> ": " <> show exp - let args' = map (second type2LlvmType) args - emit $ Define (type2LlvmType t_return) name args' - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit $ mainContent functionBody - else emit $ Ret I64 functionBody - emit DefineEnd - modify $ \s -> s { variableCount = 0 } - compileScs xs - where - t_return = snd $ partitionType (length args) t +--compileScs :: [Bind] -> CompilerState () +--compileScs [] = pure () +--compileScs (Bind (name, t) args exp : xs) = do +-- emit $ UnsafeRaw "\n" +-- emit . Comment $ show name <> ": " <> show exp +-- let args' = map (second type2LlvmType) args +-- emit $ Define (type2LlvmType t_return) name args' +-- functionBody <- exprToValue exp +-- if name == "main" +-- then mapM_ emit $ mainContent functionBody +-- else emit $ Ret I64 functionBody +-- emit DefineEnd +-- modify $ \s -> s { variableCount = 0 } +-- compileScs xs +-- where +-- t_return = snd $ partitionType (length args) t -mainContent :: LLVMValue -> [LLVMIr] -mainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" - , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") - -- , Label (Ident "b_1") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (Ident "end") - -- , Label (Ident "b_2") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (Ident "end") - -- , Label (Ident "end") - Ret I64 (VInteger 0) - ] +--mainContent :: LLVMValue -> [LLVMIr] +--mainContent var = +-- [ UnsafeRaw $ +-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" +-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) +-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") +-- -- , Label (Ident "b_1") +-- -- , UnsafeRaw +-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" +-- -- , Br (Ident "end") +-- -- , Label (Ident "b_2") +-- -- , UnsafeRaw +-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" +-- -- , Br (Ident "end") +-- -- , Label (Ident "end") +-- Ret I64 (VInteger 0) +-- ] -defaultStart :: [LLVMIr] -defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - ] +--defaultStart :: [LLVMIr] +--defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" +-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" +-- ] -compileExp :: Exp -> CompilerState () -compileExp = \case - ELit _ (LInt i) -> emitInt i - EAdd t e1 e2 -> emitAdd t e1 e2 - EId (name, _) -> emitIdent name - EApp t e1 e2 -> emitApp t e1 e2 - EAbs t ti e -> emitAbs t ti e - ELet bind e -> emitLet bind e +--compileExp :: Exp -> CompilerState () +--compileExp = \case +-- ELit _ (LInt i) -> emitInt i +-- EAdd t e1 e2 -> emitAdd t e1 e2 +-- EId (name, _) -> emitIdent name +-- EApp t e1 e2 -> emitApp t e1 e2 +-- EAbs t ti e -> emitAbs t ti e +-- ELet bind e -> emitLet bind e ---- aux functions --- -emitAbs :: Type -> Id -> Exp -> CompilerState () -emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e +----- aux functions --- +--emitAbs :: Type -> Id -> Exp -> CompilerState () +--emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e -emitLet :: Bind -> Exp -> CompilerState () -emitLet b e = emit . Comment $ concat [ "ELet (" - , show b - , " = " - , show e - , ") is not implemented!" - ] +--emitLet :: Bind -> Exp -> CompilerState () +--emitLet b e = emit . Comment $ concat [ "ELet (" +-- , show b +-- , " = " +-- , show e +-- , ") is not implemented!" +-- ] -emitApp :: Type -> Exp -> Exp -> CompilerState () -emitApp t e1 e2 = appEmitter t e1 e2 [] - where - appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter t e1 e2 stack = do - let newStack = e2 : stack - case e1 of - EApp _ e1' e2' -> appEmitter t e1' e2' newStack - EId id@(name, _) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - let visibility = maybe Local (const Global) $ Map.lookup id funcs - args' = map (first valueGetType . dupe) args - call = Call (type2LlvmType t) visibility name args' - emit $ SetVariable (Ident $ show vs) call - x -> do - emit . Comment $ "The unspeakable happened: " - emit . Comment $ show x +--emitApp :: Type -> Exp -> Exp -> CompilerState () +--emitApp t e1 e2 = appEmitter t e1 e2 [] +-- where +-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () +-- appEmitter t e1 e2 stack = do +-- let newStack = e2 : stack +-- case e1 of +-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack +-- EId id@(name, _) -> do +-- args <- traverse exprToValue newStack +-- vs <- getNewVar +-- funcs <- gets functions +-- let visibility = maybe Local (const Global) $ Map.lookup id funcs +-- args' = map (first valueGetType . dupe) args +-- call = Call (type2LlvmType t) visibility name args' +-- emit $ SetVariable (Ident $ show vs) call +-- x -> do +-- emit . Comment $ "The unspeakable happened: " +-- emit . Comment $ show x -emitIdent :: Ident -> CompilerState () -emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" +--emitIdent :: Ident -> CompilerState () +--emitIdent id = do +-- -- !!this should never happen!! +-- emit $ Comment "This should not have happened!" +-- emit $ Variable id +-- emit $ UnsafeRaw "\n" -emitInt :: Integer -> CompilerState () -emitInt i = do - -- !!this should never happen!! - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) +--emitInt :: Integer -> CompilerState () +--emitInt i = do +-- -- !!this should never happen!! +-- varCount <- getNewVar +-- emit $ Comment "This should not have happened!" +-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) -emitAdd :: Type -> Exp -> Exp -> CompilerState () -emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) +--emitAdd :: Type -> Exp -> Exp -> CompilerState () +--emitAdd t e1 e2 = do +-- v1 <- exprToValue e1 +-- v2 <- exprToValue e2 +-- v <- getNewVar +-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) --- emitMul :: Exp -> Exp -> CompilerState () --- emitMul e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Mul I64 v1 v2 +---- emitMul :: Exp -> Exp -> CompilerState () +---- emitMul e1 e2 = do +---- (v1,v2) <- binExprToValues e1 e2 +---- increaseVarCount +---- v <- gets variableCount +---- emit $ SetVariable $ Ident $ show v +---- emit $ Mul I64 v1 v2 --- emitMod :: Exp -> Exp -> CompilerState () --- emitMod e1 e2 = do --- -- `let m a b = rem (abs $ b + a) b` --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- vadd <- gets variableCount --- emit $ SetVariable $ Ident $ show vadd --- emit $ Add I64 v1 v2 --- --- increaseVarCount --- vabs <- gets variableCount --- emit $ SetVariable $ Ident $ show vabs --- emit $ Call I64 (Ident "llvm.abs.i64") --- [ (I64, VIdent (Ident $ show vadd)) --- , (I1, VInteger 1) --- ] --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 +---- emitMod :: Exp -> Exp -> CompilerState () +---- emitMod e1 e2 = do +---- -- `let m a b = rem (abs $ b + a) b` +---- (v1,v2) <- binExprToValues e1 e2 +---- increaseVarCount +---- vadd <- gets variableCount +---- emit $ SetVariable $ Ident $ show vadd +---- emit $ Add I64 v1 v2 +---- +---- increaseVarCount +---- vabs <- gets variableCount +---- emit $ SetVariable $ Ident $ show vabs +---- emit $ Call I64 (Ident "llvm.abs.i64") +---- [ (I64, VIdent (Ident $ show vadd)) +---- , (I1, VInteger 1) +---- ] +---- increaseVarCount +---- v <- gets variableCount +---- emit $ SetVariable $ Ident $ show v +---- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 --- emitDiv :: Exp -> Exp -> CompilerState () --- emitDiv e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Div I64 v1 v2 +---- emitDiv :: Exp -> Exp -> CompilerState () +---- emitDiv e1 e2 = do +---- (v1,v2) <- binExprToValues e1 e2 +---- increaseVarCount +---- v <- gets variableCount +---- emit $ SetVariable $ Ident $ show v +---- emit $ Div I64 v1 v2 --- emitSub :: Exp -> Exp -> CompilerState () --- emitSub e1 e2 = do --- (v1,v2) <- binExprToValues e1 e2 --- increaseVarCount --- v <- gets variableCount --- emit $ SetVariable $ Ident $ show v --- emit $ Sub I64 v1 v2 +---- emitSub :: Exp -> Exp -> CompilerState () +---- emitSub e1 e2 = do +---- (v1,v2) <- binExprToValues e1 e2 +---- increaseVarCount +---- v <- gets variableCount +---- emit $ SetVariable $ Ident $ show v +---- emit $ Sub I64 v1 v2 -exprToValue :: Exp -> CompilerState LLVMValue -exprToValue = \case - ELit _ (LInt i) -> pure $ VInteger i +--exprToValue :: Exp -> CompilerState LLVMValue +--exprToValue = \case +-- ELit _ (LInt i) -> pure $ VInteger i - EId id@(name, t) -> do - funcs <- gets functions - case Map.lookup id funcs of - Just fi -> do - if numArgs fi == 0 - then do - vc <- getNewVar - emit $ SetVariable (Ident $ show vc) - (Call (type2LlvmType t) Global name []) - pure $ VIdent (Ident $ show vc) (type2LlvmType t) - else pure $ VFunction name Global (type2LlvmType t) - Nothing -> pure $ VIdent name (type2LlvmType t) +-- EId id@(name, t) -> do +-- funcs <- gets functions +-- case Map.lookup id funcs of +-- Just fi -> do +-- if numArgs fi == 0 +-- then do +-- vc <- getNewVar +-- emit $ SetVariable (Ident $ show vc) +-- (Call (type2LlvmType t) Global name []) +-- pure $ VIdent (Ident $ show vc) (type2LlvmType t) +-- else pure $ VFunction name Global (type2LlvmType t) +-- Nothing -> pure $ VIdent name (type2LlvmType t) - e -> do - compileExp e - v <- getVarCount - pure $ VIdent (Ident $ show v) (getType e) +-- e -> do +-- compileExp e +-- v <- getVarCount +-- pure $ VIdent (Ident $ show v) (getType e) -type2LlvmType :: Type -> LLVMType -type2LlvmType = \case - (TMono "Int") -> I64 - TArr t xs -> do - let (t', xs') = function2LLVMType xs [type2LlvmType t] - Function t' xs' - t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"") - where - function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) - function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) +--type2LlvmType :: Type -> LLVMType +--type2LlvmType = \case +-- (TMono "Int") -> I64 +-- TArr t xs -> do +-- let (t', xs') = function2LLVMType xs [type2LlvmType t] +-- Function t' xs' +-- -- This part will not work as we don't have a monomorphization step yet +-- t -> CustomType $ Ident ("\"" ++ show t ++ "\"") +-- where +-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) +-- function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s) +-- function2LLVMType x s = (type2LlvmType x, s) -getType :: Exp -> LLVMType -getType (ELit _ (LInt _)) = I64 -getType (EAdd t _ _) = type2LlvmType t -getType (EId (_, t)) = type2LlvmType t -getType (EApp t _ _) = type2LlvmType t -getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e +--getType :: Exp -> LLVMType +--getType (ELit _ (LInt _)) = I64 +--getType (EAdd t _ _) = type2LlvmType t +--getType (EId (_, t)) = type2LlvmType t +--getType (EApp t _ _) = type2LlvmType t +--getType (EAbs t _ _) = type2LlvmType t +--getType (ELet _ e) = getType e -valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (length s) I8 -valueGetType (VFunction _ _ t) = t +--valueGetType :: LLVMValue -> LLVMType +--valueGetType (VInteger _) = I64 +--valueGetType (VIdent _ t) = t +--valueGetType (VConstant s) = Array (length s) I8 +--valueGetType (VFunction _ _ t) = t --- | Partion type into types of parameters and return type. -partitionType :: Int -- Number of parameters to apply - -> Type - -> ([Type], Type) -partitionType = go [] - where - go acc 0 t = (acc, t) - go acc i t = case t of - TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" +---- | Partion type into types of parameters and return type. +--partitionType :: Int -- Number of parameters to apply +-- -> Type +-- -> ([Type], Type) +--partitionType = go [] +-- where +-- go acc 0 t = (acc, t) +-- go acc i t = case t of +-- TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2 +-- _ -> error "Number of parameters and type doesn't match" diff --git a/src/LambdaLifter/LambdaLifter.hs b/src/LambdaLifter/LambdaLifter.hs index a617159..271cc70 100644 --- a/src/LambdaLifter/LambdaLifter.hs +++ b/src/LambdaLifter/LambdaLifter.hs @@ -1,192 +1,192 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +--{-# LANGUAGE LambdaCase #-} +--{-# LANGUAGE OverloadedStrings #-} -module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where +module LambdaLifter.LambdaLifter where -import Auxiliary (snoc) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad.State (MonadState (get, put), State, - evalState) -import Data.Set (Set) -import qualified Data.Set as Set -import Prelude hiding (exp) -import Renamer.Renamer -import TypeChecker.TypeCheckerIr +--import Auxiliary (snoc) +--import Control.Applicative (Applicative (liftA2)) +--import Control.Monad.State (MonadState (get, put), State, +-- evalState) +--import Data.Set (Set) +--import qualified Data.Set as Set +--import Prelude hiding (exp) +--import Renamer.Renamer +--import TypeChecker.TypeCheckerIr --- | Lift lambdas and let expression into supercombinators. --- Three phases: --- @freeVars@ annotatss all the free variables. --- @abstract@ converts lambdas into let expressions. --- @collectScs@ moves every non-constant let expression to a top-level function. -lambdaLift :: Program -> Program -lambdaLift = collectScs . abstract . freeVars +---- | Lift lambdas and let expression into supercombinators. +---- Three phases: +---- @freeVars@ annotatss all the free variables. +---- @abstract@ converts lambdas into let expressions. +---- @collectScs@ moves every non-constant let expression to a top-level function. +--lambdaLift :: Program -> Program +--lambdaLift = collectScs . abstract . freeVars --- | Annotate free variables -freeVars :: Program -> AnnProgram -freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) - | Bind n xs e <- ds - ] +---- | Annotate free variables +--freeVars :: Program -> AnnProgram +--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) +-- | Bind n xs e <- ds +-- ] -freeVarsExp :: Set Id -> Exp -> AnnExp -freeVarsExp localVars = \case - EId n | Set.member n localVars -> (Set.singleton n, AId n) - | otherwise -> (mempty, AId n) +--freeVarsExp :: Set Id -> Exp -> AnnExp +--freeVarsExp localVars = \case +-- EId n | Set.member n localVars -> (Set.singleton n, AId n) +-- | otherwise -> (mempty, AId n) - ELit _ (LInt i) -> (mempty, AInt i) +-- ELit _ (LInt i) -> (mempty, AInt i) - EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 +-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') +-- where +-- e1' = freeVarsExp localVars e1 +-- e2' = freeVarsExp localVars e2 - EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') - where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 +-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') +-- where +-- e1' = freeVarsExp localVars e1 +-- e2' = freeVarsExp localVars e2 - EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') - where - e' = freeVarsExp (Set.insert par localVars) e +-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') +-- where +-- e' = freeVarsExp (Set.insert par localVars) e - -- Sum free variables present in bind and the expression - ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') - where - binders_frees = Set.delete name $ freeVarsOf rhs' - e_free = Set.delete name $ freeVarsOf e' +-- -- Sum free variables present in bind and the expression +-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') +-- where +-- binders_frees = Set.delete name $ freeVarsOf rhs' +-- e_free = Set.delete name $ freeVarsOf e' - rhs' = freeVarsExp e_localVars rhs - new_bind = ABind name parms rhs' +-- rhs' = freeVarsExp e_localVars rhs +-- new_bind = ABind name parms rhs' - e' = freeVarsExp e_localVars e - e_localVars = Set.insert name localVars +-- e' = freeVarsExp e_localVars e +-- e_localVars = Set.insert name localVars -freeVarsOf :: AnnExp -> Set Id -freeVarsOf = fst +--freeVarsOf :: AnnExp -> Set Id +--freeVarsOf = fst --- AST annotated with free variables -type AnnProgram = [(Id, [Id], AnnExp)] +---- AST annotated with free variables +--type AnnProgram = [(Id, [Id], AnnExp)] -type AnnExp = (Set Id, AnnExp') +--type AnnExp = (Set Id, AnnExp') -data ABind = ABind Id [Id] AnnExp deriving Show +--data ABind = ABind Id [Id] AnnExp deriving Show -data AnnExp' = AId Id - | AInt Integer - | ALet ABind AnnExp - | AApp Type AnnExp AnnExp - | AAdd Type AnnExp AnnExp - | AAbs Type Id AnnExp - deriving Show --- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. --- Free variables are @v₁ v₂ .. vₙ@ are bound. -abstract :: AnnProgram -> Program -abstract prog = Program $ evalState (mapM go prog) 0 - where - go :: (Id, [Id], AnnExp) -> State Int Bind - go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' - where - (rhs', parms1) = flattenLambdasAnn rhs +--data AnnExp' = AId Id +-- | AInt Integer +-- | ALet ABind AnnExp +-- | AApp Type AnnExp AnnExp +-- | AAdd Type AnnExp AnnExp +-- | AAbs Type Id AnnExp +-- deriving Show +---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +---- Free variables are @v₁ v₂ .. vₙ@ are bound. +--abstract :: AnnProgram -> Program +--abstract prog = Program $ evalState (mapM go prog) 0 +-- where +-- go :: (Id, [Id], AnnExp) -> State Int Bind +-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' +-- where +-- (rhs', parms1) = flattenLambdasAnn rhs --- | Flatten nested lambdas and collect the parameters --- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) -flattenLambdasAnn ae = go (ae, []) - where - go :: (AnnExp, [Id]) -> (AnnExp, [Id]) - go ((free, e), acc) = - case e of - AAbs _ par (free1, e1) -> - go ((Set.delete par free1, e1), snoc par acc) - _ -> ((free, e), acc) +---- | Flatten nested lambdas and collect the parameters +---- @\x.\y.\z. ae → (ae, [x,y,z])@ +--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) +--flattenLambdasAnn ae = go (ae, []) +-- where +-- go :: (AnnExp, [Id]) -> (AnnExp, [Id]) +-- go ((free, e), acc) = +-- case e of +-- AAbs _ par (free1, e1) -> +-- go ((Set.delete par free1, e1), snoc par acc) +-- _ -> ((free, e), acc) -abstractExp :: AnnExp -> State Int Exp -abstractExp (free, exp) = case exp of - AId n -> pure $ EId n - AInt i -> pure $ ELit (TMono "Int") (LInt i) - AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) - AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) - ALet b e -> liftA2 ELet (go b) (abstractExp e) - where - go (ABind name parms rhs) = do - (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs - pure $ Bind name (parms ++ parms1) rhs' +--abstractExp :: AnnExp -> State Int Exp +--abstractExp (free, exp) = case exp of +-- AId n -> pure $ EId n +-- AInt i -> pure $ ELit (TMono "Int") (LInt i) +-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) +-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) +-- ALet b e -> liftA2 ELet (go b) (abstractExp e) +-- where +-- go (ABind name parms rhs) = do +-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs +-- pure $ Bind name (parms ++ parms1) rhs' - skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp - skipLambdas f (free, ae) = case ae of - AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 - _ -> f (free, ae) +-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp +-- skipLambdas f (free, ae) = case ae of +-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 +-- _ -> f (free, ae) - -- Lift lambda into let and bind free variables - AAbs t parm e -> do - i <- nextNumber - rhs <- abstractExp e +-- -- Lift lambda into let and bind free variables +-- AAbs t parm e -> do +-- i <- nextNumber +-- rhs <- abstractExp e - let sc_name = Ident ("sc_" ++ show i) - sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) +-- let sc_name = Ident ("sc_" ++ show i) +-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) - pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList - where - freeList = Set.toList free - parms = snoc parm freeList +-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList +-- where +-- freeList = Set.toList free +-- parms = snoc parm freeList -nextNumber :: State Int Int -nextNumber = do - i <- get - put $ succ i - pure i +--nextNumber :: State Int Int +--nextNumber = do +-- i <- get +-- put $ succ i +-- pure i --- | Collects supercombinators by lifting non-constant let expressions -collectScs :: Program -> Program -collectScs (Program scs) = Program $ concatMap collectFromRhs scs - where - collectFromRhs (Bind name parms rhs) = - let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs +---- | Collects supercombinators by lifting non-constant let expressions +--collectScs :: Program -> Program +--collectScs (Program scs) = Program $ concatMap collectFromRhs scs +-- where +-- collectFromRhs (Bind name parms rhs) = +-- let (rhs_scs, rhs') = collectScsExp rhs +-- in Bind name parms rhs' : rhs_scs -collectScsExp :: Exp -> ([Bind], Exp) -collectScsExp = \case - EId n -> ([], EId n) - ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i)) +--collectScsExp :: Exp -> ([Bind], Exp) +--collectScsExp = \case +-- EId n -> ([], EId n) +-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i)) - EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 +-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') +-- where +-- (scs1, e1') = collectScsExp e1 +-- (scs2, e2') = collectScsExp e2 - EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') - where - (scs1, e1') = collectScsExp e1 - (scs2, e2') = collectScsExp e2 +-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') +-- where +-- (scs1, e1') = collectScsExp e1 +-- (scs2, e2') = collectScsExp e2 - EAbs t par e -> (scs, EAbs t par e') - where - (scs, e') = collectScsExp e +-- EAbs t par e -> (scs, EAbs t par e') +-- where +-- (scs, e') = collectScsExp e - -- Collect supercombinators from bind, the rhss, and the expression. - -- - -- > f = let sc x y = rhs in e - -- - ELet (Bind name parms rhs) e -> if null parms - then ( rhs_scs ++ e_scs, ELet bind e') - else (bind : rhs_scs ++ e_scs, e') - where - bind = Bind name parms rhs' - (rhs_scs, rhs') = collectScsExp rhs - (e_scs, e') = collectScsExp e +-- -- Collect supercombinators from bind, the rhss, and the expression. +-- -- +-- -- > f = let sc x y = rhs in e +-- -- +-- ELet (Bind name parms rhs) e -> if null parms +-- then ( rhs_scs ++ e_scs, ELet bind e') +-- else (bind : rhs_scs ++ e_scs, e') +-- where +-- bind = Bind name parms rhs' +-- (rhs_scs, rhs') = collectScsExp rhs +-- (e_scs, e') = collectScsExp e --- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: Exp -> (Exp, [Id]) -flattenLambdas = go . (, []) - where - go (e, acc) = case e of - EAbs _ par e1 -> go (e1, snoc par acc) - _ -> (e, acc) +---- @\x.\y.\z. e → (e, [x,y,z])@ +--flattenLambdas :: Exp -> (Exp, [Id]) +--flattenLambdas = go . (, []) +-- where +-- go (e, acc) = case e of +-- EAbs _ par e1 -> go (e1, snoc par acc) +-- _ -> (e, acc) diff --git a/src/Main.hs b/src/Main.hs index 3a7bde4..316c599 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,18 +2,18 @@ module Main where -import Codegen.Codegen (compile) -import GHC.IO.Handle.Text (hPutStrLn) -import Grammar.ErrM (Err) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) +-- import Codegen.Codegen (compile) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) -import LambdaLifter.LambdaLifter (lambdaLift) -import Renamer.Renamer (rename) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) -import System.IO (stderr) -import TypeChecker.TypeChecker (typecheck) +-- import LambdaLifter.LambdaLifter (lambdaLift) +import Renamer.Renamer (rename) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import TypeChecker.TypeChecker (typecheck) main :: IO () main = @@ -37,14 +37,14 @@ main' s = do typechecked <- fromTypeCheckerErr $ typecheck renamed printToErr $ printTree typechecked - printToErr "\n-- Lambda Lifter --" - let lifted = lambdaLift typechecked - printToErr $ printTree lifted + -- printToErr "\n-- Lambda Lifter --" + -- let lifted = lambdaLift typechecked + -- printToErr $ printTree lifted - printToErr "\n -- Printing compiler output to stdout --" - compiled <- fromCompilerErr $ compile lifted - putStrLn compiled - writeFile "llvm.ll" compiled + -- printToErr "\n -- Printing compiler output to stdout --" + -- compiled <- fromCompilerErr $ compile lifted + -- putStrLn compiled + -- writeFile "llvm.ll" compiled exitSuccess diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 0d9ace9..d09a002 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -3,6 +3,7 @@ {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use traverse_" #-} {-# OPTIONS_GHC -Wno-overlapping-patterns #-} +{-# HLINT ignore "Use zipWithM" #-} module TypeChecker.TypeChecker where @@ -16,6 +17,7 @@ import qualified Data.Map as M import Data.Set (Set) import qualified Data.Set as S +import Data.Foldable (traverse_) import Grammar.Abs import Grammar.Print (printTree) import qualified TypeChecker.TypeCheckerIr as T @@ -24,10 +26,12 @@ import qualified TypeChecker.TypeCheckerIr as T data Poly = Forall [Ident] Type deriving Show -newtype Ctx = Ctx { vars :: Map Ident Poly } +newtype Ctx = Ctx { vars :: Map Ident Poly + } -data Env = Env { count :: Int - , sigs :: Map Ident Type +data Env = Env { count :: Int + , sigs :: Map Ident Type + , dtypes :: Map Ident Type } type Error = String @@ -36,7 +40,7 @@ type Subst = Map Ident Type type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) initCtx = Ctx mempty -initEnv = Env 0 mempty +initEnv = Env 0 mempty mempty runPretty :: Exp -> Either Error String runPretty = fmap (printTree . fst). run . inferExp @@ -50,21 +54,44 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e typecheck :: Program -> Either Error T.Program typecheck = run . checkPrg +checkData :: Data -> Infer () +checkData d = case d of + (Data typ@(TConstr name _) constrs) -> do + traverse_ (\(Constructor name' t') + -> if typ == retType t' + then insertConstr name' t' else + throwError $ + unwords + [ "return type of constructor:" + , printTree name + , "with type:" + , printTree (retType t') + , "does not match data: " + , printTree typ]) constrs + _ -> throwError "Data type incorrectly declared" + where + retType :: Type -> Type + retType (TArr _ t2) = retType t2 + retType a = a + checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do - let bs' = getBinds bs - traverse (\(Bind n t _ _ _) -> insertSig n t) bs' - bs' <- mapM checkBind bs' - return $ T.Program bs' + preRun bs + T.Program <$> checkDef bs where - getBinds :: [Def] -> [Bind] - getBinds = map toBind . filter isBind - isBind :: Def -> Bool - isBind (DBind _) = True - isBind _ = True - toBind :: Def -> Bind - toBind (DBind bind) = bind - toBind _ = error "Can't convert DData to Bind" + preRun :: [Def] -> Infer () + preRun [] = return () + preRun (x:xs) = case x of + DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs + DData d@(Data _ _) -> checkData d >> preRun xs + + checkDef :: [Def] -> Infer [T.Def] + checkDef [] = return [] + checkDef (x:xs) = case x of + (DBind b) -> do + b' <- checkBind b + fmap (T.DBind b' :) (checkDef xs) + (DData d) -> fmap (T.DData d :) (checkDef xs) checkBind :: Bind -> Infer T.Bind checkBind (Bind n t _ args e) = do @@ -77,15 +104,18 @@ checkBind (Bind n t _ args e) = do makeLambda :: Exp -> [Ident] -> Exp makeLambda = foldl (flip EAbs) +-- | Check if two types are considered equal +-- For the purpose of the algorithm two polymorphic types are always considered equal typeEq :: Type -> Type -> Bool -typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' -typeEq (TMono a) (TMono b) = a == b -typeEq (TPol _) (TPol _) = True -typeEq _ _ = False +typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' +typeEq (TMono a) (TMono b) = a == b +typeEq (TConstr name a) (TConstr name' b) = name == name' && and (zipWith typeEq a b) +typeEq (TPol _) (TPol _) = True +typeEq _ _ = False inferExp :: Exp -> Infer (Type, T.Exp) inferExp e = do - (s, t, e') <- w e + (s, t, e') <- algoW e let subbed = apply s t return (subbed, replace subbed e') @@ -98,19 +128,26 @@ replace t = \case T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2 -w :: Exp -> Infer (Subst, Type, T.Exp) -w = \case +algoW :: Exp -> Infer (Subst, Type, T.Exp) +algoW = \case EAnn e t -> do - (s1, t', e') <- w e + (s1, t', e') <- algoW e applySt s1 $ do s2 <- unify (apply s1 t) t' return (s2 `compose` s1, t, e') +-- | ------------------ +-- | Γ ⊢ e₀ : Int, ∅ + ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a +-- | x : σ ∈ Γ   τ = inst(σ) +-- | ---------------------- +-- | Γ ⊢ x : τ, ∅ + EId i -> do var <- asks vars case M.lookup i var of @@ -118,42 +155,67 @@ w = \case Nothing -> do sig <- gets sigs case M.lookup i sig of - Nothing -> throwError $ "Unbound variable: " ++ show i Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> do + constr <- gets dtypes + case M.lookup i constr of + Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> throwError $ "Unbound variable: " ++ show i + +-- | τ = newvar Γ, x : τ ⊢ e : τ', S +-- | --------------------------------- +-- | Γ ⊢ w λx. e : Sτ → τ', S EAbs name e -> do fr <- fresh withBinding name (Forall [] fr) $ do - (s1, t', e') <- w e + (s1, t', e') <- algoW e let varType = apply s1 fr let newArr = TArr varType t' return (s1, newArr, T.EAbs newArr (name, varType) e') +-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ +-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) +-- | ------------------------------------------ +-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ +-- This might be wrong + EAdd e0 e1 -> do - (s1, t0, e0') <- w e0 + (s1, t0, e0') <- algoW e0 applySt s1 $ do - (s2, t1, e1') <- w e1 - applySt s2 $ do - s3 <- unify (apply s2 t0) (TMono "Int") - s4 <- unify (apply s3 t1) (TMono "Int") - return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') + (s2, t1, e1') <- algoW e1 + -- applySt s2 $ do + s3 <- unify (apply s2 t0) (TMono "Int") + s4 <- unify (apply s3 t1) (TMono "Int") + return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') + +-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 +-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') +-- | -------------------------------------- +-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ EApp e0 e1 -> do fr <- fresh - (s0, t0, e0') <- w e0 + (s0, t0, e0') <- algoW e0 applySt s0 $ do - (s1, t1, e1') <- w e1 + (s1, t1, e1') <- algoW e1 -- applySt s1 $ do s2 <- unify (apply s1 t0) (TArr t1 fr) let t = apply s2 fr return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') +-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ +-- | ---------------------------------------------- +-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ + +-- The bar over S₀ and Γ means "generalize" + ELet name e0 e1 -> do - (s1, t1, e0') <- w e0 + (s1, t1, e0') <- algoW e0 env <- asks vars let t' = generalize (apply s1 env) t1 withBinding name t' $ do - (s2, t2, e1') <- w e1 + (s2, t2, e1') <- algoW e1 return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' ) ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b @@ -168,6 +230,12 @@ unify t0 t1 = case (t0, t1) of (TPol a, b) -> occurs a b (a, TPol b) -> occurs b a (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" + -- | TODO: Figure out a cleaner way to express the same thing + (TConstr name t, TConstr name' t') -> if name == name' && length t == length t' + then do + xs <- sequence $ zipWith unify t t' + return $ foldr compose nullSubst xs + else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"] (a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] -- | Check if a type is contained in another type. @@ -202,9 +270,11 @@ class FreeVars t where instance FreeVars Type where free :: Type -> Set Ident - free (TPol a) = S.singleton a - free (TMono _) = mempty - free (TArr a b) = free a `S.union` free b + free (TPol a) = S.singleton a + free (TMono _) = mempty + free (TArr a b) = free a `S.union` free b + -- | Not guaranteed to be correct + free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a apply :: Subst -> Type -> Type apply sub t = do case t of @@ -213,6 +283,7 @@ instance FreeVars Type where Nothing -> TPol a Just t -> t TArr a b -> TArr (apply sub a) (apply sub b) + TConstr name a -> TConstr name (map (apply sub) a) instance FreeVars Poly where free :: Poly -> Set Ident @@ -248,3 +319,7 @@ withBinding i p = local (\st -> st { vars = M.insert i p (vars st) }) -- | Insert a function signature into the environment insertSig :: Ident -> Type -> Infer () insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) }) + +-- | Insert a constructor with its data type +insertConstr :: Ident -> Type -> Infer () +insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) }) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index c85ebcc..ee02416 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -5,12 +5,12 @@ module TypeChecker.TypeCheckerIr , module TypeChecker.TypeCheckerIr ) where -import Grammar.Abs (Ident (..), Literal (..), Type (..)) +import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..)) import Grammar.Print import Prelude import qualified Prelude as C (Eq, Ord, Read, Show) -newtype Program = Program [Bind] +newtype Program = Program [Def] deriving (C.Eq, C.Ord, C.Show, C.Read) data Exp @@ -22,11 +22,18 @@ data Exp | EAbs Type Id Exp deriving (C.Eq, C.Ord, C.Read, C.Show) +data Def = DBind Bind | DData Data + deriving (C.Eq, C.Ord, C.Read, C.Show) + type Id = (Ident, Type) data Bind = Bind Id [Id] Exp deriving (C.Eq, C.Ord, C.Show, C.Read) +instance Print Def where + prt i (DBind bind) = prt i bind + prt i (DData d) = prt i d + instance Print Program where prt i (Program sc) = prPrec i 0 $ prt 0 sc @@ -75,7 +82,7 @@ instance Print Exp where , doc $ showString "in" , prt 0 e ] - EApp t e1 e2 -> prPrec i 2 $ concatD + EApp _ e1 e2 -> prPrec i 2 $ concatD [ prt 2 e1 , prt 3 e2 ] diff --git a/test_program b/test_program index 69a2c20..5c2a164 100644 --- a/test_program +++ b/test_program @@ -1,2 +1,14 @@ -main : _Int ; -main = 3 + 3 ; +data List ('a) where; + Nil : List ('a), + Cons : 'a -> List ('a) -> List ('a) ; + +main : List (_Int) ; +main = Cons 1 (Cons 0 Nil) ; + +data Bool () where; + True : Bool (), + False : Bool (); + +boolean : Bool (_Int); +boolean = True ; +