Refactored HM to use TVar correctly, fixed unbound variable tests from
EAdd removal
This commit is contained in:
parent
5000b05152
commit
c96f3fc593
4 changed files with 165 additions and 148 deletions
|
|
@ -47,35 +47,10 @@ typecheck = onLeft msg . run . checkPrg
|
|||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
-- sgs <- gets sigs
|
||||
bs <- checkDef bs
|
||||
-- return . prettify sgs . T.Program $ bs
|
||||
return . T.Program $ bs
|
||||
|
||||
-- | Send the map of user declared signatures to not rename stuff the user defined
|
||||
prettify :: Map T.Ident (Maybe Type) -> T.Program' Type -> T.Program' Type
|
||||
prettify s (T.Program defs) = T.Program $ map (go s) defs
|
||||
where
|
||||
go :: Map T.Ident (Maybe Type) -> T.Def' Type -> T.Def' Type
|
||||
go _ (T.DData d) = T.DData d
|
||||
go m b@(T.DBind (T.Bind (name, t) args (e, et)))
|
||||
| Just (Just _) <- M.lookup name m = b
|
||||
| otherwise =
|
||||
let fvs = nub $ freeOrdered t
|
||||
m = M.fromList $ zip fvs letters
|
||||
in T.DBind $ T.Bind (name, replace m t) args (fmap (replace m) e, replace m et)
|
||||
|
||||
replace :: Map T.Ident T.Ident -> Type -> Type
|
||||
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
|
||||
Just t -> TVar . MkTVar . LIdent $ coerce t
|
||||
Nothing -> def
|
||||
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
|
||||
replace m (TData name ts) = TData name (map (replace m) ts)
|
||||
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
|
||||
Just found -> TAll (MkTVar $ coerce found) (replace m t)
|
||||
Nothing -> def
|
||||
replace _ t = t
|
||||
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x : xs) = case x of
|
||||
|
|
@ -484,24 +459,21 @@ inferPattern = \case
|
|||
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: Type -> Type -> Infer Subst
|
||||
unify t0 t1 =
|
||||
let fvs = S.toList $ free t0 `S.union` free t1
|
||||
m = M.fromList $ zip fvs letters
|
||||
in case (t0, t1) of
|
||||
unify t0 t1 = case (t0, t1) of
|
||||
(TFun a b, TFun c d) -> do
|
||||
s1 <- unify a c
|
||||
s2 <- unify (apply s1 b) (apply s1 d)
|
||||
return $ s2 `compose` s1
|
||||
(TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
|
||||
(t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
|
||||
(TVar (MkTVar a), t) -> occurs (coerce a) t
|
||||
(t, TVar (MkTVar b)) -> occurs (coerce b) t
|
||||
(TVar a, t@(TData _ _)) -> return $ singleton a t
|
||||
(t@(TData _ _), TVar b) -> return $ singleton b t
|
||||
(TVar a, t) -> occurs a t
|
||||
(t, TVar b) -> occurs b t
|
||||
-- Forall unification should change
|
||||
(TAll _ t, b) -> unify t b
|
||||
(a, TAll _ t) -> unify a t
|
||||
(TLit a, TLit b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
then return nullSubst
|
||||
else catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
|
|
@ -517,13 +489,13 @@ unify t0 t1 =
|
|||
Aux.do
|
||||
"Type constructor:"
|
||||
printTree name
|
||||
quote $ printTree $ map (replace m) t
|
||||
quote $ printTree t
|
||||
"does not match with:"
|
||||
printTree name'
|
||||
quote $ printTree $ map (replace m) t'
|
||||
quote $ printTree t'
|
||||
(TEVar a, TEVar b) ->
|
||||
if a == b
|
||||
then return M.empty
|
||||
then return nullSubst
|
||||
else catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
|
|
@ -534,30 +506,27 @@ unify t0 t1 =
|
|||
catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
quote $ printTree $ replace m a
|
||||
quote $ printTree a
|
||||
"with"
|
||||
quote $ printTree $ replace m b
|
||||
quote $ printTree b
|
||||
|
||||
{- | Check if a type is contained in another type.
|
||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||
where these are equal
|
||||
-}
|
||||
occurs :: T.Ident -> Type -> Infer Subst
|
||||
occurs i t@(TEVar _) = return (M.singleton i t)
|
||||
occurs i t@(TVar _) = return (M.singleton i t)
|
||||
occurs i t =
|
||||
let fvs = S.toList $ free t
|
||||
m = M.fromList $ zip fvs letters
|
||||
in if S.member i (free t)
|
||||
then
|
||||
occurs :: TVar -> Type -> Infer Subst
|
||||
occurs i t@(TEVar _) = return (singleton i t)
|
||||
occurs i t@(TVar _) = return (singleton i t)
|
||||
occurs i t
|
||||
| S.member i (free t) =
|
||||
catchableErr
|
||||
( Aux.do
|
||||
"Occurs check failed, can't unify"
|
||||
quote $ printTree $ replace m (TVar $ MkTVar (coerce i))
|
||||
quote $ printTree (TVar i)
|
||||
"with"
|
||||
quote $ printTree $ replace m t
|
||||
quote $ printTree t
|
||||
)
|
||||
else return $ M.singleton i t
|
||||
| otherwise = return $ singleton i t
|
||||
|
||||
{- | Generalize a type over all free variables in the substitution set
|
||||
Used for let bindings to allow expression that do not type check in
|
||||
|
|
@ -568,9 +537,9 @@ occurs i t =
|
|||
generalize :: Map T.Ident Type -> Type -> Type
|
||||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||
where
|
||||
go :: [T.Ident] -> Type -> Type
|
||||
go :: [TVar] -> Type -> Type
|
||||
go [] t = t
|
||||
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
|
||||
go (x : xs) t = TAll x (go xs t)
|
||||
removeForalls :: Type -> Type
|
||||
removeForalls (TAll _ t) = removeForalls t
|
||||
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
|
||||
|
|
@ -581,9 +550,9 @@ with fresh ones.
|
|||
-}
|
||||
inst :: Type -> Infer Type
|
||||
inst = \case
|
||||
TAll (MkTVar bound) t -> do
|
||||
TAll bound t -> do
|
||||
fr <- fresh
|
||||
let s = M.singleton (coerce bound) fr
|
||||
let s = singleton bound fr
|
||||
apply s <$> inst t
|
||||
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
|
||||
rest -> return rest
|
||||
|
|
@ -598,7 +567,7 @@ freshen :: Type -> Infer Type
|
|||
freshen t = do
|
||||
let frees = S.toList (free t)
|
||||
xs <- mapM (const fresh) frees
|
||||
let sub = M.fromList $ zip frees xs
|
||||
let sub = Subst . M.fromList $ zip frees xs
|
||||
return $ apply sub t
|
||||
|
||||
-- | Generate a new fresh variable
|
||||
|
|
@ -633,10 +602,10 @@ fresh = do
|
|||
go tvars t1 t2 = do
|
||||
-- probably not necessary
|
||||
freshies <- mapM (const fresh) tvars
|
||||
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
|
||||
let sub = Subst . M.fromList $ zip tvars freshies
|
||||
let t1' = apply sub t1
|
||||
let t2' = apply sub t2
|
||||
let alph = execState (alpha t1' t2') mempty
|
||||
let alph = Subst $ execState (alpha t1' t2') mempty
|
||||
return $ apply alph t1' == t2'
|
||||
|
||||
-- Pre-condition: All TAlls are outermost
|
||||
|
|
@ -646,11 +615,11 @@ fresh = do
|
|||
|
||||
-- Alpha rename the first type's type variable to match second.
|
||||
-- Pre-condition: No TAll are checked
|
||||
alpha :: Type -> Type -> State (Map T.Ident Type) ()
|
||||
alpha :: Type -> Type -> State (Map TVar Type) ()
|
||||
alpha t1 t2 = case (t1, t2) of
|
||||
(TVar (MkTVar (LIdent i)), t2) -> do
|
||||
(TVar i, t2) -> do
|
||||
m <- get
|
||||
put (M.insert (coerce i) t2 m)
|
||||
put (M.insert i t2 m)
|
||||
(TFun t1 t2, TFun t3 t4) -> do
|
||||
alpha t1 t3
|
||||
alpha t2 t4
|
||||
|
|
@ -664,16 +633,16 @@ class SubstType t where
|
|||
|
||||
class FreeVars t where
|
||||
-- | Get all free variables from t
|
||||
free :: t -> Set T.Ident
|
||||
free :: t -> Set TVar
|
||||
|
||||
instance FreeVars (T.Bind' Type) where
|
||||
free (T.Bind (_, t) _ _) = free t
|
||||
|
||||
instance FreeVars Type where
|
||||
free :: Type -> Set T.Ident
|
||||
free (TVar (MkTVar a)) = S.singleton (coerce a)
|
||||
free (TAll (MkTVar bound) t) =
|
||||
S.singleton (coerce bound) `S.intersection` free t
|
||||
free :: Type -> Set TVar
|
||||
free (TVar a) = S.singleton a
|
||||
free (TAll bound t) =
|
||||
S.singleton bound `S.intersection` free t
|
||||
free (TLit _) = mempty
|
||||
free (TFun a b) = free a `S.union` free b
|
||||
free (TData _ a) = free a
|
||||
|
|
@ -687,22 +656,26 @@ instance SubstType Type where
|
|||
apply sub t = do
|
||||
case t of
|
||||
TLit _ -> t
|
||||
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TVar (MkTVar $ coerce a)
|
||||
TVar a -> case find a sub of
|
||||
Nothing -> TVar a
|
||||
Just t -> t
|
||||
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
|
||||
Nothing -> TAll (MkTVar i) (apply sub t)
|
||||
TAll i t -> case find i sub of
|
||||
Nothing -> TAll i (apply sub t)
|
||||
Just _ -> apply sub t
|
||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||
TData name a -> TData name (apply sub a)
|
||||
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
|
||||
TEVar (MkTEVar a) -> case find (MkTVar a) sub of
|
||||
Nothing -> TEVar (MkTEVar $ coerce a)
|
||||
Just t -> t
|
||||
|
||||
instance FreeVars (Map T.Ident Type) where
|
||||
free :: Map T.Ident Type -> Set T.Ident
|
||||
free :: Map T.Ident Type -> Set TVar
|
||||
free = free . M.elems
|
||||
|
||||
instance SubstType (Map TVar Type) where
|
||||
apply :: Subst -> Map TVar Type -> Map TVar Type
|
||||
apply = M.map . apply
|
||||
|
||||
instance SubstType (Map T.Ident Type) where
|
||||
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
|
||||
apply = M.map . apply
|
||||
|
|
@ -759,7 +732,7 @@ nullSubst = mempty
|
|||
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
compose m1@(Subst m1') (Subst m2) = Subst $ M.map (apply m1) m2 `M.union` m1'
|
||||
|
||||
-- | Compose a list of substitution sets into one
|
||||
composeAll :: [Subst] -> Subst
|
||||
|
|
@ -910,7 +883,21 @@ data Env = Env
|
|||
|
||||
data Error = Error {msg :: String, catchable :: Bool}
|
||||
deriving (Show)
|
||||
type Subst = Map T.Ident Type
|
||||
|
||||
newtype Subst = Subst {unSubst :: Map TVar Type}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
singleton :: TVar -> Type -> Subst
|
||||
singleton a b = Subst (M.singleton a b)
|
||||
|
||||
find :: TVar -> Subst -> Maybe Type
|
||||
find tvar (Subst s) = M.lookup tvar s
|
||||
|
||||
instance Semigroup Subst where
|
||||
(<>) = compose
|
||||
|
||||
instance Monoid Subst where
|
||||
mempty = Subst mempty
|
||||
|
||||
newtype Warning = NonExhaustive String
|
||||
deriving (Show)
|
||||
|
|
|
|||
|
|
@ -76,14 +76,18 @@ rn_sig =
|
|||
rn_bind1 =
|
||||
specify "Rename simple bind" $
|
||||
shouldSatisfyOk
|
||||
"f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
|
||||
( unlines
|
||||
[ ".+ x y = x"
|
||||
, "f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
|
||||
]
|
||||
)
|
||||
|
||||
rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $
|
||||
D.do
|
||||
"data forall a. List a where"
|
||||
" Nil : List a "
|
||||
" Cons : a -> List a -> List a"
|
||||
|
||||
".+ x y = x"
|
||||
"length : forall a. List a -> Int"
|
||||
"length list = case list of"
|
||||
" Nil => 0"
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import TypeChecker.RemoveForall (removeForall)
|
|||
import TypeChecker.ReportTEVar (reportTEVar)
|
||||
import TypeChecker.TypeChecker (TypeChecker (Bi))
|
||||
import TypeChecker.TypeCheckerBidir (typecheck)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
||||
test = hspec testTypeCheckerBidir
|
||||
|
||||
|
|
@ -54,13 +54,19 @@ tc_id =
|
|||
tc_double =
|
||||
specify "Addition inference" $
|
||||
run
|
||||
["double x = x + x"]
|
||||
[ ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "double x = x + x"
|
||||
]
|
||||
`shouldSatisfy` ok
|
||||
|
||||
tc_add_lam =
|
||||
specify "Addition lambda inference" $
|
||||
run
|
||||
["four = (\\x. x + x) 2"]
|
||||
[ ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "four = (\\x. x + x) 2"
|
||||
]
|
||||
`shouldSatisfy` ok
|
||||
|
||||
tc_const =
|
||||
|
|
@ -88,6 +94,8 @@ tc_rank2 =
|
|||
run
|
||||
[ "const : a -> b -> a"
|
||||
, "const x y = x"
|
||||
, ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "rank2 : a -> (forall c. c -> Int) -> b -> Int"
|
||||
, "rank2 x f y = f x + f y"
|
||||
, "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
|
||||
|
|
@ -195,18 +203,21 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
|
|||
-- run (fs ++ correct1) `shouldSatisfy` ok
|
||||
specify "Second correct case expression accepted" $
|
||||
run (fs ++ correct2) `shouldSatisfy` ok
|
||||
where
|
||||
-- specify "Third correct case expression accepted" $
|
||||
-- run (fs ++ correct3) `shouldSatisfy` ok
|
||||
-- specify "Forth correct case expression accepted" $
|
||||
-- run (fs ++ correct4) `shouldSatisfy` ok
|
||||
where
|
||||
|
||||
fs =
|
||||
[ "data List a where"
|
||||
, " Nil : List a"
|
||||
, " Cons : a -> List a -> List a"
|
||||
]
|
||||
wrong1 =
|
||||
[ "length : List c -> Int"
|
||||
[ ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "length : List c -> Int"
|
||||
, "length = \\list. case list of"
|
||||
, " Nil => 0"
|
||||
, " Cons 6 xs => 1 + length xs"
|
||||
|
|
@ -254,10 +265,10 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
|
|||
correct4 =
|
||||
[ "elems : List (List c) -> Int"
|
||||
, "elems = \\list. case list of"
|
||||
--, " Nil => 0"
|
||||
, -- , " Nil => 0"
|
||||
-- , " Cons Nil Nil => 0"
|
||||
-- , " Cons Nil xs => elems xs"
|
||||
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
|
||||
" Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
|
||||
]
|
||||
|
||||
tc_if = specify "Test if else case expression" $ do
|
||||
|
|
@ -298,12 +309,19 @@ tc_infer_case = describe "Infer case expression" $ do
|
|||
|
||||
tc_rec1 =
|
||||
specify "Infer simple recursive definition" $
|
||||
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok
|
||||
run
|
||||
[ ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "test x = 1 + test (x + 1)"
|
||||
]
|
||||
`shouldSatisfy` ok
|
||||
|
||||
tc_rec2 =
|
||||
specify "Infer recursive definition with pattern matching" $
|
||||
run
|
||||
[ "data Bool where"
|
||||
[ ".+ : Int -> Int -> Int"
|
||||
, ".+ x y = x"
|
||||
, "data Bool where"
|
||||
, " False : Bool"
|
||||
, " True : Bool"
|
||||
, "test = \\x. case x of"
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ goods =
|
|||
, testSatisfy
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
( D.do
|
||||
".+ : Int -> Int -> Int"
|
||||
".+ x y = x"
|
||||
"plusOne x = x + 1 ;"
|
||||
"main x = plusOne x ;"
|
||||
)
|
||||
|
|
@ -74,6 +76,8 @@ goods =
|
|||
, testSatisfy
|
||||
"length function on int list infers correct signature"
|
||||
( D.do
|
||||
".+ : Int -> Int -> Int"
|
||||
".+ x y = x"
|
||||
"data List where "
|
||||
" Nil : List"
|
||||
" Cons : Int -> List -> List"
|
||||
|
|
@ -114,6 +118,8 @@ bads =
|
|||
, testSatisfy
|
||||
"Using a concrete function (primitive type) on a skolem variable should not succeed"
|
||||
( D.do
|
||||
".+ : Int -> Int -> Int"
|
||||
".+ x y = x"
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
"f : a -> Int ;"
|
||||
|
|
@ -131,6 +137,8 @@ bads =
|
|||
"Pattern matching on literal and _List should not succeed"
|
||||
( D.do
|
||||
_List
|
||||
".+ : Int -> Int -> Int"
|
||||
".+ x y = x"
|
||||
"length : List c -> Int;"
|
||||
"length _List = case _List of {"
|
||||
" 0 => 0;"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue