diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index ec7b005..af62451 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -118,11 +118,10 @@ checkPrg (Program bs) = do d' <- freshenData d fmap (T.DData d' :) (checkDef xs) --- TODO: Unify top level types with the types of the expressions beneath --- PERHAPS DONE checkBind :: Bind -> Infer T.Bind checkBind (Bind n t _ args e) = do - (t', e) <- inferExp $ makeLambda e (reverse args) + let lambda = makeLambda e (reverse args) + (t', e) <- inferExp lambda s <- unify t' t let t'' = apply s t unless @@ -296,10 +295,11 @@ algoW = \case -- TODO: give caseExpr a concrete type before proceeding -- probably by returning substitutions in the functions used in this body ECase caseExpr injs -> do - (sub, _, e') <- algoW caseExpr - trace ("SUB: " ++ show sub) return () - t <- checkCase caseExpr injs - return (sub, t, T.ECase t e' (map (\(Inj i _) -> T.Inj (i, t) e') injs)) + (sub, t, e') <- algoW caseExpr + (subst, t) <- checkCase t injs + let composition = subst `compose` sub + let t' = apply composition t + return (composition, t', T.ECase t' e' (map (\(Inj i _) -> T.Inj (i, t') e') injs)) -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst @@ -328,12 +328,18 @@ unify t0 t1 = do , printTree name' , "(" ++ printTree t' ++ ")" ] - (a, b) -> + (a, b) -> do + ctx <- ask + env <- get throwError . unwords $ [ "Type:" , printTree a , "can't be unified with:" , printTree b + , "\nCtx:" + , show ctx + , "\nEnv:" + , show env ] {- | Check if a type is contained in another type. @@ -464,23 +470,12 @@ insertConstr i t = -------- PATTERN MATCHING --------- -unifyAll :: [Type] -> Infer [Subst] -unifyAll [] = return [] -unifyAll [_] = return [] -unifyAll (x : y : xs) = do - uni <- unify x y - all <- unifyAll (y : xs) - return $ uni : all - -checkCase :: Exp -> [Inj] -> Infer Type -checkCase e injs = do - expT <- fst <$> inferExp e - (injTs, returns) <- mapAndUnzipM checkInj injs - unifyAll (expT : injTs) - subst <- foldl1 compose <$> zipWithM unify returns (tail returns) - let substed = map (apply subst) returns - unless (allSame substed || null substed) (throwError "Different return types of case, or no cases") - return $ head substed +checkCase :: Type -> [Inj] -> Infer (Subst, Type) +checkCase expT injs = do + (injs, returns) <- mapAndUnzipM checkInj injs + (sub, _) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, (a `apply` acc))) <$> unify x acc) (nullSubst, expT) injs + t <- foldM (\acc x -> (`apply` acc) <$> unify x acc) (head returns) (tail returns) + return (sub, t) {- | fst = type of init | snd = type of expr @@ -510,3 +505,5 @@ flattenType a = [a] litType :: Literal -> Type litType (LInt _) = TMono "Int" + +ctrace a = trace (show a) a diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 475201e..016dd8a 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -23,12 +23,13 @@ data Poly = Forall [Ident] Type deriving (Show) newtype Ctx = Ctx {vars :: Map Ident Poly} + deriving Show data Env = Env { count :: Int , sigs :: Map Ident Type , constructors :: Map Ident Type - } + } deriving Show type Error = String type Subst = Map Ident Type diff --git a/test_program b/test_program index a8accca..2d6fed1 100644 --- a/test_program +++ b/test_program @@ -3,7 +3,13 @@ data Bool () where { False : Bool () }; -main : Bool () -> _Int ; +data Maybe ('a) where { + Nothing : Maybe ('a) + Just : 'a -> Maybe ('a) +}; + +main : Bool () -> Maybe (Bool ()) ; main x = case x of { - 1 => 0 + True => Nothing; + False => Just 0 }