Order binds with signatures same as binds without signatures

This commit is contained in:
Martin Fredin 2023-05-15 00:30:37 +02:00
parent 814ebc1ac0
commit 46d4ef3923

View file

@ -2,26 +2,29 @@
module OrderDefs where module OrderDefs where
import Control.Monad.State (State, execState, get, modify, when) import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on) import Data.Function (on)
import Data.List (partition, sortBy) import Data.List (find, partition, sortBy)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as Set import qualified Data.Set as Set
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
orderDefs :: Program -> Program orderDefs :: Program -> Program
orderDefs (Program defs) = orderDefs (Program defs) =
Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig) Program $ ds ++ ss' ++ concatMap addSig (orderBinds bs)
where where
(has_sig, no_sig) = addSig b
partition | Just sig <- hasSig b = [sig, DBind b]
(\(Bind n _ _) -> elem n sig_names) | otherwise = [DBind b]
[b | DBind b <- defs]
sig_names = [n | DSig (Sig n _) <- defs] hasSig (Bind n _ _) = find (\(DSig (Sig n' _)) -> n' == n) ss
not_binds = flip filter defs $ \case
DBind _ -> False (ss, ss') = partition hasBind [DSig s | DSig s <- defs]
_ -> True hasBind (DSig (Sig n _)) = any (\(Bind n' _ _) -> n' == n) bs
bs = [ b | DBind b <- defs]
ds = [ DData d | DData d <- defs]
orderBinds :: [Bind] -> [Bind] orderBinds :: [Bind] -> [Bind]
orderBinds binds = sortBy (on compare countUniqueCalls) binds orderBinds binds = sortBy (on compare countUniqueCalls) binds
@ -29,7 +32,7 @@ orderBinds binds = sortBy (on compare countUniqueCalls) binds
bind_names = [n | Bind n _ _ <- binds] bind_names = [n | Bind n _ _ <- binds]
countUniqueCalls :: Bind -> Int countUniqueCalls :: Bind -> Int
countUniqueCalls b@(BindS _ _ _) = error $ "Desugar failed to desugar bind correctly: " ++ printTree b countUniqueCalls b@BindS{} = error $ "Desugar failed to desugar bind correctly: " ++ printTree b
countUniqueCalls (Bind n _ e) = countUniqueCalls (Bind n _ e) =
Set.size $ execState (go e) (Set.singleton n) Set.size $ execState (go e) (Set.singleton n)
where where