Merge in mutual recursion handling

This commit is contained in:
sebastianselander 2023-03-31 18:26:58 +02:00
parent e2e469d84e
commit c4f78ca37d

View file

@ -6,16 +6,15 @@
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4) import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux import Auxiliary qualified as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (first)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl', intercalate)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
@ -40,19 +39,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs <- checkDef bs (subs, bs) <- checkDef bs
sub0 <- solveUndecidable ctrace "SUBS" $ unionSubsts subs
bs <- mapM (mono sub0) bs
return $ T.Program bs return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
b <- gets (S.member name . toDecide)
if b
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
else return bind
mono _ (T.DData d) = return $ T.DData d
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
@ -62,7 +52,8 @@ preRun (x : xs) = case x of
duplicateDecl n s $ Aux.do duplicateDecl n s $ Aux.do
"Multiple signatures of function" "Multiple signatures of function"
quote $ printTree n quote $ printTree n
insertSig (coerce n) (Just t) >> preRun xs insertSig (coerce n) t
preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
s <- gets (S.toList . declaredBinds) s <- gets (S.toList . declaredBinds)
duplicateDecl n s $ Aux.do duplicateDecl n s $ Aux.do
@ -70,43 +61,46 @@ preRun (x : xs) = case x of
quote $ printTree n quote $ printTree n
collect (collectTVars e) collect (collectTVars e)
insertBind $ coerce n insertBind $ coerce n
s <- gets sigs sigs <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) sigs of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> do
fr <- fresh
insertSig (coerce n) fr
preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where where
-- Check if function body / signature has been declared already -- Check if function body / signature has been declared already
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer [T.Def' Type] checkDef :: [Def] -> Infer ([Subst], [T.Def' Type])
checkDef [] = return [] checkDef [] = return ([], [])
checkDef (x : xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
b' <- checkBind b (sub0, b') <- checkBind b
xs' <- checkDef xs (sub1, xs') <- checkDef xs
return $ T.DBind b' : xs' return (sub1 ++ sub0, T.DBind b' : xs')
(DData d) -> do (DData d) -> do
xs' <- checkDef xs (sub, xs') <- checkDef xs
return $ T.DData (coerceData d) : xs' return (sub, T.DData (coerceData d) : xs')
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
where where
coerceData (Data t injs) = coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type) checkBind :: Bind -> Infer ([Subst], T.Bind' Type)
checkBind bind@(Bind name args e) = do checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda (sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t') -> do Just t' -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind sub1 <- bindErr (unify t' lambda_t) bind
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t) ctrace "SUB0" sub0
_ -> do ctrace "SUB1" sub1
insertSig (coerce name) (Just lambda_t) return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t))
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) _ -> error "First pass through failed to add function to env"
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do checkData err@(Data typ injs) = do
@ -174,8 +168,7 @@ returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t subbed <- apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (s, (e', subbed)) return (s, (e', subbed))
class CollectTVars a where class CollectTVars a where
@ -202,16 +195,9 @@ algoW = \case
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t' sub1 <- unify t t'
sub2 <- unify t' t sub2 <- unify t' t
unless
(apply sub1 t == t' && apply sub2 t' == t)
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp (e', t)) et <- apply comp (e', t)
return (comp, et)
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
@ -228,13 +214,8 @@ algoW = \case
return (nullSubst, (T.EVar $ coerce i, x)) return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
cb <- gets currentBind
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> Nothing ->
uncatchableErr $ uncatchableErr $
"Unbound variable: " "Unbound variable: "
@ -257,9 +238,10 @@ algoW = \case
fr <- fresh fr <- fresh
withBinding (coerce name) fr $ do withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr varType <- apply s1 fr
let newArr = TFun varType t' let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr)
return (s1, eabs)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -302,7 +284,7 @@ algoW = \case
err@(ELet b@(Bind name args e) e1) -> do err@(ELet b@(Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args)) (s1, (_, t0)) <- algoW (makeLambda e (coerce args))
bind' <- exprErr (checkBind b) err (_, bind') <- exprErr (checkBind b) err
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t0 let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do withBinding (coerce name) t' $ do
@ -422,15 +404,15 @@ unify t0 t1 =
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1 return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t (t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t (TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t (t, TVar (MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t (a, TAll _ t) -> unify a t
(TLit a, TLit b) -> (TLit a, TLit b) ->
if a == b if a == b
then return M.empty then return nullSubst
else catchableErr $ else catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
@ -452,7 +434,7 @@ unify t0 t1 =
quote $ printTree t' quote $ printTree t'
(TEVar a, TEVar b) -> (TEVar a, TEVar b) ->
if a == b if a == b
then return M.empty then return nullSubst
else catchableErr $ else catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
@ -472,7 +454,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal where these are equal
-} -}
occurs :: T.Ident -> Type -> Infer Subst occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (coerce $ M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then then
@ -483,7 +465,7 @@ occurs i t =
"with" "with"
quote $ printTree t quote $ printTree t
) )
else return $ M.singleton i t else return $ coerce $ M.singleton i t
{- | Generalize a type over all free variables in the substitution set {- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in Used for let bindings to allow expression that do not type check in
@ -509,7 +491,7 @@ inst :: Type -> Infer Type
inst = \case inst = \case
TAll (MkTVar bound) t -> do TAll (MkTVar bound) t -> do
fr <- fresh fr <- fresh
let s = M.singleton (coerce bound) fr let s = coerce $ M.singleton (coerce bound) fr
apply s <$> inst t apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
@ -545,7 +527,8 @@ skolemize t = t
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
-- | Apply a substitution to t -- | Apply a substitution to t
apply :: Subst -> t -> t -- apply :: MonadError e m => Subst -> t -> m t
apply :: Subst -> t -> Infer t
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
@ -565,32 +548,47 @@ instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty free = let f acc x = acc `S.union` free x in foldl' f S.empty
instance SubstType Type where instance SubstType Type where
apply :: Subst -> Type -> Type apply sub@(Subst s) t = do
apply sub t = do
case t of case t of
TLit a -> TLit a TLit a -> return $ TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) sub of TVar (MkTVar a) -> case M.lookup (coerce a) s of
Nothing -> TVar (MkTVar $ coerce a) Nothing -> return $ TVar (MkTVar $ coerce a)
Just t -> t Just t -> return $ t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of TAll (MkTVar i) t -> case M.lookup (coerce i) s of
Nothing -> TAll (MkTVar i) (apply sub t) Nothing -> TAll (MkTVar i) <$> apply sub t
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun <$> apply sub a <*> apply sub b
TData name a -> TData name (apply sub a) TData name a -> TData name <$> apply sub a
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
Nothing -> TEVar (MkTEVar a) Nothing -> return $ TEVar (MkTEVar a)
Just t -> t Just t -> return $ t
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
free = free . M.elems free = free . M.elems
instance SubstType (Map T.Ident Type) where instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply s = undefined -- M.map (apply s)
apply = M.map . apply
instance SubstType (Map T.Ident (Maybe Type)) where instance SubstType Subst where
apply s = M.map (fmap $ apply s) apply s@(Subst m1) (Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> Subst <$> apply s m2
xs -> do
sub0 <- apply s m2
sub1 <- loop xs m1 m2
apply sub1 (Subst sub0)
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
-- Subst $ M.map (apply s) m2
instance SubstType (T.ExpT' Type) where instance SubstType (T.ExpT' Type) where
apply s (e, t) = (apply s e, apply s t) apply s (e, t) = (apply s e, apply s t)
@ -636,16 +634,19 @@ instance SubstType (T.Id' Type) where
-- | Represents the empty substition set -- | Represents the empty substition set
nullSubst :: Subst nullSubst :: Subst
nullSubst = mempty nullSubst = Subst mempty
-- | Compose two substitution sets -- | Compose two substitution sets
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1 compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
-- | Compose a list of substitution sets into one -- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst composeAll = foldl' compose nullSubst
unionSubsts :: [Subst] -> Subst
unionSubsts = Subst . foldl' M.union M.empty . map coerce
{- | Convert a function with arguments to its pointfree version {- | Convert a function with arguments to its pointfree version
> makeLambda (add x y = x + y) = add = \x. \y. x + y > makeLambda (add x y = x + y) = add = \x. \y. x + y
-} -}
@ -671,7 +672,7 @@ withPattern p ma = case p of
T.PEnum _ -> ma T.PEnum _ -> ma
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer () insertSig :: T.Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
insertBind :: T.Ident -> Infer () insertBind :: T.Ident -> Infer ()
@ -691,24 +692,6 @@ with an equivalent name has been declared already
existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type) existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type)
existInj n = gets (M.lookup n . injections) existInj n = gets (M.lookup n . injections)
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind i = modify (\st -> st{currentBind = i})
solveUndecidable :: Infer Subst
solveUndecidable = do
sigs <- gets sigs
undecided <- gets undecidedSigs
ys <-
maybeToRightM
(Error "SIGNATURE MISSING" False)
( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $
M.toList undecided
)
composeAll <$> mapM (uncurry unify) ys
getOriginal :: T.Ident -> T.Ident
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
delim :: Char delim :: Char
delim = '_' delim = '_'
prefix :: Char prefix :: Char
@ -785,7 +768,7 @@ dataErr ma d =
) )
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
@ -804,19 +787,20 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident Type
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident , declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)
data Error = Error {msg :: String, catchable :: Bool} data Error = Error {msg :: String, catchable :: Bool}
deriving (Show) deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst (Map T.Ident Type)
instance Show Subst where
show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]"
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)