Tried solving bug, failed, added todo message, fixed printing

This commit is contained in:
sebastianselander 2023-03-29 18:47:14 +02:00
parent 61f364cd75
commit 343be08a4a
2 changed files with 98 additions and 88 deletions

View file

@ -5,18 +5,27 @@ module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (
runExceptT) ExceptT,
import Control.Monad.State (MonadState, State, evalState, gets, MonadError (throwError),
mapAndUnzipM, modify) runExceptT,
)
import Control.Monad.State (
MonadState,
State,
evalState,
gets,
mapAndUnzipM,
modify,
)
import Data.Function (on) import Data.Function (on)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import Data.Map qualified as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second) import Data.Tuple.Extra (dupe, second)
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Err Program rename :: Program -> Err Program
@ -25,11 +34,11 @@ rename (Program defs) = Program <$> renameDefs defs
initCxt :: Cxt initCxt :: Cxt
initCxt = Cxt 0 0 initCxt = Cxt 0 0
data Cxt = Cxt { var_counter :: Int data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int , tvar_counter :: Int
} }
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a} newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
@ -60,47 +69,40 @@ renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
collectTVars tvars = \case collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ show typ) _ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) = renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of substituteTVar new_names typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar
TVar tvar | Just tvar' <- lookup tvar new_names | Just tvar' <- lookup tvar new_names ->
-> TVar tvar' TVar tvar'
| otherwise | otherwise ->
-> typ typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
TAll tvar t | Just tvar' <- lookup tvar new_names | Just tvar' <- lookup tvar new_names ->
-> TAll tvar' $ substitute' t TAll tvar' $ substitute' t
| otherwise | otherwise ->
-> TAll tvar $ substitute' t TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ) _ -> error ("Impossible " ++ show typ)
where where
substitute' = substituteTVar new_names substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names) EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names) EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit) ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
@ -113,12 +115,10 @@ renameExp old_names = \case
(new_names'', rhs') <- renameExp new_names' rhs (new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e (new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e') pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do EAbs par e -> do
(new_names, par') <- newNameL old_names par (new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
t' <- renameTVars t t' <- renameTVars t
@ -147,7 +147,6 @@ renamePattern ns p = case p of
PVar name -> second PVar <$> newNameL ns name PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p) _ -> return (ns, p)
renameTVars :: Type -> Rn Type renameTVars :: Type -> Rn Type
renameTVars typ = case typ of renameTVars typ = case typ of
TAll tvar t -> do TAll tvar t -> do
@ -157,24 +156,25 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ _ -> pure typ
substitute :: TVar -- α substitute ::
-> TVar -- α_n TVar -> -- α
-> Type -- A TVar -> -- α_n
-> Type -- [α_n/α]A Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar | tvar == tvar1 -> TVar tvar2 TVar tvar
| tvar == tvar1 -> TVar tvar2
| otherwise -> typ | otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t TAll tvar t
| tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t | otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible" _ -> error "Impossible"
where where
substitute' = substitute tvar1 tvar2 substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent]) newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL newNamesL = mapAccumM newNameL
@ -185,7 +185,6 @@ newNameL env (LIdent old_name) = do
new_name <- makeName old_name new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name) pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent]) newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU newNamesU = mapAccumM newNameU
@ -196,7 +195,6 @@ newNameU env (UIdent old_name) = do
new_name <- makeName old_name new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name) pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String makeName :: String -> Rn String
makeName prefix = do makeName prefix = do

View file

@ -27,8 +27,11 @@ import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Save all substition sets encountered in the program and apply
-- to all top level functions in the end.
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
@ -53,8 +56,8 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs <- checkDef bs bs <- checkDef bs
sub <- solveUndecidable sub0 <- solveUndecidable
bs <- mapM (mono sub) bs bs <- mapM (mono sub0) bs
return $ T.Program bs return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
@ -74,11 +77,19 @@ preRun (x : xs) = case x of
>>= flip >>= flip
when when
( uncatchableErr $ Aux.do ( uncatchableErr $ Aux.do
"Duplicate signatures for function" "Duplicate signatures of function"
quote $ printTree n quote $ printTree n
) )
insertSig (coerce n) (Just t) >> preRun xs insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
binds <- gets declaredBinds
when
(coerce n `S.member` binds)
( uncatchableErr $ Aux.do
"Duplicate declarations of function"
quote $ printTree n
)
modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds})
collect (collectTVars e) collect (collectTVars e)
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
@ -105,12 +116,12 @@ checkBind :: Bind -> Infer (T.Bind' Type)
checkBind bind@(Bind name args e) = do checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda (sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t') -> do Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t) return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do _ -> do
insertSig (coerce name) (Just lambda_t) insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -178,12 +189,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type) inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed) return (s, (e', subbed))
class CollectTVars a where class CollectTVars a where
collectTVars :: a -> Set T.Ident collectTVars :: a -> Set T.Ident
@ -851,6 +862,7 @@ data Env = Env
, currentBind :: T.Ident , currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type , undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident , toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)