Tried solving bug, failed, added todo message, fixed printing

This commit is contained in:
sebastianselander 2023-03-29 18:47:14 +02:00
parent 61f364cd75
commit 343be08a4a
2 changed files with 98 additions and 88 deletions

View file

@ -5,18 +5,27 @@ module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT)
import Control.Monad.State (MonadState, State, evalState, gets,
mapAndUnzipM, modify)
import Control.Monad.Except (
ExceptT,
MonadError (throwError),
runExceptT,
)
import Control.Monad.State (
MonadState,
State,
evalState,
gets,
mapAndUnzipM,
modify,
)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds
rename :: Program -> Err Program
@ -25,11 +34,11 @@ rename (Program defs) = Program <$> renameDefs defs
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt { var_counter :: Int
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
@ -60,47 +69,40 @@ renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ show typ)
_ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar | Just tvar' <- lookup tvar new_names
-> TVar tvar'
| otherwise
-> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | Just tvar' <- lookup tvar new_names
-> TAll tvar' $ substitute' t
| otherwise
-> TAll tvar $ substitute' t
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
@ -113,12 +115,10 @@ renameExp old_names = \case
(new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
(new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
@ -147,7 +147,6 @@ renamePattern ns p = case p of
PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
@ -157,24 +156,25 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute :: TVar -- α
-> TVar -- α_n
-> Type -- A
-> Type -- [α_n/α]A
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar | tvar == tvar1 -> TVar tvar2
TVar tvar
| tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t
TAll tvar t
| tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL
@ -185,7 +185,6 @@ newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
@ -196,7 +195,6 @@ newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String
makeName prefix = do

View file

@ -27,8 +27,11 @@ import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Save all substition sets encountered in the program and apply
-- to all top level functions in the end.
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
run :: Infer a -> Either Error a
run = run' initEnv initCtx
@ -53,8 +56,8 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs <- checkDef bs
sub <- solveUndecidable
bs <- mapM (mono sub) bs
sub0 <- solveUndecidable
bs <- mapM (mono sub0) bs
return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
@ -74,11 +77,19 @@ preRun (x : xs) = case x of
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
"Duplicate signatures of function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
binds <- gets declaredBinds
when
(coerce n `S.member` binds)
( uncatchableErr $ Aux.do
"Duplicate declarations of function"
quote $ printTree n
)
modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds})
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
@ -105,12 +116,12 @@ checkBind :: Bind -> Infer (T.Bind' Type)
checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t)
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -178,12 +189,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed)
return (s, (e', subbed))
class CollectTVars a where
collectTVars :: a -> Set T.Ident
@ -851,6 +862,7 @@ data Env = Env
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
}
deriving (Show)