diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index 0a67e22..48ec228 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -1,22 +1,31 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} module Renamer.Renamer (rename) where -import Auxiliary (mapAccumM) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad.Except (ExceptT, MonadError (throwError), - runExceptT) -import Control.Monad.State (MonadState, State, evalState, gets, - mapAndUnzipM, modify) -import Data.Function (on) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) -import Data.Tuple.Extra (dupe, second) -import Grammar.Abs -import Grammar.ErrM (Err) - +import Auxiliary (mapAccumM) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.Except ( + ExceptT, + MonadError (throwError), + runExceptT, + ) +import Control.Monad.State ( + MonadState, + State, + evalState, + gets, + mapAndUnzipM, + modify, + ) +import Data.Function (on) +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Maybe (fromMaybe) +import Data.Tuple.Extra (dupe, second) +import Grammar.Abs +import Grammar.ErrM (Err) +import Grammar.Print (printTree) -- | Rename all variables and local binds rename :: Program -> Err Program @@ -25,14 +34,14 @@ rename (Program defs) = Program <$> renameDefs defs initCxt :: Cxt initCxt = Cxt 0 0 -data Cxt = Cxt { var_counter :: Int - , tvar_counter :: Int - } - +data Cxt = Cxt + { var_counter :: Int + , tvar_counter :: Int + } -- | Rename monad. State holds the number of renamed names. -newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a } - deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) +newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a} + deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) -- | Maps old to new name type Names = Map String String @@ -40,67 +49,60 @@ type Names = Map String String renameDefs :: [Def] -> Err [Def] renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt where - initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs] + initNames = Map.fromList [dupe s | DBind (Bind (LIdent s) _ _) <- defs] renameDef :: Def -> Rn Def renameDef = \case DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DBind (Bind name vars rhs) -> do (new_names, vars') <- newNamesL initNames vars - rhs' <- snd <$> renameExp new_names rhs + rhs' <- snd <$> renameExp new_names rhs pure . DBind $ Bind name vars' rhs' DData (Data typ injs) -> do tvars <- collectTVars [] typ tvars' <- mapM nextNameTVar tvars let tvars_lt = zip tvars tvars' - typ' = substituteTVar tvars_lt typ + typ' = substituteTVar tvars_lt typ injs' = map (renameInj tvars_lt) injs pure . DData $ Data typ' injs' where collectTVars tvars = \case - TAll tvar t -> collectTVars (tvar:tvars) t - TData _ _ -> pure tvars - _ -> throwError ("Bad data type definition: " ++ show typ) + TAll tvar t -> collectTVars (tvar : tvars) t + TData _ _ -> pure tvars + _ -> throwError ("Bad data type definition: " ++ printTree typ) renameInj :: [(TVar, TVar)] -> Inj -> Inj renameInj new_types (Inj name typ) = Inj name $ substituteTVar new_types typ - substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar new_names typ = case typ of TLit _ -> typ - - TVar tvar | Just tvar' <- lookup tvar new_names - -> TVar tvar' - | otherwise - -> typ - + TVar tvar + | Just tvar' <- lookup tvar new_names -> + TVar tvar' + | otherwise -> + typ TFun t1 t2 -> on TFun substitute' t1 t2 - - TAll tvar t | Just tvar' <- lookup tvar new_names - -> TAll tvar' $ substitute' t - | otherwise - -> TAll tvar $ substitute' t - + TAll tvar t + | Just tvar' <- lookup tvar new_names -> + TAll tvar' $ substitute' t + | otherwise -> + TAll tvar $ substitute' t TData name typs -> TData name $ map substitute' typs - _ -> error ("Impossible " ++ show typ) + _ -> error ("Impossible " ++ show typ) where substitute' = substituteTVar new_names - renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp old_names = \case - EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names) - EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names) - - ELit lit -> pure (old_names, ELit lit) - + EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names) + EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names) + ELit lit -> pure (old_names, ELit lit) EApp e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 pure (Map.union env1 env2, EApp e1' e2') - EAdd e1 e2 -> do (env1, e1') <- renameExp old_names e1 (env2, e2') <- renameExp old_names e2 @@ -111,14 +113,12 @@ renameExp old_names = \case (new_names, name') <- newNameL old_names name (new_names', vars') <- newNamesL new_names vars (new_names'', rhs') <- renameExp new_names' rhs - (new_names''', e') <- renameExp new_names'' e + (new_names''', e') <- renameExp new_names'' e pure (new_names''', ELet (Bind name' vars' rhs') e') - - EAbs par e -> do + EAbs par e -> do (new_names, par') <- newNameL old_names par - (new_names', e') <- renameExp new_names e + (new_names', e') <- renameExp new_names e pure (new_names', EAbs par' e') - EAnn e t -> do (new_names, e') <- renameExp old_names e t' <- renameTVars t @@ -145,8 +145,7 @@ renamePattern ns p = case p of (ns_new, ps') <- mapAccumM renamePattern ns ps return (ns_new, PInj cs ps') PVar name -> second PVar <$> newNameL ns name - _ -> return (ns, p) - + _ -> return (ns, p) renameTVars :: Type -> Rn Type renameTVars typ = case typ of @@ -157,24 +156,25 @@ renameTVars typ = case typ of TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) _ -> pure typ -substitute :: TVar -- α - -> TVar -- α_n - -> Type -- A - -> Type -- [α_n/α]A +substitute :: + TVar -> -- α + TVar -> -- α_n + Type -> -- A + Type -- [α_n/α]A substitute tvar1 tvar2 typ = case typ of - TLit _ -> typ - TVar tvar | tvar == tvar1 -> TVar tvar2 - | otherwise -> typ - TFun t1 t2 -> on TFun substitute' t1 t2 - TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t - | otherwise -> TAll tvar $ substitute' t - TData name typs -> TData name $ map substitute' typs - _ -> error "Impossible" + TLit _ -> typ + TVar tvar + | tvar == tvar1 -> TVar tvar2 + | otherwise -> typ + TFun t1 t2 -> on TFun substitute' t1 t2 + TAll tvar t + | tvar == tvar1 -> TAll tvar2 $ substitute' t + | otherwise -> TAll tvar $ substitute' t + TData name typs -> TData name $ map substitute' typs + _ -> error "Impossible" where substitute' = substitute tvar1 tvar2 - - -- | Create multiple names and add them to the name environment newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent]) newNamesL = mapAccumM newNameL @@ -185,7 +185,6 @@ newNameL env (LIdent old_name) = do new_name <- makeName old_name pure (Map.insert old_name new_name env, LIdent new_name) - -- | Create multiple names and add them to the name environment newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent]) newNamesU = mapAccumM newNameU @@ -196,18 +195,17 @@ newNameU env (UIdent old_name) = do new_name <- makeName old_name pure (Map.insert old_name new_name env, UIdent new_name) - -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. makeName :: String -> Rn String makeName prefix = do - i <- gets var_counter - let name = prefix ++ "_" ++ show i - modify $ \cxt -> cxt { var_counter = succ cxt.var_counter} - pure name + i <- gets var_counter + let name = prefix ++ "_" ++ show i + modify $ \cxt -> cxt{var_counter = succ cxt.var_counter} + pure name nextNameTVar :: TVar -> Rn TVar -nextNameTVar (MkTVar (LIdent s))= do - i <- gets tvar_counter - let tvar = MkTVar . LIdent $ s ++ "_" ++ show i - modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter} - pure tvar +nextNameTVar (MkTVar (LIdent s)) = do + i <- gets tvar_counter + let tvar = MkTVar . LIdent $ s ++ "_" ++ show i + modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter} + pure tvar diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 3ae6df2..166e680 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -27,8 +27,11 @@ import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T +-- TODO: Save all substition sets encountered in the program and apply +-- to all top level functions in the end. + initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty +initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty run :: Infer a -> Either Error a run = run' initEnv initCtx @@ -53,8 +56,8 @@ checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs bs <- checkDef bs - sub <- solveUndecidable - bs <- mapM (mono sub) bs + sub0 <- solveUndecidable + bs <- mapM (mono sub0) bs return $ T.Program bs mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) @@ -74,11 +77,19 @@ preRun (x : xs) = case x of >>= flip when ( uncatchableErr $ Aux.do - "Duplicate signatures for function" + "Duplicate signatures of function" quote $ printTree n ) insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do + binds <- gets declaredBinds + when + (coerce n `S.member` binds) + ( uncatchableErr $ Aux.do + "Duplicate declarations of function" + quote $ printTree n + ) + modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds}) collect (collectTVars e) s <- gets sigs case M.lookup (coerce n) s of @@ -105,12 +116,12 @@ checkBind :: Bind -> Infer (T.Bind' Type) checkBind bind@(Bind name args e) = do setCurrentBind $ coerce name let lambda = makeLambda e (reverse (coerce args)) - (e, lambda_t) <- inferExp lambda + (sub0, (e, lambda_t)) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of Just (Just t') -> do sub1 <- bindErr (unify lambda_t (skolemize t')) bind - return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t) + return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t) _ -> do insertSig (coerce name) (Just lambda_t) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) @@ -178,12 +189,12 @@ returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 returnType a = a -inferExp :: Exp -> Infer (T.ExpT' Type) +inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) - return (e', subbed) + return (s, (e', subbed)) class CollectTVars a where collectTVars :: a -> Set T.Ident @@ -851,6 +862,7 @@ data Env = Env , currentBind :: T.Ident , undecidedSigs :: Map T.Ident Type , toDecide :: Set T.Ident + , declaredBinds :: Set T.Ident } deriving (Show)