Rewrote using unification-fd. Heavily inspired (aka copied) from:
https://byorgey.wordpress.com/2021/09/08/implementing-hindley-milner-with-the-unification-fd-library/
This commit is contained in:
parent
f1b77a7efa
commit
eafe0fea0b
5 changed files with 314 additions and 21 deletions
|
|
@ -33,6 +33,7 @@ executable language
|
||||||
Grammar.ErrM
|
Grammar.ErrM
|
||||||
TypeChecker.TypeChecker
|
TypeChecker.TypeChecker
|
||||||
TypeChecker.TypeCheckerIr
|
TypeChecker.TypeCheckerIr
|
||||||
|
TypeChecker.Unification
|
||||||
Renamer.Renamer
|
Renamer.Renamer
|
||||||
Renamer.RenamerIr
|
Renamer.RenamerIr
|
||||||
|
|
||||||
|
|
@ -45,6 +46,6 @@ executable language
|
||||||
, either
|
, either
|
||||||
, extra
|
, extra
|
||||||
, array
|
, array
|
||||||
, equivalence
|
, unification-fd
|
||||||
|
|
||||||
default-language: GHC2021
|
default-language: GHC2021
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ import Grammar.Par (myLexer, pProgram)
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import System.Environment (getArgs)
|
import System.Environment (getArgs)
|
||||||
import System.Exit (exitFailure, exitSuccess)
|
import System.Exit (exitFailure, exitSuccess)
|
||||||
import TypeChecker.TypeChecker (typecheck)
|
-- import TypeChecker.TypeChecker (typecheck)
|
||||||
|
import TypeChecker.Unification (typecheck)
|
||||||
import Renamer.Renamer (rename)
|
import Renamer.Renamer (rename)
|
||||||
import Grammar.Print (prt)
|
import Grammar.Print (prt)
|
||||||
|
|
||||||
|
|
@ -43,4 +44,4 @@ main = getArgs >>= \case
|
||||||
putStrLn ""
|
putStrLn ""
|
||||||
putStrLn " ----- TYPECHECKER ----- "
|
putStrLn " ----- TYPECHECKER ----- "
|
||||||
putStrLn ""
|
putStrLn ""
|
||||||
putStrLn . printTree $ prg
|
putStrLn . show $ prg
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,14 @@ import Data.Map (Map)
|
||||||
import qualified Data.Map as M
|
import qualified Data.Map as M
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Grammar.Print
|
import Grammar.Print
|
||||||
|
import Data.List (findIndex)
|
||||||
|
|
||||||
import Debug.Trace (trace)
|
import Debug.Trace (trace)
|
||||||
import TypeChecker.TypeCheckerIr
|
import TypeChecker.TypeCheckerIr
|
||||||
|
|
||||||
data Ctx = Ctx { vars :: Map Integer Type
|
data Ctx = Ctx { vars :: Map Integer Type
|
||||||
, sigs :: Map Ident Type
|
, sigs :: Map Ident Type
|
||||||
, count :: Int
|
, nextFresh :: Ident
|
||||||
}
|
}
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
|
|
@ -32,7 +33,7 @@ programmer.
|
||||||
type Infer = StateT Ctx (ExceptT Error Identity)
|
type Infer = StateT Ctx (ExceptT Error Identity)
|
||||||
|
|
||||||
initEnv :: Ctx
|
initEnv :: Ctx
|
||||||
initEnv = Ctx mempty mempty 0
|
initEnv = Ctx mempty mempty "a"
|
||||||
|
|
||||||
run :: Infer a -> Either Error a
|
run :: Infer a -> Either Error a
|
||||||
run = runIdentity . runExceptT . flip St.evalStateT initEnv
|
run = runIdentity . runExceptT . flip St.evalStateT initEnv
|
||||||
|
|
@ -51,7 +52,6 @@ inferBind (RBind name e) = do
|
||||||
insertSigs name t
|
insertSigs name t
|
||||||
return $ TBind name t e'
|
return $ TBind name t e'
|
||||||
|
|
||||||
|
|
||||||
inferExp :: RExp -> Infer (Type, TExp)
|
inferExp :: RExp -> Infer (Type, TExp)
|
||||||
inferExp = \case
|
inferExp = \case
|
||||||
|
|
||||||
|
|
@ -79,14 +79,14 @@ inferExp = \case
|
||||||
RApp expr1 expr2 -> do
|
RApp expr1 expr2 -> do
|
||||||
(typ1, expr1') <- inferExp expr1
|
(typ1, expr1') <- inferExp expr1
|
||||||
(typ2, expr2') <- inferExp expr2
|
(typ2, expr2') <- inferExp expr2
|
||||||
cnt <- incCount
|
fvar <- fresh
|
||||||
case typ1 of
|
case typ1 of
|
||||||
(TPoly (Ident x)) -> do
|
(TPoly (Ident x)) -> do
|
||||||
let newType = (TArrow (TPoly (Ident x)) (TPoly . Ident $ x ++ (show cnt)))
|
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
|
||||||
specifyType expr1 newType
|
specifyType expr1 newType
|
||||||
typ1' <- apply newType typ1
|
typ1' <- apply newType typ1
|
||||||
return $ (typ1', TApp expr1' expr2' typ1')
|
return $ (typ1', TApp expr1' expr2' typ1')
|
||||||
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ2 typ1
|
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
|
||||||
|
|
||||||
RAdd expr1 expr2 -> do
|
RAdd expr1 expr2 -> do
|
||||||
(typ1, expr1') <- inferExp expr1
|
(typ1, expr1') <- inferExp expr1
|
||||||
|
|
@ -115,11 +115,22 @@ isPoly :: Type -> Bool
|
||||||
isPoly (TPoly _) = True
|
isPoly (TPoly _) = True
|
||||||
isPoly _ = False
|
isPoly _ = False
|
||||||
|
|
||||||
incCount :: Infer Int
|
fresh :: Infer Ident
|
||||||
incCount = do
|
fresh = do
|
||||||
st <- St.get
|
(Ident var) <- St.gets nextFresh
|
||||||
St.put ( st { count = succ st.count } )
|
when (length var == 0) (throwError $ Default "fresh")
|
||||||
return st.count
|
index <- case findIndex (== (head var)) alphabet of
|
||||||
|
Nothing -> throwError $ Default "fresh"
|
||||||
|
Just i -> return i
|
||||||
|
let nextIndex = (index + 1) `mod` 26
|
||||||
|
let newVar = Ident $ [alphabet !! nextIndex]
|
||||||
|
St.modify (\st -> st { nextFresh = newVar })
|
||||||
|
return newVar
|
||||||
|
where
|
||||||
|
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
|
||||||
|
|
||||||
|
unify :: Type -> Type -> Infer Type
|
||||||
|
unify = todo
|
||||||
|
|
||||||
-- | Specify the type of a bound variable
|
-- | Specify the type of a bound variable
|
||||||
-- Because in lambdas we have to assume a general type and update it
|
-- Because in lambdas we have to assume a general type and update it
|
||||||
|
|
@ -153,12 +164,6 @@ insertSigs i t = do
|
||||||
st <- St.get
|
st <- St.get
|
||||||
St.put ( st { sigs = M.insert i t st.sigs } )
|
St.put ( st { sigs = M.insert i t st.sigs } )
|
||||||
|
|
||||||
union :: Type -> Type -> Infer ()
|
|
||||||
union = todo
|
|
||||||
|
|
||||||
find :: Type -> Type
|
|
||||||
find = todo
|
|
||||||
|
|
||||||
-- Have to figure out the equivalence classes for types.
|
-- Have to figure out the equivalence classes for types.
|
||||||
-- Currently this does not support more than exact matches.
|
-- Currently this does not support more than exact matches.
|
||||||
apply :: Type -> Type -> Infer Type
|
apply :: Type -> Type -> Infer Type
|
||||||
|
|
|
||||||
284
src/TypeChecker/Unification.hs
Normal file
284
src/TypeChecker/Unification.hs
Normal file
|
|
@ -0,0 +1,284 @@
|
||||||
|
{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TypeChecker.Unification where
|
||||||
|
|
||||||
|
import Renamer.Renamer
|
||||||
|
import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..))
|
||||||
|
import qualified Renamer.RenamerIr as R
|
||||||
|
|
||||||
|
import Control.Monad.Reader
|
||||||
|
import Control.Monad.State
|
||||||
|
import Control.Monad.Except
|
||||||
|
import Data.Functor.Identity
|
||||||
|
import Control.Arrow ((>>>))
|
||||||
|
import Control.Unification hiding ((=:=), applyBindings)
|
||||||
|
import qualified Control.Unification as U
|
||||||
|
import Control.Unification.IntVar
|
||||||
|
import Data.Functor.Fixedpoint
|
||||||
|
import GHC.Generics (Generic1)
|
||||||
|
import Data.Foldable (fold)
|
||||||
|
import Data.Map (Map)
|
||||||
|
import qualified Data.Map as M
|
||||||
|
import Data.Maybe (fromMaybe, fromJust)
|
||||||
|
import Data.Set (Set, (\\))
|
||||||
|
import qualified Data.Set as S
|
||||||
|
import Debug.Trace (trace)
|
||||||
|
|
||||||
|
type Ctx = Map Ident UPolytype
|
||||||
|
|
||||||
|
type TypeError = String
|
||||||
|
|
||||||
|
data TypeT a = TPolyT Ident | TMonoT Ident | TArrowT a a
|
||||||
|
deriving (Functor, Foldable, Traversable, Generic1, Unifiable)
|
||||||
|
|
||||||
|
instance Show a => Show (TypeT a) where
|
||||||
|
show (TPolyT (Ident i)) = i
|
||||||
|
show (TMonoT (Ident i)) = i
|
||||||
|
show (TArrowT a b) = show a ++ " -> " ++ show b
|
||||||
|
|
||||||
|
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
|
||||||
|
|
||||||
|
type Type = Fix TypeT
|
||||||
|
type UType = UTerm TypeT IntVar
|
||||||
|
|
||||||
|
data Poly t = Forall [Ident] t
|
||||||
|
deriving (Eq, Show, Functor)
|
||||||
|
|
||||||
|
type Polytype = Poly Type
|
||||||
|
|
||||||
|
type UPolytype = Poly UType
|
||||||
|
|
||||||
|
pattern TPoly :: Ident -> Type
|
||||||
|
pattern TPoly v = Fix (TPolyT v)
|
||||||
|
|
||||||
|
pattern TMono :: Ident -> Type
|
||||||
|
pattern TMono v = Fix (TMonoT v)
|
||||||
|
|
||||||
|
pattern TArrow :: Type -> Type -> Type
|
||||||
|
pattern TArrow t1 t2 = Fix (TArrowT t1 t2)
|
||||||
|
|
||||||
|
pattern UTMono :: Ident -> UType
|
||||||
|
pattern UTMono v = UTerm (TMonoT v)
|
||||||
|
|
||||||
|
pattern UTArrow :: UType -> UType -> UType
|
||||||
|
pattern UTArrow t1 t2 = UTerm (TArrowT t1 t2)
|
||||||
|
|
||||||
|
pattern UTPoly :: Ident -> UType
|
||||||
|
pattern UTPoly v = UTerm (TPolyT v)
|
||||||
|
|
||||||
|
data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
data Program = Program [Bind]
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
data Bind = Bind Ident Exp Polytype
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
data Exp
|
||||||
|
= EAnn Exp Polytype
|
||||||
|
| EBound Ident Polytype
|
||||||
|
| EFree Ident Polytype
|
||||||
|
| EConst Const Polytype
|
||||||
|
| EApp Exp Exp Polytype
|
||||||
|
| EAdd Exp Exp Polytype
|
||||||
|
| EAbs Ident Exp Polytype
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
data TExp
|
||||||
|
= TAnn TExp UType
|
||||||
|
| TFree Ident UType
|
||||||
|
| TBound Ident UType
|
||||||
|
| TConst Const UType
|
||||||
|
| TApp TExp TExp UType
|
||||||
|
| TAdd TExp TExp UType
|
||||||
|
| TAbs Ident TExp UType
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
----------------------------------------------------------
|
||||||
|
typecheck :: RProgram -> Either TypeError Program
|
||||||
|
typecheck = run . inferProgram
|
||||||
|
|
||||||
|
inferProgram :: RProgram -> Infer Program
|
||||||
|
inferProgram (RProgram binds) = do
|
||||||
|
binds' <- mapM inferBind binds
|
||||||
|
return $ Program binds'
|
||||||
|
|
||||||
|
inferBind :: RBind -> Infer Bind
|
||||||
|
inferBind (RBind i e) = do
|
||||||
|
(t,e') <- infer e
|
||||||
|
e'' <- convert fromUType e'
|
||||||
|
t' <- fromUType t
|
||||||
|
insertSigs i (Forall [] t)
|
||||||
|
return $ Bind i e'' t'
|
||||||
|
|
||||||
|
fromUType :: UType -> Infer Polytype
|
||||||
|
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
|
||||||
|
|
||||||
|
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
|
||||||
|
convert f = \case
|
||||||
|
(TAnn e t) -> do
|
||||||
|
e' <- convert f e
|
||||||
|
t' <- (f t)
|
||||||
|
return $ EAnn e' t'
|
||||||
|
(TFree i t) -> do
|
||||||
|
t' <- f t
|
||||||
|
return $ EFree i t'
|
||||||
|
(TBound i t) -> do
|
||||||
|
t' <- f t
|
||||||
|
return $ EBound i t'
|
||||||
|
(TConst c t) -> do
|
||||||
|
t' <- f t
|
||||||
|
return $ EConst c t'
|
||||||
|
(TApp e1 e2 t) -> do
|
||||||
|
e1' <- convert f e1
|
||||||
|
e2' <- convert f e2
|
||||||
|
t' <- f t
|
||||||
|
return $ EApp e1' e2' t'
|
||||||
|
(TAdd e1 e2 t) -> do
|
||||||
|
e1' <- convert f e1
|
||||||
|
e2' <- convert f e2
|
||||||
|
t' <- f t
|
||||||
|
return $ EAdd e1' e2' t'
|
||||||
|
(TAbs i e t) -> do
|
||||||
|
e' <- convert f e
|
||||||
|
t' <- f t
|
||||||
|
return $ EAbs i e' t'
|
||||||
|
|
||||||
|
run :: Infer a -> Either TypeError a
|
||||||
|
run = flip evalStateT mempty
|
||||||
|
>>> flip runReaderT mempty
|
||||||
|
>>> runExceptT
|
||||||
|
>>> evalIntBindingT
|
||||||
|
>>> runIdentity
|
||||||
|
|
||||||
|
infer :: RExp -> Infer (UType, TExp)
|
||||||
|
infer = \case
|
||||||
|
(RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int"))
|
||||||
|
(RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String"))
|
||||||
|
(RAdd e1 e2) -> do
|
||||||
|
(t1,e1') <- infer e2
|
||||||
|
(t2,e2') <- infer e1
|
||||||
|
t1 =:= (UTMono "Int")
|
||||||
|
t2 =:= (UTMono "Int")
|
||||||
|
return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
|
||||||
|
(RAnn e t) -> do
|
||||||
|
(t',e') <- infer e
|
||||||
|
check e t'
|
||||||
|
return (t', TAnn e' t')
|
||||||
|
(RApp e1 e2) -> do
|
||||||
|
(f,e1') <- infer e1
|
||||||
|
(arg,e2') <- infer e2
|
||||||
|
res <- fresh
|
||||||
|
f =:= UTArrow f arg
|
||||||
|
return (res, TApp e1' e2' res)
|
||||||
|
(RAbs _ i e) -> do
|
||||||
|
arg <- fresh
|
||||||
|
withBinding i (Forall [] arg) $ do
|
||||||
|
(res, e') <- infer e
|
||||||
|
return $ (UTArrow arg res, TAbs i e' (UTArrow arg res))
|
||||||
|
(RFree i) -> do
|
||||||
|
t <- lookupSigsT i
|
||||||
|
return (t, TFree i t)
|
||||||
|
(RBound _ i) -> do
|
||||||
|
t <- lookupVarT i
|
||||||
|
return (t, TBound i t)
|
||||||
|
|
||||||
|
check :: RExp -> UType -> Infer ()
|
||||||
|
check expr t = do
|
||||||
|
(t', _) <- infer expr
|
||||||
|
t =:= t'
|
||||||
|
return ()
|
||||||
|
|
||||||
|
lookupVarT :: Ident -> Infer UType
|
||||||
|
lookupVarT x@(Ident i) = do
|
||||||
|
ctx <- ask
|
||||||
|
maybe (throwError $ "Var - Unbound variable: " <> i) instantiate (M.lookup x ctx)
|
||||||
|
|
||||||
|
lookupSigsT :: Ident -> Infer UType
|
||||||
|
lookupSigsT x@(Ident i) = do
|
||||||
|
ctx <- ask
|
||||||
|
case M.lookup x ctx of
|
||||||
|
Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i)
|
||||||
|
Just ut -> return $ fromPolytype ut
|
||||||
|
|
||||||
|
insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m ()
|
||||||
|
insertSigs x ty = modify (M.insert x ty)
|
||||||
|
|
||||||
|
fromPolytype :: UPolytype -> UType
|
||||||
|
fromPolytype (Forall ids ut) = ut
|
||||||
|
|
||||||
|
ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a
|
||||||
|
ucata f _ (UVar v) = f v
|
||||||
|
ucata f g (UTerm t) = g (fmap (ucata f g) t)
|
||||||
|
|
||||||
|
withBinding :: MonadReader Ctx m => Ident -> UPolytype -> m a -> m a
|
||||||
|
withBinding x ty = local (M.insert x ty)
|
||||||
|
|
||||||
|
deriving instance Ord IntVar
|
||||||
|
|
||||||
|
class FreeVars a where
|
||||||
|
freeVars :: a -> Infer (Set (Either Ident IntVar))
|
||||||
|
|
||||||
|
instance FreeVars UType where
|
||||||
|
freeVars ut = do
|
||||||
|
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
|
||||||
|
let ftvs = ucata (const S.empty)
|
||||||
|
(\case {TMonoT x -> S.singleton (Left x); f -> fold f})
|
||||||
|
ut
|
||||||
|
return $ fuvs `S.union` ftvs
|
||||||
|
|
||||||
|
instance FreeVars UPolytype where
|
||||||
|
freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut
|
||||||
|
|
||||||
|
instance FreeVars Ctx where
|
||||||
|
freeVars = fmap S.unions . mapM freeVars . M.elems
|
||||||
|
|
||||||
|
fresh :: Infer UType
|
||||||
|
fresh = UVar <$> lift (lift (lift freeVar))
|
||||||
|
|
||||||
|
instance Fallible TypeT IntVar TypeError where
|
||||||
|
occursFailure iv ut = "Infinite"
|
||||||
|
mismatchFailure iv ut = "Mismatch"
|
||||||
|
|
||||||
|
(=:=) :: UType -> UType -> Infer UType
|
||||||
|
(=:=) s t = lift . lift $ s U.=:= t
|
||||||
|
|
||||||
|
applyBindings :: UType -> Infer UType
|
||||||
|
applyBindings = lift . lift . U.applyBindings
|
||||||
|
|
||||||
|
instantiate :: UPolytype -> Infer UType
|
||||||
|
instantiate (Forall xs uty) = do
|
||||||
|
xs' <- mapM (const fresh) xs
|
||||||
|
return $ substU (M.fromList (zip (map Left xs) xs')) uty
|
||||||
|
|
||||||
|
substU :: Map (Either Ident IntVar) UType -> UType -> UType
|
||||||
|
substU m = ucata
|
||||||
|
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
|
||||||
|
(\case
|
||||||
|
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
|
||||||
|
f -> UTerm f
|
||||||
|
)
|
||||||
|
|
||||||
|
skolemize :: UPolytype -> Infer UType
|
||||||
|
skolemize (Forall xs uty) = do
|
||||||
|
xs' <- mapM (const fresh) xs
|
||||||
|
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
|
||||||
|
where
|
||||||
|
toSkolem (UVar v) = UTPoly (mkVarName "s" v)
|
||||||
|
|
||||||
|
mkVarName :: String -> IntVar -> Ident
|
||||||
|
mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
|
||||||
|
|
||||||
|
generalize :: UType -> Infer UPolytype
|
||||||
|
generalize uty = do
|
||||||
|
uty' <- applyBindings uty
|
||||||
|
ctx <- ask
|
||||||
|
tmfvs <- freeVars uty'
|
||||||
|
ctxfvs <- freeVars ctx
|
||||||
|
let fvs = S.toList $ tmfvs \\ ctxfvs
|
||||||
|
xs = map (either id (mkVarName "a")) fvs
|
||||||
|
return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty')
|
||||||
|
|
||||||
|
fromUPolytype :: UPolytype -> Polytype
|
||||||
|
fromUPolytype = fmap (fromJust . freeze)
|
||||||
|
|
@ -1 +1,3 @@
|
||||||
test f x = f x
|
apply w x = \y. \z. w + x + y + z ;
|
||||||
|
|
||||||
|
main = apply 1 2 3 4 ;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue