Tried solving bug, failed, added todo message, fixed printing

This commit is contained in:
sebastianselander 2023-03-29 18:47:14 +02:00
parent 61f364cd75
commit 343be08a4a
2 changed files with 98 additions and 88 deletions

View file

@ -1,22 +1,31 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
module Renamer.Renamer (rename) where module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (
runExceptT) ExceptT,
import Control.Monad.State (MonadState, State, evalState, gets, MonadError (throwError),
mapAndUnzipM, modify) runExceptT,
import Data.Function (on) )
import Data.Map (Map) import Control.Monad.State (
import qualified Data.Map as Map MonadState,
import Data.Maybe (fromMaybe) State,
import Data.Tuple.Extra (dupe, second) evalState,
import Grammar.Abs gets,
import Grammar.ErrM (Err) mapAndUnzipM,
modify,
)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Err Program rename :: Program -> Err Program
@ -25,14 +34,14 @@ rename (Program defs) = Program <$> renameDefs defs
initCxt :: Cxt initCxt :: Cxt
initCxt = Cxt 0 0 initCxt = Cxt 0 0
data Cxt = Cxt { var_counter :: Int data Cxt = Cxt
, tvar_counter :: Int { var_counter :: Int
} , tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a } newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name -- | Maps old to new name
type Names = Map String String type Names = Map String String
@ -40,67 +49,60 @@ type Names = Map String String
renameDefs :: [Def] -> Err [Def] renameDefs :: [Def] -> Err [Def]
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
where where
initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs] initNames = Map.fromList [dupe s | DBind (Bind (LIdent s) _ _) <- defs]
renameDef :: Def -> Rn Def renameDef :: Def -> Rn Def
renameDef = \case renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNamesL initNames vars (new_names, vars') <- newNamesL initNames vars
rhs' <- snd <$> renameExp new_names rhs rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name vars' rhs' pure . DBind $ Bind name vars' rhs'
DData (Data typ injs) -> do DData (Data typ injs) -> do
tvars <- collectTVars [] typ tvars <- collectTVars [] typ
tvars' <- mapM nextNameTVar tvars tvars' <- mapM nextNameTVar tvars
let tvars_lt = zip tvars tvars' let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs' pure . DData $ Data typ' injs'
where where
collectTVars tvars = \case collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar:tvars) t TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ show typ) _ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) = renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of substituteTVar new_names typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar
TVar tvar | Just tvar' <- lookup tvar new_names | Just tvar' <- lookup tvar new_names ->
-> TVar tvar' TVar tvar'
| otherwise | otherwise ->
-> typ typ
TFun t1 t2 -> on TFun substitute' t1 t2 TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
TAll tvar t | Just tvar' <- lookup tvar new_names | Just tvar' <- lookup tvar new_names ->
-> TAll tvar' $ substitute' t TAll tvar' $ substitute' t
| otherwise | otherwise ->
-> TAll tvar $ substitute' t TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ) _ -> error ("Impossible " ++ show typ)
where where
substitute' = substituteTVar new_names substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names) EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names) EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
@ -111,14 +113,12 @@ renameExp old_names = \case
(new_names, name') <- newNameL old_names name (new_names, name') <- newNameL old_names name
(new_names', vars') <- newNamesL new_names vars (new_names', vars') <- newNamesL new_names vars
(new_names'', rhs') <- renameExp new_names' rhs (new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e (new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e') pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
EAbs par e -> do
(new_names, par') <- newNameL old_names par (new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
t' <- renameTVars t t' <- renameTVars t
@ -145,8 +145,7 @@ renamePattern ns p = case p of
(ns_new, ps') <- mapAccumM renamePattern ns ps (ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps') return (ns_new, PInj cs ps')
PVar name -> second PVar <$> newNameL ns name PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p) _ -> return (ns, p)
renameTVars :: Type -> Rn Type renameTVars :: Type -> Rn Type
renameTVars typ = case typ of renameTVars typ = case typ of
@ -157,24 +156,25 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2) TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ _ -> pure typ
substitute :: TVar -- α substitute ::
-> TVar -- α_n TVar -> -- α
-> Type -- A TVar -> -- α_n
-> Type -- [α_n/α]A Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ TLit _ -> typ
TVar tvar | tvar == tvar1 -> TVar tvar2 TVar tvar
| otherwise -> typ | tvar == tvar1 -> TVar tvar2
TFun t1 t2 -> on TFun substitute' t1 t2 | otherwise -> typ
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t TFun t1 t2 -> on TFun substitute' t1 t2
| otherwise -> TAll tvar $ substitute' t TAll tvar t
TData name typs -> TData name $ map substitute' typs | tvar == tvar1 -> TAll tvar2 $ substitute' t
_ -> error "Impossible" | otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where where
substitute' = substitute tvar1 tvar2 substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent]) newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL newNamesL = mapAccumM newNameL
@ -185,7 +185,6 @@ newNameL env (LIdent old_name) = do
new_name <- makeName old_name new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name) pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent]) newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU newNamesU = mapAccumM newNameU
@ -196,18 +195,17 @@ newNameU env (UIdent old_name) = do
new_name <- makeName old_name new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name) pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String makeName :: String -> Rn String
makeName prefix = do makeName prefix = do
i <- gets var_counter i <- gets var_counter
let name = prefix ++ "_" ++ show i let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt { var_counter = succ cxt.var_counter} modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name pure name
nextNameTVar :: TVar -> Rn TVar nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s))= do nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter} modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar pure tvar

View file

@ -27,8 +27,11 @@ import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Save all substition sets encountered in the program and apply
-- to all top level functions in the end.
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
@ -53,8 +56,8 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs <- checkDef bs bs <- checkDef bs
sub <- solveUndecidable sub0 <- solveUndecidable
bs <- mapM (mono sub) bs bs <- mapM (mono sub0) bs
return $ T.Program bs return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
@ -74,11 +77,19 @@ preRun (x : xs) = case x of
>>= flip >>= flip
when when
( uncatchableErr $ Aux.do ( uncatchableErr $ Aux.do
"Duplicate signatures for function" "Duplicate signatures of function"
quote $ printTree n quote $ printTree n
) )
insertSig (coerce n) (Just t) >> preRun xs insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
binds <- gets declaredBinds
when
(coerce n `S.member` binds)
( uncatchableErr $ Aux.do
"Duplicate declarations of function"
quote $ printTree n
)
modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds})
collect (collectTVars e) collect (collectTVars e)
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
@ -105,12 +116,12 @@ checkBind :: Bind -> Infer (T.Bind' Type)
checkBind bind@(Bind name args e) = do checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda (sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t') -> do Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t) return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do _ -> do
insertSig (coerce name) (Just lambda_t) insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -178,12 +189,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type) inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed) return (s, (e', subbed))
class CollectTVars a where class CollectTVars a where
collectTVars :: a -> Set T.Ident collectTVars :: a -> Set T.Ident
@ -851,6 +862,7 @@ data Env = Env
, currentBind :: T.Ident , currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type , undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident , toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)