unit tests, started on pattern matching

This commit is contained in:
sebastianselander 2023-02-28 17:15:48 +01:00
parent d23d417ff3
commit 05313652f9
9 changed files with 212 additions and 133 deletions

View file

@ -1,10 +1,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# HLINT ignore "Use zipWithM" #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Control.Monad.Except
@ -21,23 +18,9 @@ import Data.Foldable (traverse_)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly
}
data Env = Env { count :: Int
, sigs :: Map Ident Type
, dtypes :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
initCtx = Ctx mempty
initEnv = Env 0 mempty mempty
@ -98,7 +81,11 @@ checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature", printTree t, "does not match body with type:", printTree t''])
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
])
return $ T.Bind (n, t) [] e'
where
makeLambda :: Exp -> [Ident] -> Exp
@ -109,12 +96,17 @@ checkBind (Bind n t _ args e) = do
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr name a) (TConstr name' b) = if length a == length b
then name == name' && and (zipWith typeEq a b)
else False
typeEq (TConstr name a) (TConstr name' b) = length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreGeneral :: Type -> Type -> Bool
isMoreGeneral _ (TPol _) = True
isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d
isMoreGeneral a b = a == b
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- algoW e
@ -123,24 +115,30 @@ inferExp e = do
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case
-- | TODO: Reason more about this one. Could be wrong
EAnn e t -> do
(s1, t', e') <- algoW e
unless (t `isMoreGeneral` t') (throwError $ unwords
["Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t' ])
applySt s1 $ do
s2 <- unify (apply s1 t) t'
s2 <- unify t t'
return (s2 `compose` s1, t, e')
-- | ------------------
-- | Γ ⊢ e₀ : Int, ∅
-- | Γ ⊢ i : Int, ∅
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
@ -159,7 +157,7 @@ algoW = \case
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets dtypes
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i
@ -220,9 +218,9 @@ algoW = \case
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
ECase e0 injs -> undefined
-- | Unify two types producing a new substitution (constraint)
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do
@ -235,9 +233,15 @@ unify t0 t1 = case (t0, t1) of
-- | TODO: Figure out a cleaner way to express the same thing
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
then do
xs <- sequence $ zipWith unify t t'
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
else throwError $ unwords
["Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type.
@ -324,4 +328,4 @@ insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })
insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) })