diff --git a/src/TypeChecker/Bugs.md b/src/TypeChecker/Bugs.md index 8dad339..fb986a5 100644 --- a/src/TypeChecker/Bugs.md +++ b/src/TypeChecker/Bugs.md @@ -27,38 +27,12 @@ Program below should not type check main : a -> b ; main x = x; ``` +## Pattern match on functions + +Program below should not type check -## Bugged error message ```hs -data Maybe () where { - Nothing : Maybe - Just : Int -> Maybe - }; - -fmap : (Int -> Int) -> Maybe -> Maybe ; -fmap f ma = case ma of { - Nothing => Nothing ; - Just a => Just (f a) ; -}; - -pure : Int -> Maybe ; -pure x = Just x ; - -ap mf ma = case mf of { - Just f => case ma of { - Nothing => Nothing; - Just a => Just (f a); - }; - Nothing => Nothing; -}; - -return = pure; - -bind ma f = case ma of { - Nothing => Nothing ; - Just a => f a ; +main = case \x. x of { + _ => 0; }; ``` -``` -TYPECHECKER ERROR -Inferred type '("c" -> "Int") -> "Maybe" -> "Maybe" does not match specified type '("Int" -> "Int") -> "Maybe" -> "Maybe"' diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 1fc0ee4..2edd1f2 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -12,7 +12,6 @@ import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader import Control.Monad.State -import Data.Bifunctor (second) import Data.Coerce (coerce) import Data.Function (on) import Data.List (foldl') @@ -30,16 +29,17 @@ initCtx = Ctx mempty initEnv = Env 0 'a' mempty mempty mempty run :: Infer a -> Either Error a -run = runC initEnv initCtx +run = run' initEnv initCtx -runC :: Env -> Ctx -> Infer a -> Either Error a -runC e c = +run' :: Env -> Ctx -> Infer a -> Either Error a +run' e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e . runInfer +-- | Type check a program typecheck :: Program -> Either String (T.Program' Type) typecheck = onLeft msg . run . checkPrg where @@ -47,20 +47,87 @@ typecheck = onLeft msg . run . checkPrg onLeft f (Left x) = Left $ f x onLeft _ (Right x) = Right x +checkPrg :: Program -> Infer (T.Program' Type) +checkPrg (Program bs) = do + preRun bs + bs' <- checkDef bs + return $ T.Program bs' + +preRun :: [Def] -> Infer () +preRun [] = return () +preRun (x : xs) = case x of + DSig (Sig n t) -> do + collect (collectTVars t) + gets (M.member (coerce n) . sigs) + >>= flip + when + ( uncatchableErr $ Aux.do + "Duplicate signatures for function" + quote $ printTree n + ) + insertSig (coerce n) (Just t) >> preRun xs + DBind (Bind n _ e) -> do + collect (collectTVars e) + s <- gets sigs + case M.lookup (coerce n) s of + Nothing -> insertSig (coerce n) Nothing >> preRun xs + Just _ -> preRun xs + DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs + +checkDef :: [Def] -> Infer [T.Def' Type] +checkDef [] = return [] +checkDef (x : xs) = case x of + (DBind b) -> do + b' <- checkBind b + xs' <- checkDef xs + return $ T.DBind b' : xs' + (DData d) -> do + xs' <- checkDef xs + return $ T.DData (coerceData d) : xs' + (DSig _) -> checkDef xs + where + coerceData (Data t injs) = + T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs + +checkBind :: Bind -> Infer (T.Bind' Type) +checkBind (Bind name args e) = do + let lambda = makeLambda e (reverse (coerce args)) + (sub0, (e, lambda_t)) <- inferExp lambda + s <- gets sigs + case M.lookup (coerce name) s of + Just (Just t') -> do + let fsig = apply sub0 t' + sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty + sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty + unless + (lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig) + ( uncatchableErr $ Aux.do + "Inferred type" + quote $ printTree lambda_t + "does not match specified type" + quote $ printTree t' + ) + return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t) + _ -> do + insertSig (coerce name) (Just lambda_t) + return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) + checkData :: Data -> Infer () checkData err@(Data typ injs) = do (name, tvars) <- go typ - dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err + dataErr (mapM_ (\i -> checkInj i name tvars) injs) err where go = \case TData name typs | Right tvars' <- mapM toTVar typs -> pure (name, tvars') TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" - _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] + _ -> + uncatchableErr $ + unwords ["Bad data type definition: ", printTree typ] -typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () -typecheckInj (Inj c inj_typ) name tvars +checkInj :: Inj -> UIdent -> [TVar] -> Infer () +checkInj (Inj c inj_typ) name tvars | Right False <- boundTVars tvars inj_typ = catchableErr "Unbound type variables" | TData name' typs <- returnType inj_typ @@ -108,109 +175,11 @@ returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 returnType a = a -checkPrg :: Program -> Infer (T.Program' Type) -checkPrg (Program bs) = do - preRun bs - bs' <- checkDef bs - return $ T.Program bs' - -preRun :: [Def] -> Infer () -preRun [] = return () -preRun (x : xs) = case x of - DSig (Sig n t) -> do - collect (collectTVars t) - gets (M.member (coerce n) . sigs) - >>= flip - when - ( uncatchableErr $ Aux.do - "Duplicate signatures for function" - quote $ printTree n - ) - insertSig (coerce n) (Just t) >> preRun xs - DBind (Bind n _ e) -> do - collect (collectTVars e) - s <- gets sigs - case M.lookup (coerce n) s of - Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs - DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs - -checkDef :: [Def] -> Infer [T.Def' Type] -checkDef [] = return [] -checkDef (x : xs) = case x of - (DBind b) -> do - b' <- checkBind b - xs' <- checkDef xs - return $ T.DBind b' : xs' - (DData d) -> do - xs' <- checkDef xs - return $ T.DData (coerceData d) : xs' - (DSig _) -> checkDef xs - where - coerceData (Data t injs) = - T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs - -checkBind :: Bind -> Infer (T.Bind' Type) -checkBind (Bind name args e) = do - let lambda = makeLambda e (reverse (coerce args)) - (e, lambda_t) <- inferExp lambda - s <- gets sigs - case M.lookup (coerce name) s of - Just (Just t') -> do - sub1 <- unify lambda_t t' - sub2 <- unify t' lambda_t - unless - (apply sub1 lambda_t == t' && lambda_t == apply sub2 t') - ( uncatchableErr $ Aux.do - "Inferred type" - quote $ printTree lambda_t - "does not match specified type" - quote $ printTree t' - ) - return $ T.Bind (coerce name, t') [] (e, lambda_t) - _ -> do - insertSig (coerce name) (Just lambda_t) - return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) - -typeEq :: Type -> Type -> Bool -typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' -typeEq (TLit a) (TLit b) = a == b -typeEq (TData name a) (TData name' b) = - length a == length b - && name == name' - && and (zipWith typeEq a b) -typeEq (TAll _ t1) t2 = t1 `typeEq` t2 -typeEq t1 (TAll _ t2) = t1 `typeEq` t2 -typeEq (TVar _) (TVar _) = True -typeEq _ _ = False - -skolemize :: Type -> Type -skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a) -skolemize (TAll x t) = TAll x (skolemize t) -skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 -skolemize t = t - -isMoreSpecificOrEq :: Type -> Type -> Bool -isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 -isMoreSpecificOrEq (TFun a b) (TFun c d) = - isMoreSpecificOrEq a c && isMoreSpecificOrEq b d -isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) = - n1 == n2 - && length ts1 == length ts2 - && and (zipWith isMoreSpecificOrEq ts1 ts2) -isMoreSpecificOrEq _ (TVar _) = True -isMoreSpecificOrEq a b = a == b - -isPoly :: Type -> Bool -isPoly (TAll _ _) = True -isPoly (TVar _) = True -isPoly _ = False - -inferExp :: Exp -> Infer (T.ExpT' Type) +inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t - return $ second (const subbed) (e', t) + return (s, (e', subbed)) class CollectTVars a where collectTVars :: a -> Set T.Ident @@ -223,7 +192,8 @@ instance CollectTVars Type where collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTVars (TAll _ t) = collectTVars t collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2 - collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts + collectTVars (TData _ ts) = + foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts collectTVars _ = S.empty collect :: Set T.Ident -> Infer () @@ -232,7 +202,7 @@ collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) algoW :: Exp -> Infer (Subst, T.ExpT' Type) algoW = \case err@(EAnn e t) -> do - (s1, (e', t')) <- exprErr (algoW e) err + (sub0, (e', t')) <- exprErr (algoW e) err sub1 <- unify t t' sub2 <- unify t' t unless @@ -243,8 +213,7 @@ algoW = \case "does not match inferred type" quote $ printTree t' ) - s2 <- exprErr (unify t t') err - let comp = s2 `compose` s1 + let comp = sub2 `compose` sub1 `compose` sub0 return (comp, apply comp (e', t)) -- \| ------------------ @@ -257,7 +226,9 @@ algoW = \case EVar i -> do var <- asks vars case M.lookup (coerce i) var of - Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x)) + Just t -> + inst t >>= \x -> + return (nullSubst, (T.EVar $ coerce i, x)) Nothing -> do sig <- gets sigs case M.lookup (coerce i) sig of @@ -266,7 +237,10 @@ algoW = \case fr <- fresh insertSig (coerce i) (Just fr) return (nullSubst, (T.EVar $ coerce i, fr)) - Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i + Nothing -> + uncatchableErr $ + "Unbound variable: " + <> printTree i EInj i -> do constr <- gets injections case M.lookup (coerce i) constr of @@ -283,14 +257,11 @@ algoW = \case err@(EAbs name e) -> do fr <- fresh - exprErr - ( withBinding (coerce name) fr $ do - (s1, (e', t')) <- exprErr (algoW e) err - let varType = apply s1 fr - let newArr = TFun varType t' - return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) - ) - err + withBinding (coerce name) fr $ do + (s1, (e', t')) <- exprErr (algoW e) err + let varType = apply s1 fr + let newArr = TFun varType t' + return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -338,29 +309,120 @@ algoW = \case (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 return (comp, apply comp (T.ELet bind' (e1', t2), t2)) - - -- \| TODO: Add judgement ECase caseExpr injs -> do (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub return (comp, apply comp (T.ECase (e', t) injs, ret_t)) -makeLambda :: Exp -> [T.Ident] -> Exp -makeLambda = foldl (flip (EAbs . coerce)) +checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) +checkCase _ [] = catchableErr "Atleast one case required" +checkCase expT brnchs = do + (subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs + let sub0 = composeAll subs + (sub1, _) <- + foldM + ( \(sub, acc) x -> + (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc + ) + (nullSubst, expT) + branchTs + (sub2, returns_type) <- + foldM + ( \(sub, acc) x -> + (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc + ) + (nullSubst, head returns) + (tail returns) + let comp = sub2 `compose` sub1 `compose` sub0 + return (comp, apply comp injs, apply comp returns_type) + +inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type) +inferBranch (Branch pat expr) = do + newPat@(pat, branchT) <- inferPattern pat + (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) + return + ( sub + , apply sub branchT + , T.Branch (apply sub newPat) (apply sub newExp) + , apply sub exprT + ) + +inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) +inferPattern = \case + PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) + PInj constr patterns -> do + t <- gets (M.lookup (coerce constr) . injections) + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree constr + "does not exist" + ) + True + ) + t + let numArgs = typeLength t - 1 + let (vs, ret) = fromJust (unsnoc $ flattenType t) + patterns <- mapM inferPattern patterns + unless + (length patterns == numArgs) + ( catchableErr $ Aux.do + "The constructor" + quote $ printTree constr + " should have " + show numArgs + " arguments but has been given " + show (length patterns) + ) + sub <- composeAll <$> zipWithM unify vs (map snd patterns) + return + ( T.PInj (coerce constr) (apply sub (map fst patterns)) + , apply sub ret + ) + PCatch -> (T.PCatch,) <$> fresh + PEnum p -> do + t <- gets (M.lookup (coerce p) . injections) + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree p + "does not exist" + ) + True + ) + t + unless + (typeLength t == 1) + ( catchableErr $ Aux.do + "The constructor" + quote $ printTree p + " should have " + show (typeLength t - 1) + " arguments but has been given 0" + ) + let (TData _data _ts) = t -- nasty nasty + frs <- mapM (const fresh) _ts + return (T.PEnum $ coerce p, TData _data frs) + PVar x -> do + fr <- fresh + let pvar = T.PVar (coerce x, fr) + return (pvar, fr) -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst -unify t0 t1 = do +unify t0 t1 = case (t0, t1) of (TFun a b, TFun c d) -> do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 - ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t - ------------------------------------------------------------------- (TVar (T.MkTVar a), t) -> occurs (coerce a) t (t, TVar (T.MkTVar b)) -> occurs (coerce b) t (TAll _ t, b) -> unify t b @@ -422,7 +484,12 @@ occurs i t = ) else return $ M.singleton i t --- | Generalize a type over all free variables in the substitution set +{- | Generalize a type over all free variables in the substitution set + Used for let bindings to allow expression that do not type check in + equivalent lambda expressions: + Type checks: let f = \x. x in (f True, f 'a') + Does not type check: (\f. (f True, f 'a')) (\x. x) +-} generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where @@ -446,15 +513,27 @@ inst = \case TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest --- | Compose two substitution sets -compose :: Subst -> Subst -> Subst -compose m1 m2 = M.map (apply m1) m2 `M.union` m1 - -composeAll :: [Subst] -> Subst -composeAll = foldl' compose nullSubst - --- TODO: Split this class into two separate classes, one for free variables --- and one for applying substitutions +-- | Generate a new fresh variable +fresh :: Infer Type +fresh = do + c <- gets nextChar + n <- gets count + taken <- gets takenTypeVars + if c == 'z' + then do + modify (\st -> st{count = succ (count st), nextChar = 'a'}) + else modify (\st -> st{nextChar = next (nextChar st)}) + if coerce [c] `S.member` taken + then do + fresh + else + if n == 0 + then return . TVar . T.MkTVar $ LIdent [c] + else return . TVar . T.MkTVar . LIdent $ c : show n + where + next :: Char -> Char + next 'z' = 'a' + next a = succ a -- | A class for substitutions class SubstType t where @@ -468,7 +547,8 @@ class FreeVars t where instance FreeVars Type where free :: Type -> Set T.Ident free (TVar (T.MkTVar a)) = S.singleton (coerce a) - free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t + free (TAll (T.MkTVar bound) t) = + S.singleton (coerce bound) `S.intersection` free t free (TLit _) = mempty free (TFun a b) = free a `S.union` free b free (TData _ a) = free a @@ -540,27 +620,19 @@ instance SubstType (T.Id' Type) where nullSubst :: Subst nullSubst = M.empty --- | Generate a new fresh variable and increment the state counter -fresh :: Infer Type -fresh = do - c <- gets nextChar - n <- gets count - taken <- gets takenTypeVars - if c == 'z' - then do - modify (\st -> st{count = succ (count st), nextChar = 'a'}) - else modify (\st -> st{nextChar = next (nextChar st)}) - if coerce [c] `S.member` taken - then do - fresh - else - if n == 0 - then return . TVar . T.MkTVar $ LIdent [c] - else return . TVar . T.MkTVar . LIdent $ c : show n - where - next :: Char -> Char - next 'z' = 'a' - next a = succ a +-- | Compose two substitution sets +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 + +-- | Compose a list of substitution sets into one +composeAll :: [Subst] -> Subst +composeAll = foldl' compose nullSubst + +{- | Convert a function with arguments to its pointfree version +> makeLambda (add x y = x + y) = add = \x. \y. x + y +-} +makeLambda :: Exp -> [T.Ident] -> Exp +makeLambda = foldl (flip (EAbs . coerce)) -- | Run the monadic action with an additional binding withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a @@ -571,49 +643,8 @@ withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) --- | Insert a function signature into the environment -insertSig :: T.Ident -> Maybe Type -> Infer () -insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) - --- | Insert a constructor with its data type -insertInj :: T.Ident -> Type -> Infer () -insertInj i t = - modify (\st -> st{injections = M.insert i t (injections st)}) - -existInj :: T.Ident -> Infer (Maybe Type) -existInj n = gets (M.lookup n . injections) - --------- PATTERN MATCHING --------- - -checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) -checkCase _ [] = catchableErr "Atleast one case required" -checkCase expT brnchs = do - (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs - let sub0 = composeAll subs - (sub1, _) <- - foldM - ( \(sub, acc) x -> - (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc - ) - (nullSubst, expT) - injTs - (sub2, returns_type) <- - foldM - ( \(sub, acc) x -> - (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc - ) - (nullSubst, head returns) - (tail returns) - let comp = sub2 `compose` sub1 `compose` sub0 - return (comp, apply comp injs, apply comp returns_type) - -inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type) -inferBranch (Branch pat expr) = do - newPat@(pat, branchT) <- inferPattern pat - (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) - return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) - -withPattern :: T.Pattern' Type -> Infer a -> Infer a +-- | Run the monadic action with a pattern +withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a withPattern p ma = case p of T.PVar (x, t) -> withBinding x t ma T.PInj _ ps -> foldl' (flip withPattern) ma ps @@ -621,74 +652,27 @@ withPattern p ma = case p of T.PCatch -> ma T.PEnum _ -> ma -inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) -inferPattern = \case - PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) - PInj constr patterns -> do - t <- gets (M.lookup (coerce constr) . injections) - t <- - maybeToRightM - ( Error - ( Aux.do - "Constructor:" - quote $ printTree constr - "does not exist" - ) - True - ) - t - let numArgs = typeLength t - 1 - let (vs, ret) = fromJust (unsnoc $ flattenType t) - patterns <- mapM inferPattern patterns - unless - (length patterns == numArgs) - ( catchableErr $ Aux.do - "The constructor" - quote $ printTree constr - " should have " - show numArgs - " arguments but has been given " - show (length patterns) - ) - sub <- composeAll <$> zipWithM unify vs (map snd patterns) - return (T.PInj (coerce constr) (apply sub (map fst patterns)), apply sub ret) - PCatch -> (T.PCatch,) <$> fresh - PEnum p -> do - t <- gets (M.lookup (coerce p) . injections) - t <- - maybeToRightM - ( Error - ( Aux.do - "Constructor:" - quote $ printTree p - "does not exist" - ) - True - ) - t - unless - (typeLength t == 1) - ( catchableErr $ Aux.do - "The constructor" - quote $ printTree p - " should have " - show (typeLength t - 1) - " arguments but has been given 0" - ) - let (TData _data _ts) = t -- nasty nasty - frs <- mapM (const fresh) _ts - return (T.PEnum $ coerce p, TData _data frs) - PVar x -> do - fr <- fresh - let pvar = T.PVar (coerce x, fr) - return (pvar, fr) +-- | Insert a function signature into the environment +insertSig :: T.Ident -> Maybe Type -> Infer () +insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) + +-- | Insert a constructor into the start with its type +insertInj :: T.Ident -> Type -> Infer () +insertInj i t = + modify (\st -> st{injections = M.insert i t (injections st)}) + +{- | Check if an injection (constructor of data type) +with an equivalent name has been declared already +-} +existInj :: T.Ident -> Infer (Maybe Type) +existInj n = gets (M.lookup n . injections) flattenType :: Type -> [Type] flattenType (TFun a b) = flattenType a <> flattenType b flattenType a = [a] typeLength :: Type -> Int -typeLength (TFun a b) = typeLength a + typeLength b +typeLength (TFun _ b) = 1 + typeLength b typeLength _ = 1 litType :: Lit -> Type @@ -698,23 +682,63 @@ litType (LChar _) = char int = TLit "Int" char = TLit "Char" -partitionType :: - Int -> -- Number of parameters to apply - Type -> - ([Type], Type) -partitionType = go [] - where - go acc 0 t = (acc, t) - go acc i t = case t of - TAll tvar t' -> second (TAll tvar) $ go acc i t' - TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 - _ -> error "Number of parameters and type doesn't match" +typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) () +typeEq (TVar (T.MkTVar a)) t@(TVar _) = do + st <- get + case M.lookup (coerce a) st of + Nothing -> put $ M.insert (coerce a) t st + Just t' -> unless (t == t') (catchableErr "TYPE MISMATCH") +typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r' +typeEq (TAll _ l) (TAll _ r) = typeEq l r +typeEq (TLit a) (TLit b) = unless (a == b) (catchableErr "TYPE MISMATCH") +typeEq (TData nameL tL) (TData nameR tR) = do + unless (nameL == nameR) (catchableErr "TYPE MISMATCH") + zipWithM_ typeEq tL tR +typeEq (TEVar _) (TEVar _) = catchableErr "TYPE MISMATCH" +typeEq _ _ = catchableErr "TYPE MISMATCH" -exprErr :: Infer a -> Exp -> Infer a -exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x) +{- | Catch an error if possible and add the given +expression as addition to the error message +-} +exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a +exprErr ma exp = + catchError + ma + ( \x -> + if x.catchable + then + throwError + ( x + { msg = + x.msg + <> " in expression: \n" + <> printTree exp + , catchable = False + } + ) + else throwError x + ) +{- | Catch an error if possible and add the given +data as addition to the error message +-} dataErr :: Infer a -> Data -> Infer a -dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False})) +dataErr ma d = + catchError + ma + ( \x -> + if x.catchable + then + throwError + ( x + { msg = + x.msg + <> " in data: \n" + <> printTree d + } + ) + else throwError (x{catchable = False}) + ) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 = @@ -737,6 +761,7 @@ data Env = Env deriving (Show) data Error = Error {msg :: String, catchable :: Bool} + deriving (Show) type Subst = Map T.Ident Type newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} diff --git a/tests/TestTypeCheckerHm.hs b/tests/TestTypeCheckerHm.hs index 5f600ed..bf51a29 100644 --- a/tests/TestTypeCheckerHm.hs +++ b/tests/TestTypeCheckerHm.hs @@ -187,6 +187,31 @@ bes = " Nil => 0 ;" " };" ) + , testBe + "length function on int list infers correct signature" + ( D.do + "data List () where {" + " Nil : List ()" + " Cons : Int -> List () -> List ()" + "};" + + "length xs = case xs of {" + " Nil => 0 ;" + " Cons _ xs => 1 + length xs ;" + "};" + ) + ( D.do + "data List () where {" + " Nil : List ()" + " Cons : Int -> List () -> List ()" + "};" + + "length : List () -> Int ;" + "length xs = case xs of {" + " Nil => 0 ;" + " Cons _ xs => 1 + length xs ;" + "};" + ) ] testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction