Merge in mutual recursion handling
This commit is contained in:
parent
e2e469d84e
commit
c4f78ca37d
1 changed files with 89 additions and 105 deletions
|
|
@ -6,16 +6,15 @@
|
|||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeCheckerHm where
|
||||
|
||||
import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4)
|
||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||
import Auxiliary qualified as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
import Data.List (foldl', intercalate)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
|
|
@ -40,19 +39,10 @@ typecheck = onLeft msg . run . checkPrg
|
|||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
bs <- checkDef bs
|
||||
sub0 <- solveUndecidable
|
||||
bs <- mapM (mono sub0) bs
|
||||
(subs, bs) <- checkDef bs
|
||||
ctrace "SUBS" $ unionSubsts subs
|
||||
return $ T.Program bs
|
||||
|
||||
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
|
||||
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
|
||||
b <- gets (S.member name . toDecide)
|
||||
if b
|
||||
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
|
||||
else return bind
|
||||
mono _ (T.DData d) = return $ T.DData d
|
||||
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x : xs) = case x of
|
||||
|
|
@ -62,7 +52,8 @@ preRun (x : xs) = case x of
|
|||
duplicateDecl n s $ Aux.do
|
||||
"Multiple signatures of function"
|
||||
quote $ printTree n
|
||||
insertSig (coerce n) (Just t) >> preRun xs
|
||||
insertSig (coerce n) t
|
||||
preRun xs
|
||||
DBind (Bind n _ e) -> do
|
||||
s <- gets (S.toList . declaredBinds)
|
||||
duplicateDecl n s $ Aux.do
|
||||
|
|
@ -70,43 +61,46 @@ preRun (x : xs) = case x of
|
|||
quote $ printTree n
|
||||
collect (collectTVars e)
|
||||
insertBind $ coerce n
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
sigs <- gets sigs
|
||||
case M.lookup (coerce n) sigs of
|
||||
Nothing -> do
|
||||
fr <- fresh
|
||||
insertSig (coerce n) fr
|
||||
preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||
where
|
||||
-- Check if function body / signature has been declared already
|
||||
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
|
||||
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def' Type]
|
||||
checkDef [] = return []
|
||||
checkDef :: [Def] -> Infer ([Subst], [T.Def' Type])
|
||||
checkDef [] = return ([], [])
|
||||
checkDef (x : xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
xs' <- checkDef xs
|
||||
return $ T.DBind b' : xs'
|
||||
(sub0, b') <- checkBind b
|
||||
(sub1, xs') <- checkDef xs
|
||||
return (sub1 ++ sub0, T.DBind b' : xs')
|
||||
(DData d) -> do
|
||||
xs' <- checkDef xs
|
||||
return $ T.DData (coerceData d) : xs'
|
||||
(sub, xs') <- checkDef xs
|
||||
return (sub, T.DData (coerceData d) : xs')
|
||||
(DSig _) -> checkDef xs
|
||||
where
|
||||
coerceData (Data t injs) =
|
||||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind :: Bind -> Infer ([Subst], T.Bind' Type)
|
||||
checkBind bind@(Bind name args e) = do
|
||||
setCurrentBind $ coerce name
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(sub0, (e, lambda_t)) <- inferExp lambda
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
|
||||
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just lambda_t)
|
||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||
Just t' -> do
|
||||
sub1 <- bindErr (unify t' lambda_t) bind
|
||||
ctrace "SUB0" sub0
|
||||
ctrace "SUB1" sub1
|
||||
return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t))
|
||||
_ -> error "First pass through failed to add function to env"
|
||||
|
||||
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
|
||||
checkData err@(Data typ injs) = do
|
||||
|
|
@ -174,8 +168,7 @@ returnType a = a
|
|||
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
inferExp e = do
|
||||
(s, (e', t)) <- algoW e
|
||||
let subbed = apply s t
|
||||
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
|
||||
subbed <- apply s t
|
||||
return (s, (e', subbed))
|
||||
|
||||
class CollectTVars a where
|
||||
|
|
@ -202,16 +195,9 @@ algoW = \case
|
|||
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||
sub1 <- unify t t'
|
||||
sub2 <- unify t' t
|
||||
unless
|
||||
(apply sub1 t == t' && apply sub2 t' == t)
|
||||
( uncatchableErr $ Aux.do
|
||||
"Annotated type"
|
||||
quote $ printTree t
|
||||
"does not match inferred type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, apply comp (e', t))
|
||||
et <- apply comp (e', t)
|
||||
return (comp, et)
|
||||
|
||||
-- \| ------------------
|
||||
-- \| Γ ⊢ i : Int, ∅
|
||||
|
|
@ -228,13 +214,8 @@ algoW = \case
|
|||
return (nullSubst, (T.EVar $ coerce i, x))
|
||||
Nothing -> do
|
||||
sig <- gets sigs
|
||||
cb <- gets currentBind
|
||||
case M.lookup (coerce i) sig of
|
||||
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||
Just Nothing -> do
|
||||
fr <- fresh
|
||||
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
|
||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||
Nothing ->
|
||||
uncatchableErr $
|
||||
"Unbound variable: "
|
||||
|
|
@ -257,9 +238,10 @@ algoW = \case
|
|||
fr <- fresh
|
||||
withBinding (coerce name) fr $ do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
let varType = apply s1 fr
|
||||
varType <- apply s1 fr
|
||||
let newArr = TFun varType t'
|
||||
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
|
||||
eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr)
|
||||
return (s1, eabs)
|
||||
|
||||
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||
|
|
@ -302,7 +284,7 @@ algoW = \case
|
|||
|
||||
err@(ELet b@(Bind name args e) e1) -> do
|
||||
(s1, (_, t0)) <- algoW (makeLambda e (coerce args))
|
||||
bind' <- exprErr (checkBind b) err
|
||||
(_, bind') <- exprErr (checkBind b) err
|
||||
env <- asks vars
|
||||
let t' = generalize (apply s1 env) t0
|
||||
withBinding (coerce name) t' $ do
|
||||
|
|
@ -422,15 +404,15 @@ unify t0 t1 =
|
|||
s1 <- unify a c
|
||||
s2 <- unify (apply s1 b) (apply s1 d)
|
||||
return $ s2 `compose` s1
|
||||
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
||||
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
|
||||
(TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t
|
||||
(t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t
|
||||
(TVar (MkTVar a), t) -> occurs (coerce a) t
|
||||
(t, TVar (MkTVar b)) -> occurs (coerce b) t
|
||||
(TAll _ t, b) -> unify t b
|
||||
(a, TAll _ t) -> unify a t
|
||||
(TLit a, TLit b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
then return nullSubst
|
||||
else catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
|
|
@ -452,7 +434,7 @@ unify t0 t1 =
|
|||
quote $ printTree t'
|
||||
(TEVar a, TEVar b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
then return nullSubst
|
||||
else catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
|
|
@ -472,7 +454,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
|||
where these are equal
|
||||
-}
|
||||
occurs :: T.Ident -> Type -> Infer Subst
|
||||
occurs i t@(TVar _) = return (M.singleton i t)
|
||||
occurs i t@(TVar _) = return (coerce $ M.singleton i t)
|
||||
occurs i t =
|
||||
if S.member i (free t)
|
||||
then
|
||||
|
|
@ -483,7 +465,7 @@ occurs i t =
|
|||
"with"
|
||||
quote $ printTree t
|
||||
)
|
||||
else return $ M.singleton i t
|
||||
else return $ coerce $ M.singleton i t
|
||||
|
||||
{- | Generalize a type over all free variables in the substitution set
|
||||
Used for let bindings to allow expression that do not type check in
|
||||
|
|
@ -509,7 +491,7 @@ inst :: Type -> Infer Type
|
|||
inst = \case
|
||||
TAll (MkTVar bound) t -> do
|
||||
fr <- fresh
|
||||
let s = M.singleton (coerce bound) fr
|
||||
let s = coerce $ M.singleton (coerce bound) fr
|
||||
apply s <$> inst t
|
||||
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
|
||||
rest -> return rest
|
||||
|
|
@ -545,7 +527,8 @@ skolemize t = t
|
|||
-- | A class for substitutions
|
||||
class SubstType t where
|
||||
-- | Apply a substitution to t
|
||||
apply :: Subst -> t -> t
|
||||
-- apply :: MonadError e m => Subst -> t -> m t
|
||||
apply :: Subst -> t -> Infer t
|
||||
|
||||
class FreeVars t where
|
||||
-- | Get all free variables from t
|
||||
|
|
@ -565,32 +548,47 @@ instance FreeVars a => FreeVars [a] where
|
|||
free = let f acc x = acc `S.union` free x in foldl' f S.empty
|
||||
|
||||
instance SubstType Type where
|
||||
apply :: Subst -> Type -> Type
|
||||
apply sub t = do
|
||||
apply sub@(Subst s) t = do
|
||||
case t of
|
||||
TLit a -> TLit a
|
||||
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TVar (MkTVar $ coerce a)
|
||||
Just t -> t
|
||||
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
|
||||
Nothing -> TAll (MkTVar i) (apply sub t)
|
||||
TLit a -> return $ TLit a
|
||||
TVar (MkTVar a) -> case M.lookup (coerce a) s of
|
||||
Nothing -> return $ TVar (MkTVar $ coerce a)
|
||||
Just t -> return $ t
|
||||
TAll (MkTVar i) t -> case M.lookup (coerce i) s of
|
||||
Nothing -> TAll (MkTVar i) <$> apply sub t
|
||||
Just _ -> apply sub t
|
||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||
TData name a -> TData name (apply sub a)
|
||||
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TEVar (MkTEVar a)
|
||||
Just t -> t
|
||||
TFun a b -> TFun <$> apply sub a <*> apply sub b
|
||||
TData name a -> TData name <$> apply sub a
|
||||
TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
|
||||
Nothing -> return $ TEVar (MkTEVar a)
|
||||
Just t -> return $ t
|
||||
|
||||
instance FreeVars (Map T.Ident Type) where
|
||||
free :: Map T.Ident Type -> Set T.Ident
|
||||
free = free . M.elems
|
||||
|
||||
instance SubstType (Map T.Ident Type) where
|
||||
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
|
||||
apply = M.map . apply
|
||||
apply s = undefined -- M.map (apply s)
|
||||
|
||||
instance SubstType (Map T.Ident (Maybe Type)) where
|
||||
apply s = M.map (fmap $ apply s)
|
||||
instance SubstType Subst where
|
||||
apply s@(Subst m1) (Subst m2) = do
|
||||
let both = M.keys $ M.intersection m1 m2
|
||||
case both of
|
||||
[] -> Subst <$> apply s m2
|
||||
xs -> do
|
||||
sub0 <- apply s m2
|
||||
sub1 <- loop xs m1 m2
|
||||
apply sub1 (Subst sub0)
|
||||
where
|
||||
loop [] _ _ = return nullSubst
|
||||
loop (x : xs) m1 m2 = do
|
||||
let k1 = m1 M.! x
|
||||
let k2 = m2 M.! x
|
||||
sub <- unify k1 k2
|
||||
subs <- loop xs m1 m2
|
||||
return $ sub `compose` subs
|
||||
|
||||
-- Subst $ M.map (apply s) m2
|
||||
|
||||
instance SubstType (T.ExpT' Type) where
|
||||
apply s (e, t) = (apply s e, apply s t)
|
||||
|
|
@ -636,16 +634,19 @@ instance SubstType (T.Id' Type) where
|
|||
|
||||
-- | Represents the empty substition set
|
||||
nullSubst :: Subst
|
||||
nullSubst = mempty
|
||||
nullSubst = Subst mempty
|
||||
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
|
||||
|
||||
-- | Compose a list of substitution sets into one
|
||||
composeAll :: [Subst] -> Subst
|
||||
composeAll = foldl' compose nullSubst
|
||||
|
||||
unionSubsts :: [Subst] -> Subst
|
||||
unionSubsts = Subst . foldl' M.union M.empty . map coerce
|
||||
|
||||
{- | Convert a function with arguments to its pointfree version
|
||||
> makeLambda (add x y = x + y) = add = \x. \y. x + y
|
||||
-}
|
||||
|
|
@ -671,7 +672,7 @@ withPattern p ma = case p of
|
|||
T.PEnum _ -> ma
|
||||
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||
insertSig :: T.Ident -> Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
|
||||
|
||||
insertBind :: T.Ident -> Infer ()
|
||||
|
|
@ -691,24 +692,6 @@ with an equivalent name has been declared already
|
|||
existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type)
|
||||
existInj n = gets (M.lookup n . injections)
|
||||
|
||||
setCurrentBind :: T.Ident -> Infer ()
|
||||
setCurrentBind i = modify (\st -> st{currentBind = i})
|
||||
|
||||
solveUndecidable :: Infer Subst
|
||||
solveUndecidable = do
|
||||
sigs <- gets sigs
|
||||
undecided <- gets undecidedSigs
|
||||
ys <-
|
||||
maybeToRightM
|
||||
(Error "SIGNATURE MISSING" False)
|
||||
( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $
|
||||
M.toList undecided
|
||||
)
|
||||
composeAll <$> mapM (uncurry unify) ys
|
||||
|
||||
getOriginal :: T.Ident -> T.Ident
|
||||
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
|
||||
|
||||
delim :: Char
|
||||
delim = '_'
|
||||
prefix :: Char
|
||||
|
|
@ -785,7 +768,7 @@ dataErr ma d =
|
|||
)
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty mempty
|
||||
|
||||
run :: Infer a -> Either Error a
|
||||
run = run' initEnv initCtx
|
||||
|
|
@ -804,19 +787,20 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
|
|||
data Env = Env
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
, sigs :: Map T.Ident Type
|
||||
, takenTypeVars :: Set T.Ident
|
||||
, injections :: Map T.Ident Type
|
||||
, currentBind :: T.Ident
|
||||
, undecidedSigs :: Map T.Ident Type
|
||||
, toDecide :: Set T.Ident
|
||||
, declaredBinds :: Set T.Ident
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data Error = Error {msg :: String, catchable :: Bool}
|
||||
deriving (Show)
|
||||
type Subst = Map T.Ident Type
|
||||
|
||||
newtype Subst = Subst (Map T.Ident Type)
|
||||
|
||||
instance Show Subst where
|
||||
show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]"
|
||||
|
||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
|
||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue