fixed a substitution bug where ap was incorrectly inferred.

also added cleaner fresh variables
This commit is contained in:
sebastian 2023-03-25 22:40:15 +01:00
parent 975dd34063
commit ac43af8110
6 changed files with 287 additions and 194 deletions

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: True

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: False

View file

@ -19,6 +19,7 @@ import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr ( import TypeChecker.TypeCheckerIr (
@ -31,8 +32,7 @@ import TypeChecker.TypeCheckerIr (
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp runPretty = fmap (printTree . fst) . run . inferExp
@ -82,15 +82,14 @@ retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
-- Type check the program twice to produce all top-level types in the first pass through bs' <- checkDef bs
_ <- checkDef bs return $ T.Program bs'
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
DSig (Sig n t) -> do DSig (Sig n t) -> do
collect (collectTypeVars t)
gets (M.member (coerce n) . sigs) gets (M.member (coerce n) . sigs)
>>= flip >>= flip
when when
@ -100,12 +99,13 @@ checkPrg (Program bs) = do
<> "'" <> "'"
) )
insertSig (coerce n) (Just $ toNew t) >> preRun xs insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return [] checkDef [] = return []
@ -171,6 +171,23 @@ inferExp e = do
let subbed = apply s t let subbed = apply s t
return $ second (const subbed) (e', t) return $ second (const subbed) (e', t)
class CollectTVars a where
collectTypeVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTypeVars (TAll _ t) = collectTypeVars t
collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2
collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts
collectTypeVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where class NewType a b where
toNew :: a -> b toNew :: a -> b
@ -321,8 +338,9 @@ algoW = \case
(sub, (e', t)) <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs (subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub let comp = subst `compose` sub
let t' = apply comp ret_t trace ("EXPR: " ++ show (apply comp t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, t')) trace ("CASES: " ++ show (apply comp ret_t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
@ -335,7 +353,7 @@ unify t0 t1 = do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
------------------------------------------------------------------- -------------------------------------------------------------------
@ -517,9 +535,24 @@ nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type fresh :: Infer T.Type
fresh = do fresh = do
c <- gets nextChar
n <- gets count n <- gets count
modify (\st -> st{count = n + 1}) taken <- gets takenTypeVars
return . T.TVar . T.MkTVar . T.Ident $ show n if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
@ -567,7 +600,7 @@ inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
@ -586,15 +619,36 @@ inferPattern = \case
let numArgs = typeLength t - 1 let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t) let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns)) unless
(length patterns == numArgs)
( throwError $
"The constructor '"
++ printTree constr
++ "'"
++ " should have "
++ show numArgs
++ " arguments but has been given "
++ show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns) sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret) return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors) t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0") unless
return (T.PEnum $ coerce p, t) (typeLength t == 1)
( throwError $
"The constructor '"
++ printTree p
++ "'"
++ " should have "
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
PVar x -> do PVar x -> do
fr <- fresh fr <- fresh
let pvar = T.PVar (coerce x, fr) let pvar = T.PVar (coerce x, fr)
@ -632,4 +686,9 @@ exprErr ma exp =
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp) catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], []) unzip4 =
foldl'
( \(as, bs, cs, ds) (a, b, c, d) ->
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
)
([], [], [], [])

View file

@ -10,6 +10,7 @@ import Control.Monad.State
import Data.Char (isDigit) import Data.Char (isDigit)
import Data.Functor.Identity (Identity) import Data.Functor.Identity (Identity)
import Data.Map (Map) import Data.Map (Map)
import Data.Set (Set)
import Data.String qualified import Data.String qualified
import Grammar.Print import Grammar.Print
import Prelude import Prelude
@ -20,8 +21,10 @@ newtype Ctx = Ctx {vars :: Map Ident Type}
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char
, sigs :: Map Ident (Maybe Type) , sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type , constructors :: Map Ident Type
, takenTypeVars :: Set Ident
} }
deriving (Show) deriving (Show)

View file

@ -1,9 +1,28 @@
data Bool () where { data Maybe (a) where {
True : Bool () Nothing : Maybe (a)
False : Bool () Just : a -> Maybe (a)
}; };
main = case True of { fmap : (a -> b) -> Maybe (a) -> Maybe (b) ;
True => 1; fmap f ma = case ma of {
False => 0; Nothing => Nothing ;
Just a => Just (f a) ;
};
pure : a -> Maybe (a) ;
pure x = Just x ;
ap mf ma = case mf of {
Just f => case ma of {
Nothing => Nothing;
Just a => Just (f a);
};
Nothing => Nothing;
};
return = pure;
bind ma f = case ma of {
Nothing => Nothing ;
Just a => f a ;
}; };

View file

@ -16,75 +16,142 @@ main :: IO ()
main = do main = do
mapM_ hspec goods mapM_ hspec goods
mapM_ hspec bads mapM_ hspec bads
mapM_ hspec bes
goods = goods =
[ specify "Basic polymorphism with multiple type variables" $ [ testSatisfy
run "Basic polymorphism with multiple type variables"
( D.do ( D.do
_const _const
"main = const 'a' 65 ;" "main = const 'a' 65 ;"
) )
`shouldSatisfy` ok ok
, specify "Head with a correct signature is accepted" $ , testSatisfy
run "Head with a correct signature is accepted"
( D.do ( D.do
_list _List
_headSig _headSig
_head _head
) )
`shouldSatisfy` ok ok
, specify "A basic arithmetic function should be able to be inferred" $ , testSatisfy
run "Most simple inference possible"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
`shouldBe` run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, specify "A basic arithmetic function should be able to be inferred" $
run
( D.do
"plusOne x = x + 1 ;"
)
`shouldBe` run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, specify "Most simple inference possible" $
run
( D.do ( D.do
_id _id
) )
`shouldSatisfy` ok ok
, specify "Pattern matching on a nested list" $ , testSatisfy
run "Pattern matching on a nested list"
( D.do ( D.do
_list _List
"main : List (List (a)) -> Int ;" "main : List (List (a)) -> Int ;"
"main xs = case xs of {" "main xs = case xs of {"
" Cons Nil _ => 1 ;" " Cons Nil _ => 1 ;"
" _ => 0 ;" " _ => 0 ;"
"};" "};"
) )
`shouldSatisfy` ok ok
, specify "List of function Int -> Int functions should be inferred corretly" $ ]
run
bads =
[ testSatisfy
"Infinite type unification should not succeed"
( D.do ( D.do
_list "main = \\x. x x ;"
)
bad
, testSatisfy
"Pattern matching using different types should not succeed"
( D.do
_List
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
bad
, testSatisfy
"Using a concrete function (data type) on a skolem variable should not succeed"
( D.do
_Bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
bad
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
"f x = plusOne x ;"
)
bad
, testSatisfy
"A function without signature used in an incompatible context should not succeed"
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
bad
, testSatisfy
"Pattern matching on literal and _List should not succeed"
( D.do
_List
"length : List (c) -> Int;"
"length _List = case _List of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
bad
, testSatisfy
"List of function Int -> Int functions should not be usable on Char"
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {" "main xs = case xs of {"
" Cons f _ => f 1 ;" " Cons f _ => f 1 ;"
" Nil => 0 ;" " Nil => 0 ;"
" };" " };"
) )
`shouldBe` run
( D.do ( D.do
_list _List
"main : List (Int -> Int) -> Int ;" "main : List (Int -> Int) -> Int ;"
"main xs = case xs of {" "main xs = case xs of {"
" Cons f _ => f 1 ;" " Cons f _ => f 1 ;"
@ -93,71 +160,8 @@ goods =
) )
] ]
bads = testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
[ specify "Infinite type unification should not succeed" $ testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run
( D.do
"main = \\x. x x ;"
)
`shouldSatisfy` bad
, specify "Pattern matching using different types should not succeed" $
run
( D.do
_list
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
`shouldSatisfy` bad
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $
run
( D.do
_bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
`shouldSatisfy` bad
, specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $
run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
" f x = plusOne x ;"
)
`shouldSatisfy` bad
, specify "A function without signature used in an incompatible context should not succeed" $
run
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
`shouldSatisfy` bad
, specify "Pattern matching on literal and _list should not succeed" $
run
( D.do
_list
"length : List (c) -> Int;"
"length _list = case _list of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
`shouldSatisfy` bad
, specify "List of function Int -> Int functions should not be usable on Char" $
run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
`shouldSatisfy` bad
]
run = typecheck <=< pProgram . myLexer run = typecheck <=< pProgram . myLexer
@ -171,7 +175,7 @@ bad = not . ok
_const = D.do _const = D.do
"const : a -> b -> a ;" "const : a -> b -> a ;"
"const x y = x ;" "const x y = x ;"
_list = D.do _List = D.do
"data List (a) where" "data List (a) where"
" {" " {"
" Nil : List (a)" " Nil : List (a)"
@ -187,7 +191,7 @@ _head = D.do
" Cons x xs => x ;" " Cons x xs => x ;"
" };" " };"
_bool = D.do _Bool = D.do
"data Bool () where {" "data Bool () where {"
" True : Bool ()" " True : Bool ()"
" False : Bool ()" " False : Bool ()"
@ -200,3 +204,15 @@ _not = D.do
" False => True ;" " False => True ;"
"};" "};"
_id = "id x = x ;" _id = "id x = x ;"
_Maybe = D.do
"data Maybe (a) where {"
" Nothing : Maybe (a)"
" Just : a -> Maybe (a)"
" };"
_fmap = D.do
"fmap f ma = case ma of {"
" Nothing => Nothing ;"
" Just a => Just (f a) ;"
"};"