fixed a substitution bug where ap was incorrectly inferred.

also added cleaner fresh variables
This commit is contained in:
sebastian 2023-03-25 22:40:15 +01:00
parent 975dd34063
commit ac43af8110
6 changed files with 287 additions and 194 deletions

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: True

View file

@ -1,2 +0,0 @@
ignore-project: False
tests: False

View file

@ -19,6 +19,7 @@ import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr ( import TypeChecker.TypeCheckerIr (
@ -31,8 +32,7 @@ import TypeChecker.TypeCheckerIr (
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp runPretty = fmap (printTree . fst) . run . inferExp
@ -82,39 +82,39 @@ retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
-- Type check the program twice to produce all top-level types in the first pass through bs' <- checkDef bs
_ <- checkDef bs return $ T.Program bs'
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] preRun :: [Def] -> Infer ()
checkDef [] = return [] preRun [] = return ()
checkDef (x : xs) = case x of preRun (x : xs) = case x of
(DBind b) -> do DSig (Sig n t) -> do
b' <- checkBind b collect (collectTypeVars t)
fmap (T.DBind b' :) (checkDef xs) gets (M.member (coerce n) . sigs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) >>= flip
(DSig _) -> checkDef xs when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
@ -171,6 +171,23 @@ inferExp e = do
let subbed = apply s t let subbed = apply s t
return $ second (const subbed) (e', t) return $ second (const subbed) (e', t)
class CollectTVars a where
collectTypeVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTypeVars (TAll _ t) = collectTypeVars t
collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2
collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts
collectTypeVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where class NewType a b where
toNew :: a -> b toNew :: a -> b
@ -321,8 +338,9 @@ algoW = \case
(sub, (e', t)) <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs (subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub let comp = subst `compose` sub
let t' = apply comp ret_t trace ("EXPR: " ++ show (apply comp t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, t')) trace ("CASES: " ++ show (apply comp ret_t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
@ -335,7 +353,7 @@ unify t0 t1 = do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
------------------------------------------------------------------- -------------------------------------------------------------------
@ -517,9 +535,24 @@ nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type fresh :: Infer T.Type
fresh = do fresh = do
c <- gets nextChar
n <- gets count n <- gets count
modify (\st -> st{count = n + 1}) taken <- gets takenTypeVars
return . T.TVar . T.MkTVar . T.Ident $ show n if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
@ -567,7 +600,7 @@ inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
@ -586,15 +619,36 @@ inferPattern = \case
let numArgs = typeLength t - 1 let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t) let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns)) unless
(length patterns == numArgs)
( throwError $
"The constructor '"
++ printTree constr
++ "'"
++ " should have "
++ show numArgs
++ " arguments but has been given "
++ show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns) sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret) return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors) t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0") unless
return (T.PEnum $ coerce p, t) (typeLength t == 1)
( throwError $
"The constructor '"
++ printTree p
++ "'"
++ " should have "
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
PVar x -> do PVar x -> do
fr <- fresh fr <- fresh
let pvar = T.PVar (coerce x, fr) let pvar = T.PVar (coerce x, fr)
@ -632,4 +686,9 @@ exprErr ma exp =
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp) catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], []) unzip4 =
foldl'
( \(as, bs, cs, ds) (a, b, c, d) ->
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
)
([], [], [], [])

View file

@ -10,6 +10,7 @@ import Control.Monad.State
import Data.Char (isDigit) import Data.Char (isDigit)
import Data.Functor.Identity (Identity) import Data.Functor.Identity (Identity)
import Data.Map (Map) import Data.Map (Map)
import Data.Set (Set)
import Data.String qualified import Data.String qualified
import Grammar.Print import Grammar.Print
import Prelude import Prelude
@ -20,8 +21,10 @@ newtype Ctx = Ctx {vars :: Map Ident Type}
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char
, sigs :: Map Ident (Maybe Type) , sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type , constructors :: Map Ident Type
, takenTypeVars :: Set Ident
} }
deriving (Show) deriving (Show)

View file

@ -1,9 +1,28 @@
data Bool () where { data Maybe (a) where {
True : Bool () Nothing : Maybe (a)
False : Bool () Just : a -> Maybe (a)
}; };
main = case True of { fmap : (a -> b) -> Maybe (a) -> Maybe (b) ;
True => 1; fmap f ma = case ma of {
False => 0; Nothing => Nothing ;
}; Just a => Just (f a) ;
};
pure : a -> Maybe (a) ;
pure x = Just x ;
ap mf ma = case mf of {
Just f => case ma of {
Nothing => Nothing;
Just a => Just (f a);
};
Nothing => Nothing;
};
return = pure;
bind ma f = case ma of {
Nothing => Nothing ;
Just a => f a ;
};

View file

@ -16,149 +16,153 @@ main :: IO ()
main = do main = do
mapM_ hspec goods mapM_ hspec goods
mapM_ hspec bads mapM_ hspec bads
mapM_ hspec bes
goods = goods =
[ specify "Basic polymorphism with multiple type variables" $ [ testSatisfy
run "Basic polymorphism with multiple type variables"
( D.do ( D.do
_const _const
"main = const 'a' 65 ;" "main = const 'a' 65 ;"
) )
`shouldSatisfy` ok ok
, specify "Head with a correct signature is accepted" $ , testSatisfy
run "Head with a correct signature is accepted"
( D.do ( D.do
_list _List
_headSig _headSig
_head _head
) )
`shouldSatisfy` ok ok
, specify "A basic arithmetic function should be able to be inferred" $ , testSatisfy
run "Most simple inference possible"
( D.do ( D.do
"plusOne x = x + 1 ;" _id
"main x = plusOne x ;" )
) ok
`shouldBe` run , testSatisfy
( D.do "Pattern matching on a nested list"
"plusOne : Int -> Int ;" ( D.do
"plusOne x = x + 1 ;" _List
"main : Int -> Int ;" "main : List (List (a)) -> Int ;"
"main x = plusOne x ;" "main xs = case xs of {"
) " Cons Nil _ => 1 ;"
, specify "A basic arithmetic function should be able to be inferred" $ " _ => 0 ;"
run "};"
( D.do )
"plusOne x = x + 1 ;" ok
)
`shouldBe` run
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, specify "Most simple inference possible" $
run
( D.do
_id
)
`shouldSatisfy` ok
, specify "Pattern matching on a nested list" $
run
( D.do
_list
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
`shouldSatisfy` ok
, specify "List of function Int -> Int functions should be inferred corretly" $
run
( D.do
_list
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
`shouldBe` run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
] ]
bads = bads =
[ specify "Infinite type unification should not succeed" $ [ testSatisfy
run "Infinite type unification should not succeed"
( D.do ( D.do
"main = \\x. x x ;" "main = \\x. x x ;"
) )
`shouldSatisfy` bad bad
, specify "Pattern matching using different types should not succeed" $ , testSatisfy
run "Pattern matching using different types should not succeed"
( D.do ( D.do
_list _List
"bad xs = case xs of {" "bad xs = case xs of {"
" 1 => 0 ;" " 1 => 0 ;"
" Nil => 0 ;" " Nil => 0 ;"
"};" "};"
) )
`shouldSatisfy` bad bad
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $ , testSatisfy
run "Using a concrete function (data type) on a skolem variable should not succeed"
( D.do ( D.do
_bool _Bool
_not _not
"f : a -> Bool () ;" "f : a -> Bool () ;"
"f x = not x ;" "f x = not x ;"
) )
`shouldSatisfy` bad bad
, specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $ , testSatisfy
run "Using a concrete function (primitive type) on a skolem variable should not succeed"
( D.do ( D.do
"plusOne : Int -> Int ;" "plusOne : Int -> Int ;"
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
"f : a -> Int ;" "f : a -> Int ;"
" f x = plusOne x ;" "f x = plusOne x ;"
) )
`shouldSatisfy` bad bad
, specify "A function without signature used in an incompatible context should not succeed" $ , testSatisfy
run "A function without signature used in an incompatible context should not succeed"
( D.do ( D.do
"main = _id 1 2 ;" "main = _id 1 2 ;"
"_id x = x ;" "_id x = x ;"
) )
`shouldSatisfy` bad bad
, specify "Pattern matching on literal and _list should not succeed" $ , testSatisfy
run "Pattern matching on literal and _List should not succeed"
( D.do ( D.do
_list _List
"length : List (c) -> Int;" "length : List (c) -> Int;"
"length _list = case _list of {" "length _List = case _List of {"
" 0 => 0;" " 0 => 0;"
" Cons x xs => 1 + length xs;" " Cons x xs => 1 + length xs;"
"};" "};"
) )
`shouldSatisfy` bad bad
, specify "List of function Int -> Int functions should not be usable on Char" $ , testSatisfy
run "List of function Int -> Int functions should not be usable on Char"
( D.do ( D.do
_list _List
"main : List (Int -> Int) -> Int ;" "main : List (Int -> Int) -> Int ;"
"main xs = case xs of {" "main xs = case xs of {"
" Cons f _ => f 'a' ;" " Cons f _ => f 'a' ;"
" Nil => 0 ;" " Nil => 0 ;"
" };" " };"
) )
`shouldSatisfy` bad bad
] ]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
"main : Int -> Int ;"
"main x = plusOne x ;"
)
, testBe
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = typecheck <=< pProgram . myLexer run = typecheck <=< pProgram . myLexer
ok (Right _) = True ok (Right _) = True
@ -171,7 +175,7 @@ bad = not . ok
_const = D.do _const = D.do
"const : a -> b -> a ;" "const : a -> b -> a ;"
"const x y = x ;" "const x y = x ;"
_list = D.do _List = D.do
"data List (a) where" "data List (a) where"
" {" " {"
" Nil : List (a)" " Nil : List (a)"
@ -187,7 +191,7 @@ _head = D.do
" Cons x xs => x ;" " Cons x xs => x ;"
" };" " };"
_bool = D.do _Bool = D.do
"data Bool () where {" "data Bool () where {"
" True : Bool ()" " True : Bool ()"
" False : Bool ()" " False : Bool ()"
@ -200,3 +204,15 @@ _not = D.do
" False => True ;" " False => True ;"
"};" "};"
_id = "id x = x ;" _id = "id x = x ;"
_Maybe = D.do
"data Maybe (a) where {"
" Nothing : Maybe (a)"
" Just : a -> Maybe (a)"
" };"
_fmap = D.do
"fmap f ma = case ma of {"
" Nothing => Nothing ;"
" Just a => Just (f a) ;"
"};"