Merge branch 'prep-tc-martin' of github.com:bachelor-group-66-systemf/language into prep-tc-martin

This commit is contained in:
Samuel Hammersberg 2023-02-18 15:03:11 +01:00
commit a4c12ede79
8 changed files with 590 additions and 640 deletions

View file

@ -1,19 +1,15 @@
Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}";
--
CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ",";
--terminator CaseMatch ".";
CInt. Case ::= Integer ;

View file

@ -3,18 +3,19 @@
-- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
-- main : Int;
-- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
--
-- main : Int;
--main = (\x : Int . x + 5) 5
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x y = f x y;
krimp: Int -> Int -> Int;
krimp x y = x + y;
main : Int;
main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2;
main = apply (krimp) 2 3;
-- answer: 5

View file

@ -3,22 +3,21 @@
module Compiler (compile) where
import Auxiliary (snoc)
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.List.Extra (trim)
--import Data.List.Extra (trim)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (second)
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (LLVMComp (..), LLVMIr (..),
LLVMType (..), LLVMValue (..),
Visibility (..), llvmIrToString)
import System.IO (stdin)
import System.Process.Extra (CreateProcess (std_in),
StdStream (CreatePipe), createProcess,
readCreateProcess, shell)
import LlvmIr (LLVMComp (..), LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..),
llvmIrToString)
--import System.Process.Extra (readCreateProcess, shell)
import TypeChecker (partitionType)
import TypeCheckerIr
import TypeCheckerIr (Bind (..), CLit (CInt, CatchAll),
Case (..), Exp (..), Id, Ident (..),
Program (..), Type (TFun, TInt))
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
@ -38,11 +37,11 @@ data FunctionInfo = FunctionInfo
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify (\t -> t{instructions = instructions t ++ [l]})
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1})
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
@ -58,69 +57,80 @@ getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions xs =
Map.fromList $
map
( \(Bind id args _) ->
( id
, FunctionInfo
{ numArgs = length args
, arguments = args
}
)
)
xs
getFunctions bs = Map.fromList $ map go bs
where
go (Bind id args _) =
(id, FunctionInfo { numArgs=length args, arguments=args })
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
ECased (EId ("x", TInt)) [
Case (CInt 0) (EInt 0),
Case (CInt 1) (EInt 1),
Case CatchAll (EAdd TInt
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2)))
))
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
))
)
]
),
Bind (Ident "main",TInt) [] (
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
)
]
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, variableCount = 0
, labelCount = 0
}
--run :: Err String -> IO ()
--run s = do
-- let s' = case s of
-- Right s -> s
-- Left _ -> error "yo"
-- writeFile "llvm.ll" s'
-- putStrLn . trim =<< readCreateProcess (shell "lli") s'
--
--test :: Integer -> Program
--test v = Program [
-- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
-- ECased (EId ("x", TInt)) [
-- Case (CInt 0) (EInt 0),
-- Case (CInt 1) (EInt 1),
-- Case CatchAll (EAdd TInt
-- (EApp TInt (EId (Ident "fibonacci", TInt)) (
-- EAdd TInt (EId (Ident "x", TInt))
-- (EInt (fromIntegral ((maxBound :: Int) * 2)))
-- ))
-- (EApp TInt (EId (Ident "fibonacci", TInt)) (
-- EAdd TInt (EId (Ident "x", TInt))
-- (EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
-- ))
-- )
-- ]
-- ),
-- Bind (Ident "main",TInt) [] (
-- EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
-- )
-- ]
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
compile :: Program -> Err String
compile (Program prg) = do
let s =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions prg
, variableCount = 0
, labelCount = 0
}
ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins
compile (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where
t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
@ -140,37 +150,18 @@ compile (Program prg) = do
]
defaultStart :: [LLVMIr]
defaultStart =
[ Comment (show $ printTree (Program prg))
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
goDef :: [Bind] -> CompilerState ()
goDef [] = return ()
goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit (mainContent functionBody)
else emit $ Ret I64 functionBody
emit DefineEnd
modify (\s -> s{variableCount = 0})
goDef xs
where
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState ()
go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2
go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn
go (ECased e c) = emitECased e c
compileExp :: Exp -> CompilerState ()
compileExp (EInt int) = emitInt int
compileExp (EAdd t e1 e2) = emitAdd t e1 e2
compileExp (EId (name, _)) = emitIdent name
compileExp (EApp t e1 e2) = emitApp t e1 e2
compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ELet binds e) = emitLet binds e
compileExp (ECased e c) = emitECased e c
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
@ -209,14 +200,11 @@ compile (Program prg) = do
emit $ Br label
emitEAnn :: CompilerState ()
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do
emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: [Bind] -> Exp -> CompilerState ()
emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
@ -240,10 +228,9 @@ compile (Program prg) = do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let vis = case Map.lookup id funcs of
Nothing -> Local
Just _ -> Global
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
@ -316,22 +303,26 @@ compile (Program prg) = do
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue (EInt i) = return $ VInteger i
exprToValue (EId id@(name, t)) = do
exprToValue = \case
EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name [])
return $ VIdent (Ident $ show vc) (type2LlvmType t)
else return $ VFunction name Global (type2LlvmType t)
Nothing -> return $ VIdent name (type2LlvmType t)
exprToValue e = do
go e
emit $ SetVariable (Ident $ show vc)
(Call (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
return $ VIdent (Ident $ show v) (getType e)
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
@ -352,7 +343,7 @@ getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t
getType (ECased e cs) = undefined
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64

View file

@ -7,16 +7,18 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Foldable.Extra (notNull)
import Data.List (mapAccumL, partition)
import Data.Set (Set, (\\))
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import Renamer hiding (fromBinders)
import Renamer
import TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
@ -29,7 +31,6 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
@ -49,35 +50,22 @@ freeVarsExp localVars = \case
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in binders and the expression
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e')
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = rhss_frees \\ names_set
e_free = freeVarsOf e' \\ names_set
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhss_frees = foldr1 Set.union (map freeVarsOf rhss')
names_set = Set.fromList names
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
(names, parms, rhss) = fromBinders binders
rhss' = map (freeVarsExp e_localVars) rhss
e_localVars = Set.union localVars names_set
binders' = zipWith3 ABind names parms rhss'
e' = freeVarsExp e_localVars e
EAnn e t -> (freeVarsOf e', AAnn e' t)
where
e' = freeVarsExp localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
@ -87,14 +75,11 @@ data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet [ABind] AnnExp
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
| AAnn AnnExp Type
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
@ -124,7 +109,7 @@ abstractExp (free, exp) = case exp of
AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
@ -141,14 +126,13 @@ abstractExp (free, exp) = case exp of
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
nextNumber :: State Int Int
nextNumber = do
@ -156,7 +140,7 @@ nextNumber = do
put $ succ i
pure i
-- | Collects supercombinators by lifting appropriate let expressions
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
@ -184,32 +168,18 @@ collectScsExp = \case
where
(scs, e') = collectScsExp e
-- Collect supercombinators from binds, the rhss, and the expression.
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let
-- > sc = rhs
-- > sc1 = rhs1
-- > ...
-- > in e
-- > f = let sc x y = rhs in e
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
EAnn e t -> (scs, EAnn e' t)
where
(scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
@ -218,7 +188,3 @@ flattenLambdas = go . (, [])
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e
mkEAbs bs e = ELet bs e

View file

@ -68,9 +68,8 @@ instance Show Visibility where
show Local = "%"
show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant
-}
-- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant
data LLVMValue
= VInteger Integer
| VIdent Ident LLVMType

View file

@ -2,82 +2,83 @@
module Renamer (module Renamer) where
import Data.List (mapAccumL, unzip4, zipWith4)
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all supercombinators and variables
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program sc) = Program $ map (renameSc 0) sc
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where
renameSc i (Bind n t _ xs e) = Bind n t n xs' e'
where
(i1, xs', env) = newNames i xs
e' = snd $ renameExp env i1 e
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp)
renameExp env i = \case
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
EInt i1 -> (i, EInt i1)
EApp e1 e2 -> (i2, EApp e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
EAdd e1 e2 -> (i2, EAdd e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e')
where
mkBind name t = Bind name t name
(i1, e') = renameExp e_env i e
(names, types, pars, rhss) = fromBinders bs
(i2, names', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind
renameSc old_names (Bind name t _ parms rhs) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs'
EAbs par t e -> (i2, EAbs par' t e')
where
(i1, par', env') = newName par
(i2, e') = renameExp (Map.union env' env ) i1 e
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }
deriving (Functor, Applicative, Monad, MonadState Int)
EAnn e t -> (i1, EAnn e' t)
where
(i1, e') = renameExp env i e
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
newName :: Ident -> (Int, Ident, Map Ident Ident)
newName old_name = (i, head names, env)
where (i, names, env) = newNames 1 [old_name]
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
newNames i old_names = (i', new_names, env)
where
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
EInt i1 -> pure (old_names, EInt i1)
getNames :: Int -> [Ident] -> (Int, [Ident])
getNames i ns = (i + length ss, zipWith makeName ss [i..])
where
ss = map (\(Ident s) -> s) ns
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
makeName :: String -> Int -> Ident
makeName prefix i = Ident (prefix ++ "_" ++ show i)
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ELet b e -> do
(new_names, b) <- renameLocalBind old_names b
(new_names', e') <- renameExp new_names e
pure (new_names', ELet b e')
EAbs par t e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp])
fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ]

View file

@ -14,7 +14,6 @@ import Grammar.Print (Print (prt), concatD, doc, printTree,
import Prelude hiding (exp, id)
import qualified TypeCheckerIr as T
-- NOTE: this type checker is poorly tested
-- TODO
@ -22,8 +21,8 @@ import qualified TypeCheckerIr as T
-- Type inference
data Cxt = Cxt
{ env :: Map Ident Type
, sig :: Map Ident Type
{ env :: Map Ident Type -- ^ Local scope signature
, sig :: Map Ident Type -- ^ Top-level signatures
}
initCxt :: [Bind] -> Cxt
@ -34,30 +33,27 @@ initCxt sc = Cxt { env = mempty
typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
-- | Check if infered rhs type matches type signature.
checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
-- | @ f x y = rhs ⇒ f = \x.\y. rhs @
expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t
-- | Infer type of expression.
infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
@ -87,19 +83,19 @@ infer cxt = \case
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e
pure (T.ELet bs'' e', t)
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
pure (T.ELet b' e', t)
EAnn e t -> do
e' <- check cxt e t
pure (T.EAnn e' t, t)
(e', t1) <- infer cxt e
unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
-- | Check infered type matches the supplied type.
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
@ -109,9 +105,7 @@ check cxt exp typ = case exp of
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
@ -137,25 +131,26 @@ check cxt exp typ = case exp of
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
ELet b e -> do
let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ
pure $ T.ELet bs'' e'
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
pure $ T.ELet b' e'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
e' <- check cxt e t
pure $ T.EAnn e' typ
check cxt e t
-- | Check if types are equivalent. Doesn't handle coercion or polymorphism.
typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1
partitionType :: Int -> Type -> ([Type], Type)
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
@ -163,6 +158,9 @@ partitionType = go []
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env

View file

@ -16,13 +16,18 @@ newtype Program = Program [Bind]
data Exp
= EId Id
| EInt Integer
| ELet [Bind] Exp
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Id Exp
| EAnn Exp Type
| ECased Exp [Case]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Case = Case CLit Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
data CLit = CInt Integer | CatchAll
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
@ -97,12 +102,5 @@ instance Print Exp where
, doc $ showString "."
, prt 0 e
]
EAnn e t -> prPrec i 3 $ concatD
[ doc $ showString "("
, prt 0 e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]