Refactored HM to use TVar correctly, fixed unbound variable tests from

EAdd removal
This commit is contained in:
sebastian 2023-05-15 22:57:37 +02:00
parent 5000b05152
commit c96f3fc593
4 changed files with 165 additions and 148 deletions

View file

@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
-- sgs <- gets sigs
bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs
-- | Send the map of user declared signatures to not rename stuff the user defined
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
prettify s (T.Program defs) = T.Program $ map (go s) defs
where
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
go _ (T.DData d) = T.DData d
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
| Just (Just _) <- M.lookup name m = b
| otherwise =
let fvs = nub $ freeOrdered t
m = M.fromList $ zip fvs letters
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
replace _ t = t
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
@ -484,80 +459,74 @@ inferPattern = \case
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 =
let fvs = S.toList $ free t0 `S.union` free t1
m = M.fromList $ zip fvs letters
in case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (MkTVar b)) -> occurs (coerce b) t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else catchableErr $
Aux.do
"Type constructor:"
printTree name
quote $ printTree $ map (replace m) t
"does not match with:"
printTree name'
quote $ printTree $ map (replace m) t'
(TEVar a, TEVar b) ->
if a == b
then return M.empty
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree $ replace m a
"with"
quote $ printTree $ replace m b
unify t0 t1 = case (t0, t1) of
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s2 `compose` s1
(TVar a, t@(TData _ _)) -> return $ singleton a t
(t@(TData _ _), TVar b) -> return $ singleton b t
(TVar a, t) -> occurs a t
(t, TVar b) -> occurs b t
-- Forall unification should change
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TLit a)
"with"
quote $ printTree (TLit b)
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else catchableErr $
Aux.do
"Type constructor:"
printTree name
quote $ printTree t
"does not match with:"
printTree name'
quote $ printTree t'
(TEVar a, TEVar b) ->
if a == b
then return nullSubst
else catchableErr $
Aux.do
"Can not unify"
quote $ printTree (TEVar a)
"with"
quote $ printTree (TEVar b)
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree a
"with"
quote $ printTree b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TEVar _) = return (M.singleton i t)
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
let fvs = S.toList $ free t
m = M.fromList $ zip fvs letters
in if S.member i (free t)
then
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree $ replace m (TVar $ MkTVar (coerce i))
"with"
quote $ printTree $ replace m t
)
else return $ M.singleton i t
occurs :: TVar -> Type -> Infer Subst
occurs i t@(TEVar _) = return (singleton i t)
occurs i t@(TVar _) = return (singleton i t)
occurs i t
| S.member i (free t) =
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree (TVar i)
"with"
quote $ printTree t
)
| otherwise = return $ singleton i t
{- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in
@ -568,9 +537,9 @@ occurs i t =
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go :: [TVar] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
go (x : xs) t = TAll x (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
@ -581,9 +550,9 @@ with fresh ones.
-}
inst :: Type -> Infer Type
inst = \case
TAll (MkTVar bound) t -> do
TAll bound t -> do
fr <- fresh
let s = M.singleton (coerce bound) fr
let s = singleton bound fr
apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
@ -598,7 +567,7 @@ freshen :: Type -> Infer Type
freshen t = do
let frees = S.toList (free t)
xs <- mapM (const fresh) frees
let sub = M.fromList $ zip frees xs
let sub = Subst . M.fromList $ zip frees xs
return $ apply sub t
-- | Generate a new fresh variable
@ -633,10 +602,10 @@ fresh = do
go tvars t1 t2 = do
-- probably not necessary
freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1
let t2' = apply sub t2
let alph = execState (alpha t1' t2') mempty
let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2'
-- Pre-condition: All TAlls are outermost
@ -646,11 +615,11 @@ fresh = do
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) ()
alpha :: Type -> Type -> State (Map TVar Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do
(TVar i, t2) -> do
m <- get
put (M.insert (coerce i) t2 m)
put (M.insert i t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
@ -664,16 +633,16 @@ class SubstType t where
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
free :: t -> Set TVar
instance FreeVars (T.Bind' Type) where
free (T.Bind (_, t) _ _) = free t
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (MkTVar a)) = S.singleton (coerce a)
free (TAll (MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t
free :: Type -> Set TVar
free (TVar a) = S.singleton a
free (TAll bound t) =
S.singleton bound `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a
@ -687,22 +656,26 @@ instance SubstType Type where
apply sub t = do
case t of
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
TVar a -> case find a sub of
Nothing -> TVar a
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
TAll i t -> case find i sub of
Nothing -> TAll i (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
TEVar (MkTEVar a) -> case find (MkTVar a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free :: Map T.Ident Type -> Set TVar
free = free . M.elems
instance SubstType (Map TVar Type) where
apply :: Subst -> Map TVar Type -> Map TVar Type
apply = M.map . apply
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
@ -759,7 +732,7 @@ nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1'
-- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst
@ -910,7 +883,21 @@ data Env = Env
data Error = Error {msg :: String, catchable :: Bool}
deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst {unSubst :: Map TVar Type}
deriving (Eq, Ord, Show)
singleton :: TVar -> Type -> Subst
singleton a b = Subst (M.singleton a b)
find :: TVar -> Subst -> Maybe Type
find tvar (Subst s) = M.lookup tvar s
instance Semigroup Subst where
(<>) = compose
instance Monoid Subst where
mempty = Subst mempty
newtype Warning = NonExhaustive String
deriving (Show)

View file

@ -76,14 +76,18 @@ rn_sig =
rn_bind1 =
specify "Rename simple bind" $
shouldSatisfyOk
"f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
( unlines
[ ".+ x y = x"
, "f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
]
)
rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $
D.do
"data forall a. List a where"
" Nil : List a "
" Cons : a -> List a -> List a"
".+ x y = x"
"length : forall a. List a -> Int"
"length list = case list of"
" Nil => 0"

View file

@ -1,28 +1,28 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Desugar.Desugar (desugar)
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Desugar.Desugar (desugar)
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr qualified as T
test = hspec testTypeCheckerBidir
@ -54,13 +54,19 @@ tc_id =
tc_double =
specify "Addition inference" $
run
["double x = x + x"]
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "double x = x + x"
]
`shouldSatisfy` ok
tc_add_lam =
specify "Addition lambda inference" $
run
["four = (\\x. x + x) 2"]
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "four = (\\x. x + x) 2"
]
`shouldSatisfy` ok
tc_const =
@ -88,6 +94,8 @@ tc_rank2 =
run
[ "const : a -> b -> a"
, "const x y = x"
, ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "rank2 : a -> (forall c. c -> Int) -> b -> Int"
, "rank2 x f y = f x + f y"
, "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
@ -195,18 +203,21 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
-- run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok
where
-- specify "Third correct case expression accepted" $
-- run (fs ++ correct3) `shouldSatisfy` ok
-- specify "Forth correct case expression accepted" $
-- run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data List a where"
, " Nil : List a"
, " Cons : a -> List a -> List a"
]
wrong1 =
[ "length : List c -> Int"
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "length : List c -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons 6 xs => 1 + length xs"
@ -254,10 +265,10 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
correct4 =
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
--, " Nil => 0"
--, " Cons Nil Nil => 0"
--, " Cons Nil xs => elems xs"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
, -- , " Nil => 0"
-- , " Cons Nil Nil => 0"
-- , " Cons Nil xs => elems xs"
" Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
]
tc_if = specify "Test if else case expression" $ do
@ -298,12 +309,19 @@ tc_infer_case = describe "Infer case expression" $ do
tc_rec1 =
specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok
run
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "test x = 1 + test (x + 1)"
]
`shouldSatisfy` ok
tc_rec2 =
specify "Infer recursive definition with pattern matching" $
run
[ "data Bool where"
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "data Bool where"
, " False : Bool"
, " True : Bool"
, "test = \\x. case x of"
@ -329,5 +347,5 @@ runPrint =
["double x = x + x"]
ok = \case
Ok _ -> True
Ok _ -> True
Bad _ -> False

View file

@ -57,6 +57,8 @@ goods =
, testSatisfy
"A basic arithmetic function should be able to be inferred"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
@ -74,6 +76,8 @@ goods =
, testSatisfy
"length function on int list infers correct signature"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
@ -114,6 +118,8 @@ bads =
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
@ -131,6 +137,8 @@ bads =
"Pattern matching on literal and _List should not succeed"
( D.do
_List
".+ : Int -> Int -> Int"
".+ x y = x"
"length : List c -> Int;"
"length _List = case _List of {"
" 0 => 0;"