Removed internal sorter in HM

This commit is contained in:
sebastian 2023-05-13 17:23:10 +02:00
parent a6ed6e589b
commit 86256066b6
2 changed files with 39 additions and 60 deletions

View file

@ -6,18 +6,21 @@ import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on) import Data.Function (on)
import Data.List (partition, sortBy) import Data.List (partition, sortBy)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as Set import Data.Set qualified as Set
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree)
orderDefs :: Program -> Program orderDefs :: Program -> Program
orderDefs (Program defs) = orderDefs (Program defs) =
Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig) Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig)
where where
(has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names) (has_sig, no_sig) =
partition
(\(Bind n _ _) -> elem n sig_names)
[b | DBind b <- defs] [b | DBind b <- defs]
sig_names = [n | DSig (Sig n _) <- defs] sig_names = [n | DSig (Sig n _) <- defs]
not_binds = flip filter defs $ \case DBind _ -> False not_binds = flip filter defs $ \case
DBind _ -> False
_ -> True _ -> True
orderBinds :: [Bind] -> [Bind] orderBinds :: [Bind] -> [Bind]
@ -26,12 +29,15 @@ orderBinds binds = sortBy (on compare countUniqueCalls) binds
bind_names = [n | Bind n _ _ <- binds] bind_names = [n | Bind n _ _ <- binds]
countUniqueCalls :: Bind -> Int countUniqueCalls :: Bind -> Int
countUniqueCalls b@(BindS _ _ _) = error $ "Desugar failed to desugar bind correctly: " ++ printTree b
countUniqueCalls (Bind n _ e) = countUniqueCalls (Bind n _ e) =
Set.size $ execState (go e) (Set.singleton n) Set.size $ execState (go e) (Set.singleton n)
where where
go :: Exp -> State (Set LIdent) () go :: Exp -> State (Set LIdent) ()
go exp = get >>= \called -> case exp of go exp =
EVar x -> when (Set.notMember x called && elem x bind_names) $ get >>= \called -> case exp of
EVar x ->
when (Set.notMember x called && elem x bind_names) $
modify (Set.insert x) modify (Set.insert x)
EApp e1 e2 -> on (>>) go e1 e2 EApp e1 e2 -> on (>>) go e1 e2
EAdd e1 e2 -> on (>>) go e1 e2 EAdd e1 e2 -> on (>>) go e1 e2
@ -41,3 +47,4 @@ orderBinds binds = sortBy (on compare countUniqueCalls) binds
EAnn e _ -> go e EAnn e _ -> go e
EInj _ -> pure () EInj _ -> pure ()
ELit _ -> pure () ELit _ -> pure ()
e -> error $ "Desugar failed to desugar expression correctly: " ++ printTree e

View file

@ -16,7 +16,7 @@ import Control.Monad.State
import Control.Monad.Writer import Control.Monad.Writer
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
@ -48,7 +48,6 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
-- sgs <- gets sigs -- sgs <- gets sigs
bs <- map snd . sortOn fst <$> bindCount bs
bs <- checkDef bs bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs -- return . prettify sgs . T.Program $ bs
return . T.Program $ bs return . T.Program $ bs
@ -77,37 +76,6 @@ replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Nothing -> def Nothing -> def
replace _ t = t replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)]
bindCount [] = return []
bindCount (x : xs) = do
(o, d) <- go x
b <- bindCount xs
return $ (o, d) : b
where
go :: Def -> Infer (Int, Def)
go b@(DBind (Bind _ _ e)) = do
db <- gets declaredBinds
let n = runIdentity $ evalStateT (countBinds db e) mempty
return (n, b)
go (DSig sig) = pure (0, DSig sig)
go (DData data_) = pure (-1, DData data_)
countBinds :: Set T.Ident -> Exp -> StateT (Set T.Ident) Identity Int
countBinds declared = \case
EVar i -> do
found <- get
if coerce i `S.member` declared && not (coerce i `S.member` found)
then put (S.insert (coerce i) found) >> return 1
else return 0
ELet _ e -> countBinds declared e
EApp e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAdd e1 e2 -> (+) <$> countBinds declared e1 <*> countBinds declared e2
EAbs _ e -> countBinds declared e
ECase e1 brnchs -> do
let f (Branch _ e2) = countBinds declared e2
(+) . sum <$> mapM f brnchs <*> countBinds declared e1
_ -> return 0
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
@ -190,9 +158,9 @@ checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(e, infSig) <- inferExp lambda (e, infSig) <- inferExp lambda
s <- gets sigs s <- gets sigs
let genInfSig = generalize mempty infSig
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just typSig) -> do Just (Just typSig) -> do
let genInfSig = generalize mempty infSig
sub <- genInfSig `unify` typSig sub <- genInfSig `unify` typSig
b <- genInfSig <<= typSig b <- genInfSig <<= typSig
unless unless
@ -211,8 +179,8 @@ checkBind (Bind name args e) = do
-- Unfortunately I do not know a better solution at the moment. -- Unfortunately I do not know a better solution at the moment.
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig) return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
_ -> do _ -> do
insertSig (coerce name) (Just infSig) insertSig (coerce name) (Just genInfSig)
return (T.Bind (coerce name, infSig) [] (e, infSig)) return (T.Bind (coerce name, infSig) [] (e, genInfSig))
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do checkData err@(Data typ injs) = do
@ -620,6 +588,10 @@ inst = \case
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
{-
arrint = TFun (TLit "Int") (TLit "Int")
-}
-- Only one of 'freshen' and 'inst' should be needed but something doesn't work -- Only one of 'freshen' and 'inst' should be needed but something doesn't work
-- when I remove either. -- when I remove either.
freshen :: Type -> Infer Type freshen :: Type -> Infer Type