Merge in mutual recursion handling

This commit is contained in:
sebastianselander 2023-03-31 18:28:04 +02:00
parent b7420b5adb
commit 0749ca062d
3 changed files with 81 additions and 98 deletions

View file

@ -8,6 +8,7 @@ module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux import Auxiliary qualified as Aux
import Control.Arrow ((&&&))
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
@ -18,7 +19,7 @@ import Data.List (foldl', intercalate)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust, fromMaybe, mapMaybe)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace) import Debug.Trace (trace)
@ -26,8 +27,6 @@ import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Disallow mutual recursion
-- | Type check a program -- | Type check a program
typecheck :: Program -> Either String (T.Program' Type) typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
@ -37,10 +36,16 @@ typecheck = onLeft msg . run . checkPrg
onLeft _ (Right x) = Right x onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = T.Program <$> (preRun bs >> checkDef bs >>= mapM substPrg)
preRun bs
(sub, bs) <- checkDef bs substPrg :: T.Def' Type -> Infer (T.Def' Type)
return $ T.Program $ apply sub bs substPrg (T.DBind (T.Bind (name, t) args e)) = do
(bu, sub) <- gets (bindUsages &&& bindSubs)
let uses = fromMaybe [] $ M.lookup name bu
let subs = mapMaybe (`M.lookup` sub) (name : uses)
sub <- foldM composey nullSubst (reverse subs)
return . T.DBind $ T.Bind (name, apply sub t) (apply sub args) (apply sub e)
substPrg d = return d
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
@ -51,7 +56,7 @@ preRun (x : xs) = case x of
duplicateDecl n s $ Aux.do duplicateDecl n s $ Aux.do
"Multiple signatures of function" "Multiple signatures of function"
quote $ printTree n quote $ printTree n
insertSig (coerce n) t insertSig (coerce n) (Instantiated t)
preRun xs preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
s <- gets (S.toList . declaredBinds) s <- gets (S.toList . declaredBinds)
@ -64,7 +69,7 @@ preRun (x : xs) = case x of
case M.lookup (coerce n) sigs of case M.lookup (coerce n) sigs of
Nothing -> do Nothing -> do
fr <- fresh fr <- fresh
insertSig (coerce n) fr insertSig (coerce n) (Generalized fr)
preRun xs preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
@ -73,33 +78,38 @@ preRun (x : xs) = case x of
duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m ()
duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg)
checkDef :: [Def] -> Infer (Subst, [T.Def' Type]) checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return (nullSubst, []) checkDef [] = return []
checkDef (x : xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
(sub0, b') <- checkBind b b' <- checkBind b
(sub1, xs') <- checkDef xs xs' <- checkDef xs
comp <- sub0 `composey` sub1 return $ T.DBind b' : xs'
return (comp, T.DBind b' : xs')
(DData d) -> do (DData d) -> do
(sub, xs') <- checkDef xs xs' <- checkDef xs
return (sub, T.DData (coerceData d) : xs') return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
where where
coerceData (Data t injs) = coerceData (Data t injs) =
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (Subst, T.Bind' Type) checkBind :: Bind -> Infer (T.Bind' Type)
checkBind bind@(Bind name args e) = do checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda (sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just t -> do Just t -> do
sub1 <- bindErr (unify t lambda_t) bind let t' = case t of
Instantiated a -> skolemize a
Generalized a -> a
sub1 <- bindErr (unify t' lambda_t) bind
comp <- sub1 `composey` sub0 comp <- sub1 `composey` sub0
return (comp, T.Bind (coerce name, apply comp t) [] (e, lambda_t)) insertBindSubst (coerce name) comp
_ -> error "First pass through failed to add function to env" return (T.Bind (coerce name, apply comp t') [] (e, lambda_t))
_ -> do
uncatchableErr $ "Undeclared function: " ++ printTree name
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do checkData err@(Data typ injs) = do
@ -203,7 +213,6 @@ algoW = \case
quote $ printTree t' quote $ printTree t'
) )
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub2 `compose` sub1 `compose` sub0
-- return (comp, apply comp (e', t))
return (comp, (e', t)) return (comp, (e', t))
-- \| ------------------ -- \| ------------------
@ -221,8 +230,11 @@ algoW = \case
return (nullSubst, (T.EVar $ coerce i, x)) return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
cb <- gets currentBind
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Just t -> do
insertBindUsage cb (coerce i)
return (nullSubst, (T.EVar $ coerce i, unlevel t))
Nothing -> Nothing ->
uncatchableErr $ uncatchableErr $
"Unbound variable: " "Unbound variable: "
@ -247,7 +259,6 @@ algoW = \case
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr let varType = apply s1 fr
let newArr = TFun varType t' let newArr = TFun varType t'
-- return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
return (s1, (T.EAbs (coerce name) (e', t'), newArr)) return (s1, (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
@ -262,7 +273,6 @@ algoW = \case
s3 <- exprErr (unify (apply s2 t0) int) err s3 <- exprErr (unify (apply s2 t0) int) err
s4 <- exprErr (unify (apply s3 t1) int) err s4 <- exprErr (unify (apply s3 t1) int) err
let comp = s4 `compose` s3 `compose` s2 `compose` s1 let comp = s4 `compose` s3 `compose` s2 `compose` s1
-- return (comp, apply comp (T.EAdd (e0', t0) (e1', t1), int))
return (comp, (T.EAdd (e0', t0) (e1', t1), int)) return (comp, (T.EAdd (e0', t0) (e1', t1), int))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
@ -274,13 +284,10 @@ algoW = \case
fr <- fresh fr <- fresh
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do applySt s0 $ do
modify (\st -> st{sigs = apply s0 st.sigs})
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr let t = apply s2 fr
comp <- foldM composey nullSubst [s2, s1, s0] comp <- foldM composey nullSubst [s2, s1, s0]
-- let comp = s2 `compose` s1 `compose` s0
-- return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
return (comp, (T.EApp (e0', t0) (e1', t1), t)) return (comp, (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
@ -289,16 +296,17 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize" -- The bar over S₀ and Γ means "generalize"
err@(ELet b@(Bind name args e) e1) -> do (ELet (Bind name args e) e1) -> do
(s1, (_, t0)) <- algoW (makeLambda e (coerce args)) (s1, (e, t0)) <- algoW (makeLambda e (coerce args))
(_, bind') <- exprErr (checkBind b) err
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t0 let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1 (s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1 let comp = s2 `compose` s1
-- return (comp, apply comp (T.ELet bind' (e1', t2), t2)) return
return (comp, (T.ELet bind' (e1', t2), t2)) ( comp
, (T.ELet (T.Bind (coerce name, t0) [] (e, t0)) (e1', t2), t2)
)
ECase caseExpr injs -> do ECase caseExpr injs -> do
(sub, (e', t)) <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs (subst, injs, ret_t) <- checkCase t injs
@ -413,8 +421,10 @@ unify t0 t1 =
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1 return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t (TVar (MkTVar a), t@(TData _ _)) ->
(t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t return $ coerce $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) ->
return $ coerce $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t (TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t (t, TVar (MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
@ -603,7 +613,8 @@ instance SubstType (T.Exp' Type) where
instance SubstType (T.Def' Type) where instance SubstType (T.Def' Type) where
apply s = \case apply s = \case
T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) T.DBind (T.Bind name args e) ->
T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
d -> d d -> d
instance SubstType (T.Branch' Type) where instance SubstType (T.Branch' Type) where
@ -636,8 +647,14 @@ compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m
-- Order matters. -- Order matters.
{- {-
sub0 = Subst $ (M.singleton "a" (arr d e)) `M.union` (M.singleton "b" (arr d f)) `M.union` (M.singleton "c" (arr f e)) sub0 = Subst $ (M.singleton "a" (arr d e))
sub1 = Subst $ (M.singleton "a" (arr g bool)) `M.union` (M.singleton "b" (arr g bool)) `M.union` (M.singleton "c" (arr bool bool)) `M.union` (M.singleton "h" bool) `M.union` (M.singleton "i" bool) `M.union` (M.singleton "b" (arr d f))
`M.union` (M.singleton "c" (arr f e))
sub1 = Subst $ (M.singleton "a" (arr g bool))
`M.union` (M.singleton "b" (arr g bool))
`M.union` (M.singleton "c" (arr bool bool))
`M.union` (M.singleton "h" bool)
`M.union` (M.singleton "i" bool)
sub0 `composey` sub1 != sub1 `composey` sub0 sub0 `composey` sub1 != sub1 `composey` sub0
-} -}
composey :: Subst -> Subst -> Infer Subst composey :: Subst -> Subst -> Infer Subst
@ -690,12 +707,21 @@ withPattern p ma = case p of
T.PEnum _ -> ma T.PEnum _ -> ma
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Type -> Infer () insertSig :: T.Ident -> Level Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
insertBind :: T.Ident -> Infer () insertBind :: T.Ident -> Infer ()
insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds}) insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds})
insertBindSubst :: T.Ident -> Subst -> Infer ()
insertBindSubst name sub = modify (\st -> st{bindSubs = M.insert name sub st.bindSubs})
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind n = modify (\st -> st{currentBind = n, bindUsages = M.insertWith (++) n [] st.bindUsages})
insertBindUsage :: T.Ident -> T.Ident -> Infer ()
insertBindUsage cur use = modify (\st -> st{bindUsages = M.insertWith (++) cur [use] st.bindUsages})
-- | Insert a constructor into the start with its type -- | Insert a constructor into the start with its type
insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m () insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m ()
insertInj i t = insertInj i t =
@ -786,7 +812,7 @@ dataErr ma d =
) )
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty mempty "" mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
@ -805,20 +831,28 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident Type , sigs :: Map T.Ident (Level Type)
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident , declaredBinds :: Set T.Ident
, currentBind :: T.Ident
, bindSubs :: Map T.Ident Subst
, bindUsages :: Map T.Ident [T.Ident]
} }
deriving (Show) deriving (Show)
data Level a = Instantiated {unlevel :: a} | Generalized {unlevel :: a}
deriving (Show)
data Error = Error {msg :: String, catchable :: Bool} data Error = Error {msg :: String, catchable :: Bool}
deriving (Show) deriving (Show)
newtype Subst = Subst (Map T.Ident Type) newtype Subst = Subst (Map T.Ident Type)
instance Show Subst where instance Show Subst where
show (Subst s) = "[ " ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ " ]" show (Subst s) = "[ " ++ intercalate " | " xs ++ " ]"
where
xs = map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
@ -834,16 +868,3 @@ quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m () ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure () ctrace str a = trace (str ++ ": " ++ show a) pure ()
{-
Save each subst mapped to their respective function
Apply composition of all used functions to the function
a = id 0 ;
b = id 'a' ;
id x = x ;
apply_on_a = id_sub `compose` a_sub
apply_on_b = id_sub `compose` b_sub
apply_on_id = id_sub
-}

View file

@ -188,6 +188,11 @@ instance Print t => Print (Inj' t) where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
instance Print t => Print [Inj' t] where
prt _ [] = concatD []
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print t => Print (Pattern' t) where instance Print t => Print (Pattern' t) where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])

View file

@ -26,46 +26,3 @@ bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ;
bind x f = case x of { bind x f = case x of {
Just x => f x ; Just x => f x ;
Nothing => Nothing ; Nothing => Nothing ;
};
-- represents minus one :)
minusOne : Int ;
minusOne = 9223372036854775807 + 9223372036854775807 + 1;
---- LIST STUFF ----
-- a simple list data type containing ints
data List () where {
Cons : Int -> List () -> List ()
Nil : List ()
};
-- take the length of a list
length : List () -> Int ;
length x = case x of {
Cons _ xs => 1 + length xs ;
Nil => 0 ;
};
-- sum a list
sum : List () -> Int ;
sum x = case x of {
Cons a xs => a + sum xs ;
Nil => 0 ;
};
-- sum + length of a list
sumlength: List () -> Int ;
sumlength x = sum x + length x ;
-- take the head of a list
head : List () -> Int ;
head x = case x of {
Cons h _ => h ;
};
-- repeat an element n times
repeat : Int -> Int -> List () ;
repeat x n = case n of {
0 => Nil ;
n => Cons x (repeat x (n + minusOne)) ;
};