Refactored HM to use TVar correctly, fixed unbound variable tests from

EAdd removal
This commit is contained in:
sebastian 2023-05-15 22:57:37 +02:00
parent 5000b05152
commit c96f3fc593
4 changed files with 165 additions and 148 deletions

View file

@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
-- sgs <- gets sigs
bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs
-- | Send the map of user declared signatures to not rename stuff the user defined
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
prettify s (T.Program defs) = T.Program $ map (go s) defs
where
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
go _ (T.DData d) = T.DData d
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
| Just (Just _) <- M.lookup name m = b
| otherwise =
let fvs = nub $ freeOrdered t
m = M.fromList $ zip fvs letters
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
replace _ t = t
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
@ -484,80 +459,74 @@ inferPattern = \case
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 =
let fvs = S.toList $ free t0 `S.union` free t1
m = M.fromList $ zip fvs letters
in case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else catchableErr $
Aux.do
"Type constructor:"
printTree name
quote $ printTree $ map (replace m) t
"does not match with:"
printTree name'
quote $ printTree $ map (replace m) t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree $ replace m a
"with"
quote $ printTree $ replace m b
unify t0 t1 = case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar a, t@(TData _ _)) -> return $ singleton a t
(t@(TData _ _), TVar b) -> return $ singleton b t
(TVar a, t) -> occurs a t
(t, TVar b) -> occurs b t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else catchableErr $
Aux.do
"Type constructor:"
printTree name
quote $ printTree t
"does not match with:"
printTree name'
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree a
"with"
quote $ printTree b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TEVar _) = return (M.singleton i t)
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
let fvs = S.toList $ free t
m = M.fromList $ zip fvs letters
in if S.member i (free t)
then
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree $ replace m (TVar $ MkTVar (coerce i))
"with"
quote $ printTree $ replace m t
)
else return $ M.singleton i t
occurs :: TVar -> Type -> Infer Subst
occurs i t@(TEVar _) = return (singleton i t)
occurs i t@(TVar _) = return (singleton i t)
occurs i t
| S.member i (free t) =
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree (TVar i)
"with"
quote $ printTree t
)
| otherwise = return $ singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
@ -568,9 +537,9 @@ occurs i t =
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go :: [TVar] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
go (x : xs) t = TAll x (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
@ -581,9 +550,9 @@ with fresh ones.
-}
inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
TAll bound t -> do
fr <- fresh
let s = M.singleton (coerce bound) fr
let s = singleton bound fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
@ -598,7 +567,7 @@ freshen :: Type -> Infer Type
freshen t = do
let frees = S.toList (free t)
xs <- mapM (const fresh) frees
let sub = M.fromList $ zip frees xs
let sub = Subst . M.fromList $ zip frees xs
return $ apply sub t
-- | Generate a new fresh variable
@ -633,10 +602,10 @@ fresh = do
go tvars t1 t2 = do
-- probably not necessary
freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1
let t2' = apply sub t2
let alph = execState (alpha t1' t2') mempty
let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2'
-- Pre-condition: All TAlls are outermost
@ -646,11 +615,11 @@ fresh = do
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) ()
alpha :: Type -> Type -> State (Map TVar Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do
(TVar i, t2) -> do
m <- get
put (M.insert (coerce i) t2 m)
put (M.insert i t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
@ -664,16 +633,16 @@ class SubstType t where
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
free :: t -> Set TVar
instance FreeVars (T.Bind' Type) where
free (T.Bind (_, t) _ _) = free t
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (MkTVar a)) = S.singleton (coerce a)
free (TAll (MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free :: Type -> Set TVar
free (TVar a) = S.singleton a
free (TAll bound t) =
S.singleton bound `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
@ -687,22 +656,26 @@ instance SubstType Type where
apply sub t = do
case t of
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
TVar a -> case find a sub of
Nothing -> TVar a
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
TAll i t -> case find i sub of
Nothing -> TAll i (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
TEVar (MkTEVar a) -> case find (MkTVar a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free :: Map T.Ident Type -> Set TVar
free = free . M.elems
instance SubstType (Map TVar Type) where
apply :: Subst -> Map TVar Type -> Map TVar Type
apply = M.map . apply
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
@ -759,7 +732,7 @@ nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1'
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
@ -910,7 +883,21 @@ data Env = Env
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst {unSubst :: Map TVar Type}
deriving (Eq, Ord, Show)
singleton :: TVar -> Type -> Subst
singleton a b = Subst (M.singleton a b)
find :: TVar -> Subst -> Maybe Type
find tvar (Subst s) = M.lookup tvar s
instance Semigroup Subst where
(<>) = compose
instance Monoid Subst where
mempty = Subst mempty
newtype Warning = NonExhaustive String
deriving (Show)