Refactored HM to use TVar correctly, fixed unbound variable tests from

EAdd removal
This commit is contained in:
sebastian 2023-05-15 22:57:37 +02:00
parent 5000b05152
commit c96f3fc593
4 changed files with 165 additions and 148 deletions

View file

@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
-- sgs <- gets sigs
bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs
-- | Send the map of user declared signatures to not rename stuff the user defined
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
prettify s (T.Program defs) = T.Program $ map (go s) defs
where
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
go _ (T.DData d) = T.DData d
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
| Just (Just _) <- M.lookup name m = b
| otherwise =
let fvs = nub $ freeOrdered t
m = M.fromList $ zip fvs letters
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
replace _ t = t
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
@ -484,24 +459,21 @@ inferPattern = \case
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 =
let fvs = S.toList $ free t0 `S.union` free t1
m = M.fromList $ zip fvs letters
in case (t0, t1) of
unify t0 t1 = case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
(TVar a, t@(TData _ _)) -> return $ singleton a t
(t@(TData _ _), TVar b) -> return $ singleton b t
(TVar a, t) -> occurs a t
(t, TVar b) -> occurs b t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
@ -517,13 +489,13 @@ unify t0 t1 =
Aux.do
"Type constructor:"
printTree name
quote $ printTree $ map (replace m) t
quote $ printTree t
"does not match with:"
printTree name'
quote $ printTree $ map (replace m) t'
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
@ -534,30 +506,27 @@ unify t0 t1 =
catchableErr $
Aux.do
"Can not unify"
quote $ printTree $ replace m a
quote $ printTree a
"with"
quote $ printTree $ replace m b
quote $ printTree b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TEVar _) = return (M.singleton i t)
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
let fvs = S.toList $ free t
m = M.fromList $ zip fvs letters
in if S.member i (free t)
then
occurs :: TVar -> Type -> Infer Subst
occurs i t@(TEVar _) = return (singleton i t)
occurs i t@(TVar _) = return (singleton i t)
occurs i t
| S.member i (free t) =
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree $ replace m (TVar $ MkTVar (coerce i))
quote $ printTree (TVar i)
"with"
quote $ printTree $ replace m t
quote $ printTree t
)
else return $ M.singleton i t
| otherwise = return $ singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
@ -568,9 +537,9 @@ occurs i t =
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go :: [TVar] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
go (x : xs) t = TAll x (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
@ -581,9 +550,9 @@ with fresh ones.
-}
inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
TAll bound t -> do
fr <- fresh
let s = M.singleton (coerce bound) fr
let s = singleton bound fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
@ -598,7 +567,7 @@ freshen :: Type -> Infer Type
freshen t = do
let frees = S.toList (free t)
xs <- mapM (const fresh) frees
let sub = M.fromList $ zip frees xs
let sub = Subst . M.fromList $ zip frees xs
return $ apply sub t
-- | Generate a new fresh variable
@ -633,10 +602,10 @@ fresh = do
go tvars t1 t2 = do
-- probably not necessary
freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1
let t2' = apply sub t2
let alph = execState (alpha t1' t2') mempty
let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2'
-- Pre-condition: All TAlls are outermost
@ -646,11 +615,11 @@ fresh = do
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) ()
alpha :: Type -> Type -> State (Map TVar Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do
(TVar i, t2) -> do
m <- get
put (M.insert (coerce i) t2 m)
put (M.insert i t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
@ -664,16 +633,16 @@ class SubstType t where
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
free :: t -> Set TVar
instance FreeVars (T.Bind' Type) where
free (T.Bind (_, t) _ _) = free t
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (MkTVar a)) = S.singleton (coerce a)
free (TAll (MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free :: Type -> Set TVar
free (TVar a) = S.singleton a
free (TAll bound t) =
S.singleton bound `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
@ -687,22 +656,26 @@ instance SubstType Type where
apply sub t = do
case t of
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
TVar a -> case find a sub of
Nothing -> TVar a
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
TAll i t -> case find i sub of
Nothing -> TAll i (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
TEVar (MkTEVar a) -> case find (MkTVar a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free :: Map T.Ident Type -> Set TVar
free = free . M.elems
instance SubstType (Map TVar Type) where
apply :: Subst -> Map TVar Type -> Map TVar Type
apply = M.map . apply
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
@ -759,7 +732,7 @@ nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1'
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
@ -910,7 +883,21 @@ data Env = Env
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst {unSubst :: Map TVar Type}
deriving (Eq, Ord, Show)
singleton :: TVar -> Type -> Subst
singleton a b = Subst (M.singleton a b)
find :: TVar -> Subst -> Maybe Type
find tvar (Subst s) = M.lookup tvar s
instance Semigroup Subst where
(<>) = compose
instance Monoid Subst where
mempty = Subst mempty
newtype Warning = NonExhaustive String
deriving (Show)

View file

@ -76,14 +76,18 @@ rn_sig =
rn_bind1 =
specify "Rename simple bind" $
shouldSatisfyOk
"f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
( unlines
[ ".+ x y = x"
, "f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
]
)
rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $
D.do
"data forall a. List a where"
" Nil : List a "
" Cons : a -> List a -> List a"
".+ x y = x"
"length : forall a. List a -> Int"
"length list = case list of"
" Nil => 0"

View file

@ -22,7 +22,7 @@ import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr qualified as T
test = hspec testTypeCheckerBidir
@ -54,13 +54,19 @@ tc_id =
tc_double =
specify "Addition inference" $
run
["double x = x + x"]
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "double x = x + x"
]
`shouldSatisfy` ok
tc_add_lam =
specify "Addition lambda inference" $
run
["four = (\\x. x + x) 2"]
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "four = (\\x. x + x) 2"
]
`shouldSatisfy` ok
tc_const =
@ -88,6 +94,8 @@ tc_rank2 =
run
[ "const : a -> b -> a"
, "const x y = x"
, ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "rank2 : a -> (forall c. c -> Int) -> b -> Int"
, "rank2 x f y = f x + f y"
, "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
@ -195,18 +203,21 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
-- run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok
where
-- specify "Third correct case expression accepted" $
-- run (fs ++ correct3) `shouldSatisfy` ok
-- specify "Forth correct case expression accepted" $
-- run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data List a where"
, " Nil : List a"
, " Cons : a -> List a -> List a"
]
wrong1 =
[ "length : List c -> Int"
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "length : List c -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons 6 xs => 1 + length xs"
@ -254,10 +265,10 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
correct4 =
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
--, " Nil => 0"
, -- , " Nil => 0"
-- , " Cons Nil Nil => 0"
-- , " Cons Nil xs => elems xs"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
" Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
]
tc_if = specify "Test if else case expression" $ do
@ -298,12 +309,19 @@ tc_infer_case = describe "Infer case expression" $ do
tc_rec1 =
specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok
run
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "test x = 1 + test (x + 1)"
]
`shouldSatisfy` ok
tc_rec2 =
specify "Infer recursive definition with pattern matching" $
run
[ "data Bool where"
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "data Bool where"
, " False : Bool"
, " True : Bool"
, "test = \\x. case x of"

View file

@ -57,6 +57,8 @@ goods =
, testSatisfy
"A basic arithmetic function should be able to be inferred"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
@ -74,6 +76,8 @@ goods =
, testSatisfy
"length function on int list infers correct signature"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
@ -114,6 +118,8 @@ bads =
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
@ -131,6 +137,8 @@ bads =
"Pattern matching on literal and _List should not succeed"
( D.do
_List
".+ : Int -> Int -> Int"
".+ x y = x"
"length : List c -> Int;"
"length _List = case _List of {"
" 0 => 0;"