churf/src/OrderDefs.hs
2023-04-28 19:45:15 +02:00

43 lines
1.6 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
module OrderDefs where
import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on)
import Data.List (partition, sortBy)
import Data.Set (Set)
import qualified Data.Set as Set
import Grammar.Abs
orderDefs :: Program -> Program
orderDefs (Program defs) =
Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig)
where
(has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names)
[ b | DBind b <- defs]
sig_names = [ n | DSig (Sig n _) <- defs ]
not_binds = flip filter defs $ \case DBind _ -> False
_ -> True
orderBinds :: [Bind] -> [Bind]
orderBinds binds = sortBy (on compare countUniqueCalls) binds
where
bind_names = [ n | Bind n _ _ <- binds]
countUniqueCalls :: Bind -> Int
countUniqueCalls (Bind n _ e) =
Set.size $ execState (go e) (Set.singleton n)
where
go :: Exp -> State (Set LIdent) ()
go exp = get >>= \called -> case exp of
EVar x -> when (Set.notMember x called && elem x bind_names) $
modify (Set.insert x)
EApp e1 e2 -> on (>>) go e1 e2
EAdd e1 e2 -> on (>>) go e1 e2
ELet (Bind _ _ e) e' -> on (>>) go e e'
EAbs _ e -> go e
ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs
EAnn e _ -> go e
EInj _ -> pure ()
ELit _ -> pure ()