Simplified quite a bit. Made a unify function. Still bugs left

This commit is contained in:
sebastianselander 2023-02-17 11:09:48 +01:00
parent eafe0fea0b
commit a9f54dbca1
4 changed files with 192 additions and 209 deletions

View file

@ -3,12 +3,12 @@ function-arrows: trailing
comma-style: leading
import-export-style: diff-friendly
indent-wheres: false
record-brace-space: false
record-brace-space: true
newlines-between-decls: 1
haddock-style: multi-line
haddock-style-module:
let-style: auto
in-style: right-align
respectful: true
respectful: false
fixities: []
unicode: never

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings
ghc-options: -Wdefault
ghc-options: -W
executable language
import: warnings

View file

@ -1,26 +1,30 @@
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module TypeChecker.TypeChecker where
import Control.Monad (when, void)
import Control.Monad.Except (ExceptT, throwError, runExceptT)
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.ErrM (Err)
import Grammar.Print
import Data.List (findIndex)
import Debug.Trace (trace)
import Data.Map qualified as M
import TypeChecker.TypeCheckerIr
data Ctx = Ctx { vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Ident
}
deriving Show
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Int
}
deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
{-
@ -28,18 +32,20 @@ The type checker will assume we first rename all variables to unique name, as to
have to care about scoping. It significantly improves the quality of life of the
programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
Generalize. Not really sure what that means though
-}
type Infer = StateT Ctx (ExceptT Error Identity)
initEnv :: Ctx
initEnv = Ctx mempty mempty "a"
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT initEnv
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do
@ -54,122 +60,83 @@ inferBind (RBind name e) = do
inferExp :: RExp -> Infer (Type, TExp)
inferExp = \case
RAnn expr typ -> do
(t,expr') <- inferExp expr
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn")
return (typ,expr')
-- Name is only here for proper error messages
RBound num name ->
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return (t, TBound num name t)
(t, expr') <- inferExp expr
void $ t =:= typ
return (typ, expr')
RBound num name -> do
t <- lookupVars num
return (t, TBound num name t)
RFree name -> do
M.lookup name <$> St.gets sigs >>= \case
Nothing -> throwError $ UnboundVar "RFree"
Just t -> return (t, TFree name t)
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
fvar <- fresh
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
specifyType expr1 newType
typ1' <- apply newType typ1
return $ (typ1', TApp expr1' expr2' typ1')
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
t <- lookupSigs name
return (t, TFree name t)
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RAdd expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd")
specifyType expr1 (TMono "Int")
specifyType expr2 (TMono "Int")
return (TMono "Int", TAdd expr1' expr2' (TMono "Int"))
(typ1, expr1') <- check expr1 (TMono "Int")
(_, expr2') <- check expr2 (TMono "Int")
return (typ1, TAdd expr1' expr2' typ1)
RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1
(arg_t, expr2') <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
RAbs num name expr -> do
insertVars num (TPoly "a")
arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr
newTyp <- lookupVars num
return $ (TArrow newTyp typ, TAbs num name expr' typ)
return (TArrow arg typ, TAbs num name expr' typ)
-- Aux
isInt :: Type -> Bool
isInt (TMono "Int") = True
isInt _ = False
check :: RExp -> Type -> Infer (Type, TExp)
check e t = do
(t', e') <- inferExp e
t'' <- t' =:= t
return (t'', e')
isArrow :: Type -> Bool
isArrow (TArrow _ _) = True
isArrow _ = False
isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fresh :: Infer Ident
fresh :: Infer Type
fresh = do
(Ident var) <- St.gets nextFresh
when (length var == 0) (throwError $ Default "fresh")
index <- case findIndex (== (head var)) alphabet of
Nothing -> throwError $ Default "fresh"
Just i -> return i
let nextIndex = (index + 1) `mod` 26
let newVar = Ident $ [alphabet !! nextIndex]
St.modify (\st -> st { nextFresh = newVar })
return newVar
where
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
var <- St.gets nextFresh
St.modify (\st -> st {nextFresh = succ var})
return (TPoly $ Ident (show var))
unify :: Type -> Type -> Infer Type
unify = todo
-- | Specify the type of a bound variable
-- Because in lambdas we have to assume a general type and update it
specifyType :: RExp -> Type -> Infer ()
specifyType (RBound num name) typ = do
insertVars num typ
return ()
specifyType _ _ = return ()
-- | Unify two types.
(=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
(=:=) a (TPoly _) = return a
(=:=) (TMono a) (TMono b) | a == b = return (TMono a)
(=:=) (TArrow a b) (TArrow c d) = do
t1 <- a =:= c
t2 <- b =:= d
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put ( st { vars = M.insert i t st.vars } )
St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put ( st { sigs = M.insert i t st.sigs } )
-- Have to figure out the equivalence classes for types.
-- Currently this does not support more than exact matches.
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
apply t1 t2 = throwError $ TypeMismatch "apply"
St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
@ -183,21 +150,30 @@ data Error
| UnboundVar String
| AnnotatedMismatch String
| Default String
deriving Show
deriving (Show)
-- Tests
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x"))
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String")))
-- (\x. x + 1) 1
app_lambda :: RExp
app_lambda = app lambda one
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x"))))
lambda :: RExp
lambda = RAbs 0 "x" $ add bound one
add :: RExp -> RExp -> RExp
add = RAdd
--add x = \y. x+y;
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
-- main = (\z. z+z) ((add 4) 6);
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd
four = RConst (CInt 4)
six = RConst (CInt 6)
applyAdd = (RApp (RApp add four) six)
partialAdd = RApp add four
bound = RBound 0 "x"
app :: RExp -> RExp -> RExp
app = RApp
one :: RExp
one = RConst (CInt 1)
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
arr_t = TArrow (TMono "Int") (TPoly "1")
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))

View file

@ -1,28 +1,31 @@
{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where
import Renamer.Renamer
import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..))
import qualified Renamer.RenamerIr as R
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Except
import Data.Functor.Identity
import Control.Arrow ((>>>))
import Control.Unification hiding ((=:=), applyBindings)
import qualified Control.Unification as U
import Control.Unification.IntVar
import Data.Functor.Fixedpoint
import GHC.Generics (Generic1)
import Data.Foldable (fold)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromMaybe, fromJust)
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Debug.Trace (trace)
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import Renamer.RenamerIr qualified as R
type Ctx = Map Ident UPolytype
@ -39,12 +42,13 @@ instance Show a => Show (TypeT a) where
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
type Type = Fix TypeT
type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t
deriving (Eq, Show, Functor)
deriving (Eq, Show, Functor)
type Polytype = Poly Type
type Polytype = Poly Type
type UPolytype = Poly UType
@ -67,23 +71,23 @@ pattern UTPoly :: Ident -> UType
pattern UTPoly v = UTerm (TPolyT v)
data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType
deriving Show
deriving (Show)
data Program = Program [Bind]
deriving Show
newtype Program = Program [Bind]
deriving (Show)
data Bind = Bind Ident Exp Polytype
deriving Show
deriving (Show)
data Exp
= EAnn Exp Polytype
| EBound Ident Polytype
| EFree Ident Polytype
| EFree Ident Polytype
| EConst Const Polytype
| EApp Exp Exp Polytype
| EAdd Exp Exp Polytype
| EAbs Ident Exp Polytype
deriving Show
deriving (Show)
data TExp
= TAnn TExp UType
@ -93,11 +97,11 @@ data TExp
| TApp TExp TExp UType
| TAdd TExp TExp UType
| TAbs Ident TExp UType
deriving Show
deriving (Show)
----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program
typecheck = run . inferProgram
typecheck = run . inferProgram
inferProgram :: RProgram -> Infer Program
inferProgram (RProgram binds) = do
@ -106,7 +110,7 @@ inferProgram (RProgram binds) = do
inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do
(t,e') <- infer e
(t, e') <- infer e
e'' <- convert fromUType e'
t' <- fromUType t
insertSigs i (Forall [] t)
@ -114,20 +118,19 @@ inferBind (RBind i e) = do
fromUType :: UType -> Infer Polytype
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case
(TAnn e t) -> do
e' <- convert f e
t' <- (f t)
return $ EAnn e' t'
EAnn e' <$> f t
(TFree i t) -> do
t' <- f t
return $ EFree i t'
(TBound i t) -> do
t' <- f t
return $ EBound i t'
(TConst c t) -> do
(TConst c t) -> do
t' <- f t
return $ EConst c t'
(TApp e1 e2 t) -> do
@ -135,42 +138,43 @@ convert f = \case
e2' <- convert f e2
t' <- f t
return $ EApp e1' e2' t'
(TAdd e1 e2 t) -> do
(TAdd e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EAdd e1' e2' t'
(TAbs i e t) -> do
(TAbs i e t) -> do
e' <- convert f e
t' <- f t
return $ EAbs i e' t'
run :: Infer a -> Either TypeError a
run = flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
run =
flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
infer :: RExp -> Infer (UType, TExp)
infer = \case
(RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String"))
(RConst (CInt i)) -> return (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return (UTMono "String", TConst (CStr str) (UTMono "String"))
(RAdd e1 e2) -> do
(t1,e1') <- infer e2
(t2,e2') <- infer e1
t1 =:= (UTMono "Int")
t2 =:= (UTMono "Int")
return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(t1, e1') <- infer e2
(t2, e2') <- infer e1
t1 =:= UTMono "Int"
t2 =:= UTMono "Int"
return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(RAnn e t) -> do
(t',e') <- infer e
(t', e') <- infer e
check e t'
return (t', TAnn e' t')
(RApp e1 e2) -> do
(f,e1') <- infer e1
(arg,e2') <- infer e2
(f, e1') <- infer e1
(arg, e2') <- infer e2
res <- fresh
f =:= UTArrow f arg
f =:= UTArrow arg res
return (res, TApp e1' e2' res)
(RAbs _ i e) -> do
arg <- fresh
@ -199,8 +203,8 @@ lookupSigsT :: Ident -> Infer UType
lookupSigsT x@(Ident i) = do
ctx <- ask
case M.lookup x ctx of
Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i)
Just ut -> return $ fromPolytype ut
Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i)
Just ut -> return $ fromPolytype ut
insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m ()
insertSigs x ty = modify (M.insert x ty)
@ -218,21 +222,23 @@ withBinding x ty = local (M.insert x ty)
deriving instance Ord IntVar
class FreeVars a where
freeVars :: a -> Infer (Set (Either Ident IntVar))
freeVars :: a -> Infer (Set (Either Ident IntVar))
instance FreeVars UType where
freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
let ftvs = ucata (const S.empty)
(\case {TMonoT x -> S.singleton (Left x); f -> fold f})
ut
return $ fuvs `S.union` ftvs
freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
let ftvs =
ucata
(const S.empty)
(\case TMonoT x -> S.singleton (Left x); f -> fold f)
ut
return $ fuvs `S.union` ftvs
instance FreeVars UPolytype where
freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut
freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut
instance FreeVars Ctx where
freeVars = fmap S.unions . mapM freeVars . M.elems
freeVars = fmap S.unions . mapM freeVars . M.elems
fresh :: Infer UType
fresh = UVar <$> lift (lift (lift freeVar))
@ -249,21 +255,22 @@ applyBindings = lift . lift . U.applyBindings
instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty
substU :: Map (Either Ident IntVar) UType -> UType -> UType
substU m = ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
(\case
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
f -> UTerm f
)
substU m =
ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
( \case
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
f -> UTerm f
)
skolemize :: UPolytype -> Infer UType
skolemize (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
where
toSkolem (UVar v) = UTPoly (mkVarName "s" v)
@ -272,13 +279,13 @@ mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
generalize :: UType -> Infer UPolytype
generalize uty = do
uty' <- applyBindings uty
ctx <- ask
tmfvs <- freeVars uty'
ctxfvs <- freeVars ctx
let fvs = S.toList $ tmfvs \\ ctxfvs
xs = map (either id (mkVarName "a")) fvs
return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty')
uty' <- applyBindings uty
ctx <- ask
tmfvs <- freeVars uty'
ctxfvs <- freeVars ctx
let fvs = S.toList $ tmfvs \\ ctxfvs
xs = map (either id (mkVarName "a")) fvs
return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty')
fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze)