Better inference & stuff on pattern matches, added more tests for regression

This commit is contained in:
sebastian 2023-03-25 20:43:19 +01:00
parent 88eaa466e4
commit 975dd34063
3 changed files with 94 additions and 64 deletions

View file

@ -5,6 +5,7 @@ build:
clean:
rm -r src/Grammar
rm language
rm -r dist-newstyle/
# run all tests
test:

View file

@ -6,19 +6,19 @@ module TypeChecker.TypeChecker where
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
@ -117,20 +117,12 @@ checkPrg (Program bs) = do
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind err@(Bind name args e) = do
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda
-- args <- zip args <$> mapM (const fresh) args
-- withBindings (coerce args) $ do
-- e@(_, t) <- inferExp e
-- let args_t = foldl' T.TFun t (reverse (map snd args))
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
-- sub <- bindErr (unify args_t t') err
-- let newT = apply sub args_t
-- insertSig (coerce name) (Just newT)
-- return $ T.Bind (apply sub (coerce name, newT)) [] e
unless
(args_t `typeEq` t')
( throwError $
@ -152,7 +144,8 @@ typeEq (T.TData name a) (T.TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True
typeEq _ _ = False
@ -164,6 +157,7 @@ isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
@ -175,10 +169,7 @@ inferExp :: Exp -> Infer T.ExpT
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
return $ replace subbed (e', t)
replace :: T.Type -> T.ExpT -> T.ExpT
replace t = second (const t)
return $ second (const subbed) (e', t)
class NewType a b where
toNew :: a -> b
@ -200,7 +191,7 @@ instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug in toNew Data -> T.Data"
name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
@ -213,7 +204,6 @@ instance NewType a b => NewType [a] [b] where
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
unless
@ -434,6 +424,9 @@ inst = \case
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
@ -477,21 +470,19 @@ instance SubstType (Map T.Ident T.Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
apply s = M.map (apply s)
instance SubstType T.ExpT where
apply :: Subst -> T.ExpT -> T.ExpT
instance SubstType T.Exp where
apply :: Subst -> T.Exp -> T.Exp
apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) ->
( T.ELet
T.EId i -> T.EId i
T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet
(T.Bind (ident, apply s t1) args (apply s e1))
(apply s e2)
, apply s t2
)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t)
T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2)
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
instance SubstType T.Branch where
apply :: Subst -> T.Branch -> T.Branch
@ -509,6 +500,9 @@ instance SubstType T.Pattern where
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where
apply s (name, t) = (name, apply s t)
@ -548,8 +542,10 @@ insertConstr i t =
-------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase expT injs = do
(injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs
checkCase _ [] = throwError "Atleast one case required"
checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <-
foldM
( \(sub, acc) x ->
@ -564,17 +560,14 @@ checkCase expT injs = do
)
(nullSubst, head returns)
(tail returns)
return (sub2 `compose` sub1, injs, returns_type)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
{- | fst = type of init
| snd = type of expr
-}
inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
trace ("BRANCH TYPE: " ++ show branchT) pure ()
newExp@(_, exprT) <- withPattern pat (inferExp expr)
return (branchT, T.Branch newPat newExp, exprT)
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of
@ -590,14 +583,17 @@ inferPattern = \case
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
(vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t)
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns)
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns))
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0")
return (T.PEnum $ coerce p, t)
PVar x -> do
fr <- fresh
@ -608,6 +604,10 @@ flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
typeLength :: T.Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b
typeLength _ = 1
litType :: Lit -> T.Type
litType (LInt _) = int
litType (LChar _) = char
@ -629,8 +629,7 @@ partitionType = go []
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp =
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
bindErr :: Infer a -> Bind -> Infer a
bindErr ma exp =
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], [])

View file

@ -21,16 +21,16 @@ goods =
[ specify "Basic polymorphism with multiple type variables" $
run
( D.do
const
_const
"main = const 'a' 65 ;"
)
`shouldSatisfy` ok
, specify "Head with a correct signature is accepted" $
run
( D.do
list
headSig
head
_list
_headSig
_head
)
`shouldSatisfy` ok
, specify "A basic arithmetic function should be able to be inferred" $
@ -59,13 +59,13 @@ goods =
, specify "Most simple inference possible" $
run
( D.do
"id x = x ;"
_id
)
`shouldSatisfy` ok
, specify "Pattern matching on a nested list" $
run
( D.do
list
_list
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
@ -73,6 +73,24 @@ goods =
"};"
)
`shouldSatisfy` ok
, specify "List of function Int -> Int functions should be inferred corretly" $
run
( D.do
_list
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
`shouldBe` run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
]
bads =
@ -85,7 +103,7 @@ bads =
, specify "Pattern matching using different types should not succeed" $
run
( D.do
list
_list
"bad xs = case xs of {"
" 1 => 0 ;"
" Nil => 0 ;"
@ -95,7 +113,7 @@ bads =
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $
run
( D.do
bool
_bool
_not
"f : a -> Bool () ;"
"f x = not x ;"
@ -113,21 +131,32 @@ bads =
, specify "A function without signature used in an incompatible context should not succeed" $
run
( D.do
"main = id 1 2 ;"
"id x = x ;"
"main = _id 1 2 ;"
"_id x = x ;"
)
`shouldSatisfy` bad
, specify "Pattern matching on literal and list should not succeed" $
, specify "Pattern matching on literal and _list should not succeed" $
run
( D.do
list
_list
"length : List (c) -> Int;"
"length list = case list of {"
"length _list = case _list of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
`shouldSatisfy` bad
, specify "List of function Int -> Int functions should not be usable on Char" $
run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
`shouldSatisfy` bad
]
run = typecheck <=< pProgram . myLexer
@ -139,26 +168,26 @@ bad = not . ok
-- FUNCTIONS
const = D.do
_const = D.do
"const : a -> b -> a ;"
"const x y = x ;"
list = D.do
_list = D.do
"data List (a) where"
" {"
" Nil : List (a)"
" Cons : a -> List (a) -> List (a)"
" };"
headSig = D.do
_headSig = D.do
"head : List (a) -> a ;"
head = D.do
_head = D.do
"head xs = "
" case xs of {"
" Cons x xs => x ;"
" };"
bool = D.do
_bool = D.do
"data Bool () where {"
" True : Bool ()"
" False : Bool ()"
@ -170,3 +199,4 @@ _not = D.do
" True => False ;"
" False => True ;"
"};"
_id = "id x = x ;"