diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 0c3df12..779867b 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -9,13 +9,13 @@ import Control.Monad.Reader import Control.Monad.State import Data.Foldable (traverse_) import Data.Functor.Identity (runIdentity) -import Debug.Trace (trace) import Data.List (foldl') import Data.Map (Map) import Data.Map qualified as M +import Data.Maybe (fromMaybe) import Data.Set (Set) import Data.Set qualified as S -import Data.Maybe (fromMaybe) +import Debug.Trace (trace) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr ( @@ -53,16 +53,16 @@ freshenData (Data (Constr name ts) constrs) = do frs <- traverse (const fresh) xs let m = M.fromList $ zip xs frs return $ Data (Constr name (map (freshenType m) ts)) (map (\(Constructor ident t) -> Constructor ident (freshenType m t)) constrs) - + {- | Freshen all polymorphic variables, regardless of name | freshenType "d" (a -> b -> c) becomes (d -> d -> d) -} freshenType :: Map Ident Type -> Type -> Type freshenType m t = case t of - TPol poly -> fromMaybe (error "bug in \'free\'") (M.lookup poly m) - TMono mono -> TMono mono - TArr t1 t2 -> TArr (freshenType m t1) (freshenType m t2) - TConstr (Constr ident ts) -> TConstr (Constr ident (map (freshenType m) ts)) + TPol poly -> fromMaybe (error "bug in \'free\'") (M.lookup poly m) + TMono mono -> TMono mono + TArr t1 t2 -> TArr (freshenType m t1) (freshenType m t2) + TConstr (Constr ident ts) -> TConstr (Constr ident (map (freshenType m) ts)) checkData :: Data -> Infer () checkData d = do @@ -115,10 +115,12 @@ checkPrg (Program bs) = do d' <- freshenData d fmap (T.DData d' :) (checkDef xs) +-- TODO: Unify top level types with the types of the expressions beneath +-- PERHAPS DONE checkBind :: Bind -> Infer T.Bind checkBind (Bind n t _ args e) = do - (t', e') <- inferExp $ makeLambda e (reverse args) - s <- unify t t' + (t', e) <- inferExp $ makeLambda e (reverse args) + s <- unify t' t let t'' = apply s t unless (t `typeEq` t'') @@ -130,7 +132,7 @@ checkBind (Bind n t _ args e) = do , printTree t'' ] ) - return $ T.Bind (n, t) e' + return $ T.Bind (n, t) (apply s e) where makeLambda :: Exp -> [Ident] -> Exp makeLambda = foldl (flip EAbs) @@ -287,7 +289,6 @@ algoW = \case (s2, t2, e1') <- algoW e1 let composition = s2 `compose` s1 return (composition, t2, apply composition $ T.ELet (T.Bind (name, t2) e0') e1') - ECase caseExpr injs -> do (_, t0, e0') <- algoW caseExpr (injs', ts) <- mapAndUnzipM (checkInj t0) injs @@ -340,7 +341,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} occurs :: Ident -> Type -> Infer Subst -occurs _ (TPol _) = return nullSubst +occurs i t@(TPol a) = return (M.singleton i t) occurs i t = if S.member i (free t) then @@ -414,14 +415,14 @@ instance FreeVars T.Exp where free = error "free not implemented for T.Exp" apply :: Subst -> T.Exp -> T.Exp apply s = \case - T.EId (ident, t) -> T.EId (ident, apply s t) - T.ELit t lit -> T.ELit (apply s t) lit - T.ELet (T.Bind (ident, t) e1) e2 -> T.ELet (T.Bind (ident, apply s t) (apply s e1)) (apply s e2) - T.EApp t e1 e2 -> T.EApp (apply s t) (apply s e1) (apply s e2) - T.EAdd t e1 e2 -> T.EAdd (apply s t) (apply s e1) (apply s e2) - T.EAbs t1 (ident, t2) e -> T.EAbs (apply s t1) (ident, apply s t2) (apply s e) - T.ECase t e injs -> T.ECase (apply s t) (apply s e) (apply s injs) - + T.EId (ident, t) -> T.EId (ident, apply s t) + T.ELit t lit -> T.ELit (apply s t) lit + T.ELet (T.Bind (ident, t) e1) e2 -> T.ELet (T.Bind (ident, apply s t) (apply s e1)) (apply s e2) + T.EApp t e1 e2 -> T.EApp (apply s t) (apply s e1) (apply s e2) + T.EAdd t e1 e2 -> T.EAdd (apply s t) (apply s e1) (apply s e2) + T.EAbs t1 (ident, t2) e -> T.EAbs (apply s t1) (ident, apply s t2) (apply s e) + T.ECase t e injs -> T.ECase (apply s t) (apply s e) (apply s injs) + instance FreeVars T.Inj where free :: T.Inj -> Set Ident free = undefined @@ -469,14 +470,12 @@ checkInj caseType (Inj it expr) = do (args, t') <- initType caseType it subst <- unify caseType t' applySt subst $ do - (_, t, e') <- local (\st -> st { vars = args `M.union` vars st }) (algoW expr) + (_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr) return (T.Inj (it, t') e', t) initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType expected = \case - InitLit lit -> error "Pattern match on literals not implemented yet" - InitConstr c args -> do st <- gets constructors case M.lookup c st of diff --git a/test_program b/test_program index 0d74a4e..2470637 100644 --- a/test_program +++ b/test_program @@ -1,10 +1,2 @@ -data Maybe ('a) where { - Nothing : Maybe ('a) - Just : 'a -> Maybe ('a) -}; - id : 'a -> 'a ; -id x = x ; - -main : Maybe ('a -> 'a) ; -main = Just id ; +id = \x. x ;