Unification part works (probably). Have a hard time understanding it.

This commit is contained in:
sebastianselander 2023-02-17 18:42:50 +01:00
parent 764faa582b
commit f188cffb8d
7 changed files with 167 additions and 197 deletions

View file

@ -1,21 +1,22 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker where
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import Data.Map qualified as M
import TypeChecker.TypeCheckerIr
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import TypeChecker.TypeCheckerIr
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident (RBind, Maybe Type)
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Int
}
deriving (Show)
@ -38,70 +39,54 @@ TODOs:
-}
typecheck :: RProgram -> Either Error TProgram
typecheck = todo
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do
xs' <- mapM inferBind xs
return $ TProgram xs'
-- Binds are not correctly added to the context.
-- Can't type check programs with more than one function currently
inferBind :: RBind -> Infer TBind
inferBind b@(RBind name e) = do
insertSigs name b Nothing
(t, e') <- inferExp e
return $ TBind name t e'
-- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- { \x. \y. x + y } will have the type { a -> b -> Int }
inferExp :: RExp -> Infer (Type, TExp)
inferExp :: RExp -> Infer Type
inferExp = \case
RAnn expr typ -> do
(t, expr') <- inferExp expr
t <- inferExp expr
void $ t =:= typ
return (typ, expr')
RBound num name -> do
t <- lookupVars num
return (t, TBound num name t)
RFree name -> do
(b@(RBind name _), t) <- lookupSigs name
t' <- case t of
Nothing -> do
(TBind _ a _) <- inferBind b
insertSigs name b (Just a)
return a
Just a -> return a
return (t', TFree name t')
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
return t
RBound num name -> lookupVars num
RFree name -> lookupSigs name
RConst (CInt i) -> return $ TMono "Int"
RConst (CStr str) -> return $ TMono "Str"
RAdd expr1 expr2 -> do
(typ1, expr1') <- check expr1 (TMono "Int")
(_, expr2') <- check expr2 (TMono "Int")
return (typ1, TAdd expr1' expr2' typ1)
let int = TMono "Int"
typ1 <- check expr1 int
typ2 <- check expr2 int
return int
RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1
(arg_t, expr2') <- inferExp expr2
fn_t <- inferExp expr1
arg_t <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
return res
RAbs num name expr -> do
arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr
return (TArrow arg typ, TAbs num name expr' typ)
typ <- inferExp expr
return $ TArrow arg typ
check :: RExp -> Type -> Infer (Type, TExp)
check :: RExp -> Type -> Infer ()
check e t = do
(t', e') <- inferExp e
t'' <- t' =:= t
return (t'', e')
t' <- inferExp e
t =:= t'
return ()
fresh :: Infer Type
fresh = do
@ -120,30 +105,29 @@ fresh = do
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer (RBind, Maybe Type)
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> RBind -> Maybe Type -> Infer ()
insertSigs i b t = do
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put (st {sigs = M.insert i (b, t) st.sigs})
St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
@ -158,3 +142,12 @@ data Error
| AnnotatedMismatch String
| Default String
deriving (Show)
{-
The procedure inst(σ) specializes the polytype
σ by copying the term and replacing the bound type variables
consistently by new monotype variables.
-}

View file

@ -1,21 +1,21 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr ( TProgram(..)
, TBind(..)
, TExp(..)
, RProgram(..)
, RBind(..)
, RExp(..)
, Type(..)
, Const(..)
, Ident(..)
)
where
module TypeChecker.TypeCheckerIr (
TProgram (..),
TBind (..),
TExp (..),
RProgram (..),
RBind (..),
RExp (..),
Type (..),
Const (..),
Ident (..),
) where
import Renamer.RenamerIr
import Grammar.Print
import Grammar.Print
import Renamer.RenamerIr
data TProgram = TProgram [TBind]
newtype TProgram = TProgram [TBind]
deriving (Eq, Show, Read, Ord)
data TBind = TBind Ident Type TExp
@ -50,21 +50,25 @@ instance Print TBind where
instance Print TExp where
prt i = \case
TAnn e t -> prPrec i 2 $ concatD
[ prt 0 e
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD [ prt 0 u ]
TFree u t -> prPrec i 3 $ concatD [ prt 0 u ]
TAnn e t ->
prPrec i 2 $
concatD
[ prt 0 e
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
TFree u t -> prPrec i 3 $ concatD [prt 0 u]
TConst c _ -> prPrec i 3 (concatD [prt 0 c])
TApp e e1 t -> prPrec i 2 $ concatD [ prt 2 e , prt 3 e1 ]
TAdd e e1 t -> prPrec i 1 $ concatD [ prt 1 e , doc (showString "+") , prt 2 e1 ]
TAbs _ u e t -> prPrec i 0 $ concatD
[ doc (showString "(")
, doc (showString "λ")
, prt 0 u
, doc (showString ".")
, prt 0 e
, doc (showString ")")
]
TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
TAbs _ u e t ->
prPrec i 0 $
concatD
[ doc (showString "(")
, doc (showString "λ")
, prt 0 u
, doc (showString ".")
, prt 0 e
, doc (showString ")")
]

View file

@ -1,31 +1,31 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import Renamer.RenamerIr qualified as R
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=))
import qualified Control.Unification as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import qualified Renamer.RenamerIr as R
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..),
RExp (..), RProgram (..))
type Ctx = Map Ident UPolytype
@ -37,7 +37,7 @@ data TypeT a = TPolyT Ident | TMonoT Ident | TArrowT a a
instance Show a => Show (TypeT a) where
show (TPolyT (Ident i)) = i
show (TMonoT (Ident i)) = i
show (TArrowT a b) = show a ++ " -> " ++ show b
show (TArrowT a b) = show a ++ " -> " ++ show b
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
@ -46,7 +46,10 @@ type Type = Fix TypeT
type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t
deriving (Eq, Show, Functor)
deriving (Eq, Functor)
instance Show t => Show (Poly t) where
show (Forall is t) = unwords (map (\(Ident x) -> "forall " ++ x ++ ".") is) ++ " " ++ show t
type Polytype = Poly Type
@ -101,60 +104,17 @@ data TExp
----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program
typecheck = run . inferProgram
typecheck = undefined
inferProgram :: RProgram -> Infer Program
inferProgram (RProgram binds) = do
binds' <- mapM inferBind binds
return $ Program binds'
inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do
(t, e') <- infer e
e'' <- convert fromUType e'
t' <- fromUType t
insertSigs i (Forall [] t)
return $ Bind i e'' t'
fromUType :: UType -> Infer Polytype
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case
(TAnn e t) -> do
e' <- convert f e
EAnn e' <$> f t
(TFree i t) -> do
t' <- f t
return $ EFree i t'
(TBound i t) -> do
t' <- f t
return $ EBound i t'
(TConst c t) -> do
t' <- f t
return $ EConst c t'
(TApp e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EApp e1' e2' t'
(TAdd e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EAdd e1' e2' t'
(TAbs i e t) -> do
e' <- convert f e
t' <- f t
return $ EAbs i e' t'
run :: Infer a -> Either TypeError a
run =
flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
run :: Infer (UType, TExp) -> Either TypeError Polytype
run = fmap fst
>>> (>>= applyBindings)
>>> (>>= (generalize >>> fmap fromUPolytype))
>>> flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
infer :: RExp -> Infer (UType, TExp)
infer = \case
@ -166,6 +126,7 @@ infer = \case
t1 =:= UTMono "Int"
t2 =:= UTMono "Int"
return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
-- type is not used, probably wrong
(RAnn e t) -> do
(t', e') <- infer e
check e t'
@ -180,7 +141,7 @@ infer = \case
arg <- fresh
withBinding i (Forall [] arg) $ do
(res, e') <- infer e
return $ (UTArrow arg res, TAbs i e' (UTArrow arg res))
return (UTArrow arg res, TAbs i e' (UTArrow arg res))
(RFree i) -> do
t <- lookupSigsT i
return (t, TFree i t)
@ -213,7 +174,7 @@ fromPolytype :: UPolytype -> UType
fromPolytype (Forall ids ut) = ut
ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a
ucata f _ (UVar v) = f v
ucata f _ (UVar v) = f v
ucata f g (UTerm t) = g (fmap (ucata f g) t)
withBinding :: MonadReader Ctx m => Ident -> UPolytype -> m a -> m a
@ -277,6 +238,7 @@ skolemize (Forall xs uty) = do
mkVarName :: String -> IntVar -> Ident
mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
-- | Used in let bindings to generalize functions declared there
generalize :: UType -> Infer UPolytype
generalize uty = do
uty' <- applyBindings uty
@ -289,3 +251,11 @@ generalize uty = do
fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze)
inf = RAbs 0 "x" (RApp (RBound 0 "x") (RBound 0 "x"))
one = RConst (CInt 1)
lambda = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))
fn = RAbs 0 "x" (RBound 0 "x")