Merge in mutual recursion handling

This commit is contained in:
sebastianselander 2023-03-31 18:27:30 +02:00
parent c4f78ca37d
commit b7420b5adb

View file

@ -39,9 +39,8 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
(subs, bs) <- checkDef bs
ctrace "SUBS" $ unionSubsts subs
return $ T.Program bs
(sub, bs) <- checkDef bs
return $ T.Program $ apply sub bs
preRun :: [Def] -> Infer ()
preRun [] = return ()
@ -74,13 +73,14 @@ preRun (x : xs) = case x of
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer ([Subst], [T.Def' Type])
checkDef [] = return ([], [])
checkDef :: [Def] -> Infer (Subst, [T.Def' Type])
checkDef [] = return (nullSubst, [])
checkDef (x : xs) = case x of
(DBind b) -> do
(sub0, b') <- checkBind b
(sub1, xs') <- checkDef xs
return (sub1 ++ sub0, T.DBind b' : xs')
comp <- sub0 `composey` sub1
return (comp, T.DBind b' : xs')
(DData d) -> do
(sub, xs') <- checkDef xs
return (sub, T.DData (coerceData d) : xs')
@ -89,17 +89,16 @@ checkDef (x : xs) = case x of
coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer ([Subst], T.Bind' Type)
checkBind :: Bind -> Infer (Subst, T.Bind' Type)
checkBind bind@(Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just t' -> do
sub1 <- bindErr (unify t' lambda_t) bind
ctrace "SUB0" sub0
ctrace "SUB1" sub1
return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t))
Just t -> do
sub1 <- bindErr (unify t lambda_t) bind
comp <- sub1 `composey` sub0
return (comp, T.Bind (coerce name, apply comp t) [] (e, lambda_t))
_ -> error "First pass through failed to add function to env"
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
@ -168,7 +167,7 @@ returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
subbed <- apply s t
let subbed = apply s t
return (s, (e', subbed))
class CollectTVars a where
@ -195,9 +194,17 @@ algoW = \case
(sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t'
sub2 <- unify t' t
unless
(apply sub1 t == t' && apply sub2 t' == t)
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0
et <- apply comp (e', t)
return (comp, et)
-- return (comp, apply comp (e', t))
return (comp, (e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
@ -238,10 +245,10 @@ algoW = \case
fr <- fresh
withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
varType <- apply s1 fr
let varType = apply s1 fr
let newArr = TFun varType t'
eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr)
return (s1, eabs)
-- return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
return (s1, (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -255,10 +262,8 @@ algoW = \case
s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- return (comp, apply comp (T.EAdd (e0', t0) (e1', t1), int))
return (comp, (T.EAdd (e0', t0) (e1', t1), int))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -273,8 +278,10 @@ algoW = \case
(s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
comp <- foldM composey nullSubst [s2, s1, s0]
-- let comp = s2 `compose` s1 `compose` s0
-- return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
return (comp, (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
@ -290,12 +297,14 @@ algoW = \case
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
-- return (comp, apply comp (T.ELet bind' (e1', t2), t2))
return (comp, (T.ELet bind' (e1', t2), t2))
ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
-- return (comp, apply comp (T.ECase (e', t) injs, ret_t))
return (comp, (T.ECase (e', t) injs, ret_t))
EAppInf{} -> error "desugar phase failed"
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
@ -528,7 +537,7 @@ skolemize t = t
class SubstType t where
-- | Apply a substitution to t
-- apply :: MonadError e m => Subst -> t -> m t
apply :: Subst -> t -> Infer t
apply :: Subst -> t -> t
class FreeVars t where
-- | Get all free variables from t
@ -550,43 +559,28 @@ instance FreeVars a => FreeVars [a] where
instance SubstType Type where
apply sub@(Subst s) t = do
case t of
TLit a -> return $ TLit a
TLit a -> TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TVar (MkTVar $ coerce a)
Just t -> return $ t
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) s of
Nothing -> TAll (MkTVar i) <$> apply sub t
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun <$> apply sub a <*> apply sub b
TData name a -> TData name <$> apply sub a
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TEVar (MkTEVar a)
Just t -> return $ t
Nothing -> TEVar (MkTEVar a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free = free . M.elems
instance SubstType (Map T.Ident Type) where
apply s = undefined -- M.map (apply s)
apply s = M.map (apply s)
instance SubstType Subst where
apply s@(Subst m1) (Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> Subst <$> apply s m2
xs -> do
sub0 <- apply s m2
sub1 <- loop xs m1 m2
apply sub1 (Subst sub0)
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
apply s (Subst m2) = Subst $ apply s m2
-- Subst $ M.map (apply s) m2
@ -640,6 +634,30 @@ nullSubst = Subst mempty
compose :: Subst -> Subst -> Subst
compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
-- Order matters.
{-
sub0 = Subst $ (M.singleton "a" (arr d e)) `M.union` (M.singleton "b" (arr d f)) `M.union` (M.singleton "c" (arr f e))
sub1 = Subst $ (M.singleton "a" (arr g bool)) `M.union` (M.singleton "b" (arr g bool)) `M.union` (M.singleton "c" (arr bool bool)) `M.union` (M.singleton "h" bool) `M.union` (M.singleton "i" bool)
sub0 `composey` sub1 != sub1 `composey` sub0
-}
composey :: Subst -> Subst -> Infer Subst
composey s0@(Subst m1) s1@(Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> return $ s0 `compose` s1
xs -> do
let m2' = apply s0 m2
sub <- loop xs m1 m2'
return $ sub `compose` Subst m2
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
@ -800,7 +818,7 @@ data Error = Error {msg :: String, catchable :: Bool}
newtype Subst = Subst (Map T.Ident Type)
instance Show Subst where
show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]"
show (Subst s) = "[ " ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ " ]"
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
@ -816,3 +834,16 @@ quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure ()
{-
Save each subst mapped to their respective function
Apply composition of all used functions to the function
a = id 0 ;
b = id 'a' ;
id x = x ;
apply_on_a = id_sub `compose` a_sub
apply_on_b = id_sub `compose` b_sub
apply_on_id = id_sub
-}