fixed a substitution bug where ap was incorrectly inferred.

also added cleaner fresh variables
This commit is contained in:
sebastian 2023-03-25 22:40:15 +01:00
parent 975dd34063
commit ac43af8110
6 changed files with 287 additions and 194 deletions

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: True

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: False

View file

@ -19,6 +19,7 @@ import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
@ -31,8 +32,7 @@ import TypeChecker.TypeCheckerIr (
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 mempty mempty
initEnv = Env 0 'a' mempty mempty mempty
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp
@ -82,39 +82,39 @@ retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
-- Type check the program twice to produce all top-level types in the first pass through
_ <- checkDef bs
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
bs' <- checkDef bs
return $ T.Program bs'
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTypeVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind (Bind name args e) = do
@ -171,6 +171,23 @@ inferExp e = do
let subbed = apply s t
return $ second (const subbed) (e', t)
class CollectTVars a where
collectTypeVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTypeVars (TAll _ t) = collectTypeVars t
collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2
collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts
collectTypeVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where
toNew :: a -> b
@ -321,8 +338,9 @@ algoW = \case
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
let t' = apply comp ret_t
return (comp, apply comp (T.ECase (e', t) injs, t'))
trace ("EXPR: " ++ show (apply comp t)) pure ()
trace ("CASES: " ++ show (apply comp ret_t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
@ -335,7 +353,7 @@ unify t0 t1 = do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
-------------------------------------------------------------------
@ -517,9 +535,24 @@ nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type
fresh = do
c <- gets nextChar
n <- gets count
modify (\st -> st{count = n + 1})
return . T.TVar . T.MkTVar . T.Ident $ show n
taken <- gets takenTypeVars
if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
@ -567,7 +600,7 @@ inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of
@ -586,15 +619,36 @@ inferPattern = \case
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns))
unless
(length patterns == numArgs)
( throwError $
"The constructor '"
++ printTree constr
++ "'"
++ " should have "
++ show numArgs
++ " arguments but has been given "
++ show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0")
return (T.PEnum $ coerce p, t)
unless
(typeLength t == 1)
( throwError $
"The constructor '"
++ printTree p
++ "'"
++ " should have "
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
@ -632,4 +686,9 @@ exprErr ma exp =
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], [])
unzip4 =
foldl'
( \(as, bs, cs, ds) (a, b, c, d) ->
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
)
([], [], [], [])

View file

@ -10,6 +10,7 @@ import Control.Monad.State
import Data.Char (isDigit)
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Data.Set (Set)
import Data.String qualified
import Grammar.Print
import Prelude
@ -20,8 +21,10 @@ newtype Ctx = Ctx {vars :: Map Ident Type}
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type
, takenTypeVars :: Set Ident
}
deriving (Show)

View file

@ -1,9 +1,28 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
data Maybe (a) where {
Nothing : Maybe (a)
Just : a -> Maybe (a)
};
main = case True of {
True => 1;
False => 0;
};
fmap : (a -> b) -> Maybe (a) -> Maybe (b) ;
fmap f ma = case ma of {
Nothing => Nothing ;
Just a => Just (f a) ;
};
pure : a -> Maybe (a) ;
pure x = Just x ;
ap mf ma = case mf of {
Just f => case ma of {
Nothing => Nothing;
Just a => Just (f a);
};
Nothing => Nothing;
};
return = pure;
bind ma f = case ma of {
Nothing => Nothing ;
Just a => f a ;
};

View file

@ -16,149 +16,153 @@ main :: IO ()
main = do
mapM_ hspec goods
mapM_ hspec bads
mapM_ hspec bes
goods =
[ specify "Basic polymorphism with multiple type variables" $
run
( D.do
_const
"main = const 'a' 65 ;"
)
`shouldSatisfy` ok
, specify "Head with a correct signature is accepted" $
run
( D.do
_list
_headSig
_head
)
`shouldSatisfy` ok
, specify "A basic arithmetic function should be able to be inferred" $
run
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
`shouldBe` run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, specify "A basic arithmetic function should be able to be inferred" $
run
( D.do
"plusOne x = x + 1 ;"
)
`shouldBe` run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, specify "Most simple inference possible" $
run
( D.do
_id
)
`shouldSatisfy` ok
, specify "Pattern matching on a nested list" $
run
( D.do
_list
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
`shouldSatisfy` ok
, specify "List of function Int -> Int functions should be inferred corretly" $
run
( D.do
_list
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
`shouldBe` run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
[ testSatisfy
"Basic polymorphism with multiple type variables"
( D.do
_const
"main = const 'a' 65 ;"
)
ok
, testSatisfy
"Head with a correct signature is accepted"
( D.do
_List
_headSig
_head
)
ok
, testSatisfy
"Most simple inference possible"
( D.do
_id
)
ok
, testSatisfy
"Pattern matching on a nested list"
( D.do
_List
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
ok
]
bads =
[ specify "Infinite type unification should not succeed" $
run
( D.do
"main = \\x. x x ;"
)
`shouldSatisfy` bad
, specify "Pattern matching using different types should not succeed" $
run
( D.do
_list
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
`shouldSatisfy` bad
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $
run
( D.do
_bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
`shouldSatisfy` bad
, specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $
run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
" f x = plusOne x ;"
)
`shouldSatisfy` bad
, specify "A function without signature used in an incompatible context should not succeed" $
run
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
`shouldSatisfy` bad
, specify "Pattern matching on literal and _list should not succeed" $
run
( D.do
_list
"length : List (c) -> Int;"
"length _list = case _list of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
`shouldSatisfy` bad
, specify "List of function Int -> Int functions should not be usable on Char" $
run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
`shouldSatisfy` bad
[ testSatisfy
"Infinite type unification should not succeed"
( D.do
"main = \\x. x x ;"
)
bad
, testSatisfy
"Pattern matching using different types should not succeed"
( D.do
_List
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
"};"
)
bad
, testSatisfy
"Using a concrete function (data type) on a skolem variable should not succeed"
( D.do
_Bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
)
bad
, testSatisfy
"Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"f : a -> Int ;"
"f x = plusOne x ;"
)
bad
, testSatisfy
"A function without signature used in an incompatible context should not succeed"
( D.do
"main = _id 1 2 ;"
"_id x = x ;"
)
bad
, testSatisfy
"Pattern matching on literal and _List should not succeed"
( D.do
_List
"length : List (c) -> Int;"
"length _List = case _List of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
bad
, testSatisfy
"List of function Int -> Int functions should not be usable on Char"
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = typecheck <=< pProgram . myLexer
ok (Right _) = True
@ -171,7 +175,7 @@ bad = not . ok
_const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
_list = D.do
_List = D.do
"data List (a) where"
" {"
" Nil : List (a)"
@ -187,7 +191,7 @@ _head = D.do
" Cons x xs => x ;"
" };"
_bool = D.do
_Bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
@ -200,3 +204,15 @@ _not = D.do
" False => True ;"
"};"
_id = "id x = x ;"
_Maybe = D.do
"data Maybe (a) where {"
" Nothing : Maybe (a)"
" Just : a -> Maybe (a)"
" };"
_fmap = D.do
"fmap f ma = case ma of {"
" Nothing => Nothing ;"
" Just a => Just (f a) ;"
"};"