inference does not depend on order. mutual recursion still not working correctly

This commit is contained in:
sebastianselander 2023-03-29 17:30:31 +02:00
parent 29fcddf44c
commit 4efe7cf9a2
3 changed files with 97 additions and 282 deletions

View file

@ -12,6 +12,7 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (first)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
@ -27,7 +28,7 @@ import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
@ -51,8 +52,20 @@ typecheck = onLeft msg . run . checkPrg
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs' <- checkDef bs bs <- checkDef bs
return $ T.Program bs' sub <- solveUndecidable
dec <- gets toDecide
trace (printTree bs) pure ()
bs <- mapM (mono sub) bs
return $ T.Program bs
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
b <- gets (S.member name . toDecide)
if b
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
else return bind
mono _ (T.DData d) = return $ T.DData d
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
@ -66,7 +79,7 @@ preRun (x : xs) = case x of
"Duplicate signatures for function" "Duplicate signatures for function"
quote $ printTree n quote $ printTree n
) )
insertSig (coerce n) (Just $ skolemize t) >> preRun xs insertSig (coerce n) (Just t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
collect (collectTVars e) collect (collectTVars e)
s <- gets sigs s <- gets sigs
@ -91,25 +104,15 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer (T.Bind' Type) checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do checkBind bind@(Bind name args e) = do
setCurrentBind $ coerce name
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
(sub0, (e, lambda_t)) <- inferExp lambda (e, lambda_t) <- inferExp lambda
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t') -> do Just (Just t') -> do
-- \| TODO: Fix, this is not correct sub1 <- bindErr (unify lambda_t (skolemize t')) bind
let fsig = apply sub0 t' return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t)
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
unless
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
( uncatchableErr $ Aux.do
"Inferred type"
quote $ printTree lambda_t
"does not match specified type"
quote $ printTree t'
)
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
_ -> do _ -> do
insertSig (coerce name) (Just lambda_t) insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
@ -123,7 +126,7 @@ checkData err@(Data typ injs) = do
TData name typs TData name typs
| Right tvars' <- mapM toTVar typs -> | Right tvars' <- mapM toTVar typs ->
pure (name, tvars') pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
_ -> _ ->
uncatchableErr $ uncatchableErr $
unwords ["Bad data type definition: ", printTree typ] unwords ["Bad data type definition: ", printTree typ]
@ -158,7 +161,7 @@ checkInj (Inj c inj_typ) name tvars
where where
boundTVars :: [TVar] -> Type -> Either Error Bool boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case boundTVars tvars' = \case
TAll{} -> uncatchableErr "Explicit foralls not allowed, for now" TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
TFun t1 t2 -> do TFun t1 t2 -> do
t1' <- boundTVars tvars t1 t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2 t2' <- boundTVars tvars t2
@ -177,11 +180,12 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
return (s, (e', subbed)) modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
return (e', subbed)
class CollectTVars a where class CollectTVars a where
collectTVars :: a -> Set T.Ident collectTVars :: a -> Set T.Ident
@ -225,7 +229,7 @@ algoW = \case
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
EVar i -> do EVar (LIdent i) -> do
var <- asks vars var <- asks vars
case M.lookup (coerce i) var of case M.lookup (coerce i) var of
Just t -> Just t ->
@ -237,7 +241,8 @@ algoW = \case
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do Just Nothing -> do
fr <- fresh fr <- fresh
insertSig (coerce i) (Just fr) cb <- gets currentBind
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
return (nullSubst, (T.EVar $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> Nothing ->
uncatchableErr $ uncatchableErr $
@ -591,6 +596,9 @@ instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply = M.map . apply apply = M.map . apply
instance SubstType (T.ExpT' Type) where
apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where instance SubstType (T.Exp' Type) where
apply s = \case apply s = \case
T.EVar i -> T.EVar i T.EVar i -> T.EVar i
@ -605,6 +613,11 @@ instance SubstType (T.Exp' Type) where
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj i -> T.EInj i T.EInj i -> T.EInj i
instance SubstType (T.Def' Type) where
apply s = \case
T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
d -> d
instance SubstType (T.Branch' Type) where instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
@ -616,18 +629,18 @@ instance SubstType (T.Pattern' Type) where
T.PCatch -> T.PCatch T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
instance SubstType a => SubstType [a] where instance SubstType a => SubstType [a] where
apply s = map (apply s) apply s = map (apply s)
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType (T.Id' Type) where instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set -- | Represents the empty substition set
nullSubst :: Subst nullSubst :: Subst
nullSubst = M.empty nullSubst = mempty
-- | Compose two substitution sets -- | Compose two substitution sets
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
@ -676,6 +689,31 @@ with an equivalent name has been declared already
existInj :: T.Ident -> Infer (Maybe Type) existInj :: T.Ident -> Infer (Maybe Type)
existInj n = gets (M.lookup n . injections) existInj n = gets (M.lookup n . injections)
setCurrentBind :: T.Ident -> Infer ()
setCurrentBind i = modify (\st -> st{currentBind = i})
solveUndecidable :: Infer Subst
solveUndecidable = do
sigs <- gets sigs
undecided <- gets undecidedSigs
let xs = M.toList undecided
ys <-
maybeToRightM
(Error "SIGNATURE MISSING" False)
(mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs)
composeAll <$> mapM (uncurry unify) ys
tupSequence :: Monad m => (m a, b) -> m (a, b)
tupSequence (ma, b) = (,b) <$> ma
getOriginal :: T.Ident -> T.Ident
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
delim :: Char
delim = '_'
prefix :: Char
prefix = '$'
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
@ -740,19 +778,30 @@ exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
exprErr ma exp = exprErr ma exp =
catchError catchError
ma ma
( \x -> ( \err -> if err.catchable
if x.catchable then throwError
then ( err { msg = err.msg
throwError
( x
{ msg =
x.msg
<> " in expression: \n" <> " in expression: \n"
<> printTree exp <> printTree exp
, catchable = False , catchable = False
} }
) )
else throwError x else throwError err
)
bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a
bindErr ma bind =
catchError
ma
( \err -> if err.catchable
then throwError
( err { msg = err.msg
<> " in function: \n"
<> printTree bind
, catchable = False
}
)
else throwError err
) )
{- | Catch an error if possible and add the given {- | Catch an error if possible and add the given
@ -762,18 +811,18 @@ dataErr :: Infer a -> Data -> Infer a
dataErr ma d = dataErr ma d =
catchError catchError
ma ma
( \x -> ( \err ->
if x.catchable if err.catchable
then then
throwError throwError
( x ( err
{ msg = { msg =
x.msg err.msg
<> " in data: \n" <> " in data: \n"
<> printTree d <> printTree d
} }
) )
else throwError (x{catchable = False}) else throwError (err{catchable = False})
) )
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
@ -793,6 +842,9 @@ data Env = Env
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident (Maybe Type)
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, currentBind :: T.Ident
, undecidedSigs :: Map T.Ident Type
, toDecide :: Set T.Ident
} }
deriving (Show) deriving (Show)
@ -811,3 +863,6 @@ uncatchableErr msg = throwError $ Error msg False
quote :: String -> String quote :: String -> String
quote s = "'" ++ s ++ "'" quote s = "'" ++ s ++ "'"
ctrace :: (Monad m, Show a) => String -> a -> m ()
ctrace str a = trace (str ++ ": " ++ show a) pure ()

View file

@ -1,9 +0,0 @@
module DoStrings where
import Prelude hiding ((>>), (>>=))
(>>) :: String -> String -> String
(>>) str1 str2 = str1 ++ "\n" ++ str2
(>>=) :: String -> (String -> String) -> String
(>>=) str f = f str

View file

@ -1,231 +0,0 @@
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Main where
import Control.Monad ((<=<))
import DoStrings qualified as D
import Grammar.Par (myLexer, pProgram)
import Test.Hspec
import Prelude (Bool (..), Either (..), IO, mapM_, not, ($), (.))
-- import Test.QuickCheck
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main = do
mapM_ hspec goods
mapM_ hspec bads
mapM_ hspec bes
goods =
[ testSatisfy
"Basic polymorphism with multiple type variables"
( D.do
_const
"main = const 'a' 65 ;"
)
ok
, testSatisfy
"Head with a correct signature is accepted"
( D.do
_List
_headSig
_head
)
ok
, testSatisfy
"Most simple inference possible"
( D.do
_id
)
ok
, testSatisfy
"Pattern matching on a nested list"
( D.do
_List
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
ok
]
bads =
[ testSatisfy
"Infinite type unification should not succeed"
( D.do
"main = \\x. x x ;"
)
bad
, testSatisfy
"Pattern matching using different types should not succeed"
( D.do
_List
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
bad
, testSatisfy
"Using a concrete function (data type) on a skolem variable should not succeed"
( D.do
_Bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
bad
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
"f x = plusOne x ;"
)
bad
, testSatisfy
"A function without signature used in an incompatible context should not succeed"
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
bad
, testSatisfy
"Pattern matching on literal and _List should not succeed"
( D.do
_List
"length : List (c) -> Int;"
"length _List = case _List of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
bad
, testSatisfy
"List of function Int -> Int functions should not be usable on Char"
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
bad
, testSatisfy
"id with incorrect signature"
( D.do
"id : a -> b;"
"id x = x;"
)
bad
, testSatisfy
"incorrect type signature on id lambda"
( D.do
"id = ((\\x. x) : a -> b);"
)
bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = typecheck <=< pProgram . myLexer
ok (Right _) = True
ok (Left _) = False
bad = not . ok
-- FUNCTIONS
_const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
_List = D.do
"data List (a) where"
" {"
" Nil : List (a)"
" Cons : a -> List (a) -> List (a)"
" };"
_headSig = D.do
"head : List (a) -> a ;"
_head = D.do
"head xs = "
" case xs of {"
" Cons x xs => x ;"
" };"
_Bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
"};"
_not = D.do
"not : Bool () -> Bool () ;"
"not x = case x of {"
" True => False ;"
" False => True ;"
"};"
_id = "id x = x ;"
_Maybe = D.do
"data Maybe (a) where {"
" Nothing : Maybe (a)"
" Just : a -> Maybe (a)"
" };"
_fmap = D.do
"fmap f ma = case ma of {"
" Nothing => Nothing ;"
" Just a => Just (f a) ;"
"};"