Merge in mutual recursion handling

This commit is contained in:
sebastianselander 2023-03-31 18:27:30 +02:00
parent c4f78ca37d
commit b7420b5adb

View file

@ -39,9 +39,8 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
(subs, bs) <- checkDef bs (sub, bs) <- checkDef bs
ctrace "SUBS" $ unionSubsts subs return $ T.Program $ apply sub bs
return $ T.Program bs
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
@ -74,13 +73,14 @@ preRun (x : xs) = case x of
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer ([Subst], [T.Def' Type]) checkDef :: [Def] -> Infer (Subst, [T.Def' Type])
checkDef [] = return ([], []) checkDef [] = return (nullSubst, [])
checkDef (x : xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
(sub0, b') <- checkBind b (sub0, b') <- checkBind b
(sub1, xs') <- checkDef xs (sub1, xs') <- checkDef xs
return (sub1 ++ sub0, T.DBind b' : xs') comp <- sub0 `composey` sub1
return (comp, T.DBind b' : xs')
(DData d) -> do (DData d) -> do
(sub, xs') <- checkDef xs (sub, xs') <- checkDef xs
return (sub, T.DData (coerceData d) : xs') return (sub, T.DData (coerceData d) : xs')
@ -89,17 +89,16 @@ checkDef (x : xs) = case x of
coerceData (Data t injs) = coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer ([Subst], T.Bind' Type) checkBind :: Bind -> Infer (Subst, T.Bind' Type)
checkBind bind@(Bind name args e) = do checkBind bind@(Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda (sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just t' -> do Just t -> do
sub1 <- bindErr (unify t' lambda_t) bind sub1 <- bindErr (unify t lambda_t) bind
ctrace "SUB0" sub0 comp <- sub1 `composey` sub0
ctrace "SUB1" sub1 return (comp, T.Bind (coerce name, apply comp t) [] (e, lambda_t))
return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t))
_ -> error "First pass through failed to add function to env" _ -> error "First pass through failed to add function to env"
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
@ -168,7 +167,7 @@ returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
subbed <- apply s t let subbed = apply s t
return (s, (e', subbed)) return (s, (e', subbed))
class CollectTVars a where class CollectTVars a where
@ -195,9 +194,17 @@ algoW = \case
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t' sub1 <- unify t t'
sub2 <- unify t' t sub2 <- unify t' t
unless
(apply sub1 t == t' && apply sub2 t' == t)
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub2 `compose` sub1 `compose` sub0
et <- apply comp (e', t) -- return (comp, apply comp (e', t))
return (comp, et) return (comp, (e', t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
@ -238,10 +245,10 @@ algoW = \case
fr <- fresh fr <- fresh
withBinding (coerce name) fr $ do withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
varType <- apply s1 fr let varType = apply s1 fr
let newArr = TFun varType t' let newArr = TFun varType t'
eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr) -- return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
return (s1, eabs) return (s1, (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -255,10 +262,8 @@ algoW = \case
s3 <- exprErr (unify (apply s2 t0) int) err s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- exprErr (unify (apply s3 t1) int) err s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1 let comp = s4 `compose` s3 `compose` s2 `compose` s1
return -- return (comp, apply comp (T.EAdd (e0', t0) (e1', t1), int))
( comp return (comp, (T.EAdd (e0', t0) (e1', t1), int))
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -273,8 +278,10 @@ algoW = \case
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0 comp <- foldM composey nullSubst [s2, s1, s0]
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- let comp = s2 `compose` s1 `compose` s0
-- return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
return (comp, (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ---------------------------------------------- -- \| ----------------------------------------------
@ -290,12 +297,14 @@ algoW = \case
withBinding (coerce name) t' $ do withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1 (s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (T.ELet bind' (e1', t2), t2)) -- return (comp, apply comp (T.ELet bind' (e1', t2), t2))
return (comp, (T.ELet bind' (e1', t2), t2))
ECase caseExpr injs -> do ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs (subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub let comp = subst `compose` sub
return (comp, apply comp (T.ECase (e', t) injs, ret_t)) -- return (comp, apply comp (T.ECase (e', t) injs, ret_t))
return (comp, (T.ECase (e', t) injs, ret_t))
EAppInf{} -> error "desugar phase failed" EAppInf{} -> error "desugar phase failed"
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
@ -528,7 +537,7 @@ skolemize t = t
class SubstType t where class SubstType t where
-- | Apply a substitution to t -- | Apply a substitution to t
-- apply :: MonadError e m => Subst -> t -> m t -- apply :: MonadError e m => Subst -> t -> m t
apply :: Subst -> t -> Infer t apply :: Subst -> t -> t
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
@ -550,43 +559,28 @@ instance FreeVars a => FreeVars [a] where
instance SubstType Type where instance SubstType Type where
apply sub@(Subst s) t = do apply sub@(Subst s) t = do
case t of case t of
TLit a -> return $ TLit a TLit a -> TLit a
TVar (MkTVar a) -> case M.lookup (coerce a) s of TVar (MkTVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TVar (MkTVar $ coerce a) Nothing -> TVar (MkTVar $ coerce a)
Just t -> return $ t Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) s of TAll (MkTVar i) t -> case M.lookup (coerce i) s of
Nothing -> TAll (MkTVar i) <$> apply sub t Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun <$> apply sub a <*> apply sub b TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name <$> apply sub a TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) s of TEVar (MkTEVar a) -> case M.lookup (coerce a) s of
Nothing -> return $ TEVar (MkTEVar a) Nothing -> TEVar (MkTEVar a)
Just t -> return $ t Just t -> t
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
free = free . M.elems free = free . M.elems
instance SubstType (Map T.Ident Type) where instance SubstType (Map T.Ident Type) where
apply s = undefined -- M.map (apply s) apply s = M.map (apply s)
instance SubstType Subst where instance SubstType Subst where
apply s@(Subst m1) (Subst m2) = do apply s (Subst m2) = Subst $ apply s m2
let both = M.keys $ M.intersection m1 m2
case both of
[] -> Subst <$> apply s m2
xs -> do
sub0 <- apply s m2
sub1 <- loop xs m1 m2
apply sub1 (Subst sub0)
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
-- Subst $ M.map (apply s) m2 -- Subst $ M.map (apply s) m2
@ -640,6 +634,30 @@ nullSubst = Subst mempty
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1 compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1
-- Order matters.
{-
sub0 = Subst $ (M.singleton "a" (arr d e)) `M.union` (M.singleton "b" (arr d f)) `M.union` (M.singleton "c" (arr f e))
sub1 = Subst $ (M.singleton "a" (arr g bool)) `M.union` (M.singleton "b" (arr g bool)) `M.union` (M.singleton "c" (arr bool bool)) `M.union` (M.singleton "h" bool) `M.union` (M.singleton "i" bool)
sub0 `composey` sub1 != sub1 `composey` sub0
-}
composey :: Subst -> Subst -> Infer Subst
composey s0@(Subst m1) s1@(Subst m2) = do
let both = M.keys $ M.intersection m1 m2
case both of
[] -> return $ s0 `compose` s1
xs -> do
let m2' = apply s0 m2
sub <- loop xs m1 m2'
return $ sub `compose` Subst m2
where
loop [] _ _ = return nullSubst
loop (x : xs) m1 m2 = do
let k1 = m1 M.! x
let k2 = m2 M.! x
sub <- unify k1 k2
subs <- loop xs m1 m2
return $ sub `compose` subs
-- | Compose a list of substitution sets into one -- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst composeAll = foldl' compose nullSubst
@ -816,3 +834,16 @@ quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m () ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure () ctrace str a = trace (str ++ ": " ++ show a) pure ()
{-
Save each subst mapped to their respective function
Apply composition of all used functions to the function
a = id 0 ;
b = id 'a' ;
id x = x ;
apply_on_a = id_sub `compose` a_sub
apply_on_b = id_sub `compose` b_sub
apply_on_id = id_sub
-}