Simplified quite a bit. Made a unify function. Still bugs left

This commit is contained in:
sebastianselander 2023-02-17 11:09:48 +01:00
parent eafe0fea0b
commit a9f54dbca1
4 changed files with 192 additions and 209 deletions

View file

@ -3,12 +3,12 @@ function-arrows: trailing
comma-style: leading
import-export-style: diff-friendly
indent-wheres: false
record-brace-space: false
record-brace-space: true
newlines-between-decls: 1
haddock-style: multi-line
haddock-style-module:
let-style: auto
in-style: right-align
respectful: true
respectful: false
fixities: []
unicode: never

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings
ghc-options: -Wdefault
ghc-options: -W
executable language
import: warnings

View file

@ -1,26 +1,30 @@
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module TypeChecker.TypeChecker where
import Control.Monad (when, void)
import Control.Monad.Except (ExceptT, throwError, runExceptT)
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.ErrM (Err)
import Grammar.Print
import Data.List (findIndex)
import Debug.Trace (trace)
import Data.Map qualified as M
import TypeChecker.TypeCheckerIr
data Ctx = Ctx { vars :: Map Integer Type
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Ident
, nextFresh :: Int
}
deriving Show
deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
{-
@ -28,15 +32,17 @@ The type checker will assume we first rename all variables to unique name, as to
have to care about scoping. It significantly improves the quality of life of the
programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
Generalize. Not really sure what that means though
-}
type Infer = StateT Ctx (ExceptT Error Identity)
initEnv :: Ctx
initEnv = Ctx mempty mempty "a"
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT initEnv
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
@ -54,92 +60,60 @@ inferBind (RBind name e) = do
inferExp :: RExp -> Infer (Type, TExp)
inferExp = \case
RAnn expr typ -> do
(t,expr') <- inferExp expr
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn")
return (typ,expr')
-- Name is only here for proper error messages
RBound num name ->
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return (t, TBound num name t)
(t, expr') <- inferExp expr
void $ t =:= typ
return (typ, expr')
RBound num name -> do
t <- lookupVars num
return (t, TBound num name t)
RFree name -> do
M.lookup name <$> St.gets sigs >>= \case
Nothing -> throwError $ UnboundVar "RFree"
Just t -> return (t, TFree name t)
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
fvar <- fresh
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
specifyType expr1 newType
typ1' <- apply newType typ1
return $ (typ1', TApp expr1' expr2' typ1')
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
t <- lookupSigs name
return (t, TFree name t)
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RAdd expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd")
specifyType expr1 (TMono "Int")
specifyType expr2 (TMono "Int")
return (TMono "Int", TAdd expr1' expr2' (TMono "Int"))
(typ1, expr1') <- check expr1 (TMono "Int")
(_, expr2') <- check expr2 (TMono "Int")
return (typ1, TAdd expr1' expr2' typ1)
RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1
(arg_t, expr2') <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
RAbs num name expr -> do
insertVars num (TPoly "a")
arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr
newTyp <- lookupVars num
return $ (TArrow newTyp typ, TAbs num name expr' typ)
return (TArrow arg typ, TAbs num name expr' typ)
-- Aux
isInt :: Type -> Bool
isInt (TMono "Int") = True
isInt _ = False
check :: RExp -> Type -> Infer (Type, TExp)
check e t = do
(t', e') <- inferExp e
t'' <- t' =:= t
return (t'', e')
isArrow :: Type -> Bool
isArrow (TArrow _ _) = True
isArrow _ = False
isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fresh :: Infer Ident
fresh :: Infer Type
fresh = do
(Ident var) <- St.gets nextFresh
when (length var == 0) (throwError $ Default "fresh")
index <- case findIndex (== (head var)) alphabet of
Nothing -> throwError $ Default "fresh"
Just i -> return i
let nextIndex = (index + 1) `mod` 26
let newVar = Ident $ [alphabet !! nextIndex]
St.modify (\st -> st { nextFresh = newVar })
return newVar
where
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
var <- St.gets nextFresh
St.modify (\st -> st {nextFresh = succ var})
return (TPoly $ Ident (show var))
unify :: Type -> Type -> Infer Type
unify = todo
-- | Specify the type of a bound variable
-- Because in lambdas we have to assume a general type and update it
specifyType :: RExp -> Type -> Infer ()
specifyType (RBound num name) typ = do
insertVars num typ
return ()
specifyType _ _ = return ()
-- | Unify two types.
(=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
(=:=) a (TPoly _) = return a
(=:=) (TMono a) (TMono b) | a == b = return (TMono a)
(=:=) (TArrow a b) (TArrow c d) = do
t1 <- a =:= c
t2 <- b =:= d
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
@ -150,7 +124,7 @@ lookupVars i = do
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put ( st { vars = M.insert i t st.vars } )
St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
@ -162,14 +136,7 @@ lookupSigs i = do
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put ( st { sigs = M.insert i t st.sigs } )
-- Have to figure out the equivalence classes for types.
-- Currently this does not support more than exact matches.
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
apply t1 t2 = throwError $ TypeMismatch "apply"
St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
@ -183,21 +150,30 @@ data Error
| UnboundVar String
| AnnotatedMismatch String
| Default String
deriving Show
deriving (Show)
-- Tests
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x"))
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String")))
-- (\x. x + 1) 1
app_lambda :: RExp
app_lambda = app lambda one
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x"))))
lambda :: RExp
lambda = RAbs 0 "x" $ add bound one
add :: RExp -> RExp -> RExp
add = RAdd
--add x = \y. x+y;
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
-- main = (\z. z+z) ((add 4) 6);
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd
four = RConst (CInt 4)
six = RConst (CInt 6)
applyAdd = (RApp (RApp add four) six)
partialAdd = RApp add four
bound = RBound 0 "x"
app :: RExp -> RExp -> RExp
app = RApp
one :: RExp
one = RConst (CInt 1)
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
arr_t = TArrow (TMono "Int") (TPoly "1")
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))

View file

@ -1,28 +1,31 @@
{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where
import Renamer.Renamer
import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..))
import qualified Renamer.RenamerIr as R
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Except
import Data.Functor.Identity
import Control.Arrow ((>>>))
import Control.Unification hiding ((=:=), applyBindings)
import qualified Control.Unification as U
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Functor.Fixedpoint
import GHC.Generics (Generic1)
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromMaybe, fromJust)
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Data.Set qualified as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import Renamer.RenamerIr qualified as R
type Ctx = Map Ident UPolytype
@ -39,6 +42,7 @@ instance Show a => Show (TypeT a) where
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
type Type = Fix TypeT
type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t
@ -67,13 +71,13 @@ pattern UTPoly :: Ident -> UType
pattern UTPoly v = UTerm (TPolyT v)
data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType
deriving Show
deriving (Show)
data Program = Program [Bind]
deriving Show
newtype Program = Program [Bind]
deriving (Show)
data Bind = Bind Ident Exp Polytype
deriving Show
deriving (Show)
data Exp
= EAnn Exp Polytype
@ -83,7 +87,7 @@ data Exp
| EApp Exp Exp Polytype
| EAdd Exp Exp Polytype
| EAbs Ident Exp Polytype
deriving Show
deriving (Show)
data TExp
= TAnn TExp UType
@ -93,7 +97,7 @@ data TExp
| TApp TExp TExp UType
| TAdd TExp TExp UType
| TAbs Ident TExp UType
deriving Show
deriving (Show)
----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program
@ -106,7 +110,7 @@ inferProgram (RProgram binds) = do
inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do
(t,e') <- infer e
(t, e') <- infer e
e'' <- convert fromUType e'
t' <- fromUType t
insertSigs i (Forall [] t)
@ -119,8 +123,7 @@ convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case
(TAnn e t) -> do
e' <- convert f e
t' <- (f t)
return $ EAnn e' t'
EAnn e' <$> f t
(TFree i t) -> do
t' <- f t
return $ EFree i t'
@ -146,7 +149,8 @@ convert f = \case
return $ EAbs i e' t'
run :: Infer a -> Either TypeError a
run = flip evalStateT mempty
run =
flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
@ -154,23 +158,23 @@ run = flip evalStateT mempty
infer :: RExp -> Infer (UType, TExp)
infer = \case
(RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String"))
(RConst (CInt i)) -> return (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return (UTMono "String", TConst (CStr str) (UTMono "String"))
(RAdd e1 e2) -> do
(t1,e1') <- infer e2
(t2,e2') <- infer e1
t1 =:= (UTMono "Int")
t2 =:= (UTMono "Int")
return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(t1, e1') <- infer e2
(t2, e2') <- infer e1
t1 =:= UTMono "Int"
t2 =:= UTMono "Int"
return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(RAnn e t) -> do
(t',e') <- infer e
(t', e') <- infer e
check e t'
return (t', TAnn e' t')
(RApp e1 e2) -> do
(f,e1') <- infer e1
(arg,e2') <- infer e2
(f, e1') <- infer e1
(arg, e2') <- infer e2
res <- fresh
f =:= UTArrow f arg
f =:= UTArrow arg res
return (res, TApp e1' e2' res)
(RAbs _ i e) -> do
arg <- fresh
@ -223,8 +227,10 @@ class FreeVars a where
instance FreeVars UType where
freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
let ftvs = ucata (const S.empty)
(\case {TMonoT x -> S.singleton (Left x); f -> fold f})
let ftvs =
ucata
(const S.empty)
(\case TMonoT x -> S.singleton (Left x); f -> fold f)
ut
return $ fuvs `S.union` ftvs
@ -253,9 +259,10 @@ instantiate (Forall xs uty) = do
return $ substU (M.fromList (zip (map Left xs) xs')) uty
substU :: Map (Either Ident IntVar) UType -> UType -> UType
substU m = ucata
substU m =
ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
(\case
( \case
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
f -> UTerm f
)