inference does not depend on order. mutual recursion still not working correctly

This commit is contained in:
sebastianselander 2023-03-29 17:30:31 +02:00
parent 29fcddf44c
commit 4efe7cf9a2
3 changed files with 97 additions and 282 deletions

View file

@ -12,6 +12,7 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl')
@ -27,7 +28,7 @@ import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty
run :: Infer a -> Either Error a
run = run' initEnv initCtx
@ -51,8 +52,20 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
return $ T.Program bs'
bs <- checkDef bs
sub <- solveUndecidable
dec <- gets toDecide
trace (printTree bs) pure ()
bs <- mapM (mono sub) bs
return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
b <- gets (S.member name . toDecide)
if b
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
else return bind
mono _ (T.DData d) = return $ T.DData d
preRun :: [Def] -> Infer ()
preRun [] = return ()
@ -66,7 +79,7 @@ preRun (x : xs) = case x of
"Duplicate signatures for function"
quote $ printTree n
)
insertSig (coerce n) (Just $ skolemize t) >> preRun xs
insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTVars e)
s <- gets sigs
@ -91,25 +104,15 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda
(e, lambda_t) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
-- \| TODO: Fix, this is not correct
let fsig = apply sub0 t'
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
unless
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -123,7 +126,7 @@ checkData err@(Data typ injs) = do
TData name typs
| Right tvars' <- mapM toTVar typs ->
pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
_ ->
uncatchableErr $
unwords ["Bad data type definition: ", printTree typ]
@ -158,7 +161,7 @@ checkInj (Inj c inj_typ) name tvars
where
boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case
TAll{} -> uncatchableErr "Explicit foralls not allowed, for now"
TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
TFun t1 t2 -> do
t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2
@ -177,11 +180,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
return (s, (e', subbed))
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed)
class CollectTVars a where
collectTVars :: a -> Set T.Ident
@ -225,7 +229,7 @@ algoW = \case
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EVar i -> do
EVar (LIdent i) -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t ->
@ -237,7 +241,8 @@ algoW = \case
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
insertSig (coerce i) (Just fr)
cb <- gets currentBind
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing ->
uncatchableErr $
@ -591,6 +596,9 @@ instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply
instance SubstType (T.ExpT' Type) where
apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where
apply s = \case
T.EVar i -> T.EVar i
@ -605,6 +613,11 @@ instance SubstType (T.Exp' Type) where
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj i -> T.EInj i
instance SubstType (T.Def' Type) where
apply s = \case
T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
d -> d
instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
@ -616,18 +629,18 @@ instance SubstType (T.Pattern' Type) where
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
nullSubst = mempty
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
@ -676,6 +689,31 @@ with an equivalent name has been declared already
existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections)
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind i = modify (\st -> st{currentBind = i})
solveUndecidable :: Infer Subst
solveUndecidable = do
sigs <- gets sigs
undecided <- gets undecidedSigs
let xs = M.toList undecided
ys <-
maybeToRightM
(Error "SIGNATURE MISSING" False)
(mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs)
composeAll <$> mapM (uncurry unify) ys
tupSequence :: Monad m => (m a, b) -> m (a, b)
tupSequence (ma, b) = (,b) <$> ma
getOriginal :: T.Ident -> T.Ident
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
delim :: Char
delim = '_'
prefix :: Char
prefix = '$'
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
@ -740,19 +778,30 @@ exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
exprErr ma exp =
catchError
ma
( \x ->
if x.catchable
then
throwError
( x
{ msg =
x.msg
( \err -> if err.catchable
then throwError
( err { msg = err.msg
<> " in expression: \n"
<> printTree exp
, catchable = False
}
)
else throwError x
else throwError err
)
bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a
bindErr ma bind =
catchError
ma
( \err -> if err.catchable
then throwError
( err { msg = err.msg
<> " in function: \n"
<> printTree bind
, catchable = False
}
)
else throwError err
)
{- | Catch an error if possible and add the given
@ -762,18 +811,18 @@ dataErr :: Infer a -> Data -> Infer a
dataErr ma d =
catchError
ma
( \x ->
if x.catchable
( \err ->
if err.catchable
then
throwError
( x
( err
{ msg =
x.msg
err.msg
<> " in data: \n"
<> printTree d
}
)
else throwError (x{catchable = False})
else throwError (err{catchable = False})
)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
@ -793,6 +842,9 @@ data Env = Env
, sigs :: Map T.Ident (Maybe Type)
, injections :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
}
deriving (Show)
@ -811,3 +863,6 @@ uncatchableErr msg = throwError $ Error msg False
quote :: String -> String
quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure ()

View file

@ -1,9 +0,0 @@
module DoStrings where
import Prelude hiding ((>>), (>>=))
(>>) :: String -> String -> String
(>>) str1 str2 = str1 ++ "\n" ++ str2
(>>=) :: String -> (String -> String) -> String
(>>=) str f = f str

View file

@ -1,231 +0,0 @@
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Main where
import Control.Monad ((<=<))
import DoStrings qualified as D
import Grammar.Par (myLexer, pProgram)
import Test.Hspec
import Prelude (Bool (..), Either (..), IO, mapM_, not, ($), (.))
-- import Test.QuickCheck
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main = do
mapM_ hspec goods
mapM_ hspec bads
mapM_ hspec bes
goods =
[ testSatisfy
"Basic polymorphism with multiple type variables"
( D.do
_const
"main = const 'a' 65 ;"
)
ok
, testSatisfy
"Head with a correct signature is accepted"
( D.do
_List
_headSig
_head
)
ok
, testSatisfy
"Most simple inference possible"
( D.do
_id
)
ok
, testSatisfy
"Pattern matching on a nested list"
( D.do
_List
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
ok
]
bads =
[ testSatisfy
"Infinite type unification should not succeed"
( D.do
"main = \\x. x x ;"
)
bad
, testSatisfy
"Pattern matching using different types should not succeed"
( D.do
_List
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
bad
, testSatisfy
"Using a concrete function (data type) on a skolem variable should not succeed"
( D.do
_Bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
bad
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
"f x = plusOne x ;"
)
bad
, testSatisfy
"A function without signature used in an incompatible context should not succeed"
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
bad
, testSatisfy
"Pattern matching on literal and _List should not succeed"
( D.do
_List
"length : List (c) -> Int;"
"length _List = case _List of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
bad
, testSatisfy
"List of function Int -> Int functions should not be usable on Char"
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
bad
, testSatisfy
"id with incorrect signature"
( D.do
"id : a -> b;"
"id x = x;"
)
bad
, testSatisfy
"incorrect type signature on id lambda"
( D.do
"id = ((\\x. x) : a -> b);"
)
bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = typecheck <=< pProgram . myLexer
ok (Right _) = True
ok (Left _) = False
bad = not . ok
-- FUNCTIONS
_const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
_List = D.do
"data List (a) where"
" {"
" Nil : List (a)"
" Cons : a -> List (a) -> List (a)"
" };"
_headSig = D.do
"head : List (a) -> a ;"
_head = D.do
"head xs = "
" case xs of {"
" Cons x xs => x ;"
" };"
_Bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
"};"
_not = D.do
"not : Bool () -> Bool () ;"
"not x = case x of {"
" True => False ;"
" False => True ;"
"};"
_id = "id x = x ;"
_Maybe = D.do
"data Maybe (a) where {"
" Nothing : Maybe (a)"
" Just : a -> Maybe (a)"
" };"
_fmap = D.do
"fmap f ma = case ma of {"
" Nothing => Nothing ;"
" Just a => Just (f a) ;"
"};"