fixed bug and additional test

This commit is contained in:
sebastian 2023-03-27 23:05:40 +02:00
parent 4b24755b93
commit 0d2fe862e0
2 changed files with 52 additions and 29 deletions

View file

@ -113,7 +113,7 @@ preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTypeVars t)
collect (collectTVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
@ -123,20 +123,23 @@ preRun (x : xs) = case x of
)
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
collect (collectTVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (coerceData d) :) (checkDef xs)
xs' <- checkDef xs
return $ T.DBind b' : xs'
(DData d) -> do
xs' <- checkDef xs
return $ T.DData (coerceData d) : xs'
(DSig _) -> checkDef xs
where
coerceData (Data t injs) =
@ -145,22 +148,24 @@ checkDef (x : xs) = case x of
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda
(e, lambda_t) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
sub1 <- unify lambda_t t'
sub2 <- unify t' lambda_t
unless
(args_t `typeEq` t')
(apply sub1 lambda_t == t' && lambda_t == apply sub2 t')
( throwError $ Aux.do
"Inferred type"
quote $ printTree args_t
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, t') [] e
return $ T.Bind (coerce name, t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e)
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
typeEq :: Type -> Type -> Bool
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
@ -203,18 +208,18 @@ inferExp e = do
return $ second (const subbed) (e', t)
class CollectTVars a where
collectTypeVars :: a -> Set T.Ident
collectTVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTypeVars (TAll _ t) = collectTypeVars t
collectTypeVars (TFun t1 t2) = (S.union `on` collectTypeVars) t1 t2
collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts
collectTypeVars _ = S.empty
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTVars (TAll _ t) = collectTVars t
collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2
collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts
collectTVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})

View file

@ -1,17 +1,28 @@
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE QualifiedDo #-}
module TestTypeCheckerHm where
import Control.Monad ((<=<))
import qualified DoStrings as D
import Grammar.Par (myLexer, pProgram)
import Prelude (Bool (..), Either (..), IO, foldl1,
mapM_, not, ($), (.), (>>))
import Test.Hspec
import Control.Monad ((<=<))
import DoStrings qualified as D
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Test.Hspec
import Prelude (
Bool (..),
Either (..),
IO,
fmap,
foldl1,
mapM_,
not,
($),
(.),
(>>),
)
-- import Test.QuickCheck
import TypeChecker.TypeCheckerHm (typecheck)
import TypeChecker.TypeCheckerHm (typecheck)
testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do
foldl1 (>>) goods
@ -124,6 +135,13 @@ bads =
"id x = x;"
)
bad
, testSatisfy
"incorrect signature on const"
( D.do
"const : a -> b -> b;"
"const x y = x"
)
bad
, testSatisfy
"incorrect type signature on id lambda"
( D.do
@ -176,10 +194,10 @@ bes =
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = typecheck <=< pProgram . myLexer
run = fmap printTree . typecheck <=< pProgram . myLexer
ok (Right _) = True
ok (Left _) = False
ok (Left _) = False
bad = not . ok