renamed stuff

This commit is contained in:
sebastianselander 2023-03-24 12:21:54 +01:00
parent 3f618e77f9
commit ce3971cf75
9 changed files with 414 additions and 409 deletions

View file

@ -24,7 +24,7 @@ Bind. Bind ::= LIdent [LIdent] "=" Exp ;
TLit. Type2 ::= UIdent ; TLit. Type2 ::= UIdent ;
TVar. Type2 ::= TVar ; TVar. Type2 ::= TVar ;
TAll. Type1 ::= "forall" TVar "." Type ; TAll. Type1 ::= "forall" TVar "." Type ;
TIndexed. Type1 ::= Indexed ; TData. Type1 ::= UIdent "(" [Type] ")" ;
internal TEVar. Type1 ::= TEVar ; internal TEVar. Type1 ::= TEVar ;
TFun. Type ::= Type1 "->" Type ; TFun. Type ::= Type1 "->" Type ;
@ -37,9 +37,7 @@ internal MkTEVar. TEVar ::= LIdent ;
Constructor. Constructor ::= UIdent ":" Type ; Constructor. Constructor ::= UIdent ":" Type ;
Indexed. Indexed ::= UIdent "(" [Type] ")" ; Data. Data ::= "data" Type "where" "{" [Constructor] "}" ;
Data. Data ::= "data" Indexed "where" "{" [Constructor] "}" ;
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- * EXPRESSIONS -- * EXPRESSIONS

View file

@ -1,22 +1,9 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where module Codegen.Codegen where
import Auxiliary (snoc) -- module Codegen.Codegen (generateCode) where
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad.State (StateT, execStateT, foldM_,
gets, modify)
import qualified Data.Bifunctor as BI
import Data.Coerce (coerce)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Tuple.Extra (dupe, first, second)
import qualified Grammar.Abs as GA
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
@ -27,42 +14,45 @@ data CodeGenerator = CodeGenerator
, labelCount :: Integer , labelCount :: Integer
} }
-- | A state type synonym ---- | The record used as the code generator state
type CompilerState a = StateT CodeGenerator Err a -- data CodeGenerator = CodeGenerator
-- { instructions :: [LLVMIr]
-- , functions :: Map MIR.Id FunctionInfo
-- , constructors :: Map Ident ConstructorInfo
-- , variableCount :: Integer
-- , labelCount :: Integer
-- }
data FunctionInfo = FunctionInfo ---- | A state type synonym
{ numArgs :: Int -- type CompilerState a = StateT CodeGenerator Err a
, arguments :: [Id]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
}
deriving (Show)
-- | Adds a instruction to the CodeGenerator state -- data FunctionInfo = FunctionInfo
emit :: LLVMIr -> CompilerState () -- { numArgs :: Int
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} -- , arguments :: [Id]
-- }
-- deriving (Show)
-- data ConstructorInfo = ConstructorInfo
-- { numArgsCI :: Int
-- , argumentsCI :: [Id]
-- , numCI :: Integer
-- }
-- deriving (Show)
-- | Increases the variable counter in the CodeGenerator state ---- | Adds a instruction to the CodeGenerator state
increaseVarCount :: CompilerState () -- emit :: LLVMIr -> CompilerState ()
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1} -- emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Returns the variable count from the CodeGenerator state ---- | Increases the variable counter in the CodeGenerator state
getVarCount :: CompilerState Integer -- increaseVarCount :: CompilerState ()
getVarCount = gets variableCount -- increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
-- | Increases the variable count and returns it from the CodeGenerator state ---- | Returns the variable count from the CodeGenerator state
getNewVar :: CompilerState GA.Ident -- getVarCount :: CompilerState Integer
getNewVar = (GA.Ident . show) <$> (increaseVarCount >> getVarCount) -- getVarCount = gets variableCount
-- | Increses the label count and returns a label from the CodeGenerator state ---- | Increases the variable count and returns it from the CodeGenerator state
getNewLabel :: CompilerState Integer -- getNewVar :: CompilerState GA.Ident
getNewLabel = do -- getNewVar = (GA.Ident . show) <$> (increaseVarCount >> getVarCount)
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
@ -87,8 +77,28 @@ getFunctions bs = Map.fromList $ go bs
cons cons
<> go xs <> go xs
createArgs :: [MIR.Type] -> [Id] -- {- | Produces a map of functions infos from a list of binds,
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs -- which contains useful data for code generation.
---}
-- getFunctions :: [MIR.Def] -> Map Id FunctionInfo
-- getFunctions bs = Map.fromList $ go bs
-- where
-- go [] = []
-- go (MIR.DBind (MIR.Bind id args _) : xs) =
-- (id, FunctionInfo{numArgs = length args, arguments = args})
-- : go xs
-- go (MIR.DData (MIR.Constructor n cons) : xs) = undefined
-- {-do map
-- ( \(Constructor id xs) ->
-- ( (id, MIR.TLit n)
-- , FunctionInfo
-- { numArgs = length xs
-- , arguments = createArgs xs
-- }
-- )
-- )
-- cons
-- <> go xs-}
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
@ -119,66 +129,53 @@ getConstructors bs = Map.fromList $ go bs
<> go xs <> go xs
go (_ : xs) = go xs go (_ : xs) = go xs
initCodeGenerator :: [MIR.Def] -> CodeGenerator -- {- | Produces a map of functions infos from a list of binds,
initCodeGenerator scs = -- which contains useful data for code generation.
CodeGenerator ---}
{ instructions = defaultStart -- getConstructors :: [MIR.Def] -> Map Ident ConstructorInfo
, functions = getFunctions scs -- getConstructors bs = Map.fromList $ go bs
, constructors = getConstructors scs -- where
, variableCount = 0 -- go [] = []
, labelCount = 0 -- go (MIR.DData (MIR.Constructor n cons) : xs) = undefined
} -- {-do
-- fst
-- ( foldl
-- ( \(acc, i) (GA.Constructor (GA.Ident id) xs) ->
-- ( ( (GA.Ident (n <> "_" <> id), MIR.TLit (GA.Ident n))
-- , ConstructorInfo
-- { numArgsCI = length xs
-- , argumentsCI = createArgs xs
-- , numCI = i
-- }
-- )
-- : acc
-- , i + 1
-- )
-- )
-- ([], 0)
-- cons
-- )
-- <> go xs-}
-- go (_ : xs) = go xs
{- -- initCodeGenerator :: [MIR.Def] -> CodeGenerator
run :: Err String -> IO () -- initCodeGenerator scs =
run s = do -- CodeGenerator
let s' = case s of -- { instructions = defaultStart
Right s -> s -- , functions = getFunctions scs
Left _ -> error "yo" -- , constructors = getConstructors scs
writeFile "output/llvm.ll" s' -- , variableCount = 0
putStrLn . trim =<< readCreateProcess (shell "lli") s' -- , labelCount = 0
-- }
test :: Integer -> Program -- {-
test v = -- run :: Err String -> IO ()
Program -- run s = do
[ DataType -- let s' = case s of
(GA.Ident "Craig") -- Right s -> s
[ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")] -- Left _ -> error "yo"
, Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")] -- writeFile "output/llvm.ll" s'
] -- putStrLn . trim =<< readCreateProcess (shell "lli") s'
, DataType
(GA.Ident "Alice")
[ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- ,
-- (GA.Ident "Alice", [TInt, TInt])
]
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig"))
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) []
-- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92)
$
eCaseInt
(EApp (MIR.TLit (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.TLit (GA.Ident "Craig")), MIR.TLit (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (GA.Ident "z")) (int 3)
, -- , injectionInt 5 (int 6)
injectionCatchAll (int 10)
]
]
where
injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs)
injectionInt x = Injection (CLit (LInt x))
injectionCatchAll = Injection CatchAll
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int"))
-}
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
generateCode :: MIR.Program -> Err String
generateCode (MIR.Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [MIR.Def] -> CompilerState () compileScs :: [MIR.Def] -> CompilerState ()
compileScs [] = do compileScs [] = do
@ -270,50 +267,50 @@ compileScs (MIR.DData (MIR.Constructor (GA.UIdent outer_id) ts) : xs) = do
types types
compileScs xs compileScs xs
mainContent :: LLVMValue -> [LLVMIr] -- mainContent :: LLVMValue -> [LLVMIr]
mainContent var = -- mainContent var =
[ UnsafeRaw $ -- [ UnsafeRaw $
-- "%2 = alloca %Craig\n" <> -- -- "%2 = alloca %Craig\n" <>
-- " store %Craig %1, ptr %2\n" <> -- -- " store %Craig %1, ptr %2\n" <>
-- " %3 = bitcast %Craig* %2 to i72*\n" <> -- -- " %3 = bitcast %Craig* %2 to i72*\n" <>
-- " %4 = load i72, ptr %3\n" <> -- -- " %4 = load i72, ptr %3\n" <>
-- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n" -- -- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n"
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) -- , -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2") -- -- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2")
-- , Label (GA.Ident "b_1") -- -- , Label (GA.Ident "b_1")
-- , UnsafeRaw -- -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (GA.Ident "end") -- -- , Br (GA.Ident "end")
-- , Label (GA.Ident "b_2") -- -- , Label (GA.Ident "b_2")
-- , UnsafeRaw -- -- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" -- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (GA.Ident "end") -- -- , Br (GA.Ident "end")
-- , Label (GA.Ident "end") -- -- , Label (GA.Ident "end")
Ret I64 (VInteger 0) -- Ret I64 (VInteger 0)
] -- ]
defaultStart :: [LLVMIr] -- defaultStart :: [LLVMIr]
defaultStart = -- defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" -- [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" -- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%x\n\", align 1\n" -- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%x\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" -- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
] -- ]
compileExp :: ExpT -> CompilerState () -- compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit,t) = emitLit lit -- compileExp (MIR.ELit lit,t) = emitLit lit
compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2 -- compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2 ---- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (MIR.EId name,t) = emitIdent name -- compileExp (MIR.EId name,t) = emitIdent name
compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2 -- compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e ---- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e) -- compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e)
compileExp (MIR.ECase e cs,t) = emitECased t e (map (t,) cs) -- compileExp (MIR.ECase e cs,t) = emitECased t e (map (t,) cs)
-- go (EMul e1 e2) = emitMul e1 e2 ---- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 ---- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 ---- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Injection)] -> CompilerState () emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Injection)] -> CompilerState ()
@ -336,89 +333,89 @@ emitECased t e cases = do
cons <- gets constructors cons <- gets constructors
let r = fromJust $ Map.lookup (coerce consId, t) cons let r = fromJust $ Map.lookup (coerce consId, t) cons
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel -- lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel -- lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar -- consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) -- emit $ SetVariable consVal (ExtractValue rt vs 0)
consCheck <- getNewVar -- consCheck <- getNewVar
emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r)) -- emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r))
emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos -- emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos -- emit $ Label lbl_succPos
castPtr <- getNewVar -- castPtr <- getNewVar
castedPtr <- getNewVar -- castedPtr <- getNewVar
casted <- getNewVar -- casted <- getNewVar
emit $ SetVariable castPtr (Alloca rt) -- emit $ SetVariable castPtr (Alloca rt)
emit $ Store rt vs Ptr castPtr -- emit $ Store rt vs Ptr castPtr
emit $ SetVariable castedPtr (Bitcast Ptr (VIdent castPtr Ptr) Ptr) -- emit $ SetVariable castedPtr (Bitcast Ptr (VIdent castPtr Ptr) Ptr)
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castedPtr) -- emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castedPtr)
val <- exprToValue exp -- val <- exprToValue exp
-- enumerateOneM_ -- -- enumerateOneM_
-- (\i c -> do -- -- (\i c -> do
-- case c of -- -- case c of
-- CIdent x -> do -- -- CIdent x -> do
-- emit . Comment $ "ident " <> show x -- -- emit . Comment $ "ident " <> show x
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) -- -- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- emit $ Store ty val Ptr stackPtr -- -- emit $ Store ty val Ptr stackPtr
-- CCons x cs -> error "nested constructor" -- -- CCons x cs -> error "nested constructor"
-- CLit l -> do -- -- CLit l -> do
-- testVar <- getNewVar -- -- testVar <- getNewVar
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) -- -- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- case l of -- -- case l of
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l) -- -- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c) -- -- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
-- CCatch -> emit . Comment $ "Catch all" -- -- CCatch -> emit . Comment $ "Catch all"
-- emit . Comment $ "return this " <> toIr val -- -- emit . Comment $ "return this " <> toIr val
-- emit . Comment . show $ c -- -- emit . Comment . show $ c
-- emit . Comment . show $ i -- -- emit . Comment . show $ i
-- ) -- -- )
-- cs -- -- cs
-- emit $ Store ty val Ptr stackPtr -- -- emit $ Store ty val Ptr stackPtr
emit $ Br label -- emit $ Br label
emit $ Label lbl_failPos -- emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Injection (MIR.InitLit i, _) exp) = do -- emitCases rt ty label stackPtr vs (Injection (MIR.InitLit i, _) exp) = do
let i' = case i of -- let i' = case i of
GA.LInt i -> VInteger i -- GA.LInt i -> VInteger i
GA.LChar i -> VChar i -- GA.LChar i -> VChar i
ns <- getNewVar -- ns <- getNewVar
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel -- lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel -- lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq ty vs i') -- emit $ SetVariable ns (Icmp LLEq ty vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos -- emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos -- emit $ Label lbl_succPos
val <- exprToValue exp -- val <- exprToValue exp
emit $ Store ty val Ptr stackPtr -- emit $ Store ty val Ptr stackPtr
emit $ Br label -- emit $ Br label
emit $ Label lbl_failPos -- emit $ Label lbl_failPos
-- emitCases rt ty label stackPtr vs (Injection (MIR.CIdent id) exp) = do ---- emitCases rt ty label stackPtr vs (Injection (MIR.CIdent id) exp) = do
-- -- //TODO this is pretty disgusting and would heavily benefit from a rewrite ---- -- //TODO this is pretty disgusting and would heavily benefit from a rewrite
-- valPtr <- getNewVar ---- valPtr <- getNewVar
-- emit $ SetVariable valPtr (Alloca rt) ---- emit $ SetVariable valPtr (Alloca rt)
-- emit $ Store rt vs Ptr valPtr ---- emit $ Store rt vs Ptr valPtr
-- emit $ SetVariable id (Load rt Ptr valPtr) ---- emit $ SetVariable id (Load rt Ptr valPtr)
-- increaseVarCount ---- increaseVarCount
-- val <- exprToValue (fst exp) ---- val <- exprToValue (fst exp)
---- emit $ Store ty val Ptr stackPtr
---- emit $ Br label
-- emitCases _ ty label stackPtr _ (Injection (MIR.InitCatch, _) exp) = do
-- val <- exprToValue exp
-- emit $ Store ty val Ptr stackPtr -- emit $ Store ty val Ptr stackPtr
-- emit $ Br label -- emit $ Br label
emitCases _ ty label stackPtr _ (Injection (MIR.InitCatch, _) exp) = do
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
--emitLet :: Bind -> Exp -> CompilerState () ----emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do -- emitLet xs e = do
emit $ -- emit $
Comment $ -- Comment $
concat -- concat
[ "ELet (" -- [ "ELet ("
, show xs -- , show xs
, " = " -- , " = "
, show e -- , show e
, ") is not implemented!" -- , ") is not implemented!"
] -- ]
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitApp t e1 e2 = appEmitter e1 e2 [] emitApp t e1 e2 = appEmitter e1 e2 []
@ -443,60 +440,60 @@ emitApp t e1 e2 = appEmitter e1 e2 []
emit $ SetVariable vs call emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x x -> error $ "The unspeakable happened: " <> show x
emitIdent :: GA.Ident -> CompilerState () -- emitIdent :: GA.Ident -> CompilerState ()
emitIdent id = do -- emitIdent id = do
-- !!this should never happen!! -- -- !!this should never happen!!
emit $ Comment "This should not have happened!" -- emit $ Comment "This should not have happened!"
emit $ Variable id -- emit $ Variable id
emit $ UnsafeRaw "\n" -- emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState () -- emitLit :: MIR.Lit -> CompilerState ()
emitLit i = do -- emitLit i = do
-- !!this should never happen!! -- -- !!this should never happen!!
let (i', t) = case i of -- let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64) -- (MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar i'', I8) -- (MIR.LChar i'') -> (VChar i'', I8)
varCount <- getNewVar -- varCount <- getNewVar
emit $ Comment "This should not have happened!" -- emit $ Comment "This should not have happened!"
emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0)) -- emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () -- emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd t e1 e2 = do -- emitAdd t e1 e2 = do
v1 <- exprToValue e1 -- v1 <- exprToValue e1
v2 <- exprToValue e2 -- v2 <- exprToValue e2
v <- getNewVar -- v <- getNewVar
emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2) -- emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2)
emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState () -- emitSub :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitSub t e1 e2 = do -- emitSub t e1 e2 = do
v1 <- exprToValue e1 -- v1 <- exprToValue e1
v2 <- exprToValue e2 -- v2 <- exprToValue e2
v <- getNewVar -- v <- getNewVar
emit $ SetVariable v (Sub (type2LlvmType t) v1 v2) -- emit $ SetVariable v (Sub (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue -- exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case -- exprToValue = \case
(MIR.ELit i, t) -> pure $ case i of -- (MIR.ELit i, t) -> pure $ case i of
(MIR.LInt i) -> VInteger i -- (MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar i -- (MIR.LChar i) -> VChar i
(MIR.EId name, t) -> do -- (MIR.EId name, t) -> do
funcs <- gets functions -- funcs <- gets functions
case Map.lookup (name, t) funcs of -- case Map.lookup (name, t) funcs of
Just fi -> do -- Just fi -> do
if numArgs fi == 0 -- if numArgs fi == 0
then do -- then do
vc <- getNewVar -- vc <- getNewVar
emit $ -- emit $
SetVariable -- SetVariable
vc -- vc
(Call FastCC (type2LlvmType t) Global name []) -- (Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t) -- pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t) -- else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t) -- Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do -- e -> do
compileExp e -- compileExp e
v <- getVarCount -- v <- getVarCount
pure $ VIdent (GA.Ident $ show v) (getType e) -- pure $ VIdent (GA.Ident $ show v) (getType e)
type2LlvmType :: MIR.Type -> LLVMType type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(Ident name)) = case name of type2LlvmType (MIR.TLit id@(Ident name)) = case name of
@ -510,26 +507,26 @@ type2LlvmType (MIR.TFun t xs) = do
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType -- getType :: ExpT -> LLVMType
getType (_, t) = type2LlvmType t -- getType (_, t) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType -- valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 -- valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8 -- valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t -- valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 -- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t -- valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer -- typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1 -- typeByteSize I1 = 1
typeByteSize I8 = 1 -- typeByteSize I8 = 1
typeByteSize I32 = 4 -- typeByteSize I32 = 4
typeByteSize I64 = 8 -- typeByteSize I64 = 8
typeByteSize Ptr = 8 -- typeByteSize Ptr = 8
typeByteSize (Ref _) = 8 -- typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8 -- typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t -- typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8 -- typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () -- enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 -- enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -2,7 +2,7 @@
module Main where module Main where
import Codegen.Codegen (generateCode) -- import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn) import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
@ -13,10 +13,13 @@ import Control.Monad (when)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import System.Directory (createDirectory, doesPathExist, import System.Directory (
createDirectory,
doesPathExist,
getDirectoryContents, getDirectoryContents,
removeDirectoryRecursive, removeDirectoryRecursive,
setCurrentDirectory) setCurrentDirectory,
)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess) import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr) import System.IO (stderr)
@ -50,9 +53,9 @@ main' debug s = do
-- let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted -- printToErr $ printTree lifted
-- --
printToErr "\n -- Printing compiler output to stdout --" -- printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ generateCode (monomorphize typechecked) -- compiled <- fromCompilerErr $ generateCode (monomorphize typechecked)
putStrLn compiled -- putStrLn compiled
-- check <- doesPathExist "output" -- check <- doesPathExist "output"
-- when check (removeDirectoryRecursive "output") -- when check (removeDirectoryRecursive "output")

View file

@ -3,11 +3,12 @@
module Monomorphizer.Monomorphizer (monomorphize) where module Monomorphizer.Monomorphizer (monomorphize) where
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Grammar.Abs (Constructor (..), Ident (..), import Grammar.Abs (Constructor (..), Ident (..))
Indexed (..)) import Unsafe.Coerce (unsafeCoerce)
import qualified Grammar.Abs as GA
import qualified Monomorphizer.MonomorphizerIr as M import Grammar.Abs qualified as GA
import qualified TypeChecker.TypeCheckerIr as T import Monomorphizer.MonomorphizerIr qualified as M
import TypeChecker.TypeCheckerIr qualified as T
monomorphize :: T.Program -> M.Program monomorphize :: T.Program -> M.Program
monomorphize (T.Program ds) = M.Program $ monoDefs ds monomorphize (T.Program ds) = M.Program $ monoDefs ds
@ -17,14 +18,11 @@ monoDefs = map monoDef
monoDef :: T.Def -> M.Def monoDef :: T.Def -> M.Def
monoDef (T.DBind bind) = M.DBind $ monoBind bind monoDef (T.DBind bind) = M.DBind $ monoBind bind
monoDef (T.DData d) = M.DData $ monoData d monoDef (T.DData d) = M.DData $ unsafeCoerce d
monoBind :: T.Bind -> M.Bind monoBind :: T.Bind -> M.Bind
monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t) monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t)
monoData :: T.Data -> M.Constructor
monoData (T.Data (Indexed n _) cons) = M.Constructor n (map (\(Constructor n t) -> (n, monoAbsType t)) cons)
monoExpr :: T.Exp -> M.Exp monoExpr :: T.Exp -> M.Exp
monoExpr = \case monoExpr = \case
T.EId (Ident i) -> M.EId (Ident i) T.EId (Ident i) -> M.EId (Ident i)
@ -39,23 +37,22 @@ monoAbsType :: GA.Type -> M.Type
monoAbsType (GA.TLit u) = M.TLit (coerce u) monoAbsType (GA.TLit u) = M.TLit (coerce u)
monoAbsType (GA.TVar _v) = error "NOT POLYMORHPIC TYPES" monoAbsType (GA.TVar _v) = error "NOT POLYMORHPIC TYPES"
monoAbsType (GA.TAll _v _t) = error "NOT ALL TYPES" monoAbsType (GA.TAll _v _t) = error "NOT ALL TYPES"
monoAbsType (GA.TIndexed _i) = error "NOT INDEXED TYPES" monoAbsType (GA.TData _ i) = error "NOT INDEXED TYPES"
monoAbsType (GA.TEVar _v) = error "I DONT KNOW WHAT THIS IS" monoAbsType (GA.TEVar _v) = error "I DONT KNOW WHAT THIS IS"
monoAbsType (GA.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2) monoAbsType (GA.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2)
monoType :: T.Type -> M.Type monoType :: T.Type -> M.Type
monoType (T.TAll _ t) = monoType t monoType (T.TAll _ t) = monoType t
monoType (T.TVar (T.MkTVar i)) = error "NOT POLYMORPHIC TYPES" monoType (T.TVar (T.MkTVar i)) = error "NOT POLYMORPHIC TYPES"
monoType (T.TLit i) = M.TLit i monoType (T.TLit i) = M.TLit i
monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2)
monoType (T.TIndexed _) = error "Not sure what this is" monoType (T.TData _ _) = error "Not sure what this is"
monoexpt :: T.ExpT -> M.ExpT monoexpt :: T.ExpT -> M.ExpT
monoexpt (e, t) = (monoExpr e, monoType t) monoexpt (e, t) = (monoExpr e, monoType t)
monoId :: T.Id -> M.Id monoId :: T.Id -> M.Id
monoId (n,t) = (n, monoType t) monoId (n, t) = (n, monoType t)
monoLit :: T.Lit -> M.Lit monoLit :: T.Lit -> M.Lit
monoLit (T.LInt i) = M.LInt i monoLit (T.LInt i) = M.LInt i
@ -69,4 +66,3 @@ monoInj (T.Inj (init, t) expt) = M.Injection (monoInit init, monoType t) (monoex
monoInit :: T.Init -> M.Init monoInit :: T.Init -> M.Init
monoInit = id monoInit = id

View file

@ -1,16 +1,18 @@
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module RE, module GA) where module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module RE, module GA) where
import Grammar.Abs (Ident (..), Init (..), UIdent) import Grammar.Abs (Ident (..), Init (..), UIdent)
import qualified Grammar.Abs as GA (Ident (..), Init (..)) import Grammar.Abs qualified as GA (Ident (..), Init (..))
import qualified TypeChecker.TypeCheckerIr as RE (Indexed) import TypeChecker.TypeCheckerIr qualified as RE
import TypeChecker.TypeCheckerIr (Indexed)
type Id = (Ident, Type) type Id = (Ident, Type)
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Def = DBind Bind | DData Constructor data Def = DBind Bind | DData Data
deriving (Show, Ord, Eq)
data Data = Data Type Constructor
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind Id [Id] ExpT

View file

@ -37,22 +37,23 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef
renameDef = \case renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (Indexed cname types) constrs) -> do DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_ tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars' let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (Indexed cname typ') constrs' pure . DData $ Data (TData cname typ') constrs'
where where
tvars = concat <$> mapM (collectTVars []) types tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar] collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t TAll tvar t -> collectTVars (tvar : tvars) t
TIndexed _ -> return tvars TData _ _ -> return tvars
-- Should be monad error -- Should be monad error
TVar v -> return [v] TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types) _ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
renameConstr new_types (Constructor name typ) = renameConstr new_types (Constructor name typ) =
@ -78,7 +79,7 @@ substituteTVar new_names typ = case typ of
TAll tvar' $ substitute' t TAll tvar' $ substitute' t
| otherwise -> | otherwise ->
TAll tvar $ substitute' t TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ) _ -> error ("Impossible " ++ show typ)
where where
substitute' = substituteTVar new_names substitute' = substituteTVar new_names
@ -169,7 +170,7 @@ substitute tvar1 tvar2 typ = case typ of
| otherwise -> typ | otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t TAll tvar t -> TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible" _ -> error "Impossible"
where where
substitute' = substitute tvar1 tvar2 substitute' = substitute tvar1 tvar2

View file

@ -48,13 +48,13 @@ typecheck = run . checkPrg
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = do checkData d = do
case d of case d of
(Data typ@(Indexed name ts) constrs) -> do (Data typ@(TData name ts) constrs) -> do
unless unless
(all isPoly ts) (all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"]) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ traverse_
( \(Constructor name' t') -> ( \(Constructor name' t') ->
if TIndexed typ == retType t' if typ == retType t'
then insertConstr (coerce name') (toNew t') then insertConstr (coerce name') (toNew t')
else else
throwError $ throwError $
@ -68,6 +68,7 @@ checkData d = do
] ]
) )
constrs constrs
_ -> throwError $ "incorrectly declared data type '" ++ printTree d ++ "'"
retType :: Type -> Type retType :: Type -> Type
retType (TFun _ t2) = retType t2 retType (TFun _ t2) = retType t2
@ -86,7 +87,14 @@ checkPrg (Program bs) = do
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
DSig (Sig n t) -> do DSig (Sig n t) -> do
gets (M.member (coerce n) . sigs) >>= flip when (throwError $ "Duplicate signatures for function '" ++ printTree n ++ "'") gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
++ printTree n
++ "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do DBind (Bind n _ _) -> do
s <- gets sigs s <- gets sigs
@ -140,7 +148,7 @@ isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq _ (T.TAll _ _) = True isMoreSpecificOrEq _ (T.TAll _ _) = True
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TIndexed (T.Indexed n1 ts1)) (T.TIndexed (T.Indexed n2 ts2)) = isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2) && and (zipWith isMoreSpecificOrEq ts1 ts2)
@ -169,11 +177,11 @@ instance NewType Type T.Type where
TVar v -> T.TVar $ toNew v TVar v -> T.TVar $ toNew v
TFun t1 t2 -> T.TFun (toNew t1) (toNew t2) TFun t1 t2 -> T.TFun (toNew t1) (toNew t2)
TAll b t -> T.TAll (toNew b) (toNew t) TAll b t -> T.TAll (toNew b) (toNew t)
TIndexed i -> T.TIndexed (toNew i) TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker" TEVar _ -> error "Should not exist after typechecker"
instance NewType Indexed T.Indexed where -- instance NewType Indexed T.TData where
toNew (Indexed name vars) = T.Indexed (coerce name) (map toNew vars) -- toNew (Indexed name vars) = T.TData (coerce name) (map toNew vars)
instance NewType TVar T.TVar where instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i toNew (MkTVar i) = T.MkTVar $ coerce i
@ -181,8 +189,8 @@ instance NewType TVar T.TVar where
algoW :: Exp -> Infer (Subst, T.ExpT) algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this -- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do err@(EAnn e t) -> do
(s1, (e', t')) <- algoW e (s1, (e', t')) <- exprErr (algoW e) err
unless unless
(toNew t `isMoreSpecificOrEq` t') (toNew t `isMoreSpecificOrEq` t')
( throwError $ ( throwError $
@ -194,16 +202,14 @@ algoW = \case
] ]
) )
applySt s1 $ do applySt s1 $ do
s2 <- unify (toNew t) t' s2 <- exprErr (unify (toNew t) t') err
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t)) return (comp, apply comp (e', toNew t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit lit -> ELit lit -> return (nullSubst, (T.ELit lit, litType lit))
let lt = litType lit
in return (nullSubst, (T.ELit lit, lt))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
@ -227,13 +233,16 @@ algoW = \case
-- \| --------------------------------- -- \| ---------------------------------
-- \| Γ ⊢ w λx. e : Sτ → τ', S -- \| Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do err@(EAbs name e) -> do
fr <- fresh fr <- fresh
withBinding (coerce name) fr $ do exprErr
(s1, (e', t')) <- algoW e ( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr let varType = apply s1 fr
let newArr = T.TFun varType t' let newArr = T.TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -241,13 +250,13 @@ algoW = \case
-- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong -- This might be wrong
EAdd e0 e1 -> do err@(EAdd e0 e1) -> do
(s1, (e0', t0)) <- algoW e0 (s1, (e0', t0)) <- algoW e0
applySt s1 $ do applySt s1 $ do
(s2, (e1', t1)) <- algoW e1 (s2, (e1', t1)) <- algoW e1
-- applySt s2 $ do -- applySt s2 $ do
s3 <- unify (apply s2 t0) int s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- unify (apply s3 t1) int s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1 let comp = s4 `compose` s3 `compose` s2 `compose` s1
return return
( comp ( comp
@ -259,12 +268,12 @@ algoW = \case
-- \| -------------------------------------- -- \| --------------------------------------
-- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do err@(EApp e0 e1) -> do
fr <- fresh fr <- fresh
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- exprErr (algoW e0) err
applySt s0 $ do applySt s0 $ do
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- exprErr (algoW e1) err
s2 <- unify (apply s1 t0) (T.TFun t1 fr) s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err
let t = apply s2 fr let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0 let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
@ -275,9 +284,9 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize" -- The bar over S₀ and Γ means "generalize"
ELet b@(Bind name args e) e1 -> do err@(ELet b@(Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args)) (s1, (_, t0)) <- exprErr (algoW (makeLambda e (coerce args))) err
bind' <- checkBind b bind' <- exprErr (checkBind b) err
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t0 let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do withBinding (coerce name) t' $ do
@ -311,7 +320,7 @@ unify t0 t1 = do
(a, T.TAll _ t) -> unify a t (a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) -> (T.TLit a, T.TLit b) ->
if a == b then return M.empty else throwError . unwords $ ["Can not unify", "'" ++ printTree (T.TLit a) ++ "'", "with", "'" ++ printTree (T.TLit b) ++ "'"] if a == b then return M.empty else throwError . unwords $ ["Can not unify", "'" ++ printTree (T.TLit a) ++ "'", "with", "'" ++ printTree (T.TLit b) ++ "'"]
(T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) -> (T.TData name t, T.TData name' t') ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
@ -399,7 +408,7 @@ instance FreeVars T.Type where
free (T.TLit _) = mempty free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b free (T.TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (T.TIndexed (T.Indexed _ a)) = free (T.TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> T.Type -> T.Type apply :: Subst -> T.Type -> T.Type
@ -413,7 +422,7 @@ instance FreeVars T.Type where
Nothing -> T.TAll (T.MkTVar i) (apply sub t) Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TIndexed (T.Indexed name a) -> T.TIndexed (T.Indexed name (map (apply sub) a)) T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map Ident T.Type) where instance FreeVars (Map Ident T.Type) where
free :: Map Ident T.Type -> Set Ident free :: Map Ident T.Type -> Set Ident
@ -548,3 +557,6 @@ partitionType = go []
TAll tvar t' -> second (TAll tvar) $ go acc i t' TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc ++ [t1]) (i - 1) t2 TFun t1 t2 -> go (acc ++ [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match" _ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp = catchError ma (\x -> throwError $ x ++ " on expression: " ++ printTree exp)

View file

@ -52,7 +52,7 @@ data Type
| TVar TVar | TVar TVar
| TFun Type Type | TFun Type Type
| TAll TVar Type | TAll TVar Type
| TIndexed Indexed | TData Ident [Type]
deriving (Show, Eq, Ord, Read) deriving (Show, Eq, Ord, Read)
data Exp data Exp
@ -67,9 +67,6 @@ data Exp
type ExpT = (Exp, Type) type ExpT = (Exp, Type)
data Indexed = Indexed Ident [Type]
deriving (Show, Read, Ord, Eq)
data Inj = Inj (Init, Type) ExpT data Inj = Inj (Init, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
@ -205,8 +202,5 @@ instance Print Type where
TLit uident -> prPrec i 2 (concatD [prt 0 uident]) TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar -> prPrec i 2 (concatD [prt 0 tvar]) TVar tvar -> prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TIndexed indexed -> prPrec i 1 (concatD [prt 0 indexed]) TData ident types -> prPrec i 1 (concatD [prt 0 ident, prt 0 types])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Indexed where
prt i (Indexed u ts) = concatD [prt i u, prt i ts]

View file

@ -8,18 +8,18 @@ data Bool () where {
False : Bool () False : Bool ()
}; };
-- hello_world = Cons 'h' (Cons 'e' (Cons 'l' (Cons 'l' (Cons 'o' (Cons ' ' (Cons 'w' (Cons 'o' (Cons 'r' (Cons 'l' (Cons 'd' Nil)))))))))) ; hello_world = Cons 'h' (Cons 'e' (Cons 'l' (Cons 'l' (Cons 'o' (Cons ' ' (Cons 'w' (Cons 'o' (Cons 'r' (Cons 'l' (Cons 'd' Nil)))))))))) ;
-- length : List (a) -> Int ; length : List (a) -> Int ;
-- length xs = case xs of { length xs = case xs of {
-- Nil => 0 ; Nil => 0 ;
-- Cons x xs => length xs Cons x xs => length xs
-- }; };
-- head : List (a) -> a ; head : List (a) -> a ;
-- head xs = case xs of { head xs = case xs of {
-- Cons x xs => x Cons x xs => x
-- }; };
firstIsOne : List (Int) -> Bool () ; firstIsOne : List (Int) -> Bool () ;
firstIsOne : List (Int) -> Bool () ; firstIsOne : List (Int) -> Bool () ;
@ -34,9 +34,11 @@ firstIsOne xs = case xs of {
_ => False _ => False
}; };
-- firstIsOne :: [Int] -> Bool firstIsOne :: [Int] -> Bool
-- firstIsOne xs = case xs of firstIsOne xs = case xs of
-- (1 : xs) -> True (1 : xs) -> True
-- _ -> False _ -> False
main = firstIsOne (Cons 'a' Nil) main = firstIsOne (Cons 'a' Nil)
data a -> b where