reverted Hindley-Milner type checker to before mutual recursion merge

This commit is contained in:
sebastian 2023-04-01 17:10:26 +02:00
parent ec57712eec
commit 4b14cbdebf

View file

@ -6,20 +6,20 @@
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4)
import Auxiliary qualified as Aux
import Control.Arrow ((&&&))
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', intercalate)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe, mapMaybe)
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
@ -27,6 +27,8 @@ import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Disallow mutual recursion
-- | Type check a program
typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg
@ -36,16 +38,20 @@ typecheck = onLeft msg . run . checkPrg
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = T.Program <$> (preRun bs >> checkDef bs >>= mapM substPrg)
checkPrg (Program bs) = do
preRun bs
bs <- checkDef bs
sub0 <- solveUndecidable
bs <- mapM (mono sub0) bs
return $ T.Program bs
substPrg :: T.Def' Type -> Infer (T.Def' Type)
substPrg (T.DBind (T.Bind (name, t) args e)) = do
(bu, sub) <- gets (bindUsages &&& bindSubs)
let uses = fromMaybe [] $ M.lookup name bu
let subs = mapMaybe (`M.lookup` sub) (name : uses)
sub <- foldM composey nullSubst (reverse subs)
return . T.DBind $ T.Bind (name, apply sub t) (apply sub args) (apply sub e)
substPrg d = return d
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
b <- gets (S.member name . toDecide)
if b
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
else return bind
mono _ (T.DData d) = return $ T.DData d
preRun :: [Def] -> Infer ()
preRun [] = return ()
@ -56,8 +62,7 @@ preRun (x : xs) = case x of
duplicateDecl n s $ Aux.do
"Multiple signatures of function"
quote $ printTree n
insertSig (coerce n) (Instantiated t)
preRun xs
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
s <- gets (S.toList . declaredBinds)
duplicateDecl n s $ Aux.do
@ -65,17 +70,13 @@ preRun (x : xs) = case x of
quote $ printTree n
collect (collectTVars e)
insertBind $ coerce n
sigs <- gets sigs
case M.lookup (coerce n) sigs of
Nothing -> do
fr <- fresh
insertSig (coerce n) (Generalized fr)
preRun xs
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer [T.Def' Type]
@ -100,16 +101,12 @@ checkBind bind@(Bind name args e) = do
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just t -> do
let t' = case t of
Instantiated a -> skolemize a
Generalized a -> a
sub1 <- bindErr (unify t' lambda_t) bind
comp <- sub1 `composey` sub0
insertBindSubst (coerce name) comp
return (T.Bind (coerce name, apply comp t') [] (e, lambda_t))
Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do
uncatchableErr $ "Undeclared function: " ++ printTree name
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
@ -178,6 +175,7 @@ inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (s, (e', subbed))
class CollectTVars a where
@ -213,7 +211,7 @@ algoW = \case
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, (e', t))
return (comp, apply comp (e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
@ -232,9 +230,11 @@ algoW = \case
sig <- gets sigs
cb <- gets currentBind
case M.lookup (coerce i) sig of
Just t -> do
insertBindUsage cb (coerce i)
return (nullSubst, (T.EVar $ coerce i, unlevel t))
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing ->
uncatchableErr $
"Unbound variable: "
@ -259,7 +259,7 @@ algoW = \case
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = TFun varType t'
return (s1, (T.EAbs (coerce name) (e', t'), newArr))
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -273,7 +273,10 @@ algoW = \case
s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return (comp, (T.EAdd (e0', t0) (e1', t1), int))
return
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -284,11 +287,12 @@ algoW = \case
fr <- fresh
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
modify (\st -> st{sigs = apply s0 st.sigs})
(s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr
comp <- foldM composey nullSubst [s2, s1, s0]
return (comp, (T.EApp (e0', t0) (e1', t1), t))
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
@ -296,23 +300,20 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize"
(ELet (Bind name args e) e1) -> do
(s1, (e, t0)) <- algoW (makeLambda e (coerce args))
err@(ELet b@(Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args))
bind' <- exprErr (checkBind b) err
env <- asks vars
let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return
( comp
, (T.ELet (T.Bind (coerce name, t0) [] (e, t0)) (e1', t2), t2)
)
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
-- return (comp, apply comp (T.ECase (e', t) injs, ret_t))
return (comp, (T.ECase (e', t) injs, ret_t))
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
EAppInf{} -> error "desugar phase failed"
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
@ -421,17 +422,15 @@ unify t0 t1 =
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) ->
return $ coerce $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) ->
return $ coerce $ M.singleton (coerce b) t
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return nullSubst
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
@ -453,7 +452,7 @@ unify t0 t1 =
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return nullSubst
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
@ -473,7 +472,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (coerce $ M.singleton i t)
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
@ -484,7 +483,7 @@ occurs i t =
"with"
quote $ printTree t
)
else return $ coerce $ M.singleton i t
else return $ M.singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
@ -510,7 +509,7 @@ inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
fr <- fresh
let s = coerce $ M.singleton (coerce bound) fr
let s = M.singleton (coerce bound) fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
@ -546,7 +545,6 @@ skolemize t = t
-- | A class for substitutions
class SubstType t where
-- | Apply a substitution to t
-- apply :: MonadError e m => Subst -> t -> m t
apply :: Subst -> t -> t
class FreeVars t where
@ -567,18 +565,19 @@ instance FreeVars a => FreeVars [a] where
free = let f acc x = acc `S.union` free x in foldl' f S.empty
instance SubstType Type where
apply sub@(Subst s) t = do
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TLit a -> TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) s of
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) s of
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar a)
Just t -> t
@ -587,12 +586,11 @@ instance FreeVars (Map T.Ident Type) where
free = free . M.elems
instance SubstType (Map T.Ident Type) where
apply s = M.map (apply s)
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
instance SubstType Subst where
apply s (Subst m2) = Subst $ apply s m2
-- Subst $ M.map (apply s) m2
instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s)
instance SubstType (T.ExpT' Type) where
apply s (e, t) = (apply s e, apply s t)
@ -613,8 +611,7 @@ instance SubstType (T.Exp' Type) where
instance SubstType (T.Def' Type) where
apply s = \case
T.DBind (T.Bind name args e) ->
T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
d -> d
instance SubstType (T.Branch' Type) where
@ -639,49 +636,16 @@ instance SubstType (T.Id' Type) where
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = Subst mempty
nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
-- Order matters.
{-
sub0 = Subst $ (M.singleton "a" (arr d e))
`M.union` (M.singleton "b" (arr d f))
`M.union` (M.singleton "c" (arr f e))
sub1 = Subst $ (M.singleton "a" (arr g bool))
`M.union` (M.singleton "b" (arr g bool))
`M.union` (M.singleton "c" (arr bool bool))
`M.union` (M.singleton "h" bool)
`M.union` (M.singleton "i" bool)
sub0 `composey` sub1 != sub1 `composey` sub0
-}
composey :: Subst -> Subst -> Infer Subst
composey s0@(Subst m1) s1@(Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> return $ s0 `compose` s1
xs -> do
let m2' = apply s0 m2
sub <- loop xs m1 m2'
return $ sub `compose` Subst m2
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
unionSubsts :: [Subst] -> Subst
unionSubsts = Subst . foldl' M.union M.empty . map coerce
{- | Convert a function with arguments to its pointfree version
> makeLambda (add x y = x + y) = add = \x. \y. x + y
-}
@ -707,21 +671,12 @@ withPattern p ma = case p of
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Level Type -> Infer ()
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
insertBind :: T.Ident -> Infer ()
insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds})
insertBindSubst :: T.Ident -> Subst -> Infer ()
insertBindSubst name sub = modify (\st -> st{bindSubs = M.insert name sub st.bindSubs})
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind n = modify (\st -> st{currentBind = n, bindUsages = M.insertWith (++) n [] st.bindUsages})
insertBindUsage :: T.Ident -> T.Ident -> Infer ()
insertBindUsage cur use = modify (\st -> st{bindUsages = M.insertWith (++) cur [use] st.bindUsages})
-- | Insert a constructor into the start with its type
insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m ()
insertInj i t =
@ -736,6 +691,24 @@ with an equivalent name has been declared already
existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type)
existInj n = gets (M.lookup n . injections)
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind i = modify (\st -> st{currentBind = i})
solveUndecidable :: Infer Subst
solveUndecidable = do
sigs <- gets sigs
undecided <- gets undecidedSigs
ys <-
maybeToRightM
(Error "SIGNATURE MISSING" False)
( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $
M.toList undecided
)
composeAll <$> mapM (uncurry unify) ys
getOriginal :: T.Ident -> T.Ident
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
delim :: Char
delim = '_'
prefix :: Char
@ -812,7 +785,7 @@ dataErr ma d =
)
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty mempty "" mempty mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
run :: Infer a -> Either Error a
run = run' initEnv initCtx
@ -831,28 +804,19 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Level Type)
, sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident
, currentBind :: T.Ident
, bindSubs :: Map T.Ident Subst
, bindUsages :: Map T.Ident [T.Ident]
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
}
deriving (Show)
data Level a = Instantiated {unlevel :: a} | Generalized {unlevel :: a}
deriving (Show)
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
newtype Subst = Subst (Map T.Ident Type)
instance Show Subst where
show (Subst s) = "[ " ++ intercalate " | " xs ++ " ]"
where
xs = map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s
type Subst = Map T.Ident Type
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
@ -868,3 +832,4 @@ quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure ()