Removed internal sorter in HM

This commit is contained in:
sebastian 2023-05-13 17:23:10 +02:00
parent a6ed6e589b
commit 86256066b6
2 changed files with 39 additions and 60 deletions

View file

@ -2,42 +2,49 @@
module OrderDefs where
import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on)
import Data.List (partition, sortBy)
import Data.Set (Set)
import qualified Data.Set as Set
import Grammar.Abs
import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on)
import Data.List (partition, sortBy)
import Data.Set (Set)
import Data.Set qualified as Set
import Grammar.Abs
import Grammar.Print (printTree)
orderDefs :: Program -> Program
orderDefs (Program defs) =
Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig)
where
(has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names)
[ b | DBind b <- defs]
sig_names = [ n | DSig (Sig n _) <- defs ]
not_binds = flip filter defs $ \case DBind _ -> False
_ -> True
(has_sig, no_sig) =
partition
(\(Bind n _ _) -> elem n sig_names)
[b | DBind b <- defs]
sig_names = [n | DSig (Sig n _) <- defs]
not_binds = flip filter defs $ \case
DBind _ -> False
_ -> True
orderBinds :: [Bind] -> [Bind]
orderBinds :: [Bind] -> [Bind]
orderBinds binds = sortBy (on compare countUniqueCalls) binds
where
bind_names = [ n | Bind n _ _ <- binds]
bind_names = [n | Bind n _ _ <- binds]
countUniqueCalls :: Bind -> Int
countUniqueCalls b@(BindS _ _ _) = error $ "Desugar failed to desugar bind correctly: " ++ printTree b
countUniqueCalls (Bind n _ e) =
Set.size $ execState (go e) (Set.singleton n)
where
go :: Exp -> State (Set LIdent) ()
go exp = get >>= \called -> case exp of
EVar x -> when (Set.notMember x called && elem x bind_names) $
go exp =
get >>= \called -> case exp of
EVar x ->
when (Set.notMember x called && elem x bind_names) $
modify (Set.insert x)
EApp e1 e2 -> on (>>) go e1 e2
EAdd e1 e2 -> on (>>) go e1 e2
ELet (Bind _ _ e) e' -> on (>>) go e e'
EAbs _ e -> go e
ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs
EAnn e _ -> go e
EInj _ -> pure ()
ELit _ -> pure ()
EApp e1 e2 -> on (>>) go e1 e2
EAdd e1 e2 -> on (>>) go e1 e2
ELet (Bind _ _ e) e' -> on (>>) go e e'
EAbs _ e -> go e
ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs
EAnn e _ -> go e
EInj _ -> pure ()
ELit _ -> pure ()
e -> error $ "Desugar failed to desugar expression correctly: " ++ printTree e

View file

@ -16,7 +16,7 @@ import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List (foldl', nub)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
@ -48,7 +48,6 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
-- sgs <- gets sigs
bs <- map snd . sortOn fst <$> bindCount bs
bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs
@ -77,37 +76,6 @@ replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Nothing -> def
replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)]
bindCount [] = return []
bindCount (x : xs) = do
(o, d) <- go x
b <- bindCount xs
return $ (o, d) : b
where
go :: Def -> Infer (Int, Def)
go b@(DBind (Bind _ _ e)) = do
db <- gets declaredBinds
let n = runIdentity $ evalStateT (countBinds db e) mempty
return (n, b)
go (DSig sig) = pure (0, DSig sig)
go (DData data_) = pure (-1, DData data_)
countBinds :: Set T.Ident -> Exp -> StateT (Set T.Ident) Identity Int
countBinds declared = \case
EVar i -> do
found <- get
if coerce i `S.member` declared && not (coerce i `S.member` found)
then put (S.insert (coerce i) found) >> return 1
else return 0
ELet _ e -> countBinds declared e
EApp e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAdd e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAbs _ e -> countBinds declared e
ECase e1 brnchs -> do
let f (Branch _ e2) = countBinds declared e2
(+) . sum <$> mapM f brnchs <*> countBinds declared e1
_ -> return 0
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
@ -190,9 +158,9 @@ checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(e, infSig) <- inferExp lambda
s <- gets sigs
let genInfSig = generalize mempty infSig
case M.lookup (coerce name) s of
Just (Just typSig) -> do
let genInfSig = generalize mempty infSig
sub <- genInfSig `unify` typSig
b <- genInfSig <<= typSig
unless
@ -211,8 +179,8 @@ checkBind (Bind name args e) = do
-- Unfortunately I do not know a better solution at the moment.
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
_ -> do
insertSig (coerce name) (Just infSig)
return (T.Bind (coerce name, infSig) [] (e, infSig))
insertSig (coerce name) (Just genInfSig)
return (T.Bind (coerce name, infSig) [] (e, genInfSig))
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
@ -620,6 +588,10 @@ inst = \case
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
{-
arrint = TFun (TLit "Int") (TLit "Int")
-}
-- Only one of 'freshen' and 'inst' should be needed but something doesn't work
-- when I remove either.
freshen :: Type -> Infer Type