Type inference/checking on ADTs mostly complete(?). Still have to test

This commit is contained in:
sebastianselander 2023-02-27 17:22:42 +01:00
parent 2f45f39435
commit bbf6e159c7
8 changed files with 563 additions and 467 deletions

View file

@ -3,7 +3,7 @@ Program. Program ::= [Def] ;
DBind. Def ::= Bind ;
DData. Def ::= Data ;
terminator Def ";" ;
separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ;
@ -31,16 +31,19 @@ IMatch. Match ::= Ident ;
InitMatch. Match ::= Ident Match ;
separator Match " " ;
TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ;
TArr. Type ::= Type1 "->" Type ;
TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ;
TConstr. Type1 ::= Ident "(" [Type] ")" ;
TArr. Type ::= Type1 "->" Type ;
separator Type " " ;
coercions Type 2 ;
-- shift/reduce problem here
Data. Data ::= "data" Ident [Type] "where" ";"
Data. Data ::= "data" Type "where" ";"
[Constructor];
terminator Constructor ";" ;
separator Constructor "," ;
Constructor. Constructor ::= Ident ":" Type ;
@ -48,10 +51,9 @@ Constructor. Constructor ::= Ident ":" Type ;
-- token Poly upper (letter | digit | '_')* ;
-- token Mono lower (letter | digit | '_')* ;
terminator Bind ";" ;
separator Bind ";" ;
separator Ident " ";
coercions Type 1 ;
coercions Exp 5 ;
comment "--" ;

View file

@ -34,9 +34,9 @@ executable language
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
LambdaLifter.LambdaLifter
Codegen.Codegen
Codegen.LlvmIr
-- LambdaLifter.LambdaLifter
-- Codegen.Codegen
-- Codegen.LlvmIr
hs-source-dirs: src

View file

@ -1,277 +1,277 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
--{-# LANGUAGE LambdaCase #-}
--{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (compile) where
module Codegen.Codegen where
import Auxiliary (snoc)
import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..),
llvmIrToString)
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import TypeChecker.TypeChecker
import TypeChecker.TypeCheckerIr
--import Auxiliary (snoc)
--import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
-- LLVMValue (..), Visibility (..),
-- llvmIrToString)
--import Control.Monad.State (StateT, execStateT, gets, modify)
--import Data.Map (Map)
--import qualified Data.Map as Map
--import Data.Tuple.Extra (dupe, first, second)
--import Grammar.ErrM (Err)
--import TypeChecker.TypeCheckerIr
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, variableCount :: Integer
}
---- | The record used as the code generator state
--data CodeGenerator = CodeGenerator
-- { instructions :: [LLVMIr]
-- , functions :: Map Id FunctionInfo
-- , variableCount :: Integer
-- }
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
---- | A state type synonym
--type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
--data FunctionInfo = FunctionInfo
-- { numArgs :: Int
-- , arguments :: [Id]
-- }
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
---- | Adds a instruction to the CodeGenerator state
--emit :: LLVMIr -> CompilerState ()
--emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
---- | Increases the variable counter in the CodeGenerator state
--increaseVarCount :: CompilerState ()
--increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
---- | Returns the variable count from the CodeGenerator state
--getVarCount :: CompilerState Integer
--getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
---- | Increases the variable count and returns it from the CodeGenerator state
--getNewVar :: CompilerState Integer
--getNewVar = increaseVarCount >> getVarCount
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ map go bs
where
go (Bind id args _) =
(id, FunctionInfo { numArgs=length args, arguments=args })
---- | Produces a map of functions infos from a list of binds,
---- which contains useful data for code generation.
--getFunctions :: [Bind] -> Map Id FunctionInfo
--getFunctions bs = Map.fromList $ map go bs
-- where
-- go (Bind id args _) =
-- (id, FunctionInfo { numArgs=length args, arguments=args })
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, variableCount = 0
}
--initCodeGenerator :: [Bind] -> CodeGenerator
--initCodeGenerator scs = CodeGenerator { instructions = defaultStart
-- , functions = getFunctions scs
-- , variableCount = 0
-- }
-- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to
-- Simply pipe it to lli
compile :: Program -> Err String
compile (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
---- | Compiles an AST and produces a LLVM Ir string.
---- An easy way to actually "compile" this output is to
---- Simply pipe it to lli
--compile :: Program -> Err String
--compile (Program scs) = do
-- let codegen = initCodeGenerator scs
-- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where
t_return = snd $ partitionType (length args) t
--compileScs :: [Bind] -> CompilerState ()
--compileScs [] = pure ()
--compileScs (Bind (name, t) args exp : xs) = do
-- emit $ UnsafeRaw "\n"
-- emit . Comment $ show name <> ": " <> show exp
-- let args' = map (second type2LlvmType) args
-- emit $ Define (type2LlvmType t_return) name args'
-- functionBody <- exprToValue exp
-- if name == "main"
-- then mapM_ emit $ mainContent functionBody
-- else emit $ Ret I64 functionBody
-- emit DefineEnd
-- modify $ \s -> s { variableCount = 0 }
-- compileScs xs
-- where
-- t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
--mainContent :: LLVMValue -> [LLVMIr]
--mainContent var =
-- [ UnsafeRaw $
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
-- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- -- , Label (Ident "b_1")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "b_2")
-- -- , UnsafeRaw
-- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- -- , Br (Ident "end")
-- -- , Label (Ident "end")
-- Ret I64 (VInteger 0)
-- ]
defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
--defaultStart :: [LLVMIr]
--defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
-- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
-- ]
compileExp :: Exp -> CompilerState ()
compileExp = \case
ELit _ (LInt i) -> emitInt i
EAdd t e1 e2 -> emitAdd t e1 e2
EId (name, _) -> emitIdent name
EApp t e1 e2 -> emitApp t e1 e2
EAbs t ti e -> emitAbs t ti e
ELet bind e -> emitLet bind e
--compileExp :: Exp -> CompilerState ()
--compileExp = \case
-- ELit _ (LInt i) -> emitInt i
-- EAdd t e1 e2 -> emitAdd t e1 e2
-- EId (name, _) -> emitIdent name
-- EApp t e1 e2 -> emitApp t e1 e2
-- EAbs t ti e -> emitAbs t ti e
-- ELet bind e -> emitLet bind e
--- aux functions ---
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
----- aux functions ---
--emitAbs :: Type -> Id -> Exp -> CompilerState ()
--emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet b e = emit . Comment $ concat [ "ELet ("
, show b
, " = "
, show e
, ") is not implemented!"
]
--emitLet :: Bind -> Exp -> CompilerState ()
--emitLet b e = emit . Comment $ concat [ "ELet ("
-- , show b
-- , " = "
-- , show e
-- , ") is not implemented!"
-- ]
emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 []
where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
--emitApp :: Type -> Exp -> Exp -> CompilerState ()
--emitApp t e1 e2 = appEmitter t e1 e2 []
-- where
-- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
-- appEmitter t e1 e2 stack = do
-- let newStack = e2 : stack
-- case e1 of
-- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
-- EId id@(name, _) -> do
-- args <- traverse exprToValue newStack
-- vs <- getNewVar
-- funcs <- gets functions
-- let visibility = maybe Local (const Global) $ Map.lookup id funcs
-- args' = map (first valueGetType . dupe) args
-- call = Call (type2LlvmType t) visibility name args'
-- emit $ SetVariable (Ident $ show vs) call
-- x -> do
-- emit . Comment $ "The unspeakable happened: "
-- emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
--emitIdent :: Ident -> CompilerState ()
--emitIdent id = do
-- -- !!this should never happen!!
-- emit $ Comment "This should not have happened!"
-- emit $ Variable id
-- emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
--emitInt :: Integer -> CompilerState ()
--emitInt i = do
-- -- !!this should never happen!!
-- varCount <- getNewVar
-- emit $ Comment "This should not have happened!"
-- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
--emitAdd :: Type -> Exp -> Exp -> CompilerState ()
--emitAdd t e1 e2 = do
-- v1 <- exprToValue e1
-- v2 <- exprToValue e2
-- v <- getNewVar
-- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
---- emitMul :: Exp -> Exp -> CompilerState ()
---- emitMul e1 e2 = do
---- (v1,v2) <- binExprToValues e1 e2
---- increaseVarCount
---- v <- gets variableCount
---- emit $ SetVariable $ Ident $ show v
---- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
---- emitMod :: Exp -> Exp -> CompilerState ()
---- emitMod e1 e2 = do
---- -- `let m a b = rem (abs $ b + a) b`
---- (v1,v2) <- binExprToValues e1 e2
---- increaseVarCount
---- vadd <- gets variableCount
---- emit $ SetVariable $ Ident $ show vadd
---- emit $ Add I64 v1 v2
----
---- increaseVarCount
---- vabs <- gets variableCount
---- emit $ SetVariable $ Ident $ show vabs
---- emit $ Call I64 (Ident "llvm.abs.i64")
---- [ (I64, VIdent (Ident $ show vadd))
---- , (I1, VInteger 1)
---- ]
---- increaseVarCount
---- v <- gets variableCount
---- emit $ SetVariable $ Ident $ show v
---- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
---- emitDiv :: Exp -> Exp -> CompilerState ()
---- emitDiv e1 e2 = do
---- (v1,v2) <- binExprToValues e1 e2
---- increaseVarCount
---- v <- gets variableCount
---- emit $ SetVariable $ Ident $ show v
---- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
---- emitSub :: Exp -> Exp -> CompilerState ()
---- emitSub e1 e2 = do
---- (v1,v2) <- binExprToValues e1 e2
---- increaseVarCount
---- v <- gets variableCount
---- emit $ SetVariable $ Ident $ show v
---- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
ELit _ (LInt i) -> pure $ VInteger i
--exprToValue :: Exp -> CompilerState LLVMValue
--exprToValue = \case
-- ELit _ (LInt i) -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc)
(Call (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
-- EId id@(name, t) -> do
-- funcs <- gets functions
-- case Map.lookup id funcs of
-- Just fi -> do
-- if numArgs fi == 0
-- then do
-- vc <- getNewVar
-- emit $ SetVariable (Ident $ show vc)
-- (Call (type2LlvmType t) Global name [])
-- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
-- else pure $ VFunction name Global (type2LlvmType t)
-- Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e)
-- e -> do
-- compileExp e
-- v <- getVarCount
-- pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
(TMono "Int") -> I64
TArr t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"")
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
--type2LlvmType :: Type -> LLVMType
--type2LlvmType = \case
-- (TMono "Int") -> I64
-- TArr t xs -> do
-- let (t', xs') = function2LLVMType xs [type2LlvmType t]
-- Function t' xs'
-- -- This part will not work as we don't have a monomorphization step yet
-- t -> CustomType $ Ident ("\"" ++ show t ++ "\"")
-- where
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
-- function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (ELit _ (LInt _)) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
--getType :: Exp -> LLVMType
--getType (ELit _ (LInt _)) = I64
--getType (EAdd t _ _) = type2LlvmType t
--getType (EId (_, t)) = type2LlvmType t
--getType (EApp t _ _) = type2LlvmType t
--getType (EAbs t _ _) = type2LlvmType t
--getType (ELet _ e) = getType e
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t
--valueGetType :: LLVMValue -> LLVMType
--valueGetType (VInteger _) = I64
--valueGetType (VIdent _ t) = t
--valueGetType (VConstant s) = Array (length s) I8
--valueGetType (VFunction _ _ t) = t
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
---- | Partion type into types of parameters and return type.
--partitionType :: Int -- Number of parameters to apply
-- -> Type
-- -> ([Type], Type)
--partitionType = go []
-- where
-- go acc 0 t = (acc, t)
-- go acc i t = case t of
-- TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
-- _ -> error "Number of parameters and type doesn't match"

View file

@ -1,192 +1,192 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
--{-# LANGUAGE LambdaCase #-}
--{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
module LambdaLifter.LambdaLifter where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import Renamer.Renamer
import TypeChecker.TypeCheckerIr
--import Auxiliary (snoc)
--import Control.Applicative (Applicative (liftA2))
--import Control.Monad.State (MonadState (get, put), State,
-- evalState)
--import Data.Set (Set)
--import qualified Data.Set as Set
--import Prelude hiding (exp)
--import Renamer.Renamer
--import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
---- | Lift lambdas and let expression into supercombinators.
---- Three phases:
---- @freeVars@ annotatss all the free variables.
---- @abstract@ converts lambdas into let expressions.
---- @collectScs@ moves every non-constant let expression to a top-level function.
--lambdaLift :: Program -> Program
--lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
---- | Annotate free variables
--freeVars :: Program -> AnnProgram
--freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
-- | Bind n xs e <- ds
-- ]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
--freeVarsExp :: Set Id -> Exp -> AnnExp
--freeVarsExp localVars = \case
-- EId n | Set.member n localVars -> (Set.singleton n, AId n)
-- | otherwise -> (mempty, AId n)
ELit _ (LInt i) -> (mempty, AInt i)
-- ELit _ (LInt i) -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
-- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
-- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
-- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
-- where
-- e1' = freeVarsExp localVars e1
-- e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where
e' = freeVarsExp (Set.insert par localVars) e
-- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
-- where
-- e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
-- -- Sum free variables present in bind and the expression
-- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
-- where
-- binders_frees = Set.delete name $ freeVarsOf rhs'
-- e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
-- rhs' = freeVarsExp e_localVars rhs
-- new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
-- e' = freeVarsExp e_localVars e
-- e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
--freeVarsOf :: AnnExp -> Set Id
--freeVarsOf = fst
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
---- AST annotated with free variables
--type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp')
--type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show
--data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
--data AnnExp' = AId Id
-- | AInt Integer
-- | ALet ABind AnnExp
-- | AApp Type AnnExp AnnExp
-- | AAdd Type AnnExp AnnExp
-- | AAbs Type Id AnnExp
-- deriving Show
---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
---- Free variables are @v₁ v₂ .. vₙ@ are bound.
--abstract :: AnnProgram -> Program
--abstract prog = Program $ evalState (mapM go prog) 0
-- where
-- go :: (Id, [Id], AnnExp) -> State Int Bind
-- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
-- where
-- (rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
---- | Flatten nested lambdas and collect the parameters
---- @\x.\y.\z. ae → (ae, [x,y,z])@
--flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
--flattenLambdasAnn ae = go (ae, [])
-- where
-- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
-- go ((free, e), acc) =
-- case e of
-- AAbs _ par (free1, e1) ->
-- go ((Set.delete par free1, e1), snoc par acc)
-- _ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of
AId n -> pure $ EId n
AInt i -> pure $ ELit (TMono "Int") (LInt i)
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
--abstractExp :: AnnExp -> State Int Exp
--abstractExp (free, exp) = case exp of
-- AId n -> pure $ EId n
-- AInt i -> pure $ ELit (TMono "Int") (LInt i)
-- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
-- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
-- ALet b e -> liftA2 ELet (go b) (abstractExp e)
-- where
-- go (ABind name parms rhs) = do
-- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
-- pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
-- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
-- skipLambdas f (free, ae) = case ae of
-- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
-- _ -> f (free, ae)
-- Lift lambda into let and bind free variables
AAbs t parm e -> do
i <- nextNumber
rhs <- abstractExp e
-- -- Lift lambda into let and bind free variables
-- AAbs t parm e -> do
-- i <- nextNumber
-- rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
-- let sc_name = Ident ("sc_" ++ show i)
-- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
-- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
-- where
-- freeList = Set.toList free
-- parms = snoc parm freeList
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
--nextNumber :: State Int Int
--nextNumber = do
-- i <- get
-- put $ succ i
-- pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
---- | Collects supercombinators by lifting non-constant let expressions
--collectScs :: Program -> Program
--collectScs (Program scs) = Program $ concatMap collectFromRhs scs
-- where
-- collectFromRhs (Bind name parms rhs) =
-- let (rhs_scs, rhs') = collectScsExp rhs
-- in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
--collectScsExp :: Exp -> ([Bind], Exp)
--collectScsExp = \case
-- EId n -> ([], EId n)
-- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
-- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
-- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
-- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
-- where
-- (scs1, e1') = collectScsExp e1
-- (scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- EAbs t par e -> (scs, EAbs t par e')
-- where
-- (scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
-- -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- -- > f = let sc x y = rhs in e
-- --
-- ELet (Bind name parms rhs) e -> if null parms
-- then ( rhs_scs ++ e_scs, ELet bind e')
-- else (bind : rhs_scs ++ e_scs, e')
-- where
-- bind = Bind name parms rhs'
-- (rhs_scs, rhs') = collectScsExp rhs
-- (e_scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
---- @\x.\y.\z. e → (e, [x,y,z])@
--flattenLambdas :: Exp -> (Exp, [Id])
--flattenLambdas = go . (, [])
-- where
-- go (e, acc) = case e of
-- EAbs _ par e1 -> go (e1, snoc par acc)
-- _ -> (e, acc)

View file

@ -2,18 +2,18 @@
module Main where
import Codegen.Codegen (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
-- import Codegen.Codegen (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker.TypeChecker (typecheck)
-- import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main =
@ -37,14 +37,14 @@ main' s = do
typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
-- printToErr "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ compile lifted
putStrLn compiled
writeFile "llvm.ll" compiled
-- printToErr "\n -- Printing compiler output to stdout --"
-- compiled <- fromCompilerErr $ compile lifted
-- putStrLn compiled
-- writeFile "llvm.ll" compiled
exitSuccess

View file

@ -3,6 +3,7 @@
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# HLINT ignore "Use zipWithM" #-}
module TypeChecker.TypeChecker where
@ -16,6 +17,7 @@ import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Foldable (traverse_)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
@ -24,10 +26,12 @@ import qualified TypeChecker.TypeCheckerIr as T
data Poly = Forall [Ident] Type
deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly }
newtype Ctx = Ctx { vars :: Map Ident Poly
}
data Env = Env { count :: Int
, sigs :: Map Ident Type
data Env = Env { count :: Int
, sigs :: Map Ident Type
, dtypes :: Map Ident Type
}
type Error = String
@ -36,7 +40,7 @@ type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
initCtx = Ctx mempty
initEnv = Env 0 mempty
initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst). run . inferExp
@ -50,21 +54,44 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
checkData :: Data -> Infer ()
checkData d = case d of
(Data typ@(TConstr name _) constrs) -> do
traverse_ (\(Constructor name' t')
-> if typ == retType t'
then insertConstr name' t' else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ]) constrs
_ -> throwError "Data type incorrectly declared"
where
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
let bs' = getBinds bs
traverse (\(Bind n t _ _ _) -> insertSig n t) bs'
bs' <- mapM checkBind bs'
return $ T.Program bs'
preRun bs
T.Program <$> checkDef bs
where
getBinds :: [Def] -> [Bind]
getBinds = map toBind . filter isBind
isBind :: Def -> Bool
isBind (DBind _) = True
isBind _ = True
toBind :: Def -> Bind
toBind (DBind bind) = bind
toBind _ = error "Can't convert DData to Bind"
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x:xs) = case x of
DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x:xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
@ -77,15 +104,18 @@ checkBind (Bind n t _ args e) = do
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
-- | Check if two types are considered equal
-- For the purpose of the algorithm two polymorphic types are always considered equal
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr name a) (TConstr name' b) = name == name' && and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- w e
(s, t, e') <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
@ -98,19 +128,26 @@ replace t = \case
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
w :: Exp -> Infer (Subst, Type, T.Exp)
w = \case
algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case
EAnn e t -> do
(s1, t', e') <- w e
(s1, t', e') <- algoW e
applySt s1 $ do
s2 <- unify (apply s1 t) t'
return (s2 `compose` s1, t, e')
-- | ------------------
-- | Γ ⊢ e₀ : Int, ∅
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- | x : σ ∈ Γ τ = inst(σ)
-- | ----------------------
-- | Γ ⊢ x : τ, ∅
EId i -> do
var <- asks vars
case M.lookup i var of
@ -118,42 +155,67 @@ w = \case
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Nothing -> throwError $ "Unbound variable: " ++ show i
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets dtypes
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i
-- | τ = newvar Γ, x : τ ⊢ e : τ', S
-- | ---------------------------------
-- | Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- w e
(s1, t', e') <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- | ------------------------------------------
-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- w e0
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- w e1
applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- | --------------------------------------
-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- w e0
(s0, t0, e0') <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- w e1
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- | ----------------------------------------------
-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- w e0
(s1, t1, e0') <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- w e1
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
@ -168,6 +230,12 @@ unify t0 t1 = case (t0, t1) of
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
-- | TODO: Figure out a cleaner way to express the same thing
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
then do
xs <- sequence $ zipWith unify t t'
return $ foldr compose nullSubst xs
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type.
@ -202,9 +270,11 @@ class FreeVars t where
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
-- | Not guaranteed to be correct
free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply sub t = do
case t of
@ -213,6 +283,7 @@ instance FreeVars Type where
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr name a -> TConstr name (map (apply sub) a)
instance FreeVars Poly where
free :: Poly -> Set Ident
@ -248,3 +319,7 @@ withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })

View file

@ -5,12 +5,12 @@ module TypeChecker.TypeCheckerIr
, module TypeChecker.TypeCheckerIr
) where
import Grammar.Abs (Ident (..), Literal (..), Type (..))
import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
@ -22,11 +22,18 @@ data Exp
| EAbs Type Id Exp
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
@ -75,7 +82,7 @@ instance Print Exp where
, doc $ showString "in"
, prt 0 e
]
EApp t e1 e2 -> prPrec i 2 $ concatD
EApp _ e1 e2 -> prPrec i 2 $ concatD
[ prt 2 e1
, prt 3 e2
]

View file

@ -1,2 +1,14 @@
main : _Int ;
main = 3 + 3 ;
data List ('a) where;
Nil : List ('a),
Cons : 'a -> List ('a) -> List ('a) ;
main : List (_Int) ;
main = Cons 1 (Cons 0 Nil) ;
data Bool () where;
True : Bool (),
False : Bool ();
boolean : Bool (_Int);
boolean = True ;