diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index d909e49..7e59793 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -7,17 +7,16 @@ module TypeChecker.TypeChecker where import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor (second) import Data.Foldable (traverse_) -import Data.Function (on) import Data.Functor.Identity (runIdentity) import Data.List (foldl') -import Data.List.Extra (allSame) +import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Set (Set) import Data.Set qualified as S -import Data.Tree (flatten) import Debug.Trace (trace) import Grammar.Abs import Grammar.Print (printTree) @@ -296,10 +295,13 @@ algoW = \case -- probably by returning substitutions in the functions used in this body ECase caseExpr injs -> do (sub, t, e') <- algoW caseExpr - (subst, t) <- checkCase t injs + (subst, inj_t, ret_t) <- checkCase t injs let composition = subst `compose` sub let t' = apply composition t - return (composition, t', T.ECase t' e' (map (\(Inj i _) -> T.Inj (i, t') e') injs)) + trace ("COMPOSITION: " ++ show composition) return () + trace ("T: " ++ show t) return () + trace ("T': " ++ show t') return () + return (composition, t', T.ECase t' e' (map (\(Inj i _) -> T.Inj (i, inj_t) e') injs)) -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst @@ -337,7 +339,7 @@ unify t0 t1 = do , "can't be unified with:" , printTree b , "\nCtx:" - , show ctx + , show ctx , "\nEnv:" , show env ] @@ -462,6 +464,10 @@ fresh = do withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) +-- | Run the monadic action with several additional bindings +withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, Poly)] -> m a -> m a +withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) + -- | Insert a function signature into the environment insertSig :: Ident -> Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) @@ -473,34 +479,39 @@ insertConstr i t = -------- PATTERN MATCHING --------- -checkCase :: Type -> [Inj] -> Infer (Subst, Type) +checkCase :: Type -> [Inj] -> Infer (Subst, Type, Type) checkCase expT injs = do (injs, returns) <- mapAndUnzipM checkInj injs - (sub, _) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, (a `apply` acc))) <$> unify x acc) (nullSubst, expT) injs - t <- foldM (\acc x -> (`apply` acc) <$> unify x acc) (head returns) (tail returns) - return (sub, t) + (sub, injs_type) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc) (nullSubst, expT) injs + (_, returns_type) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc) (nullSubst, head returns) (tail returns) + return (sub, injs_type, returns_type) {- | fst = type of init -| snd = type of expr + | snd = type of expr -} checkInj :: Inj -> Infer (Type, Type) checkInj (Inj it expr) = do - initT <- inferInit it - (exprT, _) <- inferExp expr + (initT, vars) <- inferInit it + let converted = map (second (Forall [])) vars + (exprT, _) <- withBindings converted (inferExp expr) return (initT, exprT) -inferInit :: Init -> Infer Type +inferInit :: Init -> Infer (Type, [T.Id]) inferInit = \case - InitLit lit -> return $ litType lit + InitLit lit -> return (litType lit, mempty) InitConstr fn vars -> do gets (M.lookup fn . constructors) >>= \case Nothing -> throwError $ "Constructor: " ++ printTree fn ++ " does not exist" Just a -> do - let ft = init $ flattenType a - case compare (length vars) (length ft) of - EQ -> return . last $ flattenType a - _ -> throwError "Partial pattern match not allowed" - InitCatch -> fresh + case unsnoc $ flattenType a of + Nothing -> throwError "Partial pattern match not allowed" + Just (vs, ret) -> + case length vars `compare` length vs of + EQ -> do + trace ("IDS AND TYPES: " ++ show (zip vars vs)) return () + return (ret, zip vars vs) + _ -> throwError "Partial pattern match not allowed" + InitCatch -> (,mempty) <$> fresh flattenType :: Type -> [Type] flattenType (TArr a b) = flattenType a ++ flattenType b diff --git a/test_program b/test_program index 2d6fed1..e420e37 100644 --- a/test_program +++ b/test_program @@ -1,15 +1,24 @@ -data Bool () where { - True : Bool () - False : Bool () -}; +-- data Bool () where { +-- True : Bool () +-- False : Bool () +-- }; data Maybe ('a) where { Nothing : Maybe ('a) Just : 'a -> Maybe ('a) }; -main : Bool () -> Maybe (Bool ()) ; -main x = case x of { - True => Nothing; - False => Just 0 -} +-- main : Bool () -> Maybe (Bool ()) ; +-- main x = +-- case x of { +-- True => Nothing; +-- False => Just True +-- }; + +fun : Maybe (_Int) -> _Int ; +fun a = + case a of { + Just b => b; + Nothing => 0 + }; +