This commit is contained in:
Rakarake 2023-04-27 14:02:10 +02:00
commit 579153b679
9 changed files with 339 additions and 261 deletions

9
benchmark.txt Normal file
View file

@ -0,0 +1,9 @@
# Full optimization Churf
File: output/hello_world, 100 runs gave average: 0.025261127948760988s
# O2 Haskell
File: ./Bench, 100 runs gave average: 0.05629507303237915s
# 03 Haskell
File: ./Bench, 100 runs gave average: 0.05490849256515503s
File: ./Bench, 100 runs gave average: 0.05323728561401367s

View file

@ -0,0 +1,3 @@
main = case (lt 3 5) of
True => 1
False => 0

View file

@ -9,6 +9,7 @@ type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64 "Int" -> I64
"Char" -> I8 "Char" -> I8
"Bool" -> I1
_ -> CustomType id _ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t] let (t', xs') = function2LLVMType xs [type2LlvmType t]

View file

@ -11,7 +11,8 @@ import Control.Monad.State (
) )
import Data.List (sortBy) import Data.List (sortBy)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Def (DBind, DData), Program (..)) import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
@ -19,8 +20,14 @@ import Monomorphizer.MonomorphizerIr as MIR (Def (DBind, DData), Program (..))
-} -}
generateCode :: MIR.Program -> Err String generateCode :: MIR.Program -> Err String
generateCode (MIR.Program scs) = do generateCode (MIR.Program scs) = do
let codegen = initCodeGenerator scs let tree = filter (not . detectPrelude) (sortBy lowData scs)
llvmIrToString . instructions <$> execStateT (compileScs (sortBy lowData scs)) codegen let codegen = initCodeGenerator tree
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
detectPrelude :: Def -> Bool
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True
detectPrelude _ = False
lowData :: Def -> Def -> Ordering lowData :: Def -> Def -> Ordering
lowData (DData _) (DBind _) = LT lowData (DData _) (DBind _) = LT

View file

@ -228,15 +228,15 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit i, t) exp) = do emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do
emit $ Comment "Plit" emit $ Comment "Plit"
let i' = case i of let i' = case i of
(MIR.LInt i, _) -> VInteger i MIR.LInt i -> VInteger i
(MIR.LChar i, _) -> VChar (ord i) MIR.LChar i -> VChar (ord i)
ns <- getNewVar ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i') emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos emit $ Label lbl_succPos
val <- exprToValue exp val <- exprToValue exp
@ -255,9 +255,13 @@ emitECased t e cases = do
emit $ Br label emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp)
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp)
emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do
-- //TODO Penum wrong, acts as a catch all -- //TODO Penum wrong, acts as a catch all
emit $ Comment "Penum" emit $ Comment $ "Penum " <> show _id
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
@ -290,7 +294,10 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
<|> Global <$ Map.lookup (name, t) funcs <|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global` -- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType rt) visibility name args' let call =
case name of
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
_ -> Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt emit $ Comment $ show rt
emit $ SetVariable vs call emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x x -> error $ "The unspeakable happened: " <> show x

View file

@ -166,4 +166,4 @@ printToErr = hPutStrLn stderr
fromErr :: Err a -> IO a fromErr :: Err a -> IO a
fromErr = either (\s -> printToErr s >> exitFailure) pure fromErr = either (\s -> printToErr s >> exitFailure) pure
prelude = "const x y = x\n\ndata Bool () where\n True : Bool ()\n False : Bool ()\n\nlt : Int -> Int -> Bool ()\nlt = \\x. \\y. const True (x + y)" prelude = "\n\nconst x y = x\n\ndata Bool () where\n True : Bool ()\n False : Bool ()\n\nlt : Int -> Int -> Bool ()\nlt = \\x. \\y. const True (x + y)"

View file

@ -1,6 +1,7 @@
module Monomorphizer.DataTypeRemover (removeDataTypes) where module Monomorphizer.DataTypeRemover (removeDataTypes) where
import qualified Monomorphizer.MorbIr as M1
import qualified Monomorphizer.MonomorphizerIr as M2 import Monomorphizer.MonomorphizerIr qualified as M2
import Monomorphizer.MorbIr qualified as M1
import TypeChecker.TypeCheckerIr (Ident (Ident)) import TypeChecker.TypeCheckerIr (Ident (Ident))
removeDataTypes :: M1.Program -> M2.Program removeDataTypes :: M1.Program -> M2.Program
@ -17,9 +18,10 @@ pCons :: M1.Inj -> M2.Inj
pCons (M1.Inj ident t) = M2.Inj ident (pType t) pCons (M1.Inj ident t) = M2.Inj ident (pType t)
pType :: M1.Type -> M2.Type pType :: M1.Type -> M2.Type
pType (M1.TLit ident) = M2.TLit ident pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2) pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType d = M2.TLit (Ident (newName d)) -- This is the step pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool")
pType d = M2.TLit (Ident (newName d)) -- This is the step
newName :: M1.Type -> String newName :: M1.Type -> String
newName (M1.TLit (Ident str)) = str newName (M1.TLit (Ident str)) = str
@ -36,24 +38,23 @@ pExpT :: M1.ExpT -> M2.ExpT
pExpT (exp, t) = (pExp exp, pType t) pExpT (exp, t) = (pExp exp, pType t)
pExp :: M1.Exp -> M2.Exp pExp :: M1.Exp -> M2.Exp
pExp (M1.EVar ident) = M2.EVar ident pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.ELit lit) = M2.ELit (pLit lit) pExp (M1.ELit lit) = M2.ELit (pLit lit)
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches) pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches)
pBranch :: M1.Branch -> M2.Branch pBranch :: M1.Branch -> M2.Branch
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt) pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
pPattern :: M1.Pattern -> M2.Pattern pPattern :: M1.Pattern -> M2.Pattern
pPattern (M1.PVar id) = M2.PVar (pId id) pPattern (M1.PVar id) = M2.PVar (pId id)
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t) pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t)
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts) pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts)
pPattern M1.PCatch = M2.PCatch pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident pPattern (M1.PEnum ident) = M2.PEnum ident
pLit :: M1.Lit -> M2.Lit pLit :: M1.Lit -> M2.Lit
pLit (M1.LInt v) = M2.LInt v pLit (M1.LInt v) = M2.LInt v
pLit (M1.LChar c) = M2.LChar c pLit (M1.LChar c) = M2.LChar c

View file

@ -1,72 +1,84 @@
-- | For now, converts polymorphic functions to concrete ones based on usage.
-- Assumes lambdas are lifted.
--
-- This step of compilation is as follows:
--
-- Split all function bindings into monomorphic and polymorphic binds. The
-- monomorphic bindings will be part of this compilation step.
-- Apply the following monomorphization function on all monomorphic binds, with
-- their type as an additional argument.
--
-- The function that transforms Binds operates on both monomorphic and
-- polymorphic functions, creates a context in which all possible polymorphic types
-- are mapped to concrete types, created using the additional argument.
-- Expressions are then recursively processed. The type of these expressions
-- are changed to using the mapped generic types. The expected type provided
-- in the recursion is changed depending on the different nodes.
--
-- When an external bind is encountered (with EId), it is checked whether it
-- exists in outputed binds or not. If it does, nothing further is evaluated.
-- If not, the bind transformer function is called on it with the
-- expected type in this context. The result of this computation (a monomorphic
-- bind) is added to the resulting set of binds.
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted.
This step of compilation is as follows:
Split all function bindings into monomorphic and polymorphic binds. The
monomorphic bindings will be part of this compilation step.
Apply the following monomorphization function on all monomorphic binds, with
their type as an additional argument.
The function that transforms Binds operates on both monomorphic and
polymorphic functions, creates a context in which all possible polymorphic types
are mapped to concrete types, created using the additional argument.
Expressions are then recursively processed. The type of these expressions
are changed to using the mapped generic types. The expected type provided
in the recursion is changed depending on the different nodes.
When an external bind is encountered (with EId), it is checked whether it
exists in outputed binds or not. If it does, nothing further is evaluated.
If not, the bind transformer function is called on it with the
expected type in this context. The result of this computation (a monomorphic
bind) is added to the resulting set of binds.
-}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Monomorphizer.DataTypeRemover (removeDataTypes) import Monomorphizer.DataTypeRemover (removeDataTypes)
import qualified Monomorphizer.MonomorphizerIr as O import Monomorphizer.MonomorphizerIr qualified as O
import qualified Monomorphizer.MorbIr as M import Monomorphizer.MorbIr qualified as M
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr (Ident (Ident))
import TypeChecker.TypeCheckerIr (Ident (Ident)) import TypeChecker.TypeCheckerIr qualified as T
import Control.Monad.Reader (MonadReader (ask, local), import Control.Monad.Reader (
Reader, asks, runReader, when) MonadReader (ask, local),
import Control.Monad.State (MonadState, StateT (runStateT), Reader,
gets, modify) asks,
import Data.Coerce (coerce) runReader,
import qualified Data.Map as Map when,
import Data.Maybe (fromJust) )
import qualified Data.Set as Set import Control.Monad.State (
import Debug.Trace MonadState,
import Grammar.Print (printTree) StateT (runStateT),
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (fromJust)
import Data.Set qualified as Set
import Debug.Trace
import Grammar.Print (printTree)
-- | EnvM is the monad containing the read-only state as well as the {- | EnvM is the monad containing the read-only state as well as the
-- output state containing monomorphized functions and to-be monomorphized output state containing monomorphized functions and to-be monomorphized
-- data type declarations. data type declarations.
-}
newtype EnvM a = EnvM (StateT Output (Reader Env) a) newtype EnvM a = EnvM (StateT Output (Reader Env) a)
deriving (Functor, Applicative, Monad, MonadState Output, MonadReader Env) deriving (Functor, Applicative, Monad, MonadState Output, MonadReader Env)
type Output = Map.Map Ident Outputted type Output = Map.Map Ident Outputted
-- | Data structure describing outputted top-level information, that is {- | Data structure describing outputted top-level information, that is
-- Binds, Polymorphic Data types (monomorphized in a later step) and Binds, Polymorphic Data types (monomorphized in a later step) and
-- Marked bind, which means that it is in the process of monomorphization Marked bind, which means that it is in the process of monomorphization
-- and should not be monomorphized again. and should not be monomorphized again.
-}
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data data Outputted = Marked | Complete M.Bind | Data M.Type T.Data
-- | Static environment. -- | Static environment.
data Env = Env { data Env = Env
-- | All binds in the program. { input :: Map.Map Ident T.Bind
input :: Map.Map Ident T.Bind, -- ^ All binds in the program.
-- | All constructors mapped to their respective polymorphic data def , dataDefs :: Map.Map Ident T.Data
-- which includes all other constructors. -- ^ All constructors mapped to their respective polymorphic data def
dataDefs :: Map.Map Ident T.Data, -- which includes all other constructors.
-- | Maps polymorphic identifiers with concrete types. , polys :: Map.Map Ident M.Type
polys :: Map.Map Ident M.Type, -- ^ Maps polymorphic identifiers with concrete types.
-- | Local variables. , locals :: Set.Set Ident
locals :: Set.Set Ident -- ^ Local variables.
} }
-- | Determines if the identifier describes a local variable in the given context. -- | Determines if the identifier describes a local variable in the given context.
localExists :: Ident -> EnvM Bool localExists :: Ident -> EnvM Bool
@ -80,8 +92,9 @@ getInputBind ident = asks (Map.lookup ident . input)
addOutputBind :: M.Bind -> EnvM () addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b)) addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
-- | Marks a global bind as being processed, meaning that when encountered again, {- | Marks a global bind as being processed, meaning that when encountered again,
-- it should not be recursively processed. it should not be recursively processed.
-}
markBind :: Ident -> EnvM () markBind :: Ident -> EnvM ()
markBind ident = modify (Map.insert ident Marked) markBind ident = modify (Map.insert ident Marked)
@ -96,125 +109,143 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
Nothing -> error "main not found in monomorphizer!" Nothing -> error "main not found in monomorphizer!"
) )
-- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime {- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime
-- error when encountering different structures between the two arguments. error when encountering different structures between the two arguments.
-}
mapTypes :: T.Type -> M.Type -> [(Ident, M.Type)] mapTypes :: T.Type -> M.Type -> [(Ident, M.Type)]
mapTypes (T.TLit _) (M.TLit _) = [] mapTypes (T.TLit _) (M.TLit _) = []
mapTypes (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] mapTypes (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
mapTypes (T.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes pt1 mt1 ++ mapTypes (T.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes pt2 mt2 mapTypes pt1 mt1
mapTypes (T.TData tIdent pTs) (M.TData mIdent mTs) = if tIdent /= mIdent ++ mapTypes pt2 mt2
then error "the data type names of monomorphic and polymorphic data types does not match" mapTypes (T.TData tIdent pTs) (M.TData mIdent mTs) =
else foldl (\xs (p, m) -> mapTypes p m ++ xs) [] (zip pTs mTs) if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes p m ++ xs) [] (zip pTs mTs)
mapTypes t1 t2 = error $ "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" mapTypes t1 t2 = error $ "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context. -- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: T.Type -> EnvM M.Type getMonoFromPoly :: T.Type -> EnvM M.Type
getMonoFromPoly t = do env <- ask getMonoFromPoly t = do
return $ getMono (polys env) t env <- ask
where return $ getMono (polys env) t
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type where
getMono polys t = case t of getMono :: Map.Map Ident M.Type -> T.Type -> M.Type
(T.TLit ident) -> M.TLit (coerce ident) getMono polys t = case t of
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) (T.TLit ident) -> M.TLit (coerce ident)
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of (T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
Just concrete -> concrete (T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of
Nothing -> M.TLit (Ident "void") Just concrete -> concrete
--error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" Nothing -> M.TLit (Ident "void")
(T.TData ident args) -> M.TData ident (map (getMono polys) args) -- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args)
-- | If ident not already in env's output, morphed bind to output {- | If ident not already in env's output, morphed bind to output
-- (and all referenced binds within this bind). (and all referenced binds within this bind).
-- Returns the annotated bind name. Returns the annotated bind name.
-}
morphBind :: M.Type -> T.Bind -> EnvM Ident morphBind :: M.Type -> T.Bind -> EnvM Ident
morphBind expectedType b@(T.Bind (Ident str, btype) args (exp, expt)) = morphBind expectedType b@(T.Bind (Ident str, btype) args (exp, expt)) =
local (\env -> env { locals = Set.fromList (map fst args), local
polys = Map.fromList (mapTypes btype expectedType) ( \env ->
}) $ do env
-- The "new name" is used to find out if it is already marked or not. { locals = Set.fromList (map fst args)
let name' = newFuncName expectedType b , polys = Map.fromList (mapTypes btype expectedType)
bindMarked <- isBindMarked (coerce name') }
-- Return with right name if already marked )
if bindMarked then return name' else do $ do
-- Mark so that this bind will not be processed in recursive or cyclic -- The "new name" is used to find out if it is already marked or not.
-- function calls let name' = newFuncName expectedType b
markBind (coerce name') bindMarked <- isBindMarked (coerce name')
expt' <- getMonoFromPoly expt -- Return with right name if already marked
exp' <- morphExp expt' exp if bindMarked
-- Get monomorphic type sof args then return name'
args' <- mapM morphArg args else do
addOutputBind $ M.Bind (coerce name', expectedType) -- Mark so that this bind will not be processed in recursive or cyclic
args' (exp', expt') -- function calls
return name' markBind (coerce name')
expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp
-- Get monomorphic type sof args
args' <- mapM morphArg args
addOutputBind $
M.Bind
(coerce name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind. -- | Monomorphizes arguments of a bind.
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type) morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do t' <- getMonoFromPoly t morphArg (ident, t) = do
return (ident, t') t' <- getMonoFromPoly t
return (ident, t')
-- | Gets the data bind from the name of a constructor. -- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe T.Data) getInputData :: Ident -> EnvM (Maybe T.Data)
getInputData ident = do env <- ask getInputData ident = do
return $ Map.lookup ident (dataDefs env) env <- ask
return $ Map.lookup ident (dataDefs env)
-- | Monomorphize a constructor using it's global name. Constructors may {- | Monomorphize a constructor using it's global name. Constructors may
-- appear as expressions in the tree, or as patterns in case-expressions. appear as expressions in the tree, or as patterns in case-expressions.
-}
morphCons :: M.Type -> Ident -> EnvM () morphCons :: M.Type -> Ident -> EnvM ()
morphCons expectedType ident = do morphCons expectedType ident = do
maybeD <- getInputData ident maybeD <- getInputData ident
case maybeD of case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found" Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
Just d -> do Just d -> do
modify (\output -> Map.insert ident (Data expectedType d) output ) modify (\output -> Map.insert ident (Data expectedType d) output)
-- | Converts literals from input to output tree. -- | Converts literals from input to output tree.
convertLit :: T.Lit -> M.Lit convertLit :: T.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v convertLit (T.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v convertLit (T.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type. -- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> T.Exp -> EnvM M.Exp morphExp :: M.Type -> T.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of morphExp expectedType exp = case exp of
T.ELit lit -> return $ M.ELit (convertLit lit) T.ELit lit -> return $ M.ELit (convertLit lit)
-- Constructor -- Constructor
T.EInj ident -> do T.EInj ident -> do
return $ M.EVar ident return $ M.EVar ident
T.EApp (e1, _t1) (e2, t2) -> do T.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1 e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2') return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
T.EAdd (e1, t1) (e2, t2) -> do T.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1 t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1 e1' <- morphExp t1' e1
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType) return $ M.EAdd (e1', expectedType) (e2', expectedType)
T.EAbs ident (exp, t) -> local (\env -> env { locals = Set.insert ident (locals env) }) $ do T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
morphExp t' exp morphExp t' exp
T.ECase (exp, t) bs -> do T.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
bs' <- mapM morphBranch bs bs' <- mapM morphBranch bs
exp' <- morphExp t' exp exp' <- morphExp t' exp
return $ M.ECase (exp', t') bs' return $ M.ECase (exp', t') bs'
T.EVar ident -> do T.EVar ident -> do
isLocal <- localExists ident isLocal <- localExists ident
if isLocal then do if isLocal
return $ M.EVar (coerce ident) then do
else do return $ M.EVar (coerce ident)
bind <- getInputBind ident else do
case bind of bind <- getInputBind ident
Nothing -> do case bind of
-- This is a constructor Nothing -> do
morphCons expectedType ident -- This is a constructor
return $ M.EVar ident morphCons expectedType ident
Just bind' -> do return $ M.EVar ident
-- New bind to process Just bind' -> do
newBindName <- morphBind expectedType bind' -- New bind to process
return $ M.EVar (coerce newBindName) newBindName <- morphBind expectedType bind'
return $ M.EVar (coerce newBindName)
T.ELet (T.Bind {}) _ -> error "lets not possible yet" T.ELet (T.Bind{}) _ -> error "lets not possible yet"
-- | Monomorphizes case-of branches. -- | Monomorphizes case-of branches.
morphBranch :: T.Branch -> EnvM M.Branch morphBranch :: T.Branch -> EnvM M.Branch
@ -243,28 +274,32 @@ morphPattern p expectedType = case p of
-- | Creates a new identifier for a function with an assigned type. -- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> T.Bind -> Ident newFuncName :: M.Type -> T.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main" if bindName == "main"
then Ident bindName then Ident bindName
else newName t ident else newName t ident
newName :: M.Type -> Ident -> Ident newName :: M.Type -> Ident -> Ident
newName t (Ident str) = Ident $ str ++ "$" ++ newName' t newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
where where
newName' :: M.Type -> String newName' :: M.Type -> String
newName' (M.TLit (Ident str)) = str newName' (M.TLit (Ident str)) = str
newName' (M.TFun t1 t2) = newName' t1 ++ "_" ++ newName' t2 newName' (M.TFun t1 t2) = newName' t1 ++ "_" ++ newName' t2
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step. -- | Monomorphization step.
monomorphize :: T.Program -> O.Program monomorphize :: T.Program -> O.Program
monomorphize (T.Program defs) = removeDataTypes $ M.Program (getDefsFromOutput monomorphize (T.Program defs) =
(runEnvM Map.empty (createEnv defs) monomorphize')) removeDataTypes $
where M.Program
monomorphize' :: EnvM () ( getDefsFromOutput
monomorphize' = do (runEnvM Map.empty (createEnv defs) monomorphize')
main <- getMain )
morphBind (M.TLit $ Ident "Int") main where
return () monomorphize' :: EnvM ()
monomorphize' = do
main <- getMain
morphBind (M.TLit $ Ident "Int") main
return ()
-- | Runs and gives the output binds. -- | Runs and gives the output binds.
runEnvM :: Output -> Env -> EnvM () -> Output runEnvM :: Output -> Env -> EnvM () -> Output
@ -272,14 +307,17 @@ runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds. -- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env createEnv :: [T.Def] -> Env
createEnv defs = Env { input = Map.fromList bindPairs, createEnv defs =
dataDefs = Map.fromList dataPairs, Env
polys = Map.empty, { input = Map.fromList bindPairs
locals = Set.empty } , dataDefs = Map.fromList dataPairs
where , polys = Map.empty
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs , locals = Set.empty
dataPairs :: [(Ident, T.Data)] }
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, T.Data)]
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name. -- | Gets a top-lefel function name.
getBindName :: T.Bind -> Ident getBindName :: T.Bind -> Ident
@ -288,51 +326,64 @@ getBindName (T.Bind (ident, _) _ _) = ident
-- Helper functions -- Helper functions
-- Gets custom data declarations form defs. -- Gets custom data declarations form defs.
getDataFromDefs :: [T.Def] -> [T.Data] getDataFromDefs :: [T.Def] -> [T.Data]
getDataFromDefs = foldl (\bs -> \case getDataFromDefs =
T.DBind _ -> bs foldl
T.DData d -> d:bs) [] ( \bs -> \case
T.DBind _ -> bs
T.DData d -> d : bs
)
[]
getConsName :: T.Inj -> Ident getConsName :: T.Inj -> Ident
getConsName (T.Inj ident _) = ident getConsName (T.Inj ident _) = ident
getBindsFromDefs :: [T.Def] -> [T.Bind] getBindsFromDefs :: [T.Def] -> [T.Bind]
getBindsFromDefs = foldl (\bs -> \case getBindsFromDefs =
T.DBind b -> b:bs foldl
T.DData _ -> bs) [] ( \bs -> \case
T.DBind b -> b : bs
T.DData _ -> bs
)
[]
getDefsFromOutput :: Output -> [M.Def] getDefsFromOutput :: Output -> [M.Def]
getDefsFromOutput o = getDefsFromOutput o =
map M.DBind binds ++ map M.DBind binds
(map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty) ++ (map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty)
where where
(binds, dataInput) = splitBindsAndData o (binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData) -- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)]) splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)])
splitBindsAndData output = foldl splitBindsAndData output =
(\(oBinds, oData) (ident, o) -> case o of foldl
Marked -> error "internal bug in monomorphizer" ( \(oBinds, oData) (ident, o) -> case o of
Complete b -> (b:oBinds, oData) Marked -> error "internal bug in monomorphizer"
Data t d -> (oBinds, (ident, t, d):oData)) Complete b -> (b : oBinds, oData)
([], []) Data t d -> (oBinds, (ident, t, d) : oData)
(Map.toList output) )
([], [])
(Map.toList output)
-- | Converts all found constructors to monomorphic data declarations. -- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o createNewData [] o = o
createNewData ((consIdent, consType, polyData):input) o = createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $ createNewData input $
Map.insertWith (\_ (M.Data _ cs) -> M.Data newDataType (newCons:cs)) Map.insertWith
newDataName (M.Data newDataType [newCons]) o (\_ (M.Data _ cs) -> M.Data newDataType (newCons : cs))
where newDataName
T.Data (T.TData polyDataIdent _) _ = polyData (M.Data newDataType [newCons])
newDataType = getDataType consType o
newDataName = newName newDataType polyDataIdent where
newCons = M.Inj consIdent consType T.Data (T.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a). -- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
getDataType :: M.Type -> M.Type getDataType :: M.Type -> M.Type
getDataType (M.TFun t1 t2) = getDataType t2 getDataType (M.TFun t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???" getDataType _ = error "???"

View file

@ -2,16 +2,15 @@
module TypeChecker.ReportTEVar where module TypeChecker.ReportTEVar where
import Auxiliary (onM) import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3) import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError)) import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM) import Data.Tuple.Extra (secondM)
import qualified Grammar.Abs as G import Grammar.Abs qualified as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..)) import TypeChecker.TypeCheckerIr hiding (Type (..))
data Type data Type
= TLit Ident = TLit Ident
@ -30,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where
instance ReportTEVar (Def' G.Type) (Def' Type) where instance ReportTEVar (Def' G.Type) (Def' Type) where
reportTEVar = \case reportTEVar = \case
DBind bind -> DBind <$> reportTEVar bind DBind bind -> DBind <$> reportTEVar bind
DData dat -> DData <$> reportTEVar dat DData dat -> DData <$> reportTEVar dat
instance ReportTEVar (Bind' G.Type) (Bind' Type) where instance ReportTEVar (Bind' G.Type) (Bind' Type) where
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs) reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
instance ReportTEVar (Exp' G.Type) (Exp' Type) where instance ReportTEVar (Exp' G.Type) (Exp' Type) where
reportTEVar exp = case exp of reportTEVar exp = case exp of
EVar name -> pure $ EVar name EVar name -> pure $ EVar name
EInj name -> pure $ EInj name EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e) ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2 EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2 EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e EAbs name e -> EAbs name <$> reportTEVar e
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches) ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
instance ReportTEVar (Branch' G.Type) (Branch' Type) where instance ReportTEVar (Branch' G.Type) (Branch' Type) where
@ -54,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
reportTEVar = \case reportTEVar = \case
PVar name -> pure $ PVar name PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit PLit lit -> pure $ PLit lit
PCatch -> pure PCatch PCatch -> pure PCatch
PEnum name -> pure $ PEnum name PEnum name -> pure $ PEnum name
PInj name ps -> PInj name <$> reportTEVar ps PInj name ps -> PInj name <$> reportTEVar ps
instance ReportTEVar (Data' G.Type) (Data' Type) where instance ReportTEVar (Data' G.Type) (Data' Type) where
@ -77,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where
instance ReportTEVar G.Type Type where instance ReportTEVar G.Type Type where
reportTEVar = \case reportTEVar = \case
G.TLit lit -> pure $ TLit (coerce lit) G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i) G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)