Inference works better now. Still work to do. Should use proper library

This commit is contained in:
sebastianselander 2023-02-15 17:40:18 +01:00
parent ad3f6b7011
commit 7619e36c60
8 changed files with 66 additions and 79 deletions

1
.gitignore vendored
View file

@ -4,3 +4,4 @@ dist-newstyle
*.bak
src/Grammar
language
test_program_result

View file

@ -14,8 +14,8 @@ EAbs. Exp ::= "\\" Ident "." Exp ;
CInt. Const ::= Integer ;
CStr. Const ::= String ;
TMono. Type ::= "Mono" Ident ;
TPoly. Type ::= "Poly" Ident ;
TMono. Type1 ::= "Mono" Ident ;
TPoly. Type1 ::= "Poly" Ident ;
TArrow. Type ::= Type1 "->" Type ;
-- This doesn't seem to work so we'll have to live with ugly keywords for now
@ -30,3 +30,7 @@ coercions Exp 5 ;
comment "--" ;
comment "{-" "-}" ;
-- Adt. Adt ::= "data" UIdent "=" [Constructor] ;
-- Sum. Constructor ::= UIdent ;
-- separator Constructor "|" ;

View file

@ -29,4 +29,7 @@ test :
./language ./sample-programs/basic-4
./language ./sample-programs/basic-5
run :
cabal -v0 new-run language -- "test_program"
# EOF

View file

@ -45,5 +45,6 @@ executable language
, either
, extra
, array
, equivalence
default-language: GHC2021

View file

@ -30,4 +30,9 @@ main = getArgs >>= \case
putStrLn . show $ err
exitFailure
Right prg -> do
putStrLn ""
putStrLn . printTree $ prg
putStrLn ""
putStrLn " ----- ADT ----- "
putStrLn ""
putStrLn $ show prg

View file

@ -4,10 +4,6 @@ module TypeChecker.TypeChecker where
import Control.Monad (when, void)
import Control.Monad.Except (ExceptT, throwError, runExceptT)
import Control.Monad.Reader (ReaderT)
import qualified Control.Monad.Reader as R
import Control.Monad.Writer (WriterT)
import qualified Control.Monad.Writer as W
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Data.Functor.Identity (Identity, runIdentity)
@ -52,6 +48,7 @@ inferBind :: RBind -> Infer TBind
inferBind (RBind name e) = do
t <- inferExp e
e' <- toTExpr e
insertSigs name t
return $ TBind name t e'
toTExpr :: RExp -> Infer TExp
@ -97,33 +94,40 @@ inferExp = \case
RAnn expr typ -> do
exprT <- inferExp expr
when (not (exprT == typ || isPoly exprT)) (throwError AnnotatedMismatch)
when (not (exprT == typ || isPoly exprT)) (throwError $ AnnotatedMismatch "inferExp, RAnn")
return typ
-- Name is only here for proper error messages
RBound num name ->
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError UnboundVar
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return t
RFree name -> do
M.lookup name <$> St.gets sigs >>= \case
Nothing -> throwError UnboundVar
Nothing -> throwError $ UnboundVar "RFree"
Just t -> return t
RConst (CInt _) -> return $ TMono "Int"
RConst (CStr _) -> return $ TMono "Str"
-- Currently does not accept using a polymorphic type as the function.
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
typ1 <- inferExp expr1
typ2 <- inferExp expr2
fit typ2 typ1
cnt <- incCount
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly . Ident $ x ++ (show cnt)))
specifyType expr1 newType
apply newType typ1
_ -> apply typ2 typ1
RAdd expr1 expr2 -> do
typ1 <- inferExp expr1
typ2 <- inferExp expr2
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError TypeMismatch)
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd")
specifyType expr1 (TMono "Int")
specifyType expr2 (TMono "Int")
return (TMono "Int")
@ -147,30 +151,12 @@ isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fit :: Type -> Type -> Infer Type
fit (TArrow t1 (TArrow t2 t3)) t4
| t1 `match` t4 = return $ TArrow t2 t3
| otherwise = fit (TArrow (TArrow t1 t2) t3) t4
fit (TArrow t1 t2) t3
| t1 `match` t3 = return t2
| otherwise = throwError TypeMismatch
fit _ _ = throwError TypeMismatch
match :: Type -> Type -> Bool
match (TPoly _) (TMono _) = True
match (TMono _) (TPoly _) = True
match (TMono _) (TMono _) = True
match (TPoly _) (TPoly _) = True
match (TArrow t1 t2) (TArrow t3 t4) = match t1 t3 && match t2 t4
incCount :: Infer Int
incCount = do
st <- St.get
St.put (Ctx { vars = st.vars, sigs = st.sigs, count = succ st.count })
St.put ( st { count = succ st.count } )
return st.count
-- | Specify the type of a bound variable
-- Because in lambdas we have to assume a general type and update it
specifyType :: RExp -> Type -> Infer ()
@ -184,33 +170,48 @@ lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError UnboundVar
Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put ( st { vars = M.insert i t st.vars } )
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Nothing -> throwError UnboundVar
Nothing -> throwError $ UnboundVar "lookupSigs"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put ( Ctx { vars = M.insert i t st.vars, sigs = st.sigs } )
St.put ( st { sigs = M.insert i t st.sigs } )
union :: Type -> Type -> Infer ()
union = todo
find :: Type -> Type
find = todo
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
| otherwise = throwError $ TypeMismatch "apply"
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
todo = error "TODO in code"
data Error
= TypeMismatch
| NotNumber
| FunctionTypeMismatch
| NotFunction
| UnboundVar
| AnnotatedMismatch
| Default
= TypeMismatch String
| NotNumber String
| FunctionTypeMismatch String
| NotFunction String
| UnboundVar String
| AnnotatedMismatch String
| Default String
deriving Show
-- Tests
@ -218,4 +219,4 @@ data Error
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x"))
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String")))
fn_on_var = RAbs 0 "x" (RAbs 1 "y" (RApp (RBound 0 "x") (RBound 1 "y")))
fn_on_var = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))

View file

@ -45,6 +45,7 @@ instance Print TBind where
, prt 0 t
, doc (showString "=")
, prt 0 e
, doc (showString "\n")
]
instance Print TExp where
@ -54,38 +55,11 @@ instance Print TExp where
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD
[ doc (showString "(")
, prt 0 u
, doc (showString ":")
, prt 0 t
, doc (showString ")")
]
TFree u t -> prPrec i 3 $ concatD
[ doc (showString "(")
, prt 0 u
, doc (showString ":")
, prt 0 t
, doc (showString ")")
]
TBound _ u t -> prPrec i 3 $ concatD [ prt 0 u ]
TFree u t -> prPrec i 3 $ concatD [ prt 0 u ]
TConst c _ -> prPrec i 3 (concatD [prt 0 c])
TApp e e1 t -> prPrec i 2 $ concatD
[ doc (showString "(")
, prt 2 e
, prt 3 e1
, doc (showString ")")
, doc (showString ":")
, prt 0 t
]
TAdd e e1 t -> prPrec i 1 $ concatD
[ doc (showString "(")
, prt 1 e
, doc (showString "+")
, prt 2 e1
, doc (showString ")")
, doc (showString ":")
, prt 0 t
]
TApp e e1 t -> prPrec i 2 $ concatD [ prt 2 e , prt 3 e1 ]
TAdd e e1 t -> prPrec i 1 $ concatD [ prt 1 e , doc (showString "+") , prt 2 e1 ]
TAbs _ u e t -> prPrec i 0 $ concatD
[ doc (showString "(")
, doc (showString "\\")
@ -93,6 +67,4 @@ instance Print TExp where
, doc (showString ".")
, prt 0 e
, doc (showString ")")
, doc (showString ":")
, prt 0 t, doc (showString ".")
]

View file

@ -1 +1 @@
testType f x = f x
test f x = f x