Tried solving bug, failed, added todo message, fixed printing

This commit is contained in:
sebastianselander 2023-03-29 18:47:14 +02:00
parent 61f364cd75
commit 343be08a4a
2 changed files with 98 additions and 88 deletions

View file

@ -1,22 +1,31 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT)
import Control.Monad.State (MonadState, State, evalState, gets,
mapAndUnzipM, modify)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (
ExceptT,
MonadError (throwError),
runExceptT,
)
import Control.Monad.State (
MonadState,
State,
evalState,
gets,
mapAndUnzipM,
modify,
)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds
rename :: Program -> Err Program
@ -25,14 +34,14 @@ rename (Program defs) = Program <$> renameDefs defs
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt { var_counter :: Int
, tvar_counter :: Int
}
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map String String
@ -40,67 +49,60 @@ type Names = Map String String
renameDefs :: [Def] -> Err [Def]
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
where
initNames = Map.fromList [ dupe s | DBind (Bind (LIdent s) _ _) <- defs]
initNames = Map.fromList [dupe s | DBind (Bind (LIdent s) _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNamesL initNames vars
rhs' <- snd <$> renameExp new_names rhs
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name vars' rhs'
DData (Data typ injs) -> do
tvars <- collectTVars [] typ
tvars' <- mapM nextNameTVar tvars
let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ
typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs'
where
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar:tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ show typ)
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar | Just tvar' <- lookup tvar new_names
-> TVar tvar'
| otherwise
-> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | Just tvar' <- lookup tvar new_names
-> TAll tvar' $ substitute' t
| otherwise
-> TAll tvar $ substitute' t
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
@ -111,14 +113,12 @@ renameExp old_names = \case
(new_names, name') <- newNameL old_names name
(new_names', vars') <- newNamesL new_names vars
(new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
EAbs par e -> do
(new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
@ -145,8 +145,7 @@ renamePattern ns p = case p of
(ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps')
PVar name -> second PVar <$> newNameL ns name
_ -> return (ns, p)
_ -> return (ns, p)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
@ -157,24 +156,25 @@ renameTVars typ = case typ of
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute :: TVar -- α
-> TVar -- α_n
-> Type -- A
-> Type -- [α_n/α]A
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar | tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t | tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
TLit _ -> typ
TVar tvar
| tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL
@ -185,7 +185,6 @@ newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
@ -196,18 +195,17 @@ newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String
makeName prefix = do
i <- gets var_counter
let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt { var_counter = succ cxt.var_counter}
pure name
i <- gets var_counter
let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s))= do
i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt { tvar_counter = succ cxt.tvar_counter}
pure tvar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -27,8 +27,11 @@ import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
-- TODO: Save all substition sets encountered in the program and apply
-- to all top level functions in the end.
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty
run :: Infer a -> Either Error a
run = run' initEnv initCtx
@ -53,8 +56,8 @@ checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs <- checkDef bs
sub <- solveUndecidable
bs <- mapM (mono sub) bs
sub0 <- solveUndecidable
bs <- mapM (mono sub0) bs
return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
@ -74,11 +77,19 @@ preRun (x : xs) = case x of
>>= flip
when
( uncatchableErr $ Aux.do
"Duplicate signatures for function"
"Duplicate signatures of function"
quote $ printTree n
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
binds <- gets declaredBinds
when
(coerce n `S.member` binds)
( uncatchableErr $ Aux.do
"Duplicate declarations of function"
quote $ printTree n
)
modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds})
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
@ -105,12 +116,12 @@ checkBind :: Bind -> Infer (T.Bind' Type)
checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda
(sub0, (e, lambda_t)) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t)
return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -178,12 +189,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed)
return (s, (e', subbed))
class CollectTVars a where
collectTVars :: a -> Set T.Ident
@ -851,6 +862,7 @@ data Env = Env
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
, declaredBinds :: Set T.Ident
}
deriving (Show)