Simplified quite a bit. Made a unify function. Still bugs left

This commit is contained in:
sebastianselander 2023-02-17 11:09:48 +01:00
parent eafe0fea0b
commit a9f54dbca1
4 changed files with 192 additions and 209 deletions

View file

@ -1,26 +1,30 @@
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module TypeChecker.TypeChecker where
import Control.Monad (when, void)
import Control.Monad.Except (ExceptT, throwError, runExceptT)
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.ErrM (Err)
import Grammar.Print
import Data.List (findIndex)
import Debug.Trace (trace)
import Data.Map qualified as M
import TypeChecker.TypeCheckerIr
data Ctx = Ctx { vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Ident
}
deriving Show
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Int
}
deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
{-
@ -28,18 +32,20 @@ The type checker will assume we first rename all variables to unique name, as to
have to care about scoping. It significantly improves the quality of life of the
programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
Generalize. Not really sure what that means though
-}
type Infer = StateT Ctx (ExceptT Error Identity)
initEnv :: Ctx
initEnv = Ctx mempty mempty "a"
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT initEnv
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do
@ -54,122 +60,83 @@ inferBind (RBind name e) = do
inferExp :: RExp -> Infer (Type, TExp)
inferExp = \case
RAnn expr typ -> do
(t,expr') <- inferExp expr
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn")
return (typ,expr')
-- Name is only here for proper error messages
RBound num name ->
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return (t, TBound num name t)
(t, expr') <- inferExp expr
void $ t =:= typ
return (typ, expr')
RBound num name -> do
t <- lookupVars num
return (t, TBound num name t)
RFree name -> do
M.lookup name <$> St.gets sigs >>= \case
Nothing -> throwError $ UnboundVar "RFree"
Just t -> return (t, TFree name t)
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
fvar <- fresh
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
specifyType expr1 newType
typ1' <- apply newType typ1
return $ (typ1', TApp expr1' expr2' typ1')
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
t <- lookupSigs name
return (t, TFree name t)
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RAdd expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd")
specifyType expr1 (TMono "Int")
specifyType expr2 (TMono "Int")
return (TMono "Int", TAdd expr1' expr2' (TMono "Int"))
(typ1, expr1') <- check expr1 (TMono "Int")
(_, expr2') <- check expr2 (TMono "Int")
return (typ1, TAdd expr1' expr2' typ1)
RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1
(arg_t, expr2') <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
RAbs num name expr -> do
insertVars num (TPoly "a")
arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr
newTyp <- lookupVars num
return $ (TArrow newTyp typ, TAbs num name expr' typ)
return (TArrow arg typ, TAbs num name expr' typ)
-- Aux
isInt :: Type -> Bool
isInt (TMono "Int") = True
isInt _ = False
check :: RExp -> Type -> Infer (Type, TExp)
check e t = do
(t', e') <- inferExp e
t'' <- t' =:= t
return (t'', e')
isArrow :: Type -> Bool
isArrow (TArrow _ _) = True
isArrow _ = False
isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fresh :: Infer Ident
fresh :: Infer Type
fresh = do
(Ident var) <- St.gets nextFresh
when (length var == 0) (throwError $ Default "fresh")
index <- case findIndex (== (head var)) alphabet of
Nothing -> throwError $ Default "fresh"
Just i -> return i
let nextIndex = (index + 1) `mod` 26
let newVar = Ident $ [alphabet !! nextIndex]
St.modify (\st -> st { nextFresh = newVar })
return newVar
where
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
var <- St.gets nextFresh
St.modify (\st -> st {nextFresh = succ var})
return (TPoly $ Ident (show var))
unify :: Type -> Type -> Infer Type
unify = todo
-- | Specify the type of a bound variable
-- Because in lambdas we have to assume a general type and update it
specifyType :: RExp -> Type -> Infer ()
specifyType (RBound num name) typ = do
insertVars num typ
return ()
specifyType _ _ = return ()
-- | Unify two types.
(=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
(=:=) a (TPoly _) = return a
(=:=) (TMono a) (TMono b) | a == b = return (TMono a)
(=:=) (TArrow a b) (TArrow c d) = do
t1 <- a =:= c
t2 <- b =:= d
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put ( st { vars = M.insert i t st.vars } )
St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put ( st { sigs = M.insert i t st.sigs } )
-- Have to figure out the equivalence classes for types.
-- Currently this does not support more than exact matches.
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
apply t1 t2 = throwError $ TypeMismatch "apply"
St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
@ -183,21 +150,30 @@ data Error
| UnboundVar String
| AnnotatedMismatch String
| Default String
deriving Show
deriving (Show)
-- Tests
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x"))
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String")))
-- (\x. x + 1) 1
app_lambda :: RExp
app_lambda = app lambda one
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x"))))
lambda :: RExp
lambda = RAbs 0 "x" $ add bound one
add :: RExp -> RExp -> RExp
add = RAdd
--add x = \y. x+y;
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
-- main = (\z. z+z) ((add 4) 6);
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd
four = RConst (CInt 4)
six = RConst (CInt 6)
applyAdd = (RApp (RApp add four) six)
partialAdd = RApp add four
bound = RBound 0 "x"
app :: RExp -> RExp -> RExp
app = RApp
one :: RExp
one = RConst (CInt 1)
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
arr_t = TArrow (TMono "Int") (TPoly "1")
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))