pattern matching works? have to test more

This commit is contained in:
sebastianselander 2023-03-03 11:46:54 +01:00
parent 7656b46e3f
commit 03d7080396
5 changed files with 76 additions and 34 deletions

View file

@ -33,7 +33,7 @@ main' s = do
putStrLn "\n-- TypeChecker --" putStrLn "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed typechecked <- fromTypeCheckerErr $ typecheck renamed
putStrLn $ printTree typechecked putStrLn $ show typechecked
-- putStrLn "\n-- Lambda Lifter --" -- putStrLn "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked

View file

@ -73,7 +73,9 @@ renameExp old_names = \case
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase _ _ -> error "ECase NOT IMPLEMENTED YET" ECase e injs -> do
(new_names, e') <- renameExp old_names e
pure (new_names, ECase e' injs)
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)

View file

@ -1,5 +1,8 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
@ -100,10 +103,12 @@ typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = length a == length
typeEq (TPol _) (TPol _) = True typeEq (TPol _) (TPol _) = True
typeEq _ _ = False typeEq _ _ = False
isMoreGeneral :: Type -> Type -> Bool isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreGeneral _ (TPol _) = True isMoreSpecificOrEq _ (TPol _) = True
isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d isMoreSpecificOrEq (TArr a b) (TArr c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreGeneral a b = a == b isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2))
= n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
isPoly (TPol _) = True isPoly (TPol _) = True
@ -123,6 +128,7 @@ replace t = \case
T.EApp _ e1 e2 -> T.EApp t e1 e2 T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2 T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
algoW :: Exp -> Infer (Subst, Type, T.Exp) algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case algoW = \case
@ -130,7 +136,7 @@ algoW = \case
-- | TODO: Reason more about this one. Could be wrong -- | TODO: Reason more about this one. Could be wrong
EAnn e t -> do EAnn e t -> do
(s1, t', e') <- algoW e (s1, t', e') <- algoW e
unless (t `isMoreGeneral` t') (throwError $ unwords unless (t `isMoreSpecificOrEq` t') (throwError $ unwords
["Annotated type:" ["Annotated type:"
, printTree t , printTree t
, "does not match inferred type:" , "does not match inferred type:"
@ -222,9 +228,14 @@ algoW = \case
ECase caseExpr injs -> do ECase caseExpr injs -> do
(s0, t0, e0') <- algoW caseExpr (s0, t0, e0') <- algoW caseExpr
injs' <- mapM (checkInj t0) injs (injs', ts) <- unzip <$> mapM (checkInj t0) injs
undefined case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
@ -340,18 +351,18 @@ insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
-- case expr of, the type of 'expr' is caseType -- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer T.Inj checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do checkInj caseType (Inj it expr) = do
(_, e') <- inferExp expr (args, t') <- initType caseType it
t' <- initType caseType it (s, t, e') <- local (\st -> st { vars = args }) (algoW expr)
return $ T.Inj (it, t') e' return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer Type initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case initType expected = \case
InitLit lit -> let returnType = litType lit InitLit lit -> let returnType = litType lit
in if expected == returnType in if expected == returnType
then return expected then return (mempty,expected)
else throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] else throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected]
InitConstr c args -> do InitConstr c args -> do
st <- gets constructors st <- gets constructors
@ -360,14 +371,14 @@ initType expected = \case
Just t -> do Just t -> do
let flat = flattenType t let flat = flattenType t
let returnType = last flat let returnType = last flat
case (length (init flat) == length args, returnType == expected) of case (length (init flat) == length args, returnType `isMoreSpecificOrEq` expected) of
(True, True) -> return returnType (True, True) -> return (M.fromList $ zip args (map (Forall []) flat), expected)
(False, _) -> throwError $ "Can't partially match on the constructor: " ++ printTree c (False, _) -> throwError $ "Can't partially match on the constructor: " ++ printTree c
(_, False) -> throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] (_, False) -> throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected]
-- Ignoring the variables for now, they can not be used in the expression to the -- Ignoring the variables for now, they can not be used in the expression to the
-- right of '=>' -- right of '=>'
InitCatch -> return expected InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b flattenType (TArr a b) = flattenType a ++ flattenType b

View file

@ -127,6 +127,17 @@ instance Print Exp where
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
] ]
ECase t exp injs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 injs, doc (showString "}"), doc (showString ":"), prt 0 t])
instance Print Inj where
prt i = \case
Inj (init,t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]

View file

@ -1,12 +1,30 @@
-- data Bool () where {
-- True : Bool ()
-- False : Bool ()
-- };
--
-- main : _Int ;
-- main = case True of {
-- False => 0 ;
-- True => 1
-- };
data List ('a) where { data List ('a) where {
Nil : List ('a) Nil : List ('a)
Cons : 'a -> List ('a) -> List ('a) Cons : ('a) -> List ('a) -> List ('a)
}; };
data Bool () where { data Maybe ('a) where {
True : Bool () Nothing : Maybe ('a)
False : Bool () Just : 'a -> Maybe ('a)
}; };
main : List (_Int) ; safeHead : List ('a) -> Maybe ('a) ;
main = Cons 1 (Cons 0 Nil) ; safeHead xs =
case xs of {
Nil => Nothing ;
Cons x xs => Just x
};
main : Maybe (_Int) ;
main = safeHead (Cons 0 (Cons 1 Nil)) ;