Type inference/checking on ADTs mostly complete(?). Still have to test

This commit is contained in:
sebastianselander 2023-02-27 17:22:42 +01:00
parent 2f45f39435
commit bbf6e159c7
8 changed files with 563 additions and 467 deletions

View file

@ -3,7 +3,7 @@ Program. Program ::= [Def] ;
DBind. Def ::= Bind ; DBind. Def ::= Bind ;
DData. Def ::= Data ; DData. Def ::= Data ;
terminator Def ";" ; separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ; Ident [Ident] "=" Exp ;
@ -31,16 +31,19 @@ IMatch. Match ::= Ident ;
InitMatch. Match ::= Ident Match ; InitMatch. Match ::= Ident Match ;
separator Match " " ; separator Match " " ;
TMono. Type1 ::= "_" Ident ; TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ; TPol. Type1 ::= "'" Ident ;
TArr. Type ::= Type1 "->" Type ; TConstr. Type1 ::= Ident "(" [Type] ")" ;
TArr. Type ::= Type1 "->" Type ;
separator Type " " ; separator Type " " ;
coercions Type 2 ;
-- shift/reduce problem here -- shift/reduce problem here
Data. Data ::= "data" Ident [Type] "where" ";" Data. Data ::= "data" Type "where" ";"
[Constructor]; [Constructor];
terminator Constructor ";" ; separator Constructor "," ;
Constructor. Constructor ::= Ident ":" Type ; Constructor. Constructor ::= Ident ":" Type ;
@ -48,10 +51,9 @@ Constructor. Constructor ::= Ident ":" Type ;
-- token Poly upper (letter | digit | '_')* ; -- token Poly upper (letter | digit | '_')* ;
-- token Mono lower (letter | digit | '_')* ; -- token Mono lower (letter | digit | '_')* ;
terminator Bind ";" ; separator Bind ";" ;
separator Ident " "; separator Ident " ";
coercions Type 1 ;
coercions Exp 5 ; coercions Exp 5 ;
comment "--" ; comment "--" ;

View file

@ -34,9 +34,9 @@ executable language
TypeChecker.TypeChecker TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
Renamer.Renamer Renamer.Renamer
LambdaLifter.LambdaLifter -- LambdaLifter.LambdaLifter
Codegen.Codegen -- Codegen.Codegen
Codegen.LlvmIr -- Codegen.LlvmIr
hs-source-dirs: src hs-source-dirs: src

View file

@ -1,277 +1,277 @@
{-# LANGUAGE LambdaCase #-} --{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} --{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (compile) where module Codegen.Codegen where
import Auxiliary (snoc) --import Auxiliary (snoc)
import Codegen.LlvmIr (LLVMIr (..), LLVMType (..), --import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..), -- LLVMValue (..), Visibility (..),
llvmIrToString) -- llvmIrToString)
import Control.Monad.State (StateT, execStateT, gets, modify) --import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) --import Data.Map (Map)
import qualified Data.Map as Map --import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second) --import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err) --import Grammar.ErrM (Err)
import TypeChecker.TypeChecker --import TypeChecker.TypeCheckerIr
import TypeChecker.TypeCheckerIr
-- | The record used as the code generator state ---- | The record used as the code generator state
data CodeGenerator = CodeGenerator --data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] -- { instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo -- , functions :: Map Id FunctionInfo
, variableCount :: Integer -- , variableCount :: Integer
} -- }
-- | A state type synonym ---- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a --type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo --data FunctionInfo = FunctionInfo
{ numArgs :: Int -- { numArgs :: Int
, arguments :: [Id] -- , arguments :: [Id]
} -- }
-- | Adds a instruction to the CodeGenerator state ---- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () --emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = snoc l $ instructions t } --emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state ---- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () --increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } --increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state ---- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer --getVarCount :: CompilerState Integer
getVarCount = gets variableCount --getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state ---- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer --getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount --getNewVar = increaseVarCount >> getVarCount
-- | Produces a map of functions infos from a list of binds, ---- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation. ---- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo --getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ map go bs --getFunctions bs = Map.fromList $ map go bs
where -- where
go (Bind id args _) = -- go (Bind id args _) =
(id, FunctionInfo { numArgs=length args, arguments=args }) -- (id, FunctionInfo { numArgs=length args, arguments=args })
initCodeGenerator :: [Bind] -> CodeGenerator --initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart --initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs -- , functions = getFunctions scs
, variableCount = 0 -- , variableCount = 0
} -- }
-- | Compiles an AST and produces a LLVM Ir string. ---- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to ---- An easy way to actually "compile" this output is to
-- Simply pipe it to lli ---- Simply pipe it to lli
compile :: Program -> Err String --compile :: Program -> Err String
compile (Program scs) = do --compile (Program scs) = do
let codegen = initCodeGenerator scs -- let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen -- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState () --compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure () --compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do --compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n" -- emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp -- emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args -- let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args' -- emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp -- functionBody <- exprToValue exp
if name == "main" -- if name == "main"
then mapM_ emit $ mainContent functionBody -- then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody -- else emit $ Ret I64 functionBody
emit DefineEnd -- emit DefineEnd
modify $ \s -> s { variableCount = 0 } -- modify $ \s -> s { variableCount = 0 }
compileScs xs -- compileScs xs
where -- where
t_return = snd $ partitionType (length args) t -- t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr] --mainContent :: LLVMValue -> [LLVMIr]
mainContent var = --mainContent var =
[ UnsafeRaw $ -- [ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) -- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") -- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1") -- -- , Label (Ident "b_1")
-- , UnsafeRaw -- -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end") -- -- , Br (Ident "end")
-- , Label (Ident "b_2") -- -- , Label (Ident "b_2")
-- , UnsafeRaw -- -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end") -- -- , Br (Ident "end")
-- , Label (Ident "end") -- -- , Label (Ident "end")
Ret I64 (VInteger 0) -- Ret I64 (VInteger 0)
] -- ]
defaultStart :: [LLVMIr] --defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" --defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" -- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
] -- ]
compileExp :: Exp -> CompilerState () --compileExp :: Exp -> CompilerState ()
compileExp = \case --compileExp = \case
ELit _ (LInt i) -> emitInt i -- ELit _ (LInt i) -> emitInt i
EAdd t e1 e2 -> emitAdd t e1 e2 -- EAdd t e1 e2 -> emitAdd t e1 e2
EId (name, _) -> emitIdent name -- EId (name, _) -> emitIdent name
EApp t e1 e2 -> emitApp t e1 e2 -- EApp t e1 e2 -> emitApp t e1 e2
EAbs t ti e -> emitAbs t ti e -- EAbs t ti e -> emitAbs t ti e
ELet bind e -> emitLet bind e -- ELet bind e -> emitLet bind e
--- aux functions --- ----- aux functions ---
emitAbs :: Type -> Id -> Exp -> CompilerState () --emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e --emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState () --emitLet :: Bind -> Exp -> CompilerState ()
emitLet b e = emit . Comment $ concat [ "ELet (" --emitLet b e = emit . Comment $ concat [ "ELet ("
, show b -- , show b
, " = " -- , " = "
, show e -- , show e
, ") is not implemented!" -- , ") is not implemented!"
] -- ]
emitApp :: Type -> Exp -> Exp -> CompilerState () --emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 [] --emitApp t e1 e2 = appEmitter t e1 e2 []
where -- where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () -- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do -- appEmitter t e1 e2 stack = do
let newStack = e2 : stack -- let newStack = e2 : stack
case e1 of -- case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack -- EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do -- EId id@(name, _) -> do
args <- traverse exprToValue newStack -- args <- traverse exprToValue newStack
vs <- getNewVar -- vs <- getNewVar
funcs <- gets functions -- funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs -- let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args -- args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args' -- call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call -- emit $ SetVariable (Ident $ show vs) call
x -> do -- x -> do
emit . Comment $ "The unspeakable happened: " -- emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x -- emit . Comment $ show x
emitIdent :: Ident -> CompilerState () --emitIdent :: Ident -> CompilerState ()
emitIdent id = do --emitIdent id = do
-- !!this should never happen!! -- -- !!this should never happen!!
emit $ Comment "This should not have happened!" -- emit $ Comment "This should not have happened!"
emit $ Variable id -- emit $ Variable id
emit $ UnsafeRaw "\n" -- emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState () --emitInt :: Integer -> CompilerState ()
emitInt i = do --emitInt i = do
-- !!this should never happen!! -- -- !!this should never happen!!
varCount <- getNewVar -- varCount <- getNewVar
emit $ Comment "This should not have happened!" -- emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) -- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState () --emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do --emitAdd t e1 e2 = do
v1 <- exprToValue e1 -- v1 <- exprToValue e1
v2 <- exprToValue e2 -- v2 <- exprToValue e2
v <- getNewVar -- v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) -- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState () ---- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do ---- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 ---- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount ---- increaseVarCount
-- v <- gets variableCount ---- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v ---- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2 ---- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState () ---- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do ---- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b` ---- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2 ---- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount ---- increaseVarCount
-- vadd <- gets variableCount ---- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd ---- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2 ---- emit $ Add I64 v1 v2
-- ----
-- increaseVarCount ---- increaseVarCount
-- vabs <- gets variableCount ---- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs ---- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64") ---- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd)) ---- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1) ---- , (I1, VInteger 1)
-- ] ---- ]
-- increaseVarCount ---- increaseVarCount
-- v <- gets variableCount ---- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v ---- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 ---- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState () ---- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do ---- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 ---- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount ---- increaseVarCount
-- v <- gets variableCount ---- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v ---- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2 ---- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState () ---- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do ---- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 ---- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount ---- increaseVarCount
-- v <- gets variableCount ---- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v ---- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2 ---- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue --exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case --exprToValue = \case
ELit _ (LInt i) -> pure $ VInteger i -- ELit _ (LInt i) -> pure $ VInteger i
EId id@(name, t) -> do -- EId id@(name, t) -> do
funcs <- gets functions -- funcs <- gets functions
case Map.lookup id funcs of -- case Map.lookup id funcs of
Just fi -> do -- Just fi -> do
if numArgs fi == 0 -- if numArgs fi == 0
then do -- then do
vc <- getNewVar -- vc <- getNewVar
emit $ SetVariable (Ident $ show vc) -- emit $ SetVariable (Ident $ show vc)
(Call (type2LlvmType t) Global name []) -- (Call (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t) -- pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t) -- else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t) -- Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do -- e -> do
compileExp e -- compileExp e
v <- getVarCount -- v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e) -- pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType --type2LlvmType :: Type -> LLVMType
type2LlvmType = \case --type2LlvmType = \case
(TMono "Int") -> I64 -- (TMono "Int") -> I64
TArr t xs -> do -- TArr t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t] -- let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs' -- Function t' xs'
t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"") -- -- This part will not work as we don't have a monomorphization step yet
where -- t -> CustomType $ Ident ("\"" ++ show t ++ "\"")
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) -- where
function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s) -- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType x s = (type2LlvmType x, s) -- function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType --getType :: Exp -> LLVMType
getType (ELit _ (LInt _)) = I64 --getType (ELit _ (LInt _)) = I64
getType (EAdd t _ _) = type2LlvmType t --getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t --getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t --getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t --getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e --getType (ELet _ e) = getType e
valueGetType :: LLVMValue -> LLVMType --valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 --valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t --valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8 --valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t --valueGetType (VFunction _ _ t) = t
-- | Partion type into types of parameters and return type. ---- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply --partitionType :: Int -- Number of parameters to apply
-> Type -- -> Type
-> ([Type], Type) -- -> ([Type], Type)
partitionType = go [] --partitionType = go []
where -- where
go acc 0 t = (acc, t) -- go acc 0 t = (acc, t)
go acc i t = case t of -- go acc i t = case t of
TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2 -- TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match" -- _ -> error "Number of parameters and type doesn't match"

View file

@ -1,192 +1,192 @@
{-# LANGUAGE LambdaCase #-} --{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} --{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where module LambdaLifter.LambdaLifter where
import Auxiliary (snoc) --import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2)) --import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, --import Control.Monad.State (MonadState (get, put), State,
evalState) -- evalState)
import Data.Set (Set) --import Data.Set (Set)
import qualified Data.Set as Set --import qualified Data.Set as Set
import Prelude hiding (exp) --import Prelude hiding (exp)
import Renamer.Renamer --import Renamer.Renamer
import TypeChecker.TypeCheckerIr --import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators. ---- | Lift lambdas and let expression into supercombinators.
-- Three phases: ---- Three phases:
-- @freeVars@ annotatss all the free variables. ---- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions. ---- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. ---- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program --lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars --lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables ---- | Annotate free variables
freeVars :: Program -> AnnProgram --freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) --freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds -- | Bind n xs e <- ds
] -- ]
freeVarsExp :: Set Id -> Exp -> AnnExp --freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case --freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n) -- EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n) -- | otherwise -> (mempty, AId n)
ELit _ (LInt i) -> (mempty, AInt i) -- ELit _ (LInt i) -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') -- EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where -- where
e1' = freeVarsExp localVars e1 -- e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2 -- e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') -- EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where -- where
e1' = freeVarsExp localVars e1 -- e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2 -- e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') -- EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where -- where
e' = freeVarsExp (Set.insert par localVars) e -- e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression -- -- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') -- ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where -- where
binders_frees = Set.delete name $ freeVarsOf rhs' -- binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e' -- e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs -- rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs' -- new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e -- e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars -- e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id --freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst --freeVarsOf = fst
-- AST annotated with free variables ---- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)] --type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp') --type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show --data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id --data AnnExp' = AId Id
| AInt Integer -- | AInt Integer
| ALet ABind AnnExp -- | ALet ABind AnnExp
| AApp Type AnnExp AnnExp -- | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp -- | AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp -- | AAbs Type Id AnnExp
deriving Show -- deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. ---- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound. ---- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program --abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0 --abstract prog = Program $ evalState (mapM go prog) 0
where -- where
go :: (Id, [Id], AnnExp) -> State Int Bind -- go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' -- go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where -- where
(rhs', parms1) = flattenLambdasAnn rhs -- (rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters ---- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@ ---- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id]) --flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, []) --flattenLambdasAnn ae = go (ae, [])
where -- where
go :: (AnnExp, [Id]) -> (AnnExp, [Id]) -- go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) = -- go ((free, e), acc) =
case e of -- case e of
AAbs _ par (free1, e1) -> -- AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc) -- go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc) -- _ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp --abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of --abstractExp (free, exp) = case exp of
AId n -> pure $ EId n -- AId n -> pure $ EId n
AInt i -> pure $ ELit (TMono "Int") (LInt i) -- AInt i -> pure $ ELit (TMono "Int") (LInt i)
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) -- AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) -- AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e) -- ALet b e -> liftA2 ELet (go b) (abstractExp e)
where -- where
go (ABind name parms rhs) = do -- go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs -- (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs' -- pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp -- skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of -- skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 -- AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae) -- _ -> f (free, ae)
-- Lift lambda into let and bind free variables -- -- Lift lambda into let and bind free variables
AAbs t parm e -> do -- AAbs t parm e -> do
i <- nextNumber -- i <- nextNumber
rhs <- abstractExp e -- rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i) -- let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) -- sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList -- pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
where -- where
freeList = Set.toList free -- freeList = Set.toList free
parms = snoc parm freeList -- parms = snoc parm freeList
nextNumber :: State Int Int --nextNumber :: State Int Int
nextNumber = do --nextNumber = do
i <- get -- i <- get
put $ succ i -- put $ succ i
pure i -- pure i
-- | Collects supercombinators by lifting non-constant let expressions ---- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program --collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs --collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where -- where
collectFromRhs (Bind name parms rhs) = -- collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs -- let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs -- in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp) --collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case --collectScsExp = \case
EId n -> ([], EId n) -- EId n -> ([], EId n)
ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i)) -- ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') -- EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where -- where
(scs1, e1') = collectScsExp e1 -- (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 -- (scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') -- EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where -- where
(scs1, e1') = collectScsExp e1 -- (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 -- (scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e') -- EAbs t par e -> (scs, EAbs t par e')
where -- where
(scs, e') = collectScsExp e -- (scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression. -- -- Collect supercombinators from bind, the rhss, and the expression.
-- -- --
-- > f = let sc x y = rhs in e -- -- > f = let sc x y = rhs in e
-- -- --
ELet (Bind name parms rhs) e -> if null parms -- ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e') -- then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e') -- else (bind : rhs_scs ++ e_scs, e')
where -- where
bind = Bind name parms rhs' -- bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs -- (rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e -- (e_scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@ ---- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id]) --flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, []) --flattenLambdas = go . (, [])
where -- where
go (e, acc) = case e of -- go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc) -- EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc) -- _ -> (e, acc)

View file

@ -2,18 +2,18 @@
module Main where module Main where
import Codegen.Codegen (compile) -- import Codegen.Codegen (compile)
import GHC.IO.Handle.Text (hPutStrLn) import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import LambdaLifter.LambdaLifter (lambdaLift) -- import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess) import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr) import System.IO (stderr)
import TypeChecker.TypeChecker (typecheck) import TypeChecker.TypeChecker (typecheck)
main :: IO () main :: IO ()
main = main =
@ -37,14 +37,14 @@ main' s = do
typechecked <- fromTypeCheckerErr $ typecheck renamed typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --" -- printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
printToErr $ printTree lifted -- printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --" -- printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ compile lifted -- compiled <- fromCompilerErr $ compile lifted
putStrLn compiled -- putStrLn compiled
writeFile "llvm.ll" compiled -- writeFile "llvm.ll" compiled
exitSuccess exitSuccess

View file

@ -3,6 +3,7 @@
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-} {-# HLINT ignore "Use traverse_" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-} {-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# HLINT ignore "Use zipWithM" #-}
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
@ -16,6 +17,7 @@ import qualified Data.Map as M
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import qualified Data.Set as S
import Data.Foldable (traverse_)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
@ -24,10 +26,12 @@ import qualified TypeChecker.TypeCheckerIr as T
data Poly = Forall [Ident] Type data Poly = Forall [Ident] Type
deriving Show deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly } newtype Ctx = Ctx { vars :: Map Ident Poly
}
data Env = Env { count :: Int data Env = Env { count :: Int
, sigs :: Map Ident Type , sigs :: Map Ident Type
, dtypes :: Map Ident Type
} }
type Error = String type Error = String
@ -36,7 +40,7 @@ type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 mempty initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst). run . inferExp runPretty = fmap (printTree . fst). run . inferExp
@ -50,21 +54,44 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg typecheck = run . checkPrg
checkData :: Data -> Infer ()
checkData d = case d of
(Data typ@(TConstr name _) constrs) -> do
traverse_ (\(Constructor name' t')
-> if typ == retType t'
then insertConstr name' t' else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ]) constrs
_ -> throwError "Data type incorrectly declared"
where
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
let bs' = getBinds bs preRun bs
traverse (\(Bind n t _ _ _) -> insertSig n t) bs' T.Program <$> checkDef bs
bs' <- mapM checkBind bs'
return $ T.Program bs'
where where
getBinds :: [Def] -> [Bind] preRun :: [Def] -> Infer ()
getBinds = map toBind . filter isBind preRun [] = return ()
isBind :: Def -> Bool preRun (x:xs) = case x of
isBind (DBind _) = True DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs
isBind _ = True DData d@(Data _ _) -> checkData d >> preRun xs
toBind :: Def -> Bind
toBind (DBind bind) = bind checkDef :: [Def] -> Infer [T.Def]
toBind _ = error "Can't convert DData to Bind" checkDef [] = return []
checkDef (x:xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do checkBind (Bind n t _ args e) = do
@ -77,15 +104,18 @@ checkBind (Bind n t _ args e) = do
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs) makeLambda = foldl (flip EAbs)
-- | Check if two types are considered equal
-- For the purpose of the algorithm two polymorphic types are always considered equal
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b typeEq (TMono a) (TMono b) = a == b
typeEq (TPol _) (TPol _) = True typeEq (TConstr name a) (TConstr name' b) = name == name' && and (zipWith typeEq a b)
typeEq _ _ = False typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
inferExp :: Exp -> Infer (Type, T.Exp) inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do inferExp e = do
(s, t, e') <- w e (s, t, e') <- algoW e
let subbed = apply s t let subbed = apply s t
return (subbed, replace subbed e') return (subbed, replace subbed e')
@ -98,19 +128,26 @@ replace t = \case
T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2 T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
w :: Exp -> Infer (Subst, Type, T.Exp) algoW :: Exp -> Infer (Subst, Type, T.Exp)
w = \case algoW = \case
EAnn e t -> do EAnn e t -> do
(s1, t', e') <- w e (s1, t', e') <- algoW e
applySt s1 $ do applySt s1 $ do
s2 <- unify (apply s1 t) t' s2 <- unify (apply s1 t) t'
return (s2 `compose` s1, t, e') return (s2 `compose` s1, t, e')
-- | ------------------
-- | Γ ⊢ e₀ : Int, ∅
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- | x : σ ∈ Γ τ = inst(σ)
-- | ----------------------
-- | Γ ⊢ x : τ, ∅
EId i -> do EId i -> do
var <- asks vars var <- asks vars
case M.lookup i var of case M.lookup i var of
@ -118,42 +155,67 @@ w = \case
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup i sig of case M.lookup i sig of
Nothing -> throwError $ "Unbound variable: " ++ show i
Just t -> return (nullSubst, t, T.EId (i, t)) Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets dtypes
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i
-- | τ = newvar Γ, x : τ ⊢ e : τ', S
-- | ---------------------------------
-- | Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do EAbs name e -> do
fr <- fresh fr <- fresh
withBinding name (Forall [] fr) $ do withBinding name (Forall [] fr) $ do
(s1, t', e') <- w e (s1, t', e') <- algoW e
let varType = apply s1 fr let varType = apply s1 fr
let newArr = TArr varType t' let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e') return (s1, newArr, T.EAbs newArr (name, varType) e')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- | ------------------------------------------
-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong
EAdd e0 e1 -> do EAdd e0 e1 -> do
(s1, t0, e0') <- w e0 (s1, t0, e0') <- algoW e0
applySt s1 $ do applySt s1 $ do
(s2, t1, e1') <- w e1 (s2, t1, e1') <- algoW e1
applySt s2 $ do -- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int") s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int") s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- | --------------------------------------
-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do EApp e0 e1 -> do
fr <- fresh fr <- fresh
(s0, t0, e0') <- w e0 (s0, t0, e0') <- algoW e0
applySt s0 $ do applySt s0 $ do
(s1, t1, e1') <- w e1 (s1, t1, e1') <- algoW e1
-- applySt s1 $ do -- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr) s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- | ----------------------------------------------
-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do ELet name e0 e1 -> do
(s1, t1, e0') <- w e0 (s1, t1, e0') <- algoW e0
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t1 let t' = generalize (apply s1 env) t1
withBinding name t' $ do withBinding name t' $ do
(s2, t2, e1') <- w e1 (s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' ) return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
@ -168,6 +230,12 @@ unify t0 t1 = case (t0, t1) of
(TPol a, b) -> occurs a b (TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a (a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
-- | TODO: Figure out a cleaner way to express the same thing
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
then do
xs <- sequence $ zipWith unify t t'
return $ foldr compose nullSubst xs
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] (a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type. -- | Check if a type is contained in another type.
@ -202,9 +270,11 @@ class FreeVars t where
instance FreeVars Type where instance FreeVars Type where
free :: Type -> Set Ident free :: Type -> Set Ident
free (TPol a) = S.singleton a free (TPol a) = S.singleton a
free (TMono _) = mempty free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b free (TArr a b) = free a `S.union` free b
-- | Not guaranteed to be correct
free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
case t of case t of
@ -213,6 +283,7 @@ instance FreeVars Type where
Nothing -> TPol a Nothing -> TPol a
Just t -> t Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b) TArr a b -> TArr (apply sub a) (apply sub b)
TConstr name a -> TConstr name (map (apply sub) a)
instance FreeVars Poly where instance FreeVars Poly where
free :: Poly -> Set Ident free :: Poly -> Set Ident
@ -248,3 +319,7 @@ withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer () insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) }) insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })

View file

@ -5,12 +5,12 @@ module TypeChecker.TypeCheckerIr
, module TypeChecker.TypeCheckerIr , module TypeChecker.TypeCheckerIr
) where ) where
import Grammar.Abs (Ident (..), Literal (..), Type (..)) import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind] newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp data Exp
@ -22,11 +22,18 @@ data Exp
| EAbs Type Id Exp | EAbs Type Id Exp
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
@ -75,7 +82,7 @@ instance Print Exp where
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
] ]
EApp t e1 e2 -> prPrec i 2 $ concatD EApp _ e1 e2 -> prPrec i 2 $ concatD
[ prt 2 e1 [ prt 2 e1
, prt 3 e2 , prt 3 e2
] ]

View file

@ -1,2 +1,14 @@
main : _Int ; data List ('a) where;
main = 3 + 3 ; Nil : List ('a),
Cons : 'a -> List ('a) -> List ('a) ;
main : List (_Int) ;
main = Cons 1 (Cons 0 Nil) ;
data Bool () where;
True : Bool (),
False : Bool ();
boolean : Bool (_Int);
boolean = True ;