{-# LANGUAGE LambdaCase #-} module OrderDefs where import Control.Monad.State (State, execState, get, modify, when) import Data.Function (on) import Data.List (partition, sortBy) import Data.Set (Set) import qualified Data.Set as Set import Grammar.Abs orderDefs :: Program -> Program orderDefs (Program defs) = Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig) where (has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names) [ b | DBind b <- defs] sig_names = [ n | DSig (Sig n _) <- defs ] not_binds = flip filter defs $ \case DBind _ -> False _ -> True orderBinds :: [Bind] -> [Bind] orderBinds binds = sortBy (on compare countUniqueCalls) binds where bind_names = [ n | Bind n _ _ <- binds] countUniqueCalls :: Bind -> Int countUniqueCalls (Bind n _ e) = Set.size $ execState (go e) (Set.singleton n) where go :: Exp -> State (Set LIdent) () go exp = get >>= \called -> case exp of EVar x -> when (Set.notMember x called && elem x bind_names) $ modify (Set.insert x) EApp e1 e2 -> on (>>) go e1 e2 EAdd e1 e2 -> on (>>) go e1 e2 ELet (Bind _ _ e) e' -> on (>>) go e e' EAbs _ e -> go e ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs EAnn e _ -> go e EInj _ -> pure () ELit _ -> pure ()