inference does not depend on order. mutual recursion still not working correctly
This commit is contained in:
parent
29fcddf44c
commit
4efe7cf9a2
3 changed files with 97 additions and 282 deletions
|
|
@ -12,6 +12,7 @@ import Control.Monad.Except
|
|||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
|
|
@ -27,7 +28,7 @@ import Grammar.Print (printTree)
|
|||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty
|
||||
|
||||
run :: Infer a -> Either Error a
|
||||
run = run' initEnv initCtx
|
||||
|
|
@ -51,8 +52,20 @@ typecheck = onLeft msg . run . checkPrg
|
|||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
checkPrg (Program bs) = do
|
||||
preRun bs
|
||||
bs' <- checkDef bs
|
||||
return $ T.Program bs'
|
||||
bs <- checkDef bs
|
||||
sub <- solveUndecidable
|
||||
dec <- gets toDecide
|
||||
trace (printTree bs) pure ()
|
||||
bs <- mapM (mono sub) bs
|
||||
return $ T.Program bs
|
||||
|
||||
mono :: Subst -> T.Def' Type -> Infer (T.Def' Type)
|
||||
mono s bind@(T.DBind (T.Bind (name, t) args e)) = do
|
||||
b <- gets (S.member name . toDecide)
|
||||
if b
|
||||
then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e)
|
||||
else return bind
|
||||
mono _ (T.DData d) = return $ T.DData d
|
||||
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
|
|
@ -66,7 +79,7 @@ preRun (x : xs) = case x of
|
|||
"Duplicate signatures for function"
|
||||
quote $ printTree n
|
||||
)
|
||||
insertSig (coerce n) (Just $ skolemize t) >> preRun xs
|
||||
insertSig (coerce n) (Just t) >> preRun xs
|
||||
DBind (Bind n _ e) -> do
|
||||
collect (collectTVars e)
|
||||
s <- gets sigs
|
||||
|
|
@ -91,25 +104,15 @@ checkDef (x : xs) = case x of
|
|||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind (Bind name args e) = do
|
||||
checkBind bind@(Bind name args e) = do
|
||||
setCurrentBind $ coerce name
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(sub0, (e, lambda_t)) <- inferExp lambda
|
||||
(e, lambda_t) <- inferExp lambda
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
-- \| TODO: Fix, this is not correct
|
||||
let fsig = apply sub0 t'
|
||||
sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty
|
||||
sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty
|
||||
unless
|
||||
(lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig)
|
||||
( uncatchableErr $ Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree lambda_t
|
||||
"does not match specified type"
|
||||
quote $ printTree t'
|
||||
)
|
||||
return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t)
|
||||
sub1 <- bindErr (unify lambda_t (skolemize t')) bind
|
||||
return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t)
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just lambda_t)
|
||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||
|
|
@ -123,7 +126,7 @@ checkData err@(Data typ injs) = do
|
|||
TData name typs
|
||||
| Right tvars' <- mapM toTVar typs ->
|
||||
pure (name, tvars')
|
||||
TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now"
|
||||
TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
|
||||
_ ->
|
||||
uncatchableErr $
|
||||
unwords ["Bad data type definition: ", printTree typ]
|
||||
|
|
@ -158,7 +161,7 @@ checkInj (Inj c inj_typ) name tvars
|
|||
where
|
||||
boundTVars :: [TVar] -> Type -> Either Error Bool
|
||||
boundTVars tvars' = \case
|
||||
TAll{} -> uncatchableErr "Explicit foralls not allowed, for now"
|
||||
TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
|
||||
TFun t1 t2 -> do
|
||||
t1' <- boundTVars tvars t1
|
||||
t2' <- boundTVars tvars t2
|
||||
|
|
@ -177,11 +180,12 @@ returnType :: Type -> Type
|
|||
returnType (TFun _ t2) = returnType t2
|
||||
returnType a = a
|
||||
|
||||
inferExp :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
||||
inferExp e = do
|
||||
(s, (e', t)) <- algoW e
|
||||
let subbed = apply s t
|
||||
return (s, (e', subbed))
|
||||
modify (\st -> st{undecidedSigs = apply s st.undecidedSigs})
|
||||
return (e', subbed)
|
||||
|
||||
class CollectTVars a where
|
||||
collectTVars :: a -> Set T.Ident
|
||||
|
|
@ -225,7 +229,7 @@ algoW = \case
|
|||
-- \| x : σ ∈ Γ τ = inst(σ)
|
||||
-- \| ----------------------
|
||||
-- \| Γ ⊢ x : τ, ∅
|
||||
EVar i -> do
|
||||
EVar (LIdent i) -> do
|
||||
var <- asks vars
|
||||
case M.lookup (coerce i) var of
|
||||
Just t ->
|
||||
|
|
@ -237,7 +241,8 @@ algoW = \case
|
|||
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
|
||||
Just Nothing -> do
|
||||
fr <- fresh
|
||||
insertSig (coerce i) (Just fr)
|
||||
cb <- gets currentBind
|
||||
modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs})
|
||||
return (nullSubst, (T.EVar $ coerce i, fr))
|
||||
Nothing ->
|
||||
uncatchableErr $
|
||||
|
|
@ -591,6 +596,9 @@ instance SubstType (Map T.Ident Type) where
|
|||
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
|
||||
apply = M.map . apply
|
||||
|
||||
instance SubstType (T.ExpT' Type) where
|
||||
apply s (e, t) = (apply s e, apply s t)
|
||||
|
||||
instance SubstType (T.Exp' Type) where
|
||||
apply s = \case
|
||||
T.EVar i -> T.EVar i
|
||||
|
|
@ -605,6 +613,11 @@ instance SubstType (T.Exp' Type) where
|
|||
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
|
||||
T.EInj i -> T.EInj i
|
||||
|
||||
instance SubstType (T.Def' Type) where
|
||||
apply s = \case
|
||||
T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e)
|
||||
d -> d
|
||||
|
||||
instance SubstType (T.Branch' Type) where
|
||||
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
|
||||
|
||||
|
|
@ -616,18 +629,18 @@ instance SubstType (T.Pattern' Type) where
|
|||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
|
||||
instance SubstType (T.Pattern' Type, Type) where
|
||||
apply s (p, t) = (apply s p, apply s t)
|
||||
|
||||
instance SubstType a => SubstType [a] where
|
||||
apply s = map (apply s)
|
||||
|
||||
instance (SubstType a, SubstType b) => SubstType (a, b) where
|
||||
apply s (a, b) = (apply s a, apply s b)
|
||||
|
||||
instance SubstType (T.Id' Type) where
|
||||
apply s (name, t) = (name, apply s t)
|
||||
|
||||
-- | Represents the empty substition set
|
||||
nullSubst :: Subst
|
||||
nullSubst = M.empty
|
||||
nullSubst = mempty
|
||||
|
||||
-- | Compose two substitution sets
|
||||
compose :: Subst -> Subst -> Subst
|
||||
|
|
@ -676,6 +689,31 @@ with an equivalent name has been declared already
|
|||
existInj :: T.Ident -> Infer (Maybe Type)
|
||||
existInj n = gets (M.lookup n . injections)
|
||||
|
||||
setCurrentBind :: T.Ident -> Infer ()
|
||||
setCurrentBind i = modify (\st -> st{currentBind = i})
|
||||
|
||||
solveUndecidable :: Infer Subst
|
||||
solveUndecidable = do
|
||||
sigs <- gets sigs
|
||||
undecided <- gets undecidedSigs
|
||||
let xs = M.toList undecided
|
||||
ys <-
|
||||
maybeToRightM
|
||||
(Error "SIGNATURE MISSING" False)
|
||||
(mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs)
|
||||
composeAll <$> mapM (uncurry unify) ys
|
||||
|
||||
tupSequence :: Monad m => (m a, b) -> m (a, b)
|
||||
tupSequence (ma, b) = (,b) <$> ma
|
||||
|
||||
getOriginal :: T.Ident -> T.Ident
|
||||
getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i
|
||||
|
||||
delim :: Char
|
||||
delim = '_'
|
||||
prefix :: Char
|
||||
prefix = '$'
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
|
|
@ -740,19 +778,30 @@ exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a
|
|||
exprErr ma exp =
|
||||
catchError
|
||||
ma
|
||||
( \x ->
|
||||
if x.catchable
|
||||
then
|
||||
throwError
|
||||
( x
|
||||
{ msg =
|
||||
x.msg
|
||||
( \err -> if err.catchable
|
||||
then throwError
|
||||
( err { msg = err.msg
|
||||
<> " in expression: \n"
|
||||
<> printTree exp
|
||||
, catchable = False
|
||||
}
|
||||
)
|
||||
else throwError x
|
||||
else throwError err
|
||||
)
|
||||
|
||||
bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a
|
||||
bindErr ma bind =
|
||||
catchError
|
||||
ma
|
||||
( \err -> if err.catchable
|
||||
then throwError
|
||||
( err { msg = err.msg
|
||||
<> " in function: \n"
|
||||
<> printTree bind
|
||||
, catchable = False
|
||||
}
|
||||
)
|
||||
else throwError err
|
||||
)
|
||||
|
||||
{- | Catch an error if possible and add the given
|
||||
|
|
@ -762,18 +811,18 @@ dataErr :: Infer a -> Data -> Infer a
|
|||
dataErr ma d =
|
||||
catchError
|
||||
ma
|
||||
( \x ->
|
||||
if x.catchable
|
||||
( \err ->
|
||||
if err.catchable
|
||||
then
|
||||
throwError
|
||||
( x
|
||||
( err
|
||||
{ msg =
|
||||
x.msg
|
||||
err.msg
|
||||
<> " in data: \n"
|
||||
<> printTree d
|
||||
}
|
||||
)
|
||||
else throwError (x{catchable = False})
|
||||
else throwError (err{catchable = False})
|
||||
)
|
||||
|
||||
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
|
||||
|
|
@ -793,6 +842,9 @@ data Env = Env
|
|||
, sigs :: Map T.Ident (Maybe Type)
|
||||
, injections :: Map T.Ident Type
|
||||
, takenTypeVars :: Set T.Ident
|
||||
, currentBind :: T.Ident
|
||||
, undecidedSigs :: Map T.Ident Type
|
||||
, toDecide :: Set T.Ident
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
|
@ -811,3 +863,6 @@ uncatchableErr msg = throwError $ Error msg False
|
|||
|
||||
quote :: String -> String
|
||||
quote s = "'" ++ s ++ "'"
|
||||
|
||||
ctrace :: (Monad m, Show a) => String -> a -> m ()
|
||||
ctrace str a = trace (str ++ ": " ++ show a) pure ()
|
||||
|
|
|
|||
|
|
@ -1,9 +0,0 @@
|
|||
module DoStrings where
|
||||
|
||||
import Prelude hiding ((>>), (>>=))
|
||||
|
||||
(>>) :: String -> String -> String
|
||||
(>>) str1 str2 = str1 ++ "\n" ++ str2
|
||||
|
||||
(>>=) :: String -> (String -> String) -> String
|
||||
(>>=) str f = f str
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
{-# LANGUAGE QualifiedDo #-}
|
||||
{-# LANGUAGE NoImplicitPrelude #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad ((<=<))
|
||||
import DoStrings qualified as D
|
||||
import Grammar.Par (myLexer, pProgram)
|
||||
import Test.Hspec
|
||||
import Prelude (Bool (..), Either (..), IO, mapM_, not, ($), (.))
|
||||
|
||||
-- import Test.QuickCheck
|
||||
import TypeChecker.TypeChecker (typecheck)
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
mapM_ hspec goods
|
||||
mapM_ hspec bads
|
||||
mapM_ hspec bes
|
||||
|
||||
goods =
|
||||
[ testSatisfy
|
||||
"Basic polymorphism with multiple type variables"
|
||||
( D.do
|
||||
_const
|
||||
"main = const 'a' 65 ;"
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"Head with a correct signature is accepted"
|
||||
( D.do
|
||||
_List
|
||||
_headSig
|
||||
_head
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"Most simple inference possible"
|
||||
( D.do
|
||||
_id
|
||||
)
|
||||
ok
|
||||
, testSatisfy
|
||||
"Pattern matching on a nested list"
|
||||
( D.do
|
||||
_List
|
||||
"main : List (List (a)) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons Nil _ => 1 ;"
|
||||
" _ => 0 ;"
|
||||
"};"
|
||||
)
|
||||
ok
|
||||
]
|
||||
|
||||
bads =
|
||||
[ testSatisfy
|
||||
"Infinite type unification should not succeed"
|
||||
( D.do
|
||||
"main = \\x. x x ;"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"Pattern matching using different types should not succeed"
|
||||
( D.do
|
||||
_List
|
||||
"bad xs = case xs of {"
|
||||
" 1 => 0 ;"
|
||||
" Nil => 0 ;"
|
||||
"};"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"Using a concrete function (data type) on a skolem variable should not succeed"
|
||||
( D.do
|
||||
_Bool
|
||||
_not
|
||||
"f : a -> Bool () ;"
|
||||
"f x = not x ;"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"Using a concrete function (primitive type) on a skolem variable should not succeed"
|
||||
( D.do
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
"f : a -> Int ;"
|
||||
"f x = plusOne x ;"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"A function without signature used in an incompatible context should not succeed"
|
||||
( D.do
|
||||
"main = _id 1 2 ;"
|
||||
"_id x = x ;"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"Pattern matching on literal and _List should not succeed"
|
||||
( D.do
|
||||
_List
|
||||
"length : List (c) -> Int;"
|
||||
"length _List = case _List of {"
|
||||
" 0 => 0;"
|
||||
" Cons x xs => 1 + length xs;"
|
||||
"};"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"List of function Int -> Int functions should not be usable on Char"
|
||||
( D.do
|
||||
_List
|
||||
"main : List (Int -> Int) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 'a' ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"id with incorrect signature"
|
||||
( D.do
|
||||
"id : a -> b;"
|
||||
"id x = x;"
|
||||
)
|
||||
bad
|
||||
, testSatisfy
|
||||
"incorrect type signature on id lambda"
|
||||
( D.do
|
||||
"id = ((\\x. x) : a -> b);"
|
||||
)
|
||||
bad
|
||||
]
|
||||
|
||||
bes =
|
||||
[ testBe
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
( D.do
|
||||
"plusOne x = x + 1 ;"
|
||||
"main x = plusOne x ;"
|
||||
)
|
||||
( D.do
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
"main : Int -> Int ;"
|
||||
"main x = plusOne x ;"
|
||||
)
|
||||
, testBe
|
||||
"A basic arithmetic function should be able to be inferred"
|
||||
( D.do
|
||||
"plusOne x = x + 1 ;"
|
||||
)
|
||||
( D.do
|
||||
"plusOne : Int -> Int ;"
|
||||
"plusOne x = x + 1 ;"
|
||||
)
|
||||
, testBe
|
||||
"List of function Int -> Int functions should be inferred corretly"
|
||||
( D.do
|
||||
_List
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
( D.do
|
||||
_List
|
||||
"main : List (Int -> Int) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
]
|
||||
|
||||
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
|
||||
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
|
||||
|
||||
run = typecheck <=< pProgram . myLexer
|
||||
|
||||
ok (Right _) = True
|
||||
ok (Left _) = False
|
||||
|
||||
bad = not . ok
|
||||
|
||||
-- FUNCTIONS
|
||||
|
||||
_const = D.do
|
||||
"const : a -> b -> a ;"
|
||||
"const x y = x ;"
|
||||
_List = D.do
|
||||
"data List (a) where"
|
||||
" {"
|
||||
" Nil : List (a)"
|
||||
" Cons : a -> List (a) -> List (a)"
|
||||
" };"
|
||||
|
||||
_headSig = D.do
|
||||
"head : List (a) -> a ;"
|
||||
|
||||
_head = D.do
|
||||
"head xs = "
|
||||
" case xs of {"
|
||||
" Cons x xs => x ;"
|
||||
" };"
|
||||
|
||||
_Bool = D.do
|
||||
"data Bool () where {"
|
||||
" True : Bool ()"
|
||||
" False : Bool ()"
|
||||
"};"
|
||||
|
||||
_not = D.do
|
||||
"not : Bool () -> Bool () ;"
|
||||
"not x = case x of {"
|
||||
" True => False ;"
|
||||
" False => True ;"
|
||||
"};"
|
||||
_id = "id x = x ;"
|
||||
|
||||
_Maybe = D.do
|
||||
"data Maybe (a) where {"
|
||||
" Nothing : Maybe (a)"
|
||||
" Just : a -> Maybe (a)"
|
||||
" };"
|
||||
|
||||
_fmap = D.do
|
||||
"fmap f ma = case ma of {"
|
||||
" Nothing => Nothing ;"
|
||||
" Just a => Just (f a) ;"
|
||||
"};"
|
||||
Loading…
Add table
Add a link
Reference in a new issue