diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 9fe62a4..1fc0ee4 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -47,6 +47,67 @@ typecheck = onLeft msg . run . checkPrg onLeft f (Left x) = Left $ f x onLeft _ (Right x) = Right x +checkData :: Data -> Infer () +checkData err@(Data typ injs) = do + (name, tvars) <- go typ + dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err + where + go = \case + TData name typs + | Right tvars' <- mapM toTVar typs -> + pure (name, tvars') + TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" + _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] + +typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () +typecheckInj (Inj c inj_typ) name tvars + | Right False <- boundTVars tvars inj_typ = + catchableErr "Unbound type variables" + | TData name' typs <- returnType inj_typ + , Right tvars' <- mapM toTVar typs + , name' == name + , tvars' == tvars = do + exist <- existInj (coerce c) + case exist of + Just t -> uncatchableErr $ Aux.do + "Constructor" + quote $ coerce name + "with type" + quote $ printTree t + "already exist" + Nothing -> insertInj (coerce c) inj_typ + | otherwise = + uncatchableErr $ + unwords + [ "Bad type constructor: " + , show name + , "\nExpected: " + , printTree . TData name $ map TVar tvars + , "\nActual: " + , printTree $ returnType inj_typ + ] + where + boundTVars :: [TVar] -> Type -> Either Error Bool + boundTVars tvars' = \case + TAll{} -> uncatchableErr "Explicit foralls not allowed, for now" + TFun t1 t2 -> do + t1' <- boundTVars tvars t1 + t2' <- boundTVars tvars t2 + return $ t1' && t2' + TVar tvar -> return $ tvar `elem` tvars' + TData _ typs -> and <$> mapM (boundTVars tvars) typs + TLit _ -> return True + TEVar _ -> error "TEVar in data type declaration" + +toTVar :: Type -> Either Error TVar +toTVar = \case + TVar tvar -> pure tvar + _ -> uncatchableErr "Not a type variable" + +returnType :: Type -> Type +returnType (TFun _ t2) = returnType t2 +returnType a = a + checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs @@ -111,66 +172,39 @@ checkBind (Bind name args e) = do insertSig (coerce name) (Just lambda_t) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) -checkData :: Data -> Infer () -checkData err@(Data typ injs) = do - (name, tvars) <- go typ - dataErr (mapM_ (\i -> checkInj i name tvars) injs) err - where - go = \case - TData name typs - | Right tvars' <- mapM toTVar typs -> - pure (name, tvars') - TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" - _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] +typeEq :: Type -> Type -> Bool +typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' +typeEq (TLit a) (TLit b) = a == b +typeEq (TData name a) (TData name' b) = + length a == length b + && name == name' + && and (zipWith typeEq a b) +typeEq (TAll _ t1) t2 = t1 `typeEq` t2 +typeEq t1 (TAll _ t2) = t1 `typeEq` t2 +typeEq (TVar _) (TVar _) = True +typeEq _ _ = False -checkInj :: Inj -> UIdent -> [TVar] -> Infer () -checkInj (Inj c inj_typ) name tvars - | Right False <- boundTVars tvars inj_typ = - catchableErr "Unbound type variables" - | TData name' typs <- returnType inj_typ - , Right tvars' <- mapM toTVar typs - , name' == name - , tvars' == tvars = do - exist <- existInj (coerce c) - case exist of - Just t -> uncatchableErr $ Aux.do - "Constructor" - quote $ coerce name - "with type" - quote $ printTree t - "already exist" - Nothing -> insertInj (coerce c) inj_typ - | otherwise = - uncatchableErr $ - unwords - [ "Bad type constructor: " - , show name - , "\nExpected: " - , printTree . TData name $ map TVar tvars - , "\nActual: " - , printTree $ returnType inj_typ - ] - where - boundTVars :: [TVar] -> Type -> Either Error Bool - boundTVars tvars' = \case - TAll{} -> uncatchableErr "Explicit foralls not allowed, for now" - TFun t1 t2 -> do - t1' <- boundTVars tvars t1 - t2' <- boundTVars tvars t2 - return $ t1' && t2' - TVar tvar -> return $ tvar `elem` tvars' - TData _ typs -> and <$> mapM (boundTVars tvars) typs - TLit _ -> return True - TEVar _ -> error "TEVar in data type declaration" +skolemize :: Type -> Type +skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a) +skolemize (TAll x t) = TAll x (skolemize t) +skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 +skolemize t = t -toTVar :: Type -> Either Error TVar -toTVar = \case - TVar tvar -> pure tvar - _ -> uncatchableErr "Not a type variable" +isMoreSpecificOrEq :: Type -> Type -> Bool +isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 +isMoreSpecificOrEq (TFun a b) (TFun c d) = + isMoreSpecificOrEq a c && isMoreSpecificOrEq b d +isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) = + n1 == n2 + && length ts1 == length ts2 + && and (zipWith isMoreSpecificOrEq ts1 ts2) +isMoreSpecificOrEq _ (TVar _) = True +isMoreSpecificOrEq a b = a == b -returnType :: Type -> Type -returnType (TFun _ t2) = returnType t2 -returnType a = a +isPoly :: Type -> Bool +isPoly (TAll _ _) = True +isPoly (TVar _) = True +isPoly _ = False inferExp :: Exp -> Infer (T.ExpT' Type) inferExp e = do @@ -690,40 +724,6 @@ unzip4 = ) ([], [], [], []) --- typeEq :: Type -> Type -> Bool --- typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' --- typeEq (TLit a) (TLit b) = a == b --- typeEq (TData name a) (TData name' b) = --- length a == length b --- && name == name' --- && and (zipWith typeEq a b) --- typeEq (TAll _ t1) t2 = t1 `typeEq` t2 --- typeEq t1 (TAll _ t2) = t1 `typeEq` t2 --- typeEq (TVar _) (TVar _) = True --- typeEq _ _ = False - --- skolemize :: Type -> Type --- skolemize (TVar (MkTVar a)) = TEVar (MkTEVar $ coerce a) --- skolemize (TAll x t) = TAll x (skolemize t) --- skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 --- skolemize t = t - --- isMoreSpecificOrEq :: Type -> Type -> Bool --- isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2 --- isMoreSpecificOrEq (TFun a b) (TFun c d) = --- isMoreSpecificOrEq a c && isMoreSpecificOrEq b d --- isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) = --- n1 == n2 --- && length ts1 == length ts2 --- && and (zipWith isMoreSpecificOrEq ts1 ts2) --- isMoreSpecificOrEq _ (TVar _) = True --- isMoreSpecificOrEq a b = a == b - --- isPoly :: Type -> Bool --- isPoly (TAll _ _) = True --- isPoly (TVar _) = True --- isPoly _ = False - newtype Ctx = Ctx {vars :: Map T.Ident Type} deriving (Show)