Fixed bug in HM, fixed and reimported tests.
This commit is contained in:
parent
c5fbd70756
commit
49ef3f9f7c
4 changed files with 186 additions and 227 deletions
8
Justfile
8
Justfile
|
|
@ -35,7 +35,13 @@ bidm FILE:
|
|||
cabal run language -- -d -t bi -m {{FILE}}
|
||||
|
||||
hmp FILE:
|
||||
cabal run language -- -t hm -d -p {{FILE}}
|
||||
cabal run language -- -t hm -p {{FILE}}
|
||||
|
||||
bip FILE:
|
||||
cabal run language -- -t bi -p {{FILE}}
|
||||
|
||||
hmdp FILE:
|
||||
cabal run language -- -t hm -d -p {{FILE}}
|
||||
|
||||
bidp FILE:
|
||||
cabal run language -- -t bi -d -p {{FILE}}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ desugarType = \case
|
|||
let (name : tvars) = flatten t1 ++ [t2]
|
||||
in case name of
|
||||
TIdent ident -> TData ident (map desugarType tvars)
|
||||
_ -> error "desugarType is not implemented correctly"
|
||||
_ -> error "desugarType is not implemented correctly, or the user made a mistake"
|
||||
TLit l -> TLit l
|
||||
TVar v -> TVar v
|
||||
(TAll i t) -> TAll i (desugarType t)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
module TypeChecker.TypeCheckerHm where
|
||||
|
||||
import Auxiliary (int, maybeToRightM, typeof, unzip4)
|
||||
import qualified Auxiliary as Aux
|
||||
import Auxiliary qualified as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
|
|
@ -19,15 +19,14 @@ import Data.Function (on)
|
|||
import Data.List (foldl', nub, sortOn)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Debug.Trace (trace, traceShow)
|
||||
import Data.Set qualified as S
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (T, T')
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
||||
{-
|
||||
TODO
|
||||
|
|
@ -193,14 +192,11 @@ checkBind (Bind name args e) = do
|
|||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just typSig) -> do
|
||||
env <- asks vars
|
||||
let genInfSig = generalize mempty infSig
|
||||
trace "\n\n" pure ()
|
||||
trace ("genInfSig: " ++ printTree genInfSig) pure ()
|
||||
trace ("typSig: " ++ printTree typSig ++ "\n\n") pure ()
|
||||
sub <- genInfSig `unify` typSig
|
||||
--b <- (genInfSig <<= typSig)
|
||||
unless True
|
||||
b <- genInfSig <<= typSig
|
||||
unless
|
||||
b
|
||||
( throwError $
|
||||
Error
|
||||
( Aux.do
|
||||
|
|
@ -231,7 +227,7 @@ checkData err@(Data typ injs) = do
|
|||
pure (name, tvars')
|
||||
_ ->
|
||||
uncatchableErr $
|
||||
unwords ["Bad data type definition: ", printTree typ]
|
||||
unwords ["Bad data type definition: ", show typ]
|
||||
|
||||
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
|
||||
checkInj (Inj c inj_typ) name tvars
|
||||
|
|
@ -296,17 +292,17 @@ algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
|
|||
algoW = \case
|
||||
err@(EAnn e t) -> do
|
||||
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||
sub1 <- unify t t'
|
||||
sub2 <- unify t' t
|
||||
b <- (apply sub1 t <<= apply sub2 t')
|
||||
unless b
|
||||
sub1 <- unify t' t
|
||||
b <- t' <<= t
|
||||
unless
|
||||
b
|
||||
( uncatchableErr $ Aux.do
|
||||
"Annotated type"
|
||||
quote $ printTree t
|
||||
"does not match inferred type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
let comp = sub1 `compose` sub0
|
||||
return (comp, (apply comp e', t))
|
||||
|
||||
-- \| ------------------
|
||||
|
|
@ -640,11 +636,10 @@ fresh = do
|
|||
modify (\st -> st{count = succ (count st)})
|
||||
return $ TVar $ MkTVar $ LIdent $ show n
|
||||
|
||||
-- Is the left a subtype of the right
|
||||
-- Is the left more general than the right
|
||||
(<<=) :: Type -> Type -> Infer Bool
|
||||
(<<=) a b = case (a, b) of
|
||||
(TVar a, TVar b) -> return $ a == b
|
||||
(TVar a, _) -> return True
|
||||
(TVar _, _) -> return True
|
||||
(TFun a b, TFun c d) -> do
|
||||
bfirst <- a <<= c
|
||||
bsecond <- b <<= d
|
||||
|
|
@ -652,37 +647,43 @@ fresh = do
|
|||
(TData n1 ts1, TData n2 ts2) -> do
|
||||
b <- and <$> zipWithM (<<=) ts1 ts2
|
||||
return (b && n1 == n2 && length ts1 == length ts2)
|
||||
(t1@(TAll _ _ ), t2) -> let (tvars1, t1') = gatherTVars [] t1
|
||||
(t1@(TAll _ _), t2) ->
|
||||
let (tvars1, t1') = gatherTVars [] t1
|
||||
(tvars2, t2') = gatherTVars [] t2
|
||||
in go (tvars1 ++ tvars2) t1 t2
|
||||
(t1, t2@(TAll _ _)) -> let (tvars1, t1') = gatherTVars [] t1
|
||||
in go (tvars1 ++ tvars2) t1' t2'
|
||||
(t1, t2@(TAll _ _)) ->
|
||||
let (tvars1, t1') = gatherTVars [] t1
|
||||
(tvars2, t2') = gatherTVars [] t2
|
||||
in go (tvars1 ++ tvars2) t1' t2'
|
||||
(t1, t2) -> return $ t1 == t2
|
||||
where
|
||||
go :: [TVar] -> Type -> Type -> Infer Bool
|
||||
go tvars t1 t2 = do
|
||||
-- probably not necessary
|
||||
freshies <- mapM (const fresh) tvars
|
||||
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
|
||||
let t1' = apply sub t1
|
||||
let t2' = apply sub t2
|
||||
trace ("t1': " ++ printTree t1') pure ()
|
||||
trace ("t2': " ++ printTree t2') pure ()
|
||||
t1' <<= t2'
|
||||
|
||||
{-
|
||||
Renaming: a -> b -> a and c -> d -> c
|
||||
gives 0 -> 1 -> 0 and -> 2 -> 3 -> 2
|
||||
They have to be given the same name. Alpha-renaming in the subtype check is done incorrectly
|
||||
-}
|
||||
let alph = execState (alpha t1' t2') mempty
|
||||
return $ apply alph t1' == t2'
|
||||
|
||||
-- Pre-condition: All TAlls are outermost
|
||||
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
|
||||
gatherTVars tvars (TAll tvar t) =
|
||||
let (tvars', t') = gatherTVars (tvar : tvars) t
|
||||
in (tvars', t')
|
||||
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
|
||||
gatherTVars tvars t = (tvars, t)
|
||||
|
||||
-- Alpha rename the first type's type variable to match second.
|
||||
-- Pre-condition: No TAll are checked
|
||||
alpha :: Type -> Type -> State (Map T.Ident Type) ()
|
||||
alpha t1 t2 = case (t1, t2) of
|
||||
(TVar (MkTVar (LIdent i)), t2) -> do
|
||||
m <- get
|
||||
put (M.insert (coerce i) t2 m)
|
||||
(TFun t1 t2, TFun t3 t4) -> do
|
||||
alpha t1 t3
|
||||
alpha t2 t4
|
||||
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
|
||||
_ -> return ()
|
||||
|
||||
-- | A class for substitutions
|
||||
class SubstType t where
|
||||
|
|
@ -956,12 +957,3 @@ quote s = "'" ++ s ++ "'"
|
|||
|
||||
letters :: [T.Ident]
|
||||
letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z']
|
||||
|
||||
{-
|
||||
|
||||
|
||||
first = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "b"))))))
|
||||
second = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "a"))))))
|
||||
|
||||
-}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import TypeChecker.TypeCheckerIr (Program)
|
|||
testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do
|
||||
sequence_ goods
|
||||
sequence_ bads
|
||||
sequence_ bes
|
||||
|
||||
goods =
|
||||
[ testSatisfy
|
||||
|
|
@ -55,6 +54,35 @@ goods =
|
|||
"};"
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
( D.do
|
||||
"plusOne x = x + 1 ;"
|
||||
"main x = plusOne x ;"
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"List of function Int -> Int functions should be inferred corretly"
|
||||
( D.do
|
||||
_List
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"length function on int list infers correct signature"
|
||||
( D.do
|
||||
"data List where "
|
||||
" Nil : List"
|
||||
" Cons : Int -> List -> List"
|
||||
|
||||
"length xs = case xs of"
|
||||
" Nil => 0"
|
||||
" Cons _ xs => 1 + length xs"
|
||||
)
|
||||
ok
|
||||
]
|
||||
|
||||
bads =
|
||||
|
|
@ -121,97 +149,38 @@ bads =
|
|||
" };"
|
||||
)
|
||||
bad
|
||||
-- FIXME FAILING TEST
|
||||
-- , testSatisfy
|
||||
-- "id with incorrect signature"
|
||||
-- ( D.do
|
||||
-- "id : a -> b;"
|
||||
-- "id x = x;"
|
||||
-- )
|
||||
-- bad
|
||||
-- FIXME FAILING TEST
|
||||
-- , testSatisfy
|
||||
-- "incorrect signature on const"
|
||||
-- ( D.do
|
||||
-- "const : a -> b -> b;"
|
||||
-- "const x y = x"
|
||||
-- )
|
||||
-- bad
|
||||
-- FIXME FAILING TEST
|
||||
-- , testSatisfy
|
||||
-- "incorrect type signature on id lambda"
|
||||
-- ( D.do
|
||||
-- "id = ((\\x. x) : a -> b);"
|
||||
-- )
|
||||
-- bad
|
||||
]
|
||||
|
||||
bes =
|
||||
[ testBe
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
, -- FIXME FAILING TEST
|
||||
testSatisfy
|
||||
"id with incorrect signature"
|
||||
( D.do
|
||||
"plusOne x = x + 1 ;"
|
||||
"main x = plusOne x ;"
|
||||
"id : a -> b;"
|
||||
"id x = x;"
|
||||
)
|
||||
bad
|
||||
, -- FIXME FAILING TEST
|
||||
testSatisfy
|
||||
"incorrect signature on const"
|
||||
( D.do
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
"main : Int -> Int ;"
|
||||
"main x = plusOne x ;"
|
||||
"const : a -> b -> b;"
|
||||
"const x y = x"
|
||||
)
|
||||
, testBe
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
bad
|
||||
, -- FIXME FAILING TEST
|
||||
testSatisfy
|
||||
"incorrect type signature on id lambda"
|
||||
( D.do
|
||||
"plusOne x = x + 1 ;"
|
||||
)
|
||||
( D.do
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
)
|
||||
, testBe
|
||||
"List of function Int -> Int functions should be inferred corretly"
|
||||
( D.do
|
||||
_List
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
( D.do
|
||||
_List
|
||||
"main : List (Int -> Int) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
, testBe
|
||||
"length function on int list infers correct signature"
|
||||
( D.do
|
||||
"data List where "
|
||||
" Nil : List"
|
||||
" Cons : Int -> List -> List"
|
||||
|
||||
"length xs = case xs of"
|
||||
" Nil => 0"
|
||||
" Cons _ xs => 1 + length xs"
|
||||
)
|
||||
( D.do
|
||||
"data List where"
|
||||
" Nil : List"
|
||||
" Cons : Int -> List -> List"
|
||||
|
||||
"length : List -> Int"
|
||||
"length xs = case xs of"
|
||||
" Nil => 0"
|
||||
" Cons _ xs => 1 + length xs"
|
||||
"id = ((\\x. x) : a -> b);"
|
||||
)
|
||||
bad
|
||||
]
|
||||
|
||||
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
|
||||
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
|
||||
|
||||
run = fmap (printTree . fst) . typecheck <=< fmap desugar . pProgram . myLexer
|
||||
run s = do
|
||||
p <- (fmap desugar . pProgram . resolveLayout True . myLexer) s
|
||||
reportForall Hm p
|
||||
(printTree . fst) <$> (typecheck <=< rename <=< annotateForall) p
|
||||
|
||||
ok (Right _) = True
|
||||
ok (Left _) = False
|
||||
|
|
@ -221,45 +190,37 @@ bad = not . ok
|
|||
-- FUNCTIONS
|
||||
|
||||
_const = D.do
|
||||
"const : a -> b -> a ;"
|
||||
"const x y = x ;"
|
||||
"const : a -> b -> a"
|
||||
"const x y = x"
|
||||
_List = D.do
|
||||
"data List a where {"
|
||||
" Nil : List a;"
|
||||
" Cons : a -> List a -> List a;"
|
||||
"};"
|
||||
"data List a where { Nil : List a; Cons : a -> List a -> List a; }"
|
||||
|
||||
_headSig = D.do
|
||||
"head : List a -> a ;"
|
||||
"head : List a -> a"
|
||||
|
||||
_head = D.do
|
||||
"head xs = "
|
||||
" case xs of {"
|
||||
" Cons x xs => x ;"
|
||||
" };"
|
||||
"head xs = case xs of"
|
||||
" Cons x xs => x"
|
||||
|
||||
_Bool = D.do
|
||||
"data Bool where {"
|
||||
"data Bool where"
|
||||
" True : Bool"
|
||||
" False : Bool"
|
||||
"};"
|
||||
|
||||
_not = D.do
|
||||
"not : Bool -> Bool ;"
|
||||
"not x = case x of {"
|
||||
" True => False ;"
|
||||
" False => True ;"
|
||||
"};"
|
||||
"not x = case x of"
|
||||
" True => False"
|
||||
" False => True"
|
||||
|
||||
_id = "id x = x ;"
|
||||
|
||||
_Maybe = D.do
|
||||
"data Maybe a where {"
|
||||
"data Maybe a where"
|
||||
" Nothing : Maybe a"
|
||||
" Just : a -> Maybe a"
|
||||
" };"
|
||||
|
||||
_fmap = D.do
|
||||
"fmap f ma = case ma of {"
|
||||
" Nothing => Nothing ;"
|
||||
" Just a => Just (f a) ;"
|
||||
"};"
|
||||
"fmap f ma = case ma of"
|
||||
" Nothing => Nothing"
|
||||
" Just a => Just (f a)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue