Simplified quite a bit. Made a unify function. Still bugs left
This commit is contained in:
parent
eafe0fea0b
commit
a9f54dbca1
4 changed files with 192 additions and 209 deletions
|
|
@ -1,26 +1,30 @@
|
|||
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use camelCase" #-}
|
||||
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
import Control.Monad (when, void)
|
||||
import Control.Monad.Except (ExceptT, throwError, runExceptT)
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.Except (ExceptT, runExceptT, throwError)
|
||||
import Control.Monad.State (StateT)
|
||||
import qualified Control.Monad.State as St
|
||||
import Control.Monad.State qualified as St
|
||||
import Data.Functor.Identity (Identity, runIdentity)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print
|
||||
import Data.List (findIndex)
|
||||
|
||||
import Debug.Trace (trace)
|
||||
import Data.Map qualified as M
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
data Ctx = Ctx { vars :: Map Integer Type
|
||||
, sigs :: Map Ident Type
|
||||
, nextFresh :: Ident
|
||||
}
|
||||
deriving Show
|
||||
data Ctx = Ctx
|
||||
{ vars :: Map Integer Type
|
||||
, sigs :: Map Ident Type
|
||||
, nextFresh :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
-- Perhaps swap over to reader monad instead for vars and sigs.
|
||||
type Infer = StateT Ctx (ExceptT Error Identity)
|
||||
|
||||
{-
|
||||
|
||||
|
|
@ -28,18 +32,20 @@ The type checker will assume we first rename all variables to unique name, as to
|
|||
have to care about scoping. It significantly improves the quality of life of the
|
||||
programmer.
|
||||
|
||||
TODOs:
|
||||
Add skolemization variables. i.e
|
||||
{ \x. 3 : forall a. a -> a }
|
||||
should not type check
|
||||
|
||||
Generalize. Not really sure what that means though
|
||||
|
||||
-}
|
||||
|
||||
type Infer = StateT Ctx (ExceptT Error Identity)
|
||||
|
||||
initEnv :: Ctx
|
||||
initEnv = Ctx mempty mempty "a"
|
||||
|
||||
run :: Infer a -> Either Error a
|
||||
run = runIdentity . runExceptT . flip St.evalStateT initEnv
|
||||
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
|
||||
|
||||
typecheck :: RProgram -> Either Error TProgram
|
||||
typecheck = run . inferPrg
|
||||
typecheck = run . inferPrg
|
||||
|
||||
inferPrg :: RProgram -> Infer TProgram
|
||||
inferPrg (RProgram xs) = do
|
||||
|
|
@ -54,122 +60,83 @@ inferBind (RBind name e) = do
|
|||
|
||||
inferExp :: RExp -> Infer (Type, TExp)
|
||||
inferExp = \case
|
||||
|
||||
RAnn expr typ -> do
|
||||
(t,expr') <- inferExp expr
|
||||
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn")
|
||||
return (typ,expr')
|
||||
|
||||
-- Name is only here for proper error messages
|
||||
RBound num name ->
|
||||
M.lookup num <$> St.gets vars >>= \case
|
||||
Nothing -> throwError $ UnboundVar "RBound"
|
||||
Just t -> return (t, TBound num name t)
|
||||
|
||||
(t, expr') <- inferExp expr
|
||||
void $ t =:= typ
|
||||
return (typ, expr')
|
||||
RBound num name -> do
|
||||
t <- lookupVars num
|
||||
return (t, TBound num name t)
|
||||
RFree name -> do
|
||||
M.lookup name <$> St.gets sigs >>= \case
|
||||
Nothing -> throwError $ UnboundVar "RFree"
|
||||
Just t -> return (t, TFree name t)
|
||||
|
||||
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
|
||||
|
||||
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
|
||||
|
||||
-- Should do proper unification using union-find. Some nice libs exist
|
||||
RApp expr1 expr2 -> do
|
||||
(typ1, expr1') <- inferExp expr1
|
||||
(typ2, expr2') <- inferExp expr2
|
||||
fvar <- fresh
|
||||
case typ1 of
|
||||
(TPoly (Ident x)) -> do
|
||||
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
|
||||
specifyType expr1 newType
|
||||
typ1' <- apply newType typ1
|
||||
return $ (typ1', TApp expr1' expr2' typ1')
|
||||
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
|
||||
|
||||
t <- lookupSigs name
|
||||
return (t, TFree name t)
|
||||
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
|
||||
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
|
||||
RAdd expr1 expr2 -> do
|
||||
(typ1, expr1') <- inferExp expr1
|
||||
(typ2, expr2') <- inferExp expr2
|
||||
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd")
|
||||
specifyType expr1 (TMono "Int")
|
||||
specifyType expr2 (TMono "Int")
|
||||
return (TMono "Int", TAdd expr1' expr2' (TMono "Int"))
|
||||
|
||||
(typ1, expr1') <- check expr1 (TMono "Int")
|
||||
(_, expr2') <- check expr2 (TMono "Int")
|
||||
return (typ1, TAdd expr1' expr2' typ1)
|
||||
RApp expr1 expr2 -> do
|
||||
(fn_t, expr1') <- inferExp expr1
|
||||
(arg_t, expr2') <- inferExp expr2
|
||||
res <- fresh
|
||||
-- TODO: Double check if this is correct behavior.
|
||||
-- It might be the case that we should return res, rather than new_t
|
||||
new_t <- fn_t =:= TArrow arg_t res
|
||||
return (new_t, TApp expr1' expr2' new_t)
|
||||
RAbs num name expr -> do
|
||||
insertVars num (TPoly "a")
|
||||
arg <- fresh
|
||||
insertVars num arg
|
||||
(typ, expr') <- inferExp expr
|
||||
newTyp <- lookupVars num
|
||||
return $ (TArrow newTyp typ, TAbs num name expr' typ)
|
||||
return (TArrow arg typ, TAbs num name expr' typ)
|
||||
|
||||
-- Aux
|
||||
isInt :: Type -> Bool
|
||||
isInt (TMono "Int") = True
|
||||
isInt _ = False
|
||||
check :: RExp -> Type -> Infer (Type, TExp)
|
||||
check e t = do
|
||||
(t', e') <- inferExp e
|
||||
t'' <- t' =:= t
|
||||
return (t'', e')
|
||||
|
||||
isArrow :: Type -> Bool
|
||||
isArrow (TArrow _ _) = True
|
||||
isArrow _ = False
|
||||
|
||||
isPoly :: Type -> Bool
|
||||
isPoly (TPoly _) = True
|
||||
isPoly _ = False
|
||||
|
||||
fresh :: Infer Ident
|
||||
fresh :: Infer Type
|
||||
fresh = do
|
||||
(Ident var) <- St.gets nextFresh
|
||||
when (length var == 0) (throwError $ Default "fresh")
|
||||
index <- case findIndex (== (head var)) alphabet of
|
||||
Nothing -> throwError $ Default "fresh"
|
||||
Just i -> return i
|
||||
let nextIndex = (index + 1) `mod` 26
|
||||
let newVar = Ident $ [alphabet !! nextIndex]
|
||||
St.modify (\st -> st { nextFresh = newVar })
|
||||
return newVar
|
||||
where
|
||||
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
|
||||
var <- St.gets nextFresh
|
||||
St.modify (\st -> st {nextFresh = succ var})
|
||||
return (TPoly $ Ident (show var))
|
||||
|
||||
unify :: Type -> Type -> Infer Type
|
||||
unify = todo
|
||||
|
||||
-- | Specify the type of a bound variable
|
||||
-- Because in lambdas we have to assume a general type and update it
|
||||
specifyType :: RExp -> Type -> Infer ()
|
||||
specifyType (RBound num name) typ = do
|
||||
insertVars num typ
|
||||
return ()
|
||||
specifyType _ _ = return ()
|
||||
-- | Unify two types.
|
||||
(=:=) :: Type -> Type -> Infer Type
|
||||
(=:=) (TPoly _) b = return b
|
||||
(=:=) a (TPoly _) = return a
|
||||
(=:=) (TMono a) (TMono b) | a == b = return (TMono a)
|
||||
(=:=) (TArrow a b) (TArrow c d) = do
|
||||
t1 <- a =:= c
|
||||
t2 <- b =:= d
|
||||
return $ TArrow t1 t2
|
||||
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
|
||||
|
||||
-- Unused currently
|
||||
lookupVars :: Integer -> Infer Type
|
||||
lookupVars i = do
|
||||
st <- St.gets vars
|
||||
case M.lookup i st of
|
||||
Just t -> return t
|
||||
Nothing -> throwError $ UnboundVar "lookupVars"
|
||||
Just t -> return t
|
||||
Nothing -> throwError $ UnboundVar "lookupVars"
|
||||
|
||||
insertVars :: Integer -> Type -> Infer ()
|
||||
insertVars i t = do
|
||||
st <- St.get
|
||||
St.put ( st { vars = M.insert i t st.vars } )
|
||||
St.put (st {vars = M.insert i t st.vars})
|
||||
|
||||
lookupSigs :: Ident -> Infer Type
|
||||
lookupSigs i = do
|
||||
st <- St.gets sigs
|
||||
case M.lookup i st of
|
||||
Just t -> return t
|
||||
Nothing -> throwError $ UnboundVar "lookupSigs"
|
||||
Just t -> return t
|
||||
Nothing -> throwError $ UnboundVar "lookupSigs"
|
||||
|
||||
insertSigs :: Ident -> Type -> Infer ()
|
||||
insertSigs i t = do
|
||||
st <- St.get
|
||||
St.put ( st { sigs = M.insert i t st.sigs } )
|
||||
|
||||
-- Have to figure out the equivalence classes for types.
|
||||
-- Currently this does not support more than exact matches.
|
||||
apply :: Type -> Type -> Infer Type
|
||||
apply (TArrow t1 t2) t3
|
||||
| t1 == t3 = return t2
|
||||
apply t1 t2 = throwError $ TypeMismatch "apply"
|
||||
St.put (st {sigs = M.insert i t st.sigs})
|
||||
|
||||
{-# WARNING todo "TODO IN CODE" #-}
|
||||
todo :: a
|
||||
|
|
@ -183,21 +150,30 @@ data Error
|
|||
| UnboundVar String
|
||||
| AnnotatedMismatch String
|
||||
| Default String
|
||||
deriving Show
|
||||
deriving (Show)
|
||||
|
||||
-- Tests
|
||||
|
||||
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x"))
|
||||
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String")))
|
||||
-- (\x. x + 1) 1
|
||||
app_lambda :: RExp
|
||||
app_lambda = app lambda one
|
||||
|
||||
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x"))))
|
||||
lambda :: RExp
|
||||
lambda = RAbs 0 "x" $ add bound one
|
||||
|
||||
add :: RExp -> RExp -> RExp
|
||||
add = RAdd
|
||||
|
||||
--add x = \y. x+y;
|
||||
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
|
||||
-- main = (\z. z+z) ((add 4) 6);
|
||||
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd
|
||||
four = RConst (CInt 4)
|
||||
six = RConst (CInt 6)
|
||||
applyAdd = (RApp (RApp add four) six)
|
||||
partialAdd = RApp add four
|
||||
bound = RBound 0 "x"
|
||||
|
||||
app :: RExp -> RExp -> RExp
|
||||
app = RApp
|
||||
|
||||
one :: RExp
|
||||
one = RConst (CInt 1)
|
||||
|
||||
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
|
||||
|
||||
arr_t = TArrow (TMono "Int") (TPoly "1")
|
||||
|
||||
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue