Refactored HM to use TVar correctly, fixed unbound variable tests from

EAdd removal
This commit is contained in:
sebastian 2023-05-15 22:57:37 +02:00
parent 5000b05152
commit c96f3fc593
4 changed files with 165 additions and 148 deletions

View file

@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
-- sgs <- gets sigs
bs <- checkDef bs bs <- checkDef bs
-- return . prettify sgs . T.Program $ bs
return . T.Program $ bs return . T.Program $ bs
-- | Send the map of user declared signatures to not rename stuff the user defined -- | Send the map of user declared signatures to not rename stuff the user defined
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
prettify s (T.Program defs) = T.Program $ map (go s) defs
where
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
go _ (T.DData d) = T.DData d
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
| Just (Just _) <- M.lookup name m = b
| otherwise =
let fvs = nub $ freeOrdered t
m = M.fromList $ zip fvs letters
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
replace _ t = t
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
@ -484,80 +459,74 @@ inferPattern = \case
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = unify t0 t1 = case (t0, t1) of
let fvs = S.toList $ free t0 `S.union` free t1 (TFun a b, TFun c d) -> do
m = M.fromList $ zip fvs letters s1 <- unify a c
in case (t0, t1) of s2 <- unify (apply s1 b) (apply s1 d)
(TFun a b, TFun c d) -> do return $ s2 `compose` s1
s1 <- unify a c (TVar a, t@(TData _ _)) -> return $ singleton a t
s2 <- unify (apply s1 b) (apply s1 d) (t@(TData _ _), TVar b) -> return $ singleton b t
return $ s2 `compose` s1 (TVar a, t) -> occurs a t
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (t, TVar b) -> occurs b t
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t -- Forall unification should change
(TVar (MkTVar a), t) -> occurs (coerce a) t (TAll _ t, b) -> unify t b
(t, TVar (MkTVar b)) -> occurs (coerce b) t (a, TAll _ t) -> unify a t
-- Forall unification should change (TLit a, TLit b) ->
(TAll _ t, b) -> unify t b if a == b
(a, TAll _ t) -> unify a t then return nullSubst
(TLit a, TLit b) -> else catchableErr $
if a == b Aux.do
then return M.empty "Can not unify"
else catchableErr $ quote $ printTree (TLit a)
Aux.do "with"
"Can not unify" quote $ printTree (TLit b)
quote $ printTree (TLit a) (TData name t, TData name' t') ->
"with" if name == name' && length t == length t'
quote $ printTree (TLit b) then do
(TData name t, TData name' t') -> xs <- zipWithM unify t t'
if name == name' && length t == length t' return $ foldr compose nullSubst xs
then do else catchableErr $
xs <- zipWithM unify t t' Aux.do
return $ foldr compose nullSubst xs "Type constructor:"
else catchableErr $ printTree name
Aux.do quote $ printTree t
"Type constructor:" "does not match with:"
printTree name printTree name'
quote $ printTree $ map (replace m) t quote $ printTree t'
"does not match with:" (TEVar a, TEVar b) ->
printTree name' if a == b
quote $ printTree $ map (replace m) t' then return nullSubst
(TEVar a, TEVar b) -> else catchableErr $
if a == b Aux.do
then return M.empty "Can not unify"
else catchableErr $ quote $ printTree (TEVar a)
Aux.do "with"
"Can not unify" quote $ printTree (TEVar b)
quote $ printTree (TEVar a) (a, b) -> do
"with" catchableErr $
quote $ printTree (TEVar b) Aux.do
(a, b) -> do "Can not unify"
catchableErr $ quote $ printTree a
Aux.do "with"
"Can not unify" quote $ printTree b
quote $ printTree $ replace m a
"with"
quote $ printTree $ replace m b
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal where these are equal
-} -}
occurs :: T.Ident -> Type -> Infer Subst occurs :: TVar -> Type -> Infer Subst
occurs i t@(TEVar _) = return (M.singleton i t) occurs i t@(TEVar _) = return (singleton i t)
occurs i t@(TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (singleton i t)
occurs i t = occurs i t
let fvs = S.toList $ free t | S.member i (free t) =
m = M.fromList $ zip fvs letters catchableErr
in if S.member i (free t) ( Aux.do
then "Occurs check failed, can't unify"
catchableErr quote $ printTree (TVar i)
( Aux.do "with"
"Occurs check failed, can't unify" quote $ printTree t
quote $ printTree $ replace m (TVar $ MkTVar (coerce i)) )
"with" | otherwise = return $ singleton i t
quote $ printTree $ replace m t
)
else return $ M.singleton i t
{- | Generalize a type over all free variables in the substitution set {- | Generalize a type over all free variables in the substitution set
Used for let bindings to allow expression that do not type check in Used for let bindings to allow expression that do not type check in
@ -568,9 +537,9 @@ occurs i t =
generalize :: Map T.Ident Type -> Type -> Type generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> Type -> Type go :: [TVar] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) go (x : xs) t = TAll x (go xs t)
removeForalls :: Type -> Type removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
@ -581,9 +550,9 @@ with fresh ones.
-} -}
inst :: Type -> Infer Type inst :: Type -> Infer Type
inst = \case inst = \case
TAll (MkTVar bound) t -> do TAll bound t -> do
fr <- fresh fr <- fresh
let s = M.singleton (coerce bound) fr let s = singleton bound fr
apply s <$> inst t apply s <$> inst t
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
@ -598,7 +567,7 @@ freshen :: Type -> Infer Type
freshen t = do freshen t = do
let frees = S.toList (free t) let frees = S.toList (free t)
xs <- mapM (const fresh) frees xs <- mapM (const fresh) frees
let sub = M.fromList $ zip frees xs let sub = Subst . M.fromList $ zip frees xs
return $ apply sub t return $ apply sub t
-- | Generate a new fresh variable -- | Generate a new fresh variable
@ -633,10 +602,10 @@ fresh = do
go tvars t1 t2 = do go tvars t1 t2 = do
-- probably not necessary -- probably not necessary
freshies <- mapM (const fresh) tvars freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1 let t1' = apply sub t1
let t2' = apply sub t2 let t2' = apply sub t2
let alph = execState (alpha t1' t2') mempty let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2' return $ apply alph t1' == t2'
-- Pre-condition: All TAlls are outermost -- Pre-condition: All TAlls are outermost
@ -646,11 +615,11 @@ fresh = do
-- Alpha rename the first type's type variable to match second. -- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked -- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) () alpha :: Type -> Type -> State (Map TVar Type) ()
alpha t1 t2 = case (t1, t2) of alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do (TVar i, t2) -> do
m <- get m <- get
put (M.insert (coerce i) t2 m) put (M.insert i t2 m)
(TFun t1 t2, TFun t3 t4) -> do (TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3 alpha t1 t3
alpha t2 t4 alpha t2 t4
@ -664,16 +633,16 @@ class SubstType t where
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
free :: t -> Set T.Ident free :: t -> Set TVar
instance FreeVars (T.Bind' Type) where instance FreeVars (T.Bind' Type) where
free (T.Bind (_, t) _ _) = free t free (T.Bind (_, t) _ _) = free t
instance FreeVars Type where instance FreeVars Type where
free :: Type -> Set T.Ident free :: Type -> Set TVar
free (TVar (MkTVar a)) = S.singleton (coerce a) free (TVar a) = S.singleton a
free (TAll (MkTVar bound) t) = free (TAll bound t) =
S.singleton (coerce bound) `S.intersection` free t S.singleton bound `S.intersection` free t
free (TLit _) = mempty free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
free (TData _ a) = free a free (TData _ a) = free a
@ -687,22 +656,26 @@ instance SubstType Type where
apply sub t = do apply sub t = do
case t of case t of
TLit _ -> t TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of TVar a -> case find a sub of
Nothing -> TVar (MkTVar $ coerce a) Nothing -> TVar a
Just t -> t Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of TAll i t -> case find i sub of
Nothing -> TAll (MkTVar i) (apply sub t) Nothing -> TAll i (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a) TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of TEVar (MkTEVar a) -> case find (MkTVar a) sub of
Nothing -> TEVar (MkTEVar $ coerce a) Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t Just t -> t
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set TVar
free = free . M.elems free = free . M.elems
instance SubstType (Map TVar Type) where
apply :: Subst -> Map TVar Type -> Map TVar Type
apply = M.map . apply
instance SubstType (Map T.Ident Type) where instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply apply = M.map . apply
@ -759,7 +732,7 @@ nullSubst = mempty
-- | Compose two substitution sets -- | Compose two substitution sets
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1 compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1'
-- | Compose a list of substitution sets into one -- | Compose a list of substitution sets into one
composeAll :: [Subst] -> Subst composeAll :: [Subst] -> Subst
@ -910,7 +883,21 @@ data Env = Env
data Error = Error {msg :: String, catchable :: Bool} data Error = Error {msg :: String, catchable :: Bool}
deriving (Show) deriving (Show)
type Subst = Map T.Ident Type
newtype Subst = Subst {unSubst :: Map TVar Type}
deriving (Eq, Ord, Show)
singleton :: TVar -> Type -> Subst
singleton a b = Subst (M.singleton a b)
find :: TVar -> Subst -> Maybe Type
find tvar (Subst s) = M.lookup tvar s
instance Semigroup Subst where
(<>) = compose
instance Monoid Subst where
mempty = Subst mempty
newtype Warning = NonExhaustive String newtype Warning = NonExhaustive String
deriving (Show) deriving (Show)

View file

@ -76,14 +76,18 @@ rn_sig =
rn_bind1 = rn_bind1 =
specify "Rename simple bind" $ specify "Rename simple bind" $
shouldSatisfyOk shouldSatisfyOk
"f x = (\\y. let y2 = y + 1 in y2) (x + 1)" ( unlines
[ ".+ x y = x"
, "f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
]
)
rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $ rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $
D.do D.do
"data forall a. List a where" "data forall a. List a where"
" Nil : List a " " Nil : List a "
" Cons : a -> List a -> List a" " Cons : a -> List a -> List a"
".+ x y = x"
"length : forall a. List a -> Int" "length : forall a. List a -> Int"
"length list = case list of" "length list = case list of"
" Nil => 0" " Nil => 0"

View file

@ -1,28 +1,28 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-} {-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestTypeCheckerBidir (test, testTypeCheckerBidir) where module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec import Test.Hspec
import AnnForall (annotateForall) import AnnForall (annotateForall)
import Control.Monad ((<=<)) import Control.Monad ((<=<))
import Desugar.Desugar (desugar) import Desugar.Desugar (desugar)
import Grammar.Abs (Program) import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok) import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout) import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import ReportForall (reportForall) import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall) import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar) import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi)) import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck) import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr qualified as T
test = hspec testTypeCheckerBidir test = hspec testTypeCheckerBidir
@ -54,13 +54,19 @@ tc_id =
tc_double = tc_double =
specify "Addition inference" $ specify "Addition inference" $
run run
["double x = x + x"] [ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "double x = x + x"
]
`shouldSatisfy` ok `shouldSatisfy` ok
tc_add_lam = tc_add_lam =
specify "Addition lambda inference" $ specify "Addition lambda inference" $
run run
["four = (\\x. x + x) 2"] [ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "four = (\\x. x + x) 2"
]
`shouldSatisfy` ok `shouldSatisfy` ok
tc_const = tc_const =
@ -88,6 +94,8 @@ tc_rank2 =
run run
[ "const : a -> b -> a" [ "const : a -> b -> a"
, "const x y = x" , "const x y = x"
, ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "rank2 : a -> (forall c. c -> Int) -> b -> Int" , "rank2 : a -> (forall c. c -> Int) -> b -> Int"
, "rank2 x f y = f x + f y" , "rank2 x f y = f x + f y"
, "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'" , "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
@ -195,18 +203,21 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
-- run (fs ++ correct1) `shouldSatisfy` ok -- run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $ specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok run (fs ++ correct2) `shouldSatisfy` ok
where
-- specify "Third correct case expression accepted" $ -- specify "Third correct case expression accepted" $
-- run (fs ++ correct3) `shouldSatisfy` ok -- run (fs ++ correct3) `shouldSatisfy` ok
-- specify "Forth correct case expression accepted" $ -- specify "Forth correct case expression accepted" $
-- run (fs ++ correct4) `shouldSatisfy` ok -- run (fs ++ correct4) `shouldSatisfy` ok
where
fs = fs =
[ "data List a where" [ "data List a where"
, " Nil : List a" , " Nil : List a"
, " Cons : a -> List a -> List a" , " Cons : a -> List a -> List a"
] ]
wrong1 = wrong1 =
[ "length : List c -> Int" [ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "length : List c -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " Cons 6 xs => 1 + length xs" , " Cons 6 xs => 1 + length xs"
@ -254,10 +265,10 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
correct4 = correct4 =
[ "elems : List (List c) -> Int" [ "elems : List (List c) -> Int"
, "elems = \\list. case list of" , "elems = \\list. case list of"
--, " Nil => 0" , -- , " Nil => 0"
--, " Cons Nil Nil => 0" -- , " Cons Nil Nil => 0"
--, " Cons Nil xs => elems xs" -- , " Cons Nil xs => elems xs"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
] ]
tc_if = specify "Test if else case expression" $ do tc_if = specify "Test if else case expression" $ do
@ -298,12 +309,19 @@ tc_infer_case = describe "Infer case expression" $ do
tc_rec1 = tc_rec1 =
specify "Infer simple recursive definition" $ specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok run
[ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "test x = 1 + test (x + 1)"
]
`shouldSatisfy` ok
tc_rec2 = tc_rec2 =
specify "Infer recursive definition with pattern matching" $ specify "Infer recursive definition with pattern matching" $
run run
[ "data Bool where" [ ".+ : Int -> Int -> Int"
, ".+ x y = x"
, "data Bool where"
, " False : Bool" , " False : Bool"
, " True : Bool" , " True : Bool"
, "test = \\x. case x of" , "test = \\x. case x of"
@ -329,5 +347,5 @@ runPrint =
["double x = x + x"] ["double x = x + x"]
ok = \case ok = \case
Ok _ -> True Ok _ -> True
Bad _ -> False Bad _ -> False

View file

@ -57,6 +57,8 @@ goods =
, testSatisfy , testSatisfy
"A basic arithmetic function should be able to be inferred" "A basic arithmetic function should be able to be inferred"
( D.do ( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
"main x = plusOne x ;" "main x = plusOne x ;"
) )
@ -74,6 +76,8 @@ goods =
, testSatisfy , testSatisfy
"length function on int list infers correct signature" "length function on int list infers correct signature"
( D.do ( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"data List where " "data List where "
" Nil : List" " Nil : List"
" Cons : Int -> List -> List" " Cons : Int -> List -> List"
@ -114,6 +118,8 @@ bads =
, testSatisfy , testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed" "Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do ( D.do
".+ : Int -> Int -> Int"
".+ x y = x"
"plusOne : Int -> Int ;" "plusOne : Int -> Int ;"
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
"f : a -> Int ;" "f : a -> Int ;"
@ -131,6 +137,8 @@ bads =
"Pattern matching on literal and _List should not succeed" "Pattern matching on literal and _List should not succeed"
( D.do ( D.do
_List _List
".+ : Int -> Int -> Int"
".+ x y = x"
"length : List c -> Int;" "length : List c -> Int;"
"length _List = case _List of {" "length _List = case _List of {"
" 0 => 0;" " 0 => 0;"