Adjust old type checker to new syntax, and refactor lambda lifter to use typed AST

This commit is contained in:
Martin Fredin 2023-02-15 23:55:16 +01:00
parent 514c809b1e
commit 210e55bb15
18 changed files with 554 additions and 145 deletions

View file

@ -7,14 +7,22 @@ EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident "." Exp;
EAbs. Exp ::= "\\" Ident ":" Type "." Exp;
EAnn. Exp3 ::= "(" Exp ":" Type ")";
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ;
Bind. Bind ::= Ident [Ident] "=" Exp;
separator Bind ";";
separator Ident " ";
separator Ident "";
coercions Exp 3;
TInt. Type1 ::= "Int" ;
TPol. Type1 ::= Ident ;
TFun. Type ::= Type1 "->" Type ;
coercions Type 1 ;
comment "--";
comment "{-" "-}";

View file

@ -33,7 +33,10 @@ executable language
Grammar.ErrM
LambdaLifter
Auxiliary
Interpreter
-- Interpreter
Renamer
TypeChecker
TypeCheckerIr
hs-source-dirs: src

View file

@ -1,2 +1,3 @@
f = \x. x+1;
f : Int -> Int;
f = \x:Int. x+1;

3
sample-programs/basic-10 Normal file
View file

@ -0,0 +1,3 @@
main : Int -> Int -> Int;
main x y = (x : Int) + y;

View file

@ -1,4 +1,7 @@
add x = \y. x+y;
main = (\z. z+z) ((add 4) 6);
add : Int -> Int -> Int;
add x = \y:Int. x+y;
main : Int;
main = (\z:Int. z+z) ((add 4) 6);

View file

@ -1,2 +1,3 @@
main = (\x. x+x+3) ((\x. x) 2)
main : Int;
main = (\x:Int. x+x+3) ((\x:Int. x) 2);

View file

@ -1,2 +1,7 @@
f x = let g = (\y. y+1) in g (g x)
f : Int -> Int;
f x = let
g : Int -> Int;
g = (\y:Int. y+1);
in
g (g x);

View file

@ -1,9 +1,14 @@
id : Int -> Int;
id x = x;
add : Int -> Int -> Int;
add x y = x + y;
double : Int -> Int;
double n = n + n;
apply f x = \y. f x y;
apply : (Int -> Int -> Int) -> Int -> Int -> Int;
apply f x = \y:Int. f x y;
main = apply (id add) ((\x. x + 1) 1) (double 3);
main : Int;
main = apply add ((\x:Int. x + 1) 1) (double (id 3));

View file

@ -1,3 +1,4 @@
f = \x.\y. x+y
f : Int -> Int -> Int;
f = \x:Int.\y:Int. x+y;

View file

@ -1,5 +1,8 @@
add : Int -> Int -> Int;
add x y = x + y;
apply : (Int -> Int) -> Int -> Int;
apply f x = f x;
main : Int;
main = apply (add 4) 6;

View file

@ -1,2 +1,7 @@
f x = let double = \y. y+y in (\x. x+y) 4;
f : Int -> Int;
f x = let
double : Int -> Int;
double = \y:Int. y+y
in
double (x + 4);

View file

@ -1,4 +1,5 @@
main = (\f.\x.\y. f x + f y) (\x. x+x) ((\x. x+1) ((\x. x+3) 2)) 4
main : Int;
main = (\f:Int -> Int.\x:Int.\y:Int. f x + f y) (\x:Int. x+x) ((\x:Int. x+1) ((\x:Int. x+3) 2)) 4

View file

@ -1,4 +1,4 @@
{-# LANGUAGE LambdaCase #-}
module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError)
@ -9,3 +9,13 @@ snoc x xs = xs ++ [x]
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err
mapAccumM :: Monad m => (s -> a -> m (s, b)) -> s -> [a] -> m (s, [b])
mapAccumM f = go
where
go acc = \case
[] -> pure (acc, [])
x:xs -> do
(acc', x') <- f acc x
(acc'', xs') <- go acc' xs
pure (acc'', x':xs')

View file

@ -5,21 +5,20 @@
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Foldable.Extra (notNull)
import Data.List (mapAccumL, mapAccumR, partition)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, mapMaybe)
import Data.List (mapAccumL, partition)
import Data.Set (Set, (\\))
import qualified Data.Set as Set
import Data.Tuple.Extra (uncurry3)
import Grammar.Abs
import Prelude hiding (exp)
import Renamer hiding (fromBinders)
import TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
lambdaLift :: Program -> Program
lambdaLift = collectScs . rename . abstract . freeVars
lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
@ -28,25 +27,25 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Ident -> Exp -> AnnExp
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2')
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd e1' e2')
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e')
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where
e' = freeVarsExp (Set.insert par localVars) e
@ -66,143 +65,111 @@ freeVarsExp localVars = \case
binders' = zipWith3 ABind names parms rhss'
e' = freeVarsExp e_localVars e
EAnn e t -> (freeVarsOf e', AAnn e' t)
where
e' = freeVarsExp localVars e
freeVarsOf :: AnnExp -> Set Ident
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
fromBinders :: [Bind] -> ([Ident], [[Ident]], [Exp])
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
-- AST annotated with free variables
type AnnProgram = [(Ident, [Ident], AnnExp)]
type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Ident, AnnExp')
type AnnExp = (Set Id, AnnExp')
data ABind = ABind Ident [Ident] AnnExp deriving Show
data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Ident
data AnnExp' = AId Id
| AInt Integer
| AApp AnnExp AnnExp
| AAdd AnnExp AnnExp
| AAbs Ident AnnExp
| ALet [ABind] AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
| AAnn AnnExp Type
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ map go prog
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Ident, [Ident], AnnExp) -> Bind
go (name, pars, rhs@(_, e)) =
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs@(_, e)) =
case e of
AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2
AAbs _ parm e1 -> do
e2' <- abstractExp e2
pure $ Bind name (snoc parm parms ++ parms2) e2'
where
(e2, pars2) = flattenLambdasAnn e1
_ -> Bind name pars $ abstractExp rhs
(e2, parms2) = flattenLambdasAnn e1
_ -> Bind name parms <$> abstractExp rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident])
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Ident]) -> (AnnExp, [Ident])
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> Exp
abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of
AId n -> EId n
AInt i -> EInt i
AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
ALet bs e -> ELet (map go bs) $ abstractExp e
AId n -> pure $ EId n
AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e)
where
go (ABind name parms rhs) =
let
(rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs
in
Bind name (parms ++ parms1) rhs'
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs name ae1 -> EAbs name $ skipLambdas f ae1
_ -> f (free, ae)
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
-- Lift lambda into let and bind free variables
AAbs par e -> foldl EApp sc $ map EId freeList
AAbs t parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t_bind) parms rhs] $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc"
t_bind = typeApplyPars (length parm) t
parms = snoc parm freeList
-- | Rename all supercombinators and variables
rename :: Program -> Program
rename (Program ds) = Program $ map (uncurry3 Bind) tuples
where
tuples = snd (mapAccumL renameSc 0 ds)
renameSc i (Bind n xs e) = (i2, (n, xs', e'))
where
(i1, xs', env) = newNames i xs
(i2, e') = renameExp env i1 e
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp)
renameExp env i = \case
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
EInt i1 -> (i, EInt i1)
EApp e1 e2 -> (i2, EApp e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
EAdd e1 e2 -> (i2, EAdd e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e')
where
(i1, e') = renameExp e_env i e
(names, pars, rhss) = fromBinders bs
(i2, ns', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
EAbs par e -> (i2, EAbs par' e')
where
(i1, par', env') = newName par
(i2, e') = renameExp (Map.union env' env ) i1 e
newName :: Ident -> (Int, Ident, Map Ident Ident)
newName old_name = (i, head names, env)
where (i, names, env) = newNames 1 [old_name]
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
newNames i old_names = (i', new_names, env)
where
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
getNames :: Int -> [Ident] -> (Int, [Ident])
getNames i ns = (i + length ss, zipWith makeName ss [i..])
where
ss = map (\(Ident s) -> s) ns
makeName :: String -> Int -> Ident
makeName prefix i = Ident (prefix ++ "_" ++ show i)
typeApplyPars :: Int -> Type -> Type
typeApplyPars 0 t = t
typeApplyPars i t =
case t of
TFun _ t1 -> typeApplyPars (i-1) t1
_ -> error "Number of applied pars and type not matching"
-- | Collects supercombinators by lifting appropriate let expressions
@ -216,20 +183,20 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2')
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, EAdd e1' e2')
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs x e -> (scs, EAbs x e')
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
@ -241,28 +208,32 @@ collectScsExp = \case
-- > ...
-- > in e
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
EAnn e t -> (scs, EAnn e' t)
where
(scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Ident])
flattenLambdas e = go (e, [])
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e

View file

@ -4,10 +4,12 @@ module Main where
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Interpreter (interpret)
--import Interpreter (interpret)
import LambdaLifter (abstract, freeVars, lambdaLift)
import Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import TypeChecker (typecheck)
main :: IO ()
main = getArgs >>= \case
@ -18,12 +20,20 @@ main' :: String -> IO ()
main' s = do
file <- readFile s
putStrLn "\n-- parse"
putStrLn "\n-- Parser"
parsed <- fromSyntaxErr . pProgram $ myLexer file
putStrLn $ printTree parsed
putStrLn "\n-- Renamer"
let renamed = rename parsed
putStrLn $ printTree renamed
putStrLn "\n-- TypeChecker"
typechecked <- fromTypeCheckerErr $ typecheck renamed
putStrLn $ printTree typechecked
putStrLn "\n-- Lambda Lifter"
let lifted = lambdaLift parsed
let lifted = lambdaLift typechecked
putStrLn $ printTree lifted
-- interpred <- fromInterpreterErr $ interpret lifted
@ -41,6 +51,14 @@ fromSyntaxErr = either
exitFailure)
pure
fromTypeCheckerErr :: Err a -> IO a
fromTypeCheckerErr = either
(\err -> do
putStrLn "\nTYPECHECKER ERROR"
putStrLn err
exitFailure)
pure
fromInterpreterErr :: Err a -> IO a
fromInterpreterErr = either
(\err -> do

83
src/Renamer.hs Normal file
View file

@ -0,0 +1,83 @@
{-# LANGUAGE LambdaCase #-}
module Renamer (module Renamer) where
import Data.List (mapAccumL, unzip4, zipWith4)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Grammar.Abs
-- | Rename all supercombinators and variables
rename :: Program -> Program
rename (Program sc) = Program $ map (renameSc 0) sc
where
renameSc i (Bind n t _ xs e) = Bind n t n xs' e'
where
(i1, xs', env) = newNames i xs
e' = snd $ renameExp env i1 e
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp)
renameExp env i = \case
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
EInt i1 -> (i, EInt i1)
EApp e1 e2 -> (i2, EApp e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
EAdd e1 e2 -> (i2, EAdd e1' e2')
where
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith4 mkBind names' types pars' es') e')
where
mkBind name t = Bind name t name
(i1, e') = renameExp e_env i e
(names, types, pars, rhss) = fromBinders bs
(i2, names', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
EAbs par t e -> (i2, EAbs par' t e')
where
(i1, par', env') = newName par
(i2, e') = renameExp (Map.union env' env ) i1 e
EAnn e t -> (i1, EAnn e' t)
where
(i1, e') = renameExp env i e
newName :: Ident -> (Int, Ident, Map Ident Ident)
newName old_name = (i, head names, env)
where (i, names, env) = newNames 1 [old_name]
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
newNames i old_names = (i', new_names, env)
where
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
getNames :: Int -> [Ident] -> (Int, [Ident])
getNames i ns = (i + length ss, zipWith makeName ss [i..])
where
ss = map (\(Ident s) -> s) ns
makeName :: String -> Int -> Ident
makeName prefix i = Ident (prefix ++ "_" ++ show i)
fromBinders :: [Bind] -> ([Ident], [Type], [[Ident]], [Exp])
fromBinders bs = unzip4 [ (name, t, parms, rhs) | Bind name t _ parms rhs <- bs ]

View file

@ -0,0 +1,180 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module TypeChecker (typecheck) where
import Auxiliary (maybeToRightM, snoc)
import Control.Monad.Except (throwError, unless)
import Data.Map (Map)
import qualified Data.Map as Map
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (Print (prt), concatD, doc, printTree,
render)
import Prelude hiding (exp, id)
import qualified TypeCheckerIr as T
-- NOTE: this type checker is poorly tested
-- TODO
-- Coercion
-- Type inference
data Cxt = Cxt
{ env :: Map Ident Type
, sig :: Map Ident Type
}
initCxt :: [Bind] -> Cxt
initCxt sc = Cxt { env = mempty
, sig = Map.fromList $ map (\(Bind n t _ _ _) -> (n, t)) sc
}
typecheck :: Program -> Err T.Program
typecheck (Program sc) = T.Program <$> mapM (checkBind $ initCxt sc) sc
checkBind :: Cxt -> Bind -> Err T.Bind
checkBind cxt b =
case expandLambdas b of
Bind name t _ parms rhs -> do
(rhs', t_rhs) <- infer cxt rhs
unless (typeEq t_rhs t) . throwError $ typeErr name t t_rhs
pure $ T.Bind (name, t) (zip parms ts_parms) rhs'
where
ts_parms = fst $ partitionType (length parms) t
expandLambdas :: Bind -> Bind
expandLambdas (Bind name t _ parms rhs) = Bind name t name [] rhs'
where
rhs' = foldr ($) rhs $ zipWith EAbs parms ts_parms
ts_parms = fst $ partitionType (length parms) t
infer :: Cxt -> Exp -> Err (T.Exp, Type)
infer cxt = \case
EId x ->
case lookupEnv x cxt of
Nothing ->
case lookupSig x cxt of
Nothing -> throwError ("Unbound variable:" ++ printTree x)
Just t -> pure (T.EId (x, t), t)
Just t -> pure (T.EId (x, t), t)
EInt i -> pure (T.EInt i, T.TInt)
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure (T.EApp t2 e' e1', t2)
_ -> do
throwError ("Not a function: " ++ show e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure (T.EAdd T.TInt e' e1', T.TInt)
EAbs x t e -> do
(e', t1) <- infer (insertEnv x t cxt) e
let t_abs = TFun t t1
pure (T.EAbs t_abs (x, t) e', t_abs)
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
(e', t) <- infer cxt' e
pure (T.ELet bs'' e', t)
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do
e' <- check cxt e t
pure (T.EAnn e' t, t)
check :: Cxt -> Exp -> Type -> Err T.Exp
check cxt exp typ = case exp of
EId x -> do
t <- case lookupEnv x cxt of
Nothing -> maybeToRightM
("Unbound variable:" ++ printTree x)
(lookupSig x cxt)
Just t -> pure t
unless (typeEq t typ) . throwError $ typeErr x typ t
pure $ T.EId (x, t)
EInt i -> do
unless (typeEq typ TInt) $ throwError $ typeErr i TInt typ
pure $ T.EInt i
EApp e e1 -> do
(e', t) <- infer cxt e
case t of
TFun t1 t2 -> do
e1' <- check cxt e1 t1
pure $ T.EApp t2 e' e1'
_ -> throwError ("Not a function 2: " ++ printTree e)
EAdd e e1 -> do
e' <- check cxt e T.TInt
e1' <- check cxt e1 T.TInt
pure $ T.EAdd T.TInt e' e1'
EAbs x t e -> do
(e', t_e) <- infer (insertEnv x t cxt) e
let t1 = TFun t t_e
unless (typeEq t1 typ) $ throwError "Wrong lamda type!"
pure $ T.EAbs t1 (x, t) e'
ELet bs e -> do
bs'' <- mapM (checkBind cxt') bs'
e' <- check cxt' e typ
pure $ T.ELet bs'' e'
where
bs' = map expandLambdas bs
cxt' = foldr (\(Bind n t _ _ _) -> insertEnv n t) cxt bs'
EAnn e t -> do
unless (typeEq t typ) $
throwError "Inferred type and type annotation doesn't match"
e' <- check cxt e t
pure $ T.EAnn e' typ
typeEq :: Type -> Type -> Bool
typeEq (TFun t t1) (TFun q q1) = typeEq t q && typeEq t1 q1
typeEq t t1 = t == t1
partitionType :: Int -> Type -> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TFun t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
lookupEnv :: Ident -> Cxt -> Maybe Type
lookupEnv x = Map.lookup x . env
insertEnv :: Ident -> Type -> Cxt -> Cxt
insertEnv x t cxt = cxt { env = Map.insert x t cxt.env }
lookupSig :: Ident -> Cxt -> Maybe Type
lookupSig x = Map.lookup x . sig
typeErr :: Print a => a -> Type -> Type -> String
typeErr p expected actual = render $ concatD
[ doc $ showString "Wrong type:", prt 0 p , doc $ showString "\n"
, doc $ showString "Expected:" , prt 0 expected, doc $ showString "\n"
, doc $ showString "Actual: " , prt 0 actual
]

108
src/TypeCheckerIr.hs Normal file
View file

@ -0,0 +1,108 @@
{-# LANGUAGE LambdaCase #-}
module TypeCheckerIr
( module Grammar.Abs
, module TypeCheckerIr
) where
import Grammar.Abs (Ident (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Id
| EInt Integer
| ELet [Bind] Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Id Exp
| EAnn Exp Type
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name
, doc $ showString ";"
, prt 0 n
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtIdP 0 n]
EInt i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet bs e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtIdP 0 n
, doc $ showString "."
, prt 0 e
]
EAnn e t -> prPrec i 3 $ concatD
[ doc $ showString "("
, prt 0 e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]