181 lines
6.3 KiB
Haskell
181 lines
6.3 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
|
{-# HLINT ignore "Use traverse_" #-}
|
|
{-# LANGUAGE FlexibleInstances #-}
|
|
|
|
module TypeChecker.HM where
|
|
|
|
import Control.Monad.Except
|
|
import Control.Monad.State
|
|
import Data.Bifunctor (bimap, second)
|
|
import Data.Functor.Identity (Identity, runIdentity)
|
|
import Data.Map (Map)
|
|
import qualified Data.Map as M
|
|
|
|
import Grammar.Abs
|
|
import Grammar.Print
|
|
import qualified TypeChecker.HMIr as T
|
|
|
|
type Infer = StateT Ctx (ExceptT String Identity)
|
|
type Error = String
|
|
|
|
data Ctx = Ctx { constr :: Map Type Type
|
|
, vars :: Map Ident Type
|
|
, sigs :: Map Ident Type
|
|
, frsh :: Char }
|
|
deriving Show
|
|
|
|
runC :: Ctx -> Infer a -> Either String (a, Ctx)
|
|
runC c = runIdentity . runExceptT . flip runStateT c
|
|
|
|
run :: Infer a -> Either String a
|
|
run = runIdentity . runExceptT . flip evalStateT initC
|
|
|
|
initC :: Ctx
|
|
initC = Ctx M.empty M.empty M.empty 'a'
|
|
|
|
typecheck :: Program -> Either Error T.Program
|
|
typecheck = run . inferPrg
|
|
|
|
inferPrg :: Program -> Infer T.Program
|
|
inferPrg (Program bs) = do
|
|
traverse (\(Bind n t _ _ _) -> insertSig n t) bs
|
|
bs' <- mapM inferBind bs
|
|
return $ T.Program bs'
|
|
|
|
inferBind :: Bind -> Infer T.Bind
|
|
inferBind (Bind i t _ params rhs) = do
|
|
(t',e') <- inferExp (makeLambda rhs (reverse params))
|
|
when (t /= t') (throwError . unwords $ [ "Signature of function"
|
|
, show i
|
|
, "with type:"
|
|
, show t
|
|
, "does not match inferred type"
|
|
, show t'
|
|
, "of expression:"
|
|
, show e'])
|
|
return $ T.Bind (t,i) [] e'
|
|
|
|
makeLambda :: Exp -> [Ident] -> Exp
|
|
makeLambda = foldl (flip EAbs)
|
|
|
|
inferExp :: Exp -> Infer (Type, T.Exp)
|
|
inferExp e = do
|
|
(t, e') <- inferExp' e
|
|
t'' <- solveConstraints t
|
|
return (t'', replaceType t'' e')
|
|
|
|
where
|
|
inferExp' :: Exp -> Infer (Type, T.Exp)
|
|
inferExp' = \case
|
|
EAnn e t -> do
|
|
(t',e') <- inferExp' e
|
|
t'' <- solveConstraints t'
|
|
when (t'' /= t) (throwError "Annotated type and inferred type don't match")
|
|
return (t', e')
|
|
EInt i -> return (int, T.EInt int i)
|
|
EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i
|
|
EAdd e1 e2 -> do
|
|
insertSig "+" (TArr int (TArr int int))
|
|
inferExp' (EApp (EApp (EId "+") e1) e2)
|
|
EApp e1 e2 -> do
|
|
(t1, e1') <- inferExp' e1
|
|
(t2, e2') <- inferExp' e2
|
|
fr <- fresh
|
|
addConstraint t1 (TArr t2 fr)
|
|
return (fr, T.EApp fr e1' e2')
|
|
EAbs name e -> do
|
|
fr <- fresh
|
|
insertVar name fr
|
|
(ret_t,e') <- inferExp' e
|
|
t <- solveConstraints (TArr fr ret_t)
|
|
return (t, T.EAbs t name e')
|
|
ELet name e1 e2 -> error "Let expression not implemented yet"
|
|
|
|
replaceType :: Type -> T.Exp -> T.Exp
|
|
replaceType t = \case
|
|
T.EInt _ i -> T.EInt t i
|
|
T.EId _ i -> T.EId t i
|
|
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
|
T.EApp _ e1 e2 -> T.EApp t e1 e2
|
|
T.EAbs _ name e -> T.EAbs t name e
|
|
T.ELet _ name e1 e2 -> T.ELet t name e1 e2
|
|
|
|
isInt :: Type -> Bool
|
|
isInt (TMono "Int") = True
|
|
isInt _ = False
|
|
|
|
lookupVar :: Ident -> Infer Type
|
|
lookupVar i = do
|
|
st <- get
|
|
case M.lookup i (vars st) of
|
|
Just t -> return t
|
|
Nothing -> case M.lookup i (sigs st) of
|
|
Just t -> return t
|
|
Nothing -> throwError $ "Unbound variable or function" ++ printTree i
|
|
|
|
insertVar :: Ident -> Type -> Infer ()
|
|
insertVar s t = modify ( \st -> st { vars = M.insert s t (vars st) } )
|
|
|
|
insertSig :: Ident -> Type -> Infer ()
|
|
insertSig s t = modify ( \st -> st { sigs = M.insert s t (sigs st) } )
|
|
|
|
-- | Generate a new fresh variable and increment the state
|
|
fresh :: Infer Type
|
|
fresh = do
|
|
chr <- gets frsh
|
|
modify (\st -> st { frsh = succ chr })
|
|
return $ TPol (Ident [chr])
|
|
|
|
-- | Adds a constraint to the constraint set.
|
|
-- i.e: a = int -> b
|
|
-- b = int
|
|
-- thus when solving constraints it must be the case that
|
|
-- a = int -> int
|
|
addConstraint :: Type -> Type -> Infer ()
|
|
addConstraint t1 t2 = do
|
|
modify (\st -> st { constr = M.insert t1 t2 (constr st) })
|
|
|
|
-- | Given a type, solve the constraints and figure out the type that should be assigned to it.
|
|
solveConstraints :: Type -> Infer Type
|
|
solveConstraints t = do
|
|
c <- gets constr
|
|
v <- gets vars
|
|
xs <- solveAll (M.toList c)
|
|
modify (\st -> st { constr = M.fromList xs })
|
|
return $ subst t xs
|
|
|
|
-- | Substitute
|
|
subst :: Type -> [(Type, Type)] -> Type
|
|
subst t [] = t
|
|
subst (TArr t1 t2) (x:xs) = subst (TArr (replace x t1) (replace x t2)) xs
|
|
subst t (x:xs) = subst (replace x t) xs
|
|
|
|
-- | Given a set of constraints run the replacement on all of them, producing a new set of
|
|
-- replacements.
|
|
-- https://youtu.be/trmq3wYcUxU - good video for explanation
|
|
solveAll :: [(Type, Type)] -> Infer [(Type, Type)]
|
|
solveAll [] = return []
|
|
solveAll (x:xs) = case x of
|
|
(TArr t1 t2, TArr t3 t4) -> solveAll $ (t1,t3) : (t2,t4) : xs
|
|
(TArr t1 t2, b) -> fmap ((b, TArr t1 t2) :) $ solveAll $ solve (b, TArr t1 t2) xs
|
|
(a, TArr t1 t2) -> fmap ((a, TArr t1 t2) :) $ solveAll $ solve (a, TArr t1 t2) xs
|
|
(TMono a, TPol b) -> fmap ((TPol b, TMono a) :) $ solveAll $ solve (TPol b, TMono a) xs
|
|
(TPol a, TMono b) -> fmap ((TPol a, TMono b) :) $ solveAll $ solve (TPol a, TMono b) xs
|
|
(TPol a, TPol b) -> fmap ((TPol a, TPol b) :) $ solveAll $ solve (TPol a, TPol b) xs
|
|
(TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types"
|
|
|
|
solve :: (Type, Type) -> [(Type, Type)] -> [(Type, Type)]
|
|
solve x = map (both (replace x))
|
|
|
|
-- | Given a constraint (type, type) and a type, if the constraint matches the input
|
|
-- replace with the constrained type
|
|
replace :: (Type, Type) -> Type -> Type
|
|
replace a (TArr t1 t2) = TArr (replace a t1) (replace a t2)
|
|
replace (a,b) c = if a==c then b else c
|
|
|
|
both :: (a -> b) -> (a,a) -> (b,b)
|
|
both f = bimap f f
|
|
|
|
int = TMono "Int"
|