diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 67af030..dcd715b 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -86,8 +86,8 @@ freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) ex freeVarsOfPattern = Set.fromList . go [] where go acc = \case - PVar (n,_) -> snoc n acc - PInj _ ps -> foldl go acc ps + PVar n -> snoc n acc + PInj _ ps -> foldl go acc $ map fst ps diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs index 86a05b6..c50a7cc 100644 --- a/src/Monomorphizer/Monomorphizer.hs +++ b/src/Monomorphizer/Monomorphizer.hs @@ -32,9 +32,8 @@ import TypeChecker.TypeCheckerIr (Ident (Ident)) import Control.Monad.Reader (MonadReader (ask, local), Reader, asks, runReader, when) -import Control.Monad.State (MonadState, - StateT (runStateT), gets, - modify) +import Control.Monad.State (MonadState, StateT (runStateT), + gets, modify) import Data.Coerce (coerce) import qualified Data.Map as Map import Data.Maybe (fromJust) @@ -50,7 +49,7 @@ newtype EnvM a = EnvM (StateT Output (Reader Env) a) type Output = Map.Map Ident Outputted --- | Data structure describing outputted top-level information, that is +-- | Data structure describing outputted top-level information, that is -- Binds, Polymorphic Data types (monomorphized in a later step) and -- Marked bind, which means that it is in the process of monomorphization -- and should not be monomorphized again. @@ -220,18 +219,18 @@ morphBranch (T.Branch (p, pt) (e, et)) = do pt' <- getMonoFromPoly pt et' <- getMonoFromPoly et env <- ask - (p', newLocals) <- morphPattern pt' (locals env) p + (p', newLocals) <- morphPattern pt' (locals env) (p, pt) local (const env { locals = newLocals }) $ do e' <- morphExp et' e return $ M.Branch (p', pt') (e', et') -- | Morphs pattern (pattern => expression), gives the newly bound local variables. -morphPattern :: M.Type -> Set.Set Ident -> T.Pattern -> EnvM (M.Pattern, Set.Set Ident) -morphPattern expectedType ls = \case - T.PVar (ident, t) -> do t' <- getMonoFromPoly t - return (M.PVar (ident, t'), Set.insert ident ls) - T.PLit (lit, t) -> do t' <- getMonoFromPoly t - return (M.PLit (convertLit lit, t'), ls) +morphPattern :: M.Type -> Set.Set Ident -> (T.Pattern, T.Type) -> EnvM (M.Pattern, Set.Set Ident) +morphPattern expectedType ls (p, t) = case p of + T.PVar ident -> do t' <- getMonoFromPoly t + return (M.PVar (ident, t'), Set.insert ident ls) + T.PLit lit -> do t' <- getMonoFromPoly t + return (M.PLit (convertLit lit, t'), ls) T.PCatch -> return (M.PCatch, ls) -- Constructor ident T.PEnum ident -> do morphCons expectedType ident diff --git a/src/TypeChecker/RemoveForall.hs b/src/TypeChecker/RemoveForall.hs index d4cdd81..886ecb0 100644 --- a/src/TypeChecker/RemoveForall.hs +++ b/src/TypeChecker/RemoveForall.hs @@ -30,13 +30,14 @@ removeForall (Program defs) = Program $ map (DData . rfData) ds ELit lit -> ELit lit EVar name -> EVar name EInj name -> EInj name - rfBranch (Branch (p, t) e) = Branch (rfPattern p, rfType t) (rfExpT e) + rfBranch (Branch p e) = Branch (rfPatternT p) (rfExpT e) + rfPatternT (p, t) = (rfPattern p, rfType t) rfPattern = \case - PVar id -> PVar (rfId id) - PLit (lit, t) -> PLit (lit, rfType t) - PCatch -> PCatch - PEnum name -> PEnum name - PInj name ps -> PInj name (map rfPattern ps) + PVar name -> PVar name + PLit lit -> PLit lit + PCatch -> PCatch + PEnum name -> PEnum name + PInj name ps -> PInj name (map rfPatternT ps) rfType :: R.Type -> Type rfType = \case diff --git a/src/TypeChecker/ReportTEVar.hs b/src/TypeChecker/ReportTEVar.hs index 61ed688..9676b8e 100644 --- a/src/TypeChecker/ReportTEVar.hs +++ b/src/TypeChecker/ReportTEVar.hs @@ -49,13 +49,16 @@ instance ReportTEVar (Exp' G.Type) (Exp' Type) where instance ReportTEVar (Branch' G.Type) (Branch' Type) where reportTEVar (Branch (patt, t_patt) e) = liftA2 Branch (liftA2 (,) (reportTEVar patt) (reportTEVar t_patt)) (reportTEVar e) +instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where + reportTEVar (p, t) = liftA2 (,) (reportTEVar p) (reportTEVar t) + instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where reportTEVar = \case - PVar (name, t) -> PVar . (name,) <$> reportTEVar t - PLit (lit, t) -> PLit . (lit,) <$> reportTEVar t - PCatch -> pure PCatch - PEnum name -> pure $ PEnum name - PInj name ps -> PInj name <$> reportTEVar ps + PVar name -> pure $ PVar name + PLit lit -> pure $ PLit lit + PCatch -> pure PCatch + PEnum name -> pure $ PEnum name + PInj name ps -> PInj name <$> reportTEVar ps instance ReportTEVar (Data' G.Type) (Data' Type) where reportTEVar (Data typ injs) = liftA2 Data (reportTEVar typ) (reportTEVar injs) diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 615169b..714b4c9 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -209,7 +209,7 @@ checkPattern patt t_patt = case patt of -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) PVar x -> do insertEnv $ EnvVar x t_patt - apply (T.PVar (coerce x, t_patt), t_patt) + apply (T.PVar (coerce x), t_patt) -- ------------- -- Γ ⊢ _ ↑ A ⊣ Γ @@ -220,7 +220,7 @@ checkPattern patt t_patt = case patt of -- Γ ⊢ τ ↑ B ⊣ Δ PLit lit -> do subtype (litType lit) t_patt - apply (T.PLit (lit, t_patt), t_patt) + apply (T.PLit lit, t_patt) -- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ -- --------------------------- @@ -249,7 +249,7 @@ checkPattern patt t_patt = case patt of subtype (sub $ getDataId t_inj) t_patt let check p t = checkPattern p =<< apply (sub t) ps' <- zipWithM check ps ts - apply (T.PInj (coerce name) (map fst ps'), t_patt) + apply (T.PInj (coerce name) ps', t_patt) where substituteTVarsOf = \case TAll tvar t -> do @@ -780,10 +780,9 @@ applyBranch (T.Branch (p, t) e) = do applyPattern :: T.Pattern' Type -> Tc (T.Pattern' Type) applyPattern = \case - T.PVar id -> T.PVar <$> apply id - T.PLit (lit, t) -> T.PLit . (lit, ) <$> apply t - T.PInj name ps -> T.PInj name <$> apply ps - p -> pure p + T.PVar id -> T.PVar <$> apply id + T.PInj name ps -> T.PInj name <$> apply ps + p -> pure p applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b) applyPair (x, y) = liftA2 (,) (apply x) (apply y) diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 1560f0d..5ef3f47 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,31 +1,31 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QualifiedDo #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary (int, litType, maybeToRightM, unzip4) -import Auxiliary qualified as Aux -import Control.Monad.Except -import Control.Monad.Identity (Identity, runIdentity) -import Control.Monad.Reader -import Control.Monad.State -import Control.Monad.Writer -import Data.Coerce (coerce) -import Data.Function (on) -import Data.List (foldl', nub, sortOn) -import Data.List.Extra (unsnoc) -import Data.Map (Map) -import Data.Map qualified as M -import Data.Maybe (fromJust) -import Data.Set (Set) -import Data.Set qualified as S -import Debug.Trace (trace) -import Grammar.Abs -import Grammar.Print (printTree) -import TypeChecker.TypeCheckerIr qualified as T +import Auxiliary (int, litType, maybeToRightM, unzip4) +import qualified Auxiliary as Aux +import Control.Monad.Except +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer +import Data.Coerce (coerce) +import Data.Function (on) +import Data.List (foldl', nub, sortOn) +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import qualified Data.Set as S +import Debug.Trace (trace) +import Grammar.Abs +import Grammar.Print (printTree) +import qualified TypeChecker.TypeCheckerIr as T {- TODO @@ -40,7 +40,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning]) typecheck = onLeft msg . run . checkPrg where onLeft :: (Error -> String) -> Either Error a -> Either String a - onLeft f (Left x) = Left $ f x + onLeft f (Left x) = Left $ f x onLeft _ (Right x) = Right x checkPrg :: Program -> Infer (T.Program' Type) @@ -67,13 +67,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs replace :: Map T.Ident T.Ident -> Type -> Type replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of - Just t -> TVar . MkTVar . LIdent $ coerce t + Just t -> TVar . MkTVar . LIdent $ coerce t Nothing -> def replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 replace m (TData name ts) = TData name (map (replace m) ts) replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of Just found -> TAll (MkTVar $ coerce found) (replace m t) - Nothing -> def + Nothing -> def replace _ t = t bindCount :: [Def] -> Infer [(Int, Def)] @@ -127,7 +127,7 @@ preRun (x : xs) = case x of s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs + Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs where -- Check if function body / signature has been declared already @@ -149,11 +149,11 @@ checkDef (x : xs) = case x of T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs freeOrdered :: Type -> [T.Ident] -freeOrdered (TVar (MkTVar a)) = return (coerce a) +freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t -freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b -freeOrdered (TData _ a) = concatMap freeOrdered a -freeOrdered _ = mempty +freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b +freeOrdered (TData _ a) = concatMap freeOrdered a +freeOrdered _ = mempty checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do @@ -227,11 +227,11 @@ checkInj (Inj c inj_typ) name tvars toTVar :: Type -> Either Error TVar toTVar = \case TVar tvar -> pure tvar - _ -> uncatchableErr "Not a type variable" + _ -> uncatchableErr "Not a type variable" returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 -returnType a = a +returnType a = a inferExp :: Exp -> Infer (T.ExpT' Type) inferExp e = do @@ -244,7 +244,7 @@ class CollectTVars a where instance CollectTVars Exp where collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e - collectTVars _ = S.empty + collectTVars _ = S.empty instance CollectTVars Type where collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) @@ -403,22 +403,22 @@ checkCase expT brnchs = do inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type) inferBranch err@(Branch pat expr) = do - newPat@(pat, branchT) <- inferPattern pat + pat@(_, branchT) <- inferPattern pat (sub, newExp@(_, exprT)) <- catchError (withPattern pat (algoW expr)) (\x -> throwError Error{msg = x.msg <> " in pattern '" <> printTree err <> "'", catchable = False}) return ( sub , apply sub branchT - , T.Branch (apply sub newPat) (apply sub newExp) + , T.Branch (apply sub pat) (apply sub newExp) , apply sub exprT ) inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) inferPattern = \case - PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) + PLit lit -> let lt = litType lit in return (T.PLit lit, lt) PCatch -> (T.PCatch,) <$> fresh PVar x -> do fr <- fresh - let pvar = T.PVar (coerce x, fr) + let pvar = T.PVar (coerce x) return (pvar, fr) PEnum p -> do t <- gets (M.lookup (coerce p) . injections) @@ -473,7 +473,7 @@ inferPattern = \case ) sub <- composeAll <$> zipWithM unify vs (map snd patterns) return - ( T.PInj (coerce constr) (apply sub (map fst patterns)) + ( T.PInj (coerce constr) (apply sub patterns) , apply sub ret ) @@ -563,12 +563,12 @@ generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where go :: [T.Ident] -> Type -> Type - go [] t = t + go [] t = t go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) removeForalls :: Type -> Type - removeForalls (TAll _ t) = removeForalls t + removeForalls (TAll _ t) = removeForalls t removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) - removeForalls t = t + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. @@ -617,27 +617,27 @@ currently this is not the case, the TAll pattern match is incorrectly implemente skipForalls :: Type -> Type skipForalls = \case TAll _ t -> skipForalls t - t -> t + t -> t foralls :: Type -> [T.Ident] foralls (TAll (MkTVar a) t) = coerce a : foralls t -foralls _ = [] +foralls _ = [] mkForall :: Type -> Type mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of [] -> t (x : xs) -> - let f acc [] = acc + let f acc [] = acc f acc (x : xs) = f (x acc) xs (y : ys) = reverse $ x : xs in f (y t) ys skolemize :: Type -> Type skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a -skolemize (TAll x t) = TAll x (skolemize t) -skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 -skolemize (TData n ts) = TData n (map skolemize ts) -skolemize t = t +skolemize (TAll x t) = TAll x (skolemize t) +skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 +skolemize (TData n ts) = TData n (map skolemize ts) +skolemize t = t -- | A class for substitutions class SubstType t where @@ -671,10 +671,10 @@ instance SubstType Type where TLit _ -> t TVar (MkTVar a) -> case M.lookup (coerce a) sub of Nothing -> TVar (MkTVar $ coerce a) - Just t -> t + Just t -> t TAll (MkTVar i) t -> case M.lookup (coerce i) sub of Nothing -> TAll (MkTVar i) (apply sub t) - Just _ -> apply sub t + Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) TEVar (MkTEVar _) -> t @@ -718,11 +718,11 @@ instance SubstType (T.Branch' Type) where instance SubstType (T.Pattern' Type) where apply s = \case - T.PVar (iden, t) -> T.PVar (iden, apply s t) - T.PLit (lit, t) -> T.PLit (lit, apply s t) + T.PVar iden -> T.PVar iden + T.PLit lit -> T.PLit lit T.PInj i ps -> T.PInj i $ apply s ps - T.PCatch -> T.PCatch - T.PEnum i -> T.PEnum i + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i instance SubstType (T.Pattern' Type, Type) where apply s (p, t) = (apply s p, apply s t) @@ -761,13 +761,13 @@ withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Run the monadic action with a pattern -withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a -withPattern p ma = case p of - T.PVar (x, t) -> withBinding x t ma +withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a +withPattern (p, t) ma = case p of + T.PVar x -> withBinding x t ma T.PInj _ ps -> foldl' (flip withPattern) ma ps - T.PLit _ -> ma - T.PCatch -> ma - T.PEnum _ -> ma + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma -- | Insert a function signature into the environment insertSig :: T.Ident -> Maybe Type -> Infer () @@ -792,11 +792,11 @@ existInj n = gets (M.lookup n . injections) flattenType :: Type -> [Type] flattenType (TFun a b) = flattenType a <> flattenType b -flattenType a = [a] +flattenType a = [a] typeLength :: Type -> Int typeLength (TFun _ b) = 1 + typeLength b -typeLength _ = 1 +typeLength _ = 1 {- | Catch an error if possible and add the given expression as addition to the error message @@ -879,11 +879,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} deriving (Show) data Env = Env - { count :: Int - , nextChar :: Char - , sigs :: Map T.Ident (Maybe Type) + { count :: Int + , nextChar :: Char + , sigs :: Map T.Ident (Maybe Type) , takenTypeVars :: Set T.Ident - , injections :: Map T.Ident Type + , injections :: Map T.Ident Type , declaredBinds :: Set T.Ident } deriving (Show) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index e898ebe..21f2227 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -153,10 +153,13 @@ instance Print t => Print [Inj' t] where prt i [x] = prt i x prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] +instance Print t => Print (Pattern' t, t) where + prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t]) + instance Print t => Print (Pattern' t) where prt i = \case PVar name -> prPrec i 1 (concatD [prt 0 name]) - PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PLit lit -> prPrec i 1 (concatD [prt 0 lit]) PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PEnum name -> prPrec i 1 (concatD [prt 0 name]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])