diff --git a/src/TypeChecker/HM.hs b/src/TypeChecker/HM.hs index 8671d1b..63072d1 100644 --- a/src/TypeChecker/HM.hs +++ b/src/TypeChecker/HM.hs @@ -2,6 +2,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use traverse_" #-} +{-# LANGUAGE FlexibleInstances #-} module TypeChecker.HM where @@ -25,14 +26,17 @@ data Ctx = Ctx { constr :: Map Type Type , frsh :: Char } deriving Show -run :: Infer a -> Either String (a, Ctx) -run = runIdentity . runExceptT . flip runStateT initC +runC :: Ctx -> Infer a -> Either String (a, Ctx) +runC c = runIdentity . runExceptT . flip runStateT c + +run :: Infer a -> Either String a +run = runIdentity . runExceptT . flip evalStateT initC initC :: Ctx initC = Ctx M.empty M.empty M.empty 'a' typecheck :: Program -> Either Error T.Program -typecheck = undefined . run . inferPrg +typecheck = run . inferPrg inferPrg :: Program -> Infer T.Program inferPrg (Program bs) = do @@ -42,39 +46,54 @@ inferPrg (Program bs) = do inferBind :: Bind -> Infer T.Bind inferBind (Bind i t _ params rhs) = do - (t',e') <- inferExp (makeLambda (reverse params) rhs) - addConstraint t t' - -- when (t /= t') (throwError $ "Signature of function" ++ printTree i ++ "does not match inferred type of expression: " ++ printTree e') + (t',e') <- inferExp (makeLambda rhs (reverse params)) + when (t /= t') (throwError $ "Signature of function " ++ show i ++ " with type: " ++ show t ++ " does not match inferred type " ++ show t' ++ " of expression: " ++ show e') return $ T.Bind (t,i) [] e' -makeLambda :: [Ident] -> Exp -> Exp -makeLambda xs e = foldl (flip EAbs) e xs +makeLambda :: Exp -> [Ident] -> Exp +makeLambda = foldl (flip EAbs) inferExp :: Exp -> Infer (Type, T.Exp) -inferExp = \case - EAnn e t -> do - (t',e') <- inferExp e - when (t' /= t) (throwError "Annotated type and inferred type don't match") - return (t', e') - EInt i -> return (int, T.EInt int i) - EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i - EAdd e1 e2 -> do - insertSig "+" (TArr int (TArr int int)) - inferExp (EApp (EApp (EId "+") e1) e2) - EApp e1 e2 -> do - (t1, e1') <- inferExp e1 - (t2, e2') <- inferExp e2 - fr <- fresh - addConstraint t1 (TArr t2 fr) - return (fr, T.EApp fr e1' e2') - EAbs name e -> do - fr <- fresh - insertVar name fr - (ret_t,e') <- inferExp e - t <- solveConstraints (TArr fr ret_t) - return (t, T.EAbs t name e') - ELet name e1 e2 -> error "Let expression not implemented yet" +inferExp e = do + (t, e') <- inferExp' e + t'' <- solveConstraints t + return (t'', replaceType t'' e') + where + inferExp' :: Exp -> Infer (Type, T.Exp) + inferExp' = \case + EAnn e t -> do + (t',e') <- inferExp' e + t'' <- solveConstraints t' + when (t'' /= t) (throwError "Annotated type and inferred type don't match") + return (t', e') + EInt i -> return (int, T.EInt int i) + EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i + EAdd e1 e2 -> do + insertSig "+" (TArr int (TArr int int)) + inferExp' (EApp (EApp (EId "+") e1) e2) + EApp e1 e2 -> do + (t1, e1') <- inferExp' e1 + (t2, e2') <- inferExp' e2 + fr <- fresh + addConstraint t1 (TArr t2 fr) + return (fr, T.EApp fr e1' e2') + EAbs name e -> do + fr <- fresh + insertVar name fr + (ret_t,e') <- inferExp' e + t <- solveConstraints (TArr fr ret_t) + return (t, T.EAbs t name e') + ELet name e1 e2 -> error "Let expression not implemented yet" + +replaceType :: Type -> T.Exp -> T.Exp +replaceType t = \case + T.EInt _ i -> T.EInt t i + T.EId _ i -> T.EId t i + T.EAdd _ e1 e2 -> T.EAdd t e1 e2 + T.EApp _ e1 e2 -> T.EApp t e1 e2 + T.EAbs _ name e -> T.EAbs t name e + T.ELet _ name e1 e2 -> T.ELet t name e1 e2 isInt :: Type -> Bool isInt (TMono "Int") = True @@ -95,25 +114,24 @@ insertVar s t = modify ( \st -> st { vars = M.insert s t (vars st) } ) insertSig :: Ident -> Type -> Infer () insertSig s t = modify ( \st -> st { sigs = M.insert s t (sigs st) } ) - +-- | Generate a new fresh variable and increment the state fresh :: Infer Type fresh = do chr <- gets frsh modify (\st -> st { frsh = succ chr }) return $ TPol (Ident [chr]) --- Constraint solving is wrong. (\x. x) 3 is inferred with the type 'a' - +-- | Adds a constraint to the constraint set. +-- i.e: a = int -> b +-- b = int +-- thus when solving constraints it must be the case that +-- a = int -> int +-- addConstraint :: Type -> Type -> Infer () addConstraint t1 t2 = do - when (t2 `contains` t1) (throwError $ "Can't match type " ++ printTree t1 ++ " with " ++ printTree t2) modify (\st -> st { constr = M.insert t1 t2 (constr st) }) -contains :: Type -> Type -> Bool -contains (TArr t1 t2) b = t1 `contains` b || t2 `contains` b -contains (TMono a) (TMono b) = False -contains a b = a == b - +-- | Given a type, solve the constraints and figure out the type that should be assigned to it. solveConstraints :: Type -> Infer Type solveConstraints t = do c <- gets constr @@ -122,12 +140,15 @@ solveConstraints t = do modify (\st -> st { constr = M.fromList xs }) return $ subst t xs +-- | Substitute subst :: Type -> [(Type, Type)] -> Type subst t [] = t subst (TArr t1 t2) (x:xs) = subst (TArr (replace x t1) (replace x t2)) xs subst t (x:xs) = subst (replace x t) xs --- Annoying fucking bug here +-- | Given a set of constraints run the replacement on all of them, producing a new set of +-- replacements. +-- https://youtu.be/trmq3wYcUxU - good video for explanation solveAll :: [(Type, Type)] -> Infer [(Type, Type)] solveAll [] = return [] solveAll (x:xs) = case x of @@ -136,12 +157,14 @@ solveAll (x:xs) = case x of (a, TArr t1 t2) -> fmap ((a, TArr t1 t2) :) $ solveAll $ solve (a, TArr t1 t2) xs (TMono a, TPol b) -> fmap ((TPol b, TMono a) :) $ solveAll $ solve (TPol b, TMono a) xs (TPol a, TMono b) -> fmap ((TPol a, TMono b) :) $ solveAll $ solve (TPol a, TMono b) xs - (TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types" (TPol a, TPol b) -> fmap ((TPol a, TPol b) :) $ solveAll $ solve (TPol a, TPol b) xs + (TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types" solve :: (Type, Type) -> [(Type, Type)] -> [(Type, Type)] solve x = map (both (replace x)) +-- | Given a constraint (type, type) and a type, if the constraint matches the input +-- replace with the constrained type replace :: (Type, Type) -> Type -> Type replace a (TArr t1 t2) = TArr (replace a t1) (replace a t2) replace (a,b) c = if a==c then b else c @@ -150,22 +173,3 @@ both :: (a -> b) -> (a,a) -> (b,b) both f = bimap f f int = TMono "Int" -a = TPol "a" -b = TPol "b" -c = TPol "c" -d = TPol "d" -e = TPol "e" -arr = TArr - -set = [(a, arr d e), (c, arr int d), (arr int (arr int int), arr b c)] - -prg = EAbs "f" (EAbs "x" (EApp (EId "f") (EAdd (EId "x") (EInt 1)))) - -bug = EApp (EAbs "x" (EAdd (EAnn (EId "x") a) (EInt 3))) (EInt 2) - --- (\x. \y. x + y + 1) -prg2 = EApp (EAbs "x" (EId "x")) (EInt 1) - --- --- Known bugs --- (x : a) + 3 type checks diff --git a/test_program b/test_program index e342096..6d38647 100644 --- a/test_program +++ b/test_program @@ -1,5 +1,2 @@ -id : Mono Int -> Mono Int ; -id = \x. x ; - -main : Poly a ; -main = id 3 ; +fun : Mono Int -> Mono Int ; +fun = \x. x ;