Add module to sort definitions

This commit is contained in:
Martin Fredin 2023-04-28 19:45:15 +02:00
parent de03a2cc34
commit df1a5de04a
4 changed files with 98 additions and 56 deletions

View file

@ -36,6 +36,7 @@ executable language
Renamer.Renamer Renamer.Renamer
TypeChecker.TypeChecker TypeChecker.TypeChecker
AnnForall AnnForall
OrderDefs
TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerBidir
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
@ -90,6 +91,7 @@ Test-suite language-testsuite
Grammar.Skel Grammar.Skel
Grammar.ErrM Grammar.ErrM
Grammar.Layout Grammar.Layout
OrderDefs
Auxiliary Auxiliary
Monomorphizer.Monomorphizer Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr

View file

@ -16,29 +16,21 @@ import Grammar.Par (myLexer, pProgram)
import Grammar.Print (Print, printTree) import Grammar.Print (Print, printTree)
import LambdaLifter (lambdaLift) import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize) import Monomorphizer.Monomorphizer (monomorphize)
import OrderDefs (orderDefs)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import ReportForall (reportForall) import ReportForall (reportForall)
import System.Console.GetOpt ( import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder), ArgOrder (RequireOrder),
OptDescr (Option), OptDescr (Option), getOpt,
getOpt, usageInfo)
usageInfo, import System.Directory (createDirectory, doesPathExist,
)
import System.Directory (
createDirectory,
doesPathExist,
getDirectoryContents, getDirectoryContents,
removeDirectoryRecursive, removeDirectoryRecursive,
setCurrentDirectory, setCurrentDirectory)
)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit ( import System.Exit (ExitCode (ExitFailure),
ExitCode (ExitFailure), exitFailure, exitSuccess,
exitFailure, exitWith)
exitSuccess,
exitWith,
)
import System.IO (stderr) import System.IO (stderr)
import System.Process (spawnCommand, waitForProcess) import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck) import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
@ -112,7 +104,7 @@ main' opts s =
file <- readFile s file <- readFile s
printToErr "-- Parse Tree -- " printToErr "-- Parse Tree -- "
parsed <- fromErr . pProgram . resolveLayout True $ myLexer (file ++ prelude) parsed <- fromErr . pProgram . resolveLayout True $ myLexer file -- (file ++ prelude)
log parsed log parsed
printToErr "-- Desugar --" printToErr "-- Desugar --"
@ -125,7 +117,7 @@ main' opts s =
log renamed log renamed
printToErr "\n-- TypeChecker --" printToErr "\n-- TypeChecker --"
typechecked <- fromErr $ typecheck (fromJust opts.typechecker) renamed typechecked <- fromErr $ typecheck (fromJust opts.typechecker) (orderDefs renamed)
log typechecked log typechecked
printToErr "\n-- Lambda Lifter --" printToErr "\n-- Lambda Lifter --"

43
src/OrderDefs.hs Normal file
View file

@ -0,0 +1,43 @@
{-# LANGUAGE LambdaCase #-}
module OrderDefs where
import Control.Monad.State (State, execState, get, modify, when)
import Data.Function (on)
import Data.List (partition, sortBy)
import Data.Set (Set)
import qualified Data.Set as Set
import Grammar.Abs
orderDefs :: Program -> Program
orderDefs (Program defs) =
Program $ not_binds ++ map DBind (has_sig ++ orderBinds no_sig)
where
(has_sig, no_sig) = partition (\(Bind n _ _) -> elem n sig_names)
[ b | DBind b <- defs]
sig_names = [ n | DSig (Sig n _) <- defs ]
not_binds = flip filter defs $ \case DBind _ -> False
_ -> True
orderBinds :: [Bind] -> [Bind]
orderBinds binds = sortBy (on compare countUniqueCalls) binds
where
bind_names = [ n | Bind n _ _ <- binds]
countUniqueCalls :: Bind -> Int
countUniqueCalls (Bind n _ e) =
Set.size $ execState (go e) (Set.singleton n)
where
go :: Exp -> State (Set LIdent) ()
go exp = get >>= \called -> case exp of
EVar x -> when (Set.notMember x called && elem x bind_names) $
modify (Set.insert x)
EApp e1 e2 -> on (>>) go e1 e2
EAdd e1 e2 -> on (>>) go e1 e2
ELet (Bind _ _ e) e' -> on (>>) go e e'
EAbs _ e -> go e
ECase e bs -> go e >> mapM_ (\(Branch _ e) -> go e) bs
EAnn e _ -> go e
EInj _ -> pure ()
ELit _ -> pure ()

View file

@ -11,7 +11,7 @@ import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
forM, runExceptT, unless, zipWithM, forM, runExceptT, unless, zipWithM,
zipWithM_) zipWithM_)
import Control.Monad.Extra (fromMaybeM) import Control.Monad.Extra (fromMaybeM, ifM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (MonadState, State, evalState, gets,
modify) modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
@ -57,6 +57,7 @@ data Cxt = Cxt
, binds :: Map LIdent Exp -- ^ Top-level binds x : e , binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά , next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A
, currentBind :: LIdent -- ^ Used for recursive functions
} deriving (Show, Eq) } deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
@ -77,6 +78,7 @@ initCxt defs = Cxt
| DData (Data _ injs) <- defs | DData (Data _ injs) <- defs
, Inj name t <- injs , Inj name t <- injs
] ]
, currentBind = ""
} }
where where
unboundedTVars = uncurry (Set.\\) . go (mempty, mempty) unboundedTVars = uncurry (Set.\\) . go (mempty, mempty)
@ -102,6 +104,7 @@ typecheckBinds cxt = flip evalState cxt
typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do typecheckBind (Bind name vars rhs) = do
modify $ \cxt -> cxt { currentBind = name }
bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
Just t -> do Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t (rhs', _) <- check (foldr EAbs rhs vars) t
@ -297,14 +300,16 @@ checkPattern patt t_patt = case patt of
infer :: Exp -> Tc (T.ExpT' Type) infer :: Exp -> Tc (T.ExpT' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit) infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ∌ (x : A) -- Γ ∋ (x : A) Γ ⊢ rec(x)
-- ------------- Var --------------------- Var' -- ------------- Var --------------------- VarRec
-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,(x : ά) -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,(x : ά)
infer (EVar x) = do infer (EVar x) = do
a <- fromMaybeM extend $ liftA2 (<|>) (lookupEnv x) (lookupSig x) a <- ifM (gets $ (x==) . currentBind) varRec var
apply (T.EVar (coerce x), a) apply (T.EVar (coerce x), a)
where where
extend = do var = maybeToRightM "Can't infer" =<<
liftA2 (<|>) (lookupEnv x) (lookupSig x)
varRec = do
alpha <- TEVar <$> fresh alpha <- TEVar <$> fresh
insertEnv (EnvVar x alpha) insertEnv (EnvVar x alpha)
pure alpha pure alpha