Fixed bug in HM, fixed and reimported tests.

This commit is contained in:
sebastian 2023-05-10 23:54:31 +02:00
parent c5fbd70756
commit 49ef3f9f7c
4 changed files with 186 additions and 227 deletions

View file

@ -35,7 +35,13 @@ bidm FILE:
cabal run language -- -d -t bi -m {{FILE}} cabal run language -- -d -t bi -m {{FILE}}
hmp FILE: hmp FILE:
cabal run language -- -t hm -d -p {{FILE}} cabal run language -- -t hm -p {{FILE}}
bip FILE: bip FILE:
cabal run language -- -t bi -p {{FILE}} cabal run language -- -t bi -p {{FILE}}
hmdp FILE:
cabal run language -- -t hm -d -p {{FILE}}
bidp FILE:
cabal run language -- -t bi -d -p {{FILE}}

View file

@ -45,7 +45,7 @@ desugarType = \case
let (name : tvars) = flatten t1 ++ [t2] let (name : tvars) = flatten t1 ++ [t2]
in case name of in case name of
TIdent ident -> TData ident (map desugarType tvars) TIdent ident -> TData ident (map desugarType tvars)
_ -> error "desugarType is not implemented correctly" _ -> error "desugarType is not implemented correctly, or the user made a mistake"
TLit l -> TLit l TLit l -> TLit l
TVar v -> TVar v TVar v -> TVar v
(TAll i t) -> TAll i (desugarType t) (TAll i t) -> TAll i (desugarType t)

View file

@ -8,7 +8,7 @@
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, maybeToRightM, typeof, unzip4) import Auxiliary (int, maybeToRightM, typeof, unzip4)
import qualified Auxiliary as Aux import Auxiliary qualified as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
@ -19,15 +19,14 @@ import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T') import TypeChecker.TypeCheckerIr (T, T')
import TypeChecker.TypeCheckerIr qualified as T
{- {-
TODO TODO
@ -193,14 +192,11 @@ checkBind (Bind name args e) = do
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just typSig) -> do Just (Just typSig) -> do
env <- asks vars
let genInfSig = generalize mempty infSig let genInfSig = generalize mempty infSig
trace "\n\n" pure ()
trace ("genInfSig: " ++ printTree genInfSig) pure ()
trace ("typSig: " ++ printTree typSig ++ "\n\n") pure ()
sub <- genInfSig `unify` typSig sub <- genInfSig `unify` typSig
--b <- (genInfSig <<= typSig) b <- genInfSig <<= typSig
unless True unless
b
( throwError $ ( throwError $
Error Error
( Aux.do ( Aux.do
@ -231,7 +227,7 @@ checkData err@(Data typ injs) = do
pure (name, tvars') pure (name, tvars')
_ -> _ ->
uncatchableErr $ uncatchableErr $
unwords ["Bad data type definition: ", printTree typ] unwords ["Bad data type definition: ", show typ]
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m () checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
checkInj (Inj c inj_typ) name tvars checkInj (Inj c inj_typ) name tvars
@ -296,17 +292,17 @@ algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t' sub1 <- unify t' t
sub2 <- unify t' t b <- t' <<= t
b <- (apply sub1 t <<= apply sub2 t') unless
unless b b
( uncatchableErr $ Aux.do ( uncatchableErr $ Aux.do
"Annotated type" "Annotated type"
quote $ printTree t quote $ printTree t
"does not match inferred type" "does not match inferred type"
quote $ printTree t' quote $ printTree t'
) )
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub1 `compose` sub0
return (comp, (apply comp e', t)) return (comp, (apply comp e', t))
-- \| ------------------ -- \| ------------------
@ -640,11 +636,10 @@ fresh = do
modify (\st -> st{count = succ (count st)}) modify (\st -> st{count = succ (count st)})
return $ TVar $ MkTVar $ LIdent $ show n return $ TVar $ MkTVar $ LIdent $ show n
-- Is the left a subtype of the right -- Is the left more general than the right
(<<=) :: Type -> Type -> Infer Bool (<<=) :: Type -> Type -> Infer Bool
(<<=) a b = case (a, b) of (<<=) a b = case (a, b) of
(TVar a, TVar b) -> return $ a == b (TVar _, _) -> return True
(TVar a, _) -> return True
(TFun a b, TFun c d) -> do (TFun a b, TFun c d) -> do
bfirst <- a <<= c bfirst <- a <<= c
bsecond <- b <<= d bsecond <- b <<= d
@ -652,37 +647,43 @@ fresh = do
(TData n1 ts1, TData n2 ts2) -> do (TData n1 ts1, TData n2 ts2) -> do
b <- and <$> zipWithM (<<=) ts1 ts2 b <- and <$> zipWithM (<<=) ts1 ts2
return (b && n1 == n2 && length ts1 == length ts2) return (b && n1 == n2 && length ts1 == length ts2)
(t1@(TAll _ _ ), t2) -> let (tvars1, t1') = gatherTVars [] t1 (t1@(TAll _ _), t2) ->
let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2 (tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1 t2 in go (tvars1 ++ tvars2) t1' t2'
(t1, t2@(TAll _ _)) -> let (tvars1, t1') = gatherTVars [] t1 (t1, t2@(TAll _ _)) ->
let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2 (tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2' in go (tvars1 ++ tvars2) t1' t2'
(t1, t2) -> return $ t1 == t2 (t1, t2) -> return $ t1 == t2
where where
go :: [TVar] -> Type -> Type -> Infer Bool go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do go tvars t1 t2 = do
-- probably not necessary
freshies <- mapM (const fresh) tvars freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let t1' = apply sub t1 let t1' = apply sub t1
let t2' = apply sub t2 let t2' = apply sub t2
trace ("t1': " ++ printTree t1') pure () let alph = execState (alpha t1' t2') mempty
trace ("t2': " ++ printTree t2') pure () return $ apply alph t1' == t2'
t1' <<= t2'
{-
Renaming: a -> b -> a and c -> d -> c
gives 0 -> 1 -> 0 and -> 2 -> 3 -> 2
They have to be given the same name. Alpha-renaming in the subtype check is done incorrectly
-}
-- Pre-condition: All TAlls are outermost -- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type) gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) = gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
let (tvars', t') = gatherTVars (tvar : tvars) t
in (tvars', t')
gatherTVars tvars t = (tvars, t) gatherTVars tvars t = (tvars, t)
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do
m <- get
put (M.insert (coerce i) t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
_ -> return ()
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
@ -956,12 +957,3 @@ quote s = "'" ++ s ++ "'"
letters :: [T.Ident] letters :: [T.Ident]
letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z'] letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z']
{-
first = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "b"))))))
second = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "a"))))))
-}

View file

@ -20,7 +20,6 @@ import TypeChecker.TypeCheckerIr (Program)
testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do
sequence_ goods sequence_ goods
sequence_ bads sequence_ bads
sequence_ bes
goods = goods =
[ testSatisfy [ testSatisfy
@ -55,6 +54,35 @@ goods =
"};" "};"
) )
ok ok
, testSatisfy
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
ok
, testSatisfy
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
ok
, testSatisfy
"length function on int list infers correct signature"
( D.do
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
)
ok
] ]
bads = bads =
@ -121,97 +149,38 @@ bads =
" };" " };"
) )
bad bad
-- FIXME FAILING TEST , -- FIXME FAILING TEST
-- , testSatisfy testSatisfy
-- "id with incorrect signature" "id with incorrect signature"
-- ( D.do
-- "id : a -> b;"
-- "id x = x;"
-- )
-- bad
-- FIXME FAILING TEST
-- , testSatisfy
-- "incorrect signature on const"
-- ( D.do
-- "const : a -> b -> b;"
-- "const x y = x"
-- )
-- bad
-- FIXME FAILING TEST
-- , testSatisfy
-- "incorrect type signature on id lambda"
-- ( D.do
-- "id = ((\\x. x) : a -> b);"
-- )
-- bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do ( D.do
"plusOne x = x + 1 ;" "id : a -> b;"
"main x = plusOne x ;" "id x = x;"
) )
bad
, -- FIXME FAILING TEST
testSatisfy
"incorrect signature on const"
( D.do ( D.do
"plusOne : Int -> Int ;" "const : a -> b -> b;"
"plusOne x = x + 1 ;" "const x y = x"
"main : Int -> Int ;"
"main x = plusOne x ;"
) )
, testBe bad
"A basic arithmetic function should be able to be inferred" , -- FIXME FAILING TEST
testSatisfy
"incorrect type signature on id lambda"
( D.do ( D.do
"plusOne x = x + 1 ;" "id = ((\\x. x) : a -> b);"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
, testBe
"length function on int list infers correct signature"
( D.do
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
)
( D.do
"data List where"
" Nil : List"
" Cons : Int -> List -> List"
"length : List -> Int"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
) )
bad
] ]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = fmap (printTree . fst) . typecheck <=< fmap desugar . pProgram . myLexer run s = do
p <- (fmap desugar . pProgram . resolveLayout True . myLexer) s
reportForall Hm p
(printTree . fst) <$> (typecheck <=< rename <=< annotateForall) p
ok (Right _) = True ok (Right _) = True
ok (Left _) = False ok (Left _) = False
@ -221,45 +190,37 @@ bad = not . ok
-- FUNCTIONS -- FUNCTIONS
_const = D.do _const = D.do
"const : a -> b -> a ;" "const : a -> b -> a"
"const x y = x ;" "const x y = x"
_List = D.do _List = D.do
"data List a where {" "data List a where { Nil : List a; Cons : a -> List a -> List a; }"
" Nil : List a;"
" Cons : a -> List a -> List a;"
"};"
_headSig = D.do _headSig = D.do
"head : List a -> a ;" "head : List a -> a"
_head = D.do _head = D.do
"head xs = " "head xs = case xs of"
" case xs of {" " Cons x xs => x"
" Cons x xs => x ;"
" };"
_Bool = D.do _Bool = D.do
"data Bool where {" "data Bool where"
" True : Bool" " True : Bool"
" False : Bool" " False : Bool"
"};"
_not = D.do _not = D.do
"not : Bool -> Bool ;" "not : Bool -> Bool ;"
"not x = case x of {" "not x = case x of"
" True => False ;" " True => False"
" False => True ;" " False => True"
"};"
_id = "id x = x ;" _id = "id x = x ;"
_Maybe = D.do _Maybe = D.do
"data Maybe a where {" "data Maybe a where"
" Nothing : Maybe a" " Nothing : Maybe a"
" Just : a -> Maybe a" " Just : a -> Maybe a"
" };"
_fmap = D.do _fmap = D.do
"fmap f ma = case ma of {" "fmap f ma = case ma of"
" Nothing => Nothing ;" " Nothing => Nothing"
" Just a => Just (f a) ;" " Just a => Just (f a)"
"};"