Merge remote-tracking branch 'origin/typechecking-merge' into pattern-matching-with-typechecking

This commit is contained in:
Samuel Hammersberg 2023-03-23 16:33:05 +01:00
commit d3d173eb59
21 changed files with 1052 additions and 476 deletions

View file

@ -1,50 +1,105 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module Renamer.Renamer where
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (
MonadState,
StateT,
evalStateT,
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNames initNames (coerce vars)
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name (coerce vars') rhs'
DData (Data (Indexed cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (Indexed cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TIndexed _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
renameConstr new_types (Constructor name typ) =
Constructor name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int)
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
ECons n -> pure (old_names, ECons n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
@ -53,25 +108,25 @@ renameExp old_names = \case
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, ESub e1' e2')
ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2')
-- TODO fix shadowing
ELet name rhs e -> do
(new_names, name') <- newName old_names (coerce name)
(new_names', rhs') <- renameExp new_names rhs
(new_names'', e') <- renameExp new_names' e
pure (new_names'', ELet (coerce name') rhs' e')
EAbs par e -> do
(new_names, par') <- newName old_names par
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(_, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameInjs new_names injs
pure (new_names', ECase e' injs')
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do
@ -80,19 +135,64 @@ renameInjs ns xs = do
renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
(new_names, init') <- renameInit ns init
(new_names', e') <- renameExp new_names e
return (new_names', Inj init' e')
renameInit :: Names -> Init -> Rn (Names, Init)
renameInit ns i = case i of
InitConstructor cs vars -> do
(ns_new, vars') <- newNames ns (coerce vars)
return (ns_new, InitConstructor cs (coerce vars'))
rest -> return (ns, rest)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -1,25 +1,33 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
@ -37,51 +45,17 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
{- | Start by freshening the type variable of data types to avoid clash with
other user defined polymorphic types
This might be wrong for type constructors that work over several variables
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
{- | Freshen all polymorphic variables, regardless of name
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
-}
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol _) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) ->
TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
checkData :: Data -> Infer ()
checkData d = do
d' <- freshenData d
case d' of
(Data typ@(Constr name ts) constrs) -> do
case d of
(Data typ@(Indexed name ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
if TConstr typ == retType t'
then insertConstr name' t'
if TIndexed typ == retType t'
then insertConstr (coerce name') (toNew t')
else
throwError $
unwords
@ -96,19 +70,30 @@ checkData d = do
constrs
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
retType (TFun _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
T.Program <$> checkDef bs
-- Type check the program twice to produce all top-level types in the first pass through
bs' <- checkDef bs
trace "\nFIRST ITERATION" return ()
trace (printTree bs' ++ "\nSECOND ITERATION\n") return ()
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
-- TODO: Check for no overlapping signature definitions
DSig (Sig n t) -> insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
@ -117,79 +102,75 @@ checkPrg (Program bs) = do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless
(t `typeEq` t'')
( throwError $
unwords
[ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e'
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse $ coerce args)
e@(_, t') <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t) -> do
sub <- unify t t'
let newT = apply sub t
insertSig (coerce name) (Just newT)
return $ T.Bind (coerce name, newT) [] e
_ -> do
insertSig (coerce name) (Just t')
return (T.Bind (coerce name, t') [] e) -- (apply s e)
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
makeLambda = foldl (flip (EAbs . coerce))
{- | Check if two types are considered equal
For the purpose of the algorithm two polymorphic types are always considered
equal
-}
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) =
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq _ (T.TAll _ _) = True
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
isMoreSpecificOrEq (T.TIndexed (T.Indexed n1 ts1)) (T.TIndexed (T.Indexed n2 ts2)) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TPol _) = True
isPoly _ = False
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp :: Exp -> Infer T.ExpT
inferExp e = do
(s, t, e') <- algoW e
(s, (e', t)) <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
return $ replace subbed (e', t)
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
replace :: T.Type -> T.ExpT -> T.ExpT
replace t = second (const t)
algoW :: Exp -> Infer (Subst, Type, T.Exp)
class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> T.TFun (toNew t1) (toNew t2)
TAll b t -> T.TAll (toNew b) (toNew t)
TIndexed i -> T.TIndexed (toNew i)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Indexed T.Indexed where
toNew (Indexed name vars) = T.Indexed (coerce name) (map toNew vars)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do
(s1, t', e') <- algoW e
(s1, (e', t')) <- algoW e
unless
(t `isMoreSpecificOrEq` t')
(toNew t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
@ -199,34 +180,34 @@ algoW = \case
]
)
applySt s1 $ do
s2 <- unify t t'
return (s2 `compose` s1, t, e')
s2 <- unify (toNew t) t'
let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit (LInt n) ->
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
ELit lit ->
let lt = litType lit
in return (nullSubst, (T.ELit lit, lt))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EId i -> do
EVar i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing ->
throwError $
"Unbound variable: " ++ show i
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh
Nothing -> throwError $ "Unbound variable: " ++ printTree i
ECons i -> do
constr <- gets constructors
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
@ -234,11 +215,11 @@ algoW = \case
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- algoW e
withBinding (coerce name) fr $ do
(s1, (e', t')) <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
let newArr = T.TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -247,29 +228,16 @@ algoW = \case
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0
(s1, (e0', t0)) <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
(s2, (e1', t1)) <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
s3 <- unify (apply s2 t0) int
s4 <- unify (apply s3 t1) int
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
@ -279,13 +247,13 @@ algoW = \case
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- algoW e0
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
(s1, (e1', t1)) <- algoW e1
s2 <- unify (apply s1 t0) (T.TFun t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
@ -294,39 +262,37 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0
(s1, (e0', t1)) <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
let t' = apply comp ret_t
return (comp, (T.ECase (e', t) injs, t'))
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify :: T.Type -> T.Type -> Infer Subst
unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of
(TArr a b, TArr c d) -> do
(T.TFun a b, T.TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) ->
(T.TVar (T.MkTVar a), t) -> occurs a t
(t, T.TVar (T.MkTVar b)) -> occurs b t
(T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
(T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
@ -334,56 +300,71 @@ unify t0 t1 = do
else
throwError $
unwords
[ "Type constructor:"
[ "T.Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) ->
(a, b) -> do
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
[ "'" ++ printTree a ++ "'"
, "can't be unified with"
, "'" ++ printTree b ++ "'"
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
where these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs :: Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, printTree (T.TVar $ T.MkTVar i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
generalize :: Map Ident T.Type -> T.Type -> T.Type
generalize env t = go freeVars $ removeForalls t
where
freeVars :: [Ident]
freeVars = S.toList $ free t S.\\ free env
go :: [Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
inst :: T.Type -> Infer T.Type
inst = \case
T.TAll (T.MkTVar bound) t -> do
fr <- fresh
let s = M.singleton bound fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
@ -392,37 +373,59 @@ class FreeVars t where
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
instance FreeVars T.Type where
free :: T.Type -> Set Ident
free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) =
free (T.TIndexed (T.Indexed _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply :: Subst -> T.Type -> T.Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TIndexed (T.Indexed name a) -> T.TIndexed (T.Indexed name (map (apply sub) a))
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
instance FreeVars (Map Ident T.Type) where
free :: Map Ident T.Type -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply :: Subst -> Map Ident T.Type -> Map Ident T.Type
apply s = M.map (apply s)
instance FreeVars T.ExpT where
free :: T.ExpT -> Set Ident
free = error "free not implemented for T.Exp"
apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t)
instance FreeVars T.Inj where
free :: T.Inj -> Set Ident
free = undefined
apply :: Subst -> T.Inj -> T.Inj
apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e)
instance FreeVars [T.Inj] where
free :: [T.Inj] -> Set Ident
free = foldl' (\acc x -> free x `S.union` acc) mempty
apply s = map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
@ -432,86 +435,85 @@ nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh :: Infer T.Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n
return . T.TVar . T.MkTVar . Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig :: Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr :: Ident -> T.Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type)
checkCase expT injs = do
(injTs, injs, returns) <- unzip3 <$> mapM checkInj injs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
return (sub2 `compose` sub1, injs, returns_type)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case
InitLit lit ->
let returnType = litType lit
in if expected == returnType
then return (mempty, expected)
else
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
{- | fst = type of init
| snd = type of expr
-}
checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type)
checkInj (Inj it expr) = do
(initT, vars) <- inferInit it
(e, exprT) <- withBindings vars (inferExp expr)
return (initT, T.Inj (it, initT) (e, exprT), exprT)
inferInit :: Init -> Infer (T.Type, [T.Id])
inferInit = \case
InitLit lit -> return (litType lit, mempty)
InitConstructor fn vars -> do
gets (M.lookup (coerce fn) . constructors) >>= \case
Nothing ->
throwError $
unwords
[ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do
let flat = flattenType t
let returnType = last flat
case ( length (init flat) == length args
, returnType `isMoreSpecificOrEq` expected
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
"Constructor: " ++ printTree fn ++ " does not exist"
Just a -> do
case unsnoc $ flattenType a of
Nothing -> throwError "Partial pattern match not allowed"
Just (vs, ret) ->
case length vars `compare` length vs of
EQ -> do
return (ret, zip (coerce vars) vs)
_ -> throwError "Partial pattern match not allowed"
InitCatch -> (,mempty) <$> fresh
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt _) = TMono "Int"
litType :: Lit -> T.Type
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"

View file

@ -2,28 +2,30 @@
module TypeChecker.TypeCheckerIr where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (
Data (..),
Ident (..),
Init (..),
Lit (..),
)
import Grammar.Print
import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
newtype Ctx = Ctx {vars :: Map Ident Type}
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env
{ count :: Int
, sigs :: Map Ident Type
{ count :: Int
, sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type
}
deriving (Show)
type Error = String
type Subst = Map Ident Type
@ -33,18 +35,33 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read)
data Type
= TLit Ident
| TVar TVar
| TFun Type Type
| TAll TVar Type
| TIndexed Indexed
deriving (Show, Eq, Ord, Read)
data Exp
= EId Id
| ELit Type Literal
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| ESub Type Exp Exp
| EAbs Type Id Exp
| ECase Type Exp [Inj]
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
| ECase ExpT [Inj]
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Inj = Inj (Init, Type) Exp
type ExpT = (Exp, Type)
data Indexed = Indexed Ident [Type]
deriving (Show, Read, Ord, Eq)
data Inj = Inj (Init, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
@ -52,22 +69,22 @@ data Def = DBind Bind | DData Data
type Id = (Ident, Type)
data Bind = Bind Id Exp
data Bind = Bind Id [Id] ExpT
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
prt i (DData d) = prt i d
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind (t, name) rhs) =
prt i (Bind (name, t) _ rhs) =
prPrec i 0 $
concatD
[ prt 0 name
@ -91,9 +108,11 @@ prtId :: Int -> Id -> Doc
prtId i (name, t) =
prPrec i 0 $
concatD
[ prt 0 name
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
prtIdP :: Int -> Id -> Doc
@ -109,8 +128,8 @@ prtIdP i (name, t) =
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
EId n -> prPrec i 3 $ concatD [prt 0 n]
ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
ELet bs e ->
prPrec i 3 $
concatD
@ -118,46 +137,30 @@ instance Print Exp where
, prt 0 bs
, doc $ showString "in"
, prt 0 e
, doc $ showString "\n"
]
EApp _ e1 e2 ->
EApp e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 ->
EAdd e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
, doc $ showString "\n"
]
ESub t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "-"
, prt 2 e2
, doc $ showString "\n"
]
EAbs t n e ->
EAbs n e ->
prPrec i 0 $
concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtId 0 n
[ doc $ showString "λ"
, prt 0 n
, doc $ showString "."
, prt 0 e
, doc $ showString "\n"
]
ECase t exp injs ->
ECase exp injs ->
prPrec
i
0
@ -169,16 +172,31 @@ instance Print Exp where
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
instance Print ExpT where
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Inj where
prt i = \case
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print TVar where
prt i (MkTVar id) = prt i id
instance Print Type where
prt i = \case
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar -> prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TIndexed indexed -> prPrec i 1 (concatD [prt 0 indexed])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Indexed where
prt i (Indexed u ts) = concatD [prt i u, prt i ts]