Merge in mutual recursion handling

This commit is contained in:
sebastianselander 2023-03-31 18:26:58 +02:00
parent e2e469d84e
commit c4f78ca37d

View file

@ -6,16 +6,15 @@
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4)
import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl')
import Data.List (foldl', intercalate)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
@ -40,19 +39,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs <- checkDef bs
sub0 <- solveUndecidable
bs <- mapM (mono sub0) bs
(subs, bs) <- checkDef bs
ctrace "SUBS" $ unionSubsts subs
return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
b <- gets (S.member name . toDecide)
if b
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
else return bind
mono _ (T.DData d) = return $ T.DData d
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
@ -62,7 +52,8 @@ preRun (x : xs) = case x of
duplicateDecl n s $ Aux.do
"Multiple signatures of function"
quote $ printTree n
insertSig (coerce n) (Just t) >> preRun xs
insertSig (coerce n) t
preRun xs
DBind (Bind n _ e) -> do
s <- gets (S.toList . declaredBinds)
duplicateDecl n s $ Aux.do
@ -70,43 +61,46 @@ preRun (x : xs) = case x of
quote $ printTree n
collect (collectTVars e)
insertBind $ coerce n
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
sigs <- gets sigs
case M.lookup (coerce n) sigs of
Nothing -> do
fr <- fresh
insertSig (coerce n) fr
preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef :: [Def] -> Infer ([Subst], [T.Def' Type])
checkDef [] = return ([], [])
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
xs' <- checkDef xs
return $ T.DBind b' : xs'
(sub0, b') <- checkBind b
(sub1, xs') <- checkDef xs
return (sub1 ++ sub0, T.DBind b' : xs')
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(sub, xs') <- checkDef xs
return (sub, T.DData (coerceData d) : xs')
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind :: Bind -> Infer ([Subst], T.Bind' Type)
checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
Just t' -> do
sub1 <- bindErr (unify t' lambda_t) bind
ctrace "SUB0" sub0
ctrace "SUB1" sub1
return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t))
_ -> error "First pass through failed to add function to env"
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
@ -174,8 +168,7 @@ returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
subbed <- apply s t
return (s, (e', subbed))
class CollectTVars a where
@ -202,16 +195,9 @@ algoW = \case
(sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t'
sub2 <- unify t' t
unless
(apply sub1 t == t' && apply sub2 t' == t)
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp (e', t))
et <- apply comp (e', t)
return (comp, et)
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
@ -228,13 +214,8 @@ algoW = \case
return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do
sig <- gets sigs
cb <- gets currentBind
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
return (nullSubst, (T.EVar $ coerce i, fr))
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing ->
uncatchableErr $
"Unbound variable: "
@ -257,9 +238,10 @@ algoW = \case
fr <- fresh
withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
varType <- apply s1 fr
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr)
return (s1, eabs)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -302,7 +284,7 @@ algoW = \case
err@(ELet b@(Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args))
bind' <- exprErr (checkBind b) err
(_, bind') <- exprErr (checkBind b) err
env <- asks vars
let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do
@ -422,15 +404,15 @@ unify t0 t1 =
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
@ -452,7 +434,7 @@ unify t0 t1 =
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
@ -472,7 +454,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t@(TVar _) = return (coerce $ M.singleton i t)
occurs i t =
if S.member i (free t)
then
@ -483,7 +465,7 @@ occurs i t =
"with"
quote $ printTree t
)
else return $ M.singleton i t
else return $ coerce $ M.singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
@ -509,7 +491,7 @@ inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
fr <- fresh
let s = M.singleton (coerce bound) fr
let s = coerce $ M.singleton (coerce bound) fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
@ -545,7 +527,8 @@ skolemize t = t
-- | A class for substitutions
class SubstType t where
-- | Apply a substitution to t
apply :: Subst -> t -> t
-- apply :: MonadError e m => Subst -> t -> m t
apply :: Subst -> t -> Infer t
class FreeVars t where
-- | Get all free variables from t
@ -565,32 +548,47 @@ instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty
instance SubstType Type where
apply :: Subst -> Type -> Type
apply sub t = do
apply sub@(Subst s) t = do
case t of
TLit a -> TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
TLit a -> return $ TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TVar (MkTVar $ coerce a)
Just t -> return $ t
TAll (MkTVar i) t -> case M.lookup (coerce i) s of
Nothing -> TAll (MkTVar i) <$> apply sub t
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar a)
Just t -> t
TFun a b -> TFun <$> apply sub a <*> apply sub b
TData name a -> TData name <$> apply sub a
TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TEVar (MkTEVar a)
Just t -> return $ t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free = free . M.elems
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
apply s = undefined -- M.map (apply s)
instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s)
instance SubstType Subst where
apply s@(Subst m1) (Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> Subst <$> apply s m2
xs -> do
sub0 <- apply s m2
sub1 <- loop xs m1 m2
apply sub1 (Subst sub0)
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
-- Subst $ M.map (apply s) m2
instance SubstType (T.ExpT' Type) where
apply s (e, t) = (apply s e, apply s t)
@ -636,16 +634,19 @@ instance SubstType (T.Id' Type) where
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = mempty
nullSubst = Subst mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
unionSubsts :: [Subst] -> Subst
unionSubsts = Subst . foldl' M.union M.empty . map coerce
{- | Convert a function with arguments to its pointfree version
> makeLambda (add x y = x + y) = add = \x. \y. x + y
-}
@ -671,7 +672,7 @@ withPattern p ma = case p of
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig :: T.Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
insertBind :: T.Ident -> Infer ()
@ -691,24 +692,6 @@ with an equivalent name has been declared already
existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type)
existInj n = gets (M.lookup n . injections)
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind i = modify (\st -> st{currentBind = i})
solveUndecidable :: Infer Subst
solveUndecidable = do
sigs <- gets sigs
undecided <- gets undecidedSigs
ys <-
maybeToRightM
(Error "SIGNATURE MISSING" False)
( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $
M.toList undecided
)
composeAll <$> mapM (uncurry unify) ys
getOriginal :: T.Ident -> T.Ident
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
delim :: Char
delim = '_'
prefix :: Char
@ -785,7 +768,7 @@ dataErr ma d =
)
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
initEnv = Env 0 'a' mempty mempty mempty mempty
run :: Infer a -> Either Error a
run = run' initEnv initCtx
@ -804,19 +787,20 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, sigs :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
}
deriving (Show)
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst (Map T.Ident Type)
instance Show Subst where
show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]"
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)