Better inference & stuff on pattern matches, added more tests for regression
This commit is contained in:
parent
88eaa466e4
commit
975dd34063
3 changed files with 94 additions and 64 deletions
1
Justfile
1
Justfile
|
|
@ -5,6 +5,7 @@ build:
|
|||
clean:
|
||||
rm -r src/Grammar
|
||||
rm language
|
||||
rm -r dist-newstyle/
|
||||
|
||||
# run all tests
|
||||
test:
|
||||
|
|
|
|||
|
|
@ -6,19 +6,19 @@ module TypeChecker.TypeChecker where
|
|||
|
||||
import Auxiliary
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable (traverse_)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.List (foldl')
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr (
|
||||
|
|
@ -117,20 +117,12 @@ checkPrg (Program bs) = do
|
|||
(DSig _) -> checkDef xs
|
||||
|
||||
checkBind :: Bind -> Infer T.Bind
|
||||
checkBind err@(Bind name args e) = do
|
||||
checkBind (Bind name args e) = do
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
e@(_, args_t) <- inferExp lambda
|
||||
-- args <- zip args <$> mapM (const fresh) args
|
||||
-- withBindings (coerce args) $ do
|
||||
-- e@(_, t) <- inferExp e
|
||||
-- let args_t = foldl' T.TFun t (reverse (map snd args))
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
-- sub <- bindErr (unify args_t t') err
|
||||
-- let newT = apply sub args_t
|
||||
-- insertSig (coerce name) (Just newT)
|
||||
-- return $ T.Bind (apply sub (coerce name, newT)) [] e
|
||||
unless
|
||||
(args_t `typeEq` t')
|
||||
( throwError $
|
||||
|
|
@ -152,7 +144,8 @@ typeEq (T.TData name a) (T.TData name' b) =
|
|||
length a == length b
|
||||
&& name == name'
|
||||
&& and (zipWith typeEq a b)
|
||||
typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2
|
||||
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2
|
||||
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2
|
||||
typeEq (T.TVar _) (T.TVar _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
|
|
@ -164,6 +157,7 @@ isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
|
|||
n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
|
||||
isMoreSpecificOrEq _ (T.TVar _) = True
|
||||
isMoreSpecificOrEq a b = a == b
|
||||
|
||||
isPoly :: Type -> Bool
|
||||
|
|
@ -175,10 +169,7 @@ inferExp :: Exp -> Infer T.ExpT
|
|||
inferExp e = do
|
||||
(s, (e', t)) <- algoW e
|
||||
let subbed = apply s t
|
||||
return $ replace subbed (e', t)
|
||||
|
||||
replace :: T.Type -> T.ExpT -> T.ExpT
|
||||
replace t = second (const t)
|
||||
return $ second (const subbed) (e', t)
|
||||
|
||||
class NewType a b where
|
||||
toNew :: a -> b
|
||||
|
|
@ -200,7 +191,7 @@ instance NewType Data T.Data where
|
|||
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
|
||||
where
|
||||
name (TData n _) = coerce n
|
||||
name _ = error "Bug in toNew Data -> T.Data"
|
||||
name _ = error "Bug: Data types should not be able to be typed over non type variables"
|
||||
|
||||
instance NewType Constructor T.Constructor where
|
||||
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
|
||||
|
|
@ -213,7 +204,6 @@ instance NewType a b => NewType [a] [b] where
|
|||
|
||||
algoW :: Exp -> Infer (Subst, T.ExpT)
|
||||
algoW = \case
|
||||
-- \| TODO: More testing need to be done. Unsure of the correctness of this
|
||||
err@(EAnn e t) -> do
|
||||
(s1, (e', t')) <- exprErr (algoW e) err
|
||||
unless
|
||||
|
|
@ -434,6 +424,9 @@ inst = \case
|
|||
compose :: Subst -> Subst -> Subst
|
||||
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
|
||||
|
||||
composeAll :: [Subst] -> Subst
|
||||
composeAll = foldl' compose nullSubst
|
||||
|
||||
-- TODO: Split this class into two separate classes, one for free variables
|
||||
-- and one for applying substitutions
|
||||
|
||||
|
|
@ -477,21 +470,19 @@ instance SubstType (Map T.Ident T.Type) where
|
|||
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
|
||||
apply s = M.map (apply s)
|
||||
|
||||
instance SubstType T.ExpT where
|
||||
apply :: Subst -> T.ExpT -> T.ExpT
|
||||
instance SubstType T.Exp where
|
||||
apply :: Subst -> T.Exp -> T.Exp
|
||||
apply s = \case
|
||||
(T.EId i, outerT) -> (T.EId i, apply s outerT)
|
||||
(T.ELit lit, t) -> (T.ELit lit, apply s t)
|
||||
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) ->
|
||||
( T.ELet
|
||||
T.EId i -> T.EId i
|
||||
T.ELit lit -> T.ELit lit
|
||||
T.ELet (T.Bind (ident, t1) args e1) e2 ->
|
||||
T.ELet
|
||||
(T.Bind (ident, apply s t1) args (apply s e1))
|
||||
(apply s e2)
|
||||
, apply s t2
|
||||
)
|
||||
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
|
||||
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
|
||||
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
|
||||
(T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t)
|
||||
T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2)
|
||||
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
|
||||
T.EAbs ident e -> T.EAbs ident (apply s e)
|
||||
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
|
||||
|
||||
instance SubstType T.Branch where
|
||||
apply :: Subst -> T.Branch -> T.Branch
|
||||
|
|
@ -509,6 +500,9 @@ instance SubstType T.Pattern where
|
|||
instance SubstType a => SubstType [a] where
|
||||
apply s = map (apply s)
|
||||
|
||||
instance (SubstType a, SubstType b) => SubstType (a, b) where
|
||||
apply s (a, b) = (apply s a, apply s b)
|
||||
|
||||
instance SubstType T.Id where
|
||||
apply s (name, t) = (name, apply s t)
|
||||
|
||||
|
|
@ -548,8 +542,10 @@ insertConstr i t =
|
|||
-------- PATTERN MATCHING ---------
|
||||
|
||||
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
|
||||
checkCase expT injs = do
|
||||
(injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs
|
||||
checkCase _ [] = throwError "Atleast one case required"
|
||||
checkCase expT brnchs = do
|
||||
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||
let sub0 = composeAll subs
|
||||
(sub1, _) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
|
|
@ -564,17 +560,14 @@ checkCase expT injs = do
|
|||
)
|
||||
(nullSubst, head returns)
|
||||
(tail returns)
|
||||
return (sub2 `compose` sub1, injs, returns_type)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, apply comp injs, apply comp returns_type)
|
||||
|
||||
{- | fst = type of init
|
||||
| snd = type of expr
|
||||
-}
|
||||
inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
|
||||
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
|
||||
inferBranch (Branch pat expr) = do
|
||||
newPat@(pat, branchT) <- inferPattern pat
|
||||
trace ("BRANCH TYPE: " ++ show branchT) pure ()
|
||||
newExp@(_, exprT) <- withPattern pat (inferExp expr)
|
||||
return (branchT, T.Branch newPat newExp, exprT)
|
||||
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
|
||||
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
|
||||
|
||||
withPattern :: T.Pattern -> Infer a -> Infer a
|
||||
withPattern p ma = case p of
|
||||
|
|
@ -590,14 +583,17 @@ inferPattern = \case
|
|||
PInj constr patterns -> do
|
||||
t <- gets (M.lookup (coerce constr) . constructors)
|
||||
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
|
||||
(vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t)
|
||||
let numArgs = typeLength t - 1
|
||||
let (vs, ret) = fromJust (unsnoc $ flattenType t)
|
||||
patterns <- mapM inferPattern patterns
|
||||
sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns)
|
||||
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns))
|
||||
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
|
||||
return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
|
||||
PCatch -> (T.PCatch,) <$> fresh
|
||||
PEnum p -> do
|
||||
t <- gets (M.lookup (coerce p) . constructors)
|
||||
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
|
||||
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0")
|
||||
return (T.PEnum $ coerce p, t)
|
||||
PVar x -> do
|
||||
fr <- fresh
|
||||
|
|
@ -608,6 +604,10 @@ flattenType :: T.Type -> [T.Type]
|
|||
flattenType (T.TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
|
||||
typeLength :: T.Type -> Int
|
||||
typeLength (T.TFun a b) = typeLength a + typeLength b
|
||||
typeLength _ = 1
|
||||
|
||||
litType :: Lit -> T.Type
|
||||
litType (LInt _) = int
|
||||
litType (LChar _) = char
|
||||
|
|
@ -629,8 +629,7 @@ partitionType = go []
|
|||
|
||||
exprErr :: Infer a -> Exp -> Infer a
|
||||
exprErr ma exp =
|
||||
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)
|
||||
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
|
||||
|
||||
bindErr :: Infer a -> Bind -> Infer a
|
||||
bindErr ma exp =
|
||||
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)
|
||||
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
|
||||
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], [])
|
||||
|
|
|
|||
|
|
@ -21,16 +21,16 @@ goods =
|
|||
[ specify "Basic polymorphism with multiple type variables" $
|
||||
run
|
||||
( D.do
|
||||
const
|
||||
_const
|
||||
"main = const 'a' 65 ;"
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
, specify "Head with a correct signature is accepted" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
headSig
|
||||
head
|
||||
_list
|
||||
_headSig
|
||||
_head
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
, specify "A basic arithmetic function should be able to be inferred" $
|
||||
|
|
@ -59,13 +59,13 @@ goods =
|
|||
, specify "Most simple inference possible" $
|
||||
run
|
||||
( D.do
|
||||
"id x = x ;"
|
||||
_id
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
, specify "Pattern matching on a nested list" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
_list
|
||||
"main : List (List (a)) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons Nil _ => 1 ;"
|
||||
|
|
@ -73,6 +73,24 @@ goods =
|
|||
"};"
|
||||
)
|
||||
`shouldSatisfy` ok
|
||||
, specify "List of function Int -> Int functions should be inferred corretly" $
|
||||
run
|
||||
( D.do
|
||||
_list
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
`shouldBe` run
|
||||
( D.do
|
||||
_list
|
||||
"main : List (Int -> Int) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 1 ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
]
|
||||
|
||||
bads =
|
||||
|
|
@ -85,7 +103,7 @@ bads =
|
|||
, specify "Pattern matching using different types should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
_list
|
||||
"bad xs = case xs of {"
|
||||
" 1 => 0 ;"
|
||||
" Nil => 0 ;"
|
||||
|
|
@ -95,7 +113,7 @@ bads =
|
|||
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
bool
|
||||
_bool
|
||||
_not
|
||||
"f : a -> Bool () ;"
|
||||
"f x = not x ;"
|
||||
|
|
@ -113,21 +131,32 @@ bads =
|
|||
, specify "A function without signature used in an incompatible context should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
"main = id 1 2 ;"
|
||||
"id x = x ;"
|
||||
"main = _id 1 2 ;"
|
||||
"_id x = x ;"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
, specify "Pattern matching on literal and list should not succeed" $
|
||||
, specify "Pattern matching on literal and _list should not succeed" $
|
||||
run
|
||||
( D.do
|
||||
list
|
||||
_list
|
||||
"length : List (c) -> Int;"
|
||||
"length list = case list of {"
|
||||
"length _list = case _list of {"
|
||||
" 0 => 0;"
|
||||
" Cons x xs => 1 + length xs;"
|
||||
"};"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
, specify "List of function Int -> Int functions should not be usable on Char" $
|
||||
run
|
||||
( D.do
|
||||
_list
|
||||
"main : List (Int -> Int) -> Int ;"
|
||||
"main xs = case xs of {"
|
||||
" Cons f _ => f 'a' ;"
|
||||
" Nil => 0 ;"
|
||||
" };"
|
||||
)
|
||||
`shouldSatisfy` bad
|
||||
]
|
||||
|
||||
run = typecheck <=< pProgram . myLexer
|
||||
|
|
@ -139,26 +168,26 @@ bad = not . ok
|
|||
|
||||
-- FUNCTIONS
|
||||
|
||||
const = D.do
|
||||
_const = D.do
|
||||
"const : a -> b -> a ;"
|
||||
"const x y = x ;"
|
||||
list = D.do
|
||||
_list = D.do
|
||||
"data List (a) where"
|
||||
" {"
|
||||
" Nil : List (a)"
|
||||
" Cons : a -> List (a) -> List (a)"
|
||||
" };"
|
||||
|
||||
headSig = D.do
|
||||
_headSig = D.do
|
||||
"head : List (a) -> a ;"
|
||||
|
||||
head = D.do
|
||||
_head = D.do
|
||||
"head xs = "
|
||||
" case xs of {"
|
||||
" Cons x xs => x ;"
|
||||
" };"
|
||||
|
||||
bool = D.do
|
||||
_bool = D.do
|
||||
"data Bool () where {"
|
||||
" True : Bool ()"
|
||||
" False : Bool ()"
|
||||
|
|
@ -170,3 +199,4 @@ _not = D.do
|
|||
" True => False ;"
|
||||
" False => True ;"
|
||||
"};"
|
||||
_id = "id x = x ;"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue