Change grammar: only one bind in let and no EAnn for typed syntax

This commit is contained in:
Martin Fredin 2023-02-18 12:57:23 +01:00
parent 7cedc2e28c
commit a3e57dde7b
7 changed files with 172 additions and 228 deletions

View file

@ -1,24 +1,20 @@
Program. Program ::= [Bind]; Program. Program ::= [Bind];
EId. Exp3 ::= Ident;
EId. Exp3 ::= Ident; EInt. Exp3 ::= Integer;
EInt. Exp3 ::= Integer; EAnn. Exp3 ::= "(" Exp ":" Type ")";
ELet. Exp3 ::= "let" [Bind] "in" Exp; ELet. Exp3 ::= "let" Bind "in" Exp;
EApp. Exp2 ::= Exp2 Exp3; EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2; EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp; EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}"; ECase. Exp ::= "case" Exp "of" "{" [CaseMatch] "}";
--
CaseMatch. CaseMatch ::= Case "=>" Exp ; CaseMatch. CaseMatch ::= Case "=>" Exp ;
separator CaseMatch ","; separator CaseMatch ",";
--terminator CaseMatch ".";
CInt. Case ::= Integer ; CInt. Case ::= Integer ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ; Ident [Ident] "=" Exp;
separator Bind ";"; separator Bind ";";
separator Ident ""; separator Ident "";

View file

@ -1,20 +1,21 @@
--tripplemagic : Int -> Int -> Int -> Int; -- tripplemagic : Int -> Int -> Int -> Int;
--tripplemagic x y z = ((\x:Int. x+x) x) + y + z; -- tripplemagic x y z = ((\x:Int. x+x) x) + y + z;
--main : Int; -- main : Int;
--main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3 -- main = tripplemagic ((\x:Int. x+x+3) ((\x:Int. x) 2)) 5 3
-- answer: 22
--apply : (Int -> Int) -> Int -> Int;
--apply f x = f x;
--
--main : Int;
--main = (\x : Int . x + 5) 5
-- apply : (Int -> Int) -> Int -> Int;
-- apply f x = f x;
-- main : Int;
-- main = apply (\x : Int . x + 5) 5
-- answer: 10
apply : (Int -> Int -> Int) -> Int -> Int -> Int; apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x y = f x y; apply f x y = f x y;
krimp: Int -> Int -> Int; krimp: Int -> Int -> Int;
krimp x y = x + y; krimp x y = x + y;
main : Int; main : Int;
main = apply (krimp) 2 3;--apply (\y: Int . (\x: Int . x + y + 2)) 5 2; main = apply (krimp) 2 3;
-- answer: 5

View file

@ -1,28 +1,24 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Compiler (compile) where module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Tuple.Extra (second) import Data.Tuple.Extra (second)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import LlvmIr ( import LlvmIr (LLVMIr (..), LLVMType (..),
LLVMIr (..), LLVMValue (..), Visibility (..),
LLVMType (..), llvmIrToString)
LLVMValue (..), import TypeChecker (partitionType)
Visibility (..), import TypeCheckerIr
llvmIrToString,
)
import TypeChecker (partitionType)
import TypeCheckerIr
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo , functions :: Map Id FunctionInfo
, variableCount :: Integer , variableCount :: Integer
} }
@ -30,7 +26,7 @@ data CodeGenerator = CodeGenerator
type CompilerState a = StateT CodeGenerator Err a type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo data FunctionInfo = FunctionInfo
{ numArgs :: Int { numArgs :: Int
, arguments :: [Id] , arguments :: [Id]
} }
@ -124,33 +120,29 @@ compile (Program prg) = do
t_return = snd $ partitionType (length args) t t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState () go :: Exp -> CompilerState ()
go (EInt int) = emitInt int go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2 go (EAdd t e1 e2) = emitAdd t e1 e2
go (EId (name, _)) = emitIdent name go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2 go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e go (ELet bind e) = emitLet bind e
go (EAnn _ _) = emitEAnn
-- go (ESub e1 e2) = emitSub e1 e2 -- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitEAnn :: CompilerState ()
emitEAnn = emit . UnsafeRaw $ "why?"
emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do emitAbs _t tid e = do
emit . Comment $ emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: [Bind] -> Exp -> CompilerState () emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do emitLet b e = do
emit $ emit $
Comment $ Comment $
concat concat
[ "ELet (" [ "ELet ("
, show xs , show b
, " = " , " = "
, show e , show e
, ") is not implemented!" , ") is not implemented!"
@ -170,7 +162,7 @@ compile (Program prg) = do
funcs <- gets functions funcs <- gets functions
let vis = case Map.lookup id funcs of let vis = case Map.lookup id funcs of
Nothing -> Local Nothing -> Local
Just _ -> Global Just _ -> Global
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call emit $ SetVariable (Ident $ show vs) call
x -> do x -> do
@ -271,19 +263,18 @@ type2LlvmType = \case
where where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType getType :: Exp -> LLVMType
getType (EInt _) = I64 getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8 valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t valueGetType (VFunction _ _ t) = t

View file

@ -7,12 +7,10 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc) import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState) import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Foldable.Extra (notNull) import Data.Set (Set)
import Data.List (mapAccumL, partition)
import Data.Set (Set, (\\))
import qualified Data.Set as Set import qualified Data.Set as Set
import Prelude hiding (exp) import Prelude hiding (exp)
import Renamer hiding (fromBinders) import Renamer
import TypeCheckerIr import TypeCheckerIr
@ -49,35 +47,22 @@ freeVarsExp localVars = \case
where where
e' = freeVarsExp (Set.insert par localVars) e e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in binders and the expression -- Sum free variables present in bind and the expression
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where where
binders_frees = rhss_frees \\ names_set binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = freeVarsOf e' \\ names_set e_free = Set.delete name $ freeVarsOf e'
rhss_frees = foldr1 Set.union (map freeVarsOf rhss') rhs' = freeVarsExp e_localVars rhs
names_set = Set.fromList names new_bind = ABind name parms rhs'
(names, parms, rhss) = fromBinders binders e' = freeVarsExp e_localVars e
rhss' = map (freeVarsExp e_localVars) rhss e_localVars = Set.insert name localVars
e_localVars = Set.union localVars names_set
binders' = zipWith3 ABind names parms rhss'
e' = freeVarsExp e_localVars e
EAnn e t -> (freeVarsOf e', AAnn e' t)
where
e' = freeVarsExp localVars e
freeVarsOf :: AnnExp -> Set Id freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst freeVarsOf = fst
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
-- AST annotated with free variables -- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)] type AnnProgram = [(Id, [Id], AnnExp)]
@ -87,14 +72,11 @@ data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id data AnnExp' = AId Id
| AInt Integer | AInt Integer
| ALet [ABind] AnnExp | ALet ABind AnnExp
| AApp Type AnnExp AnnExp | AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp | AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp | AAbs Type Id AnnExp
| AAnn AnnExp Type
deriving Show deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound. -- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program abstract :: AnnProgram -> Program
@ -124,7 +106,7 @@ abstractExp (free, exp) = case exp of
AInt i -> pure $ EInt i AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e) ALet b e -> liftA2 ELet (go b) (abstractExp e)
where where
go (ABind name parms rhs) = do go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
@ -141,14 +123,13 @@ abstractExp (free, exp) = case exp of
rhs <- abstractExp e rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t) sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList pure $ foldl (EApp TInt) sc $ map EId freeList
where where
freeList = Set.toList free freeList = Set.toList free
parms = snoc parm freeList parms = snoc parm freeList
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
nextNumber :: State Int Int nextNumber :: State Int Int
nextNumber = do nextNumber = do
@ -156,7 +137,7 @@ nextNumber = do
put $ succ i put $ succ i
pure i pure i
-- | Collects supercombinators by lifting appropriate let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where where
@ -167,49 +148,35 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs
collectScsExp :: Exp -> ([Bind], Exp) collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case collectScsExp = \case
EId n -> ([], EId n) EId n -> ([], EId n)
EInt i -> ([], EInt i) EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2') EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2') EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e') EAbs t par e -> (scs, EAbs t par e')
where where
(scs, e') = collectScsExp e (scs, e') = collectScsExp e
-- Collect supercombinators from binds, the rhss, and the expression. -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- > f = let -- > f = let sc x y = rhs in e
-- > sc = rhs --
-- > sc1 = rhs1 ELet (Bind name parms rhs) e -> if null parms
-- > ... then ( rhs_scs ++ e_scs, ELet bind e')
-- > in e else (bind : rhs_scs ++ e_scs, e')
-- where
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') bind = Bind name parms rhs'
where (rhs_scs, rhs') = collectScsExp rhs
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in (e_scs, e') = collectScsExp e
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
EAnn e t -> (scs, EAnn e' t)
where
(scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@ -- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id]) flattenLambdas :: Exp -> (Exp, [Id])
@ -218,7 +185,3 @@ flattenLambdas = go . (, [])
go (e, acc) = case e of go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc) EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc) _ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e
mkEAbs bs e = ELet bs e

View file

@ -2,82 +2,84 @@
module Renamer (module Renamer) where module Renamer (module Renamer) where
import Data.List (mapAccumL, unzip4, zipWith4) import Auxiliary (mapAccumM)
import Data.Map (Map) import Control.Monad.State (MonadState, State, evalState, gets,
import qualified Data.Map as Map modify)
import Data.Maybe (fromMaybe) import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all supercombinators and variables -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
rename (Program sc) = Program $ map (renameSc 0) sc rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where where
renameSc i (Bind n t _ xs e) = Bind n t n xs' e' initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
where renameSc :: Names -> Bind -> Rn Bind
(i1, xs', env) = newNames i xs renameSc old_names (Bind name t _ parms rhs) = do
e' = snd $ renameExp env i1 e (new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) pure $ Bind name t name parms' rhs'
renameExp env i = \case
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
EInt i1 -> (i, EInt i1)
EApp e1 e2 -> (i2, EApp e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
EAdd e1 e2 -> (i2, EAdd e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e')
where
mkBind name t = Bind name t name
(i1, e') = renameExp e_env i e
(names, types, pars, rhss) = fromBinders bs
(i2, names', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
EAbs par t e -> (i2, EAbs par' t e') -- | Rename monad. State holds the number of renamed names.
where newtype Rn a = Rn { runRn :: State Int a }
(i1, par', env') = newName par deriving (Functor, Applicative, Monad, MonadState Int)
(i2, e') = renameExp (Map.union env' env ) i1 e
EAnn e t -> (i1, EAnn e' t) -- | Maps old to new name
where type Names = Map Ident Ident
(i1, e') = renameExp env i e
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
newName :: Ident -> (Int, Ident, Map Ident Ident) renameExp :: Names -> Exp -> Rn (Names, Exp)
newName old_name = (i, head names, env) renameExp old_names = \case
where (i, names, env) = newNames 1 [old_name]
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
newNames i old_names = (i', new_names, env)
where
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
getNames :: Int -> [Ident] -> (Int, [Ident]) EInt i1 -> pure (old_names, EInt i1)
getNames i ns = (i + length ss, zipWith makeName ss [i..])
where
ss = map (\(Ident s) -> s) ns
makeName :: String -> Int -> Ident EApp e1 e2 -> do
makeName prefix i = Ident (prefix ++ "_" ++ show i) (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ELet b e -> do
(new_names, b) <- renameLocalBind old_names b
(new_names', e') <- renameExp new_names e
pure (new_names', ELet b e')
EAbs par t e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' t e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp])
fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ]

View file

@ -87,18 +87,17 @@ infer cxt = \case
let t_abs = TFun t t1 let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs) pure (T.EAbs t_abs (x, t) e', t_abs)
ELet bs e -> do ELet b e -> do
bs'' <- mapM (checkBind cxt') bs' let cxt' = insertBind b cxt
b' <- checkBind cxt' b
(e', t) <- infer cxt' e (e', t) <- infer cxt' e
pure (T.ELet bs'' e', t) pure (T.ELet b' e', t)
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do EAnn e t -> do
e' <- check cxt e t (e', t1) <- infer cxt e
pure (T.EAnn e' t, t) unless (typeEq t t1) $
throwError "Inferred type and type annotation doesn't match"
pure (e', t1)
check :: Cxt -> Exp -> Type -> Err T.Exp check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of check cxt exp typ = case exp of
@ -137,19 +136,19 @@ check cxt exp typ = case exp of
unless (typeEq t1 typ) $ throwError "Wrong lamda type!" unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e' pure $ T.EAbs t1 (x, t) e'
ELet bs e -> do ELet b e -> do
bs'' <- mapM (checkBind cxt') bs' let cxt' = insertBind b cxt
b' <- checkBind cxt' b
e' <- check cxt' e typ e' <- check cxt' e typ
pure $ T.ELet bs'' e' pure $ T.ELet b' e'
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do EAnn e t -> do
unless (typeEq t typ) $ unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match" throwError "Inferred type and type annotation doesn't match"
e' <- check cxt e t check cxt e t
pure $ T.EAnn e' typ
insertBind :: Bind -> Cxt -> Cxt
insertBind (Bind n t _ _ _) = insertEnv n t
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1 typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1

View file

@ -16,11 +16,10 @@ newtype Program = Program [Bind]
data Exp data Exp
= EId Id = EId Id
| EInt Integer | EInt Integer
| ELet [Bind] Exp | ELet Bind Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| EAnn Exp Type
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type) type Id = (Ident, Type)
@ -97,12 +96,5 @@ instance Print Exp where
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
] ]
EAnn e t -> prPrec i 3 $ concatD
[ doc $ showString "("
, prt 0 e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]