From 86256066b62da7719a5770e5e03334c34325c08f Mon Sep 17 00:00:00 2001 From: sebastian Date: Sat, 13 May 2023 17:23:10 +0200 Subject: [PATCH] Removed internal sorter in HM --- src/OrderDefs.hs | 55 ++++++++++++++++++-------------- src/TypeChecker/TypeCheckerHm.hs | 44 +++++-------------------- 2 files changed, 39 insertions(+), 60 deletions(-) diff --git a/src/OrderDefs.hs b/src/OrderDefs.hs index 079512b..fed2755 100644 --- a/src/OrderDefs.hs +++ b/src/OrderDefs.hs @@ -2,42 +2,49 @@ module OrderDefs where -import Control.Monad.State (State, execState, get, modify, when) -import Data.Function (on) -import Data.List (partition, sortBy) -import Data.Set (Set) -import qualified Data.Set as Set -import Grammar.Abs +import Control.Monad.State (State, execState, get, modify, when) +import Data.Function (on) +import Data.List (partition, sortBy) +import Data.Set (Set) +import Data.Set qualified as Set +import Grammar.Abs +import Grammar.Print (printTree) orderDefs :: Program -> Program orderDefs (Program defs) = Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig) - where - (has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names) - [ b | DBind b <- defs] - sig_names = [ n | DSig (Sig n _) <- defs ] - not_binds = flip filter defs $ \case DBind _ -> False - _ -> True + (has_sig, no_sig) = + partition + (\(Bind n _ _) -> elem n sig_names) + [b | DBind b <- defs] + sig_names = [n | DSig (Sig n _) <- defs] + not_binds = flip filter defs $ \case + DBind _ -> False + _ -> True -orderBinds :: [Bind] -> [Bind] +orderBinds :: [Bind] -> [Bind] orderBinds binds = sortBy (on compare countUniqueCalls) binds where - bind_names = [ n | Bind n _ _ <- binds] + bind_names = [n | Bind n _ _ <- binds] countUniqueCalls :: Bind -> Int + countUniqueCalls b@(BindS _ _ _) = error $ "Desugar failed to desugar bind correctly: " ++ printTree b countUniqueCalls (Bind n _ e) = Set.size $ execState (go e) (Set.singleton n) where go :: Exp -> State (Set LIdent) () - go exp = get >>= \called -> case exp of - EVar x -> when (Set.notMember x called && elem x bind_names) $ + go exp = + get >>= \called -> case exp of + EVar x -> + when (Set.notMember x called && elem x bind_names) $ modify (Set.insert x) - EApp e1 e2 -> on (>>) go e1 e2 - EAdd e1 e2 -> on (>>) go e1 e2 - ELet (Bind _ _ e) e' -> on (>>) go e e' - EAbs _ e -> go e - ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs - EAnn e _ -> go e - EInj _ -> pure () - ELit _ -> pure () + EApp e1 e2 -> on (>>) go e1 e2 + EAdd e1 e2 -> on (>>) go e1 e2 + ELet (Bind _ _ e) e' -> on (>>) go e e' + EAbs _ e -> go e + ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs + EAnn e _ -> go e + EInj _ -> pure () + ELit _ -> pure () + e -> error $ "Desugar failed to desugar expression correctly: " ++ printTree e diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 45725aa..a371977 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -16,7 +16,7 @@ import Control.Monad.State import Control.Monad.Writer import Data.Coerce (coerce) import Data.Function (on) -import Data.List (foldl', nub, sortOn) +import Data.List (foldl', nub) import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M @@ -48,7 +48,6 @@ checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs -- sgs <- gets sigs - bs <- map snd . sortOn fst <$> bindCount bs bs <- checkDef bs -- return . prettify sgs . T.Program $ bs return . T.Program $ bs @@ -77,37 +76,6 @@ replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of Nothing -> def replace _ t = t -bindCount :: [Def] -> Infer [(Int, Def)] -bindCount [] = return [] -bindCount (x : xs) = do - (o, d) <- go x - b <- bindCount xs - return $ (o, d) : b - where - go :: Def -> Infer (Int, Def) - go b@(DBind (Bind _ _ e)) = do - db <- gets declaredBinds - let n = runIdentity $ evalStateT (countBinds db e) mempty - return (n, b) - go (DSig sig) = pure (0, DSig sig) - go (DData data_) = pure (-1, DData data_) - - countBinds :: Set T.Ident -> Exp -> StateT (Set T.Ident) Identity Int - countBinds declared = \case - EVar i -> do - found <- get - if coerce i `S.member` declared && not (coerce i `S.member` found) - then put (S.insert (coerce i) found) >> return 1 - else return 0 - ELet _ e -> countBinds declared e - EApp e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2 - EAdd e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2 - EAbs _ e -> countBinds declared e - ECase e1 brnchs -> do - let f (Branch _ e2) = countBinds declared e2 - (+) . sum <$> mapM f brnchs <*> countBinds declared e1 - _ -> return 0 - preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of @@ -190,9 +158,9 @@ checkBind (Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) (e, infSig) <- inferExp lambda s <- gets sigs + let genInfSig = generalize mempty infSig case M.lookup (coerce name) s of Just (Just typSig) -> do - let genInfSig = generalize mempty infSig sub <- genInfSig `unify` typSig b <- genInfSig <<= typSig unless @@ -211,8 +179,8 @@ checkBind (Bind name args e) = do -- Unfortunately I do not know a better solution at the moment. return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig) _ -> do - insertSig (coerce name) (Just infSig) - return (T.Bind (coerce name, infSig) [] (e, infSig)) + insertSig (coerce name) (Just genInfSig) + return (T.Bind (coerce name, infSig) [] (e, genInfSig)) checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData err@(Data typ injs) = do @@ -620,6 +588,10 @@ inst = \case TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest +{- +arrint = TFun (TLit "Int") (TLit "Int") +-} + -- Only one of 'freshen' and 'inst' should be needed but something doesn't work -- when I remove either. freshen :: Type -> Infer Type