added new test and found another bug

This commit is contained in:
sebastianselander 2023-03-06 16:25:03 +01:00
parent 6947614fba
commit eef6fa7668
5 changed files with 210 additions and 124 deletions

View file

@ -2,7 +2,7 @@
-- double n = n + n; -- double n = n + n;
apply : ('a -> 'b -> 'c) -> 'a -> 'b -> 'c ; apply : ('a -> 'b -> 'c) -> 'a -> 'b -> 'c ;
apply f x = \y. f x y ; apply f x y = f x y ;
id : 'a -> 'a ; id : 'a -> 'a ;
id x = x ; id x = x ;
@ -11,4 +11,7 @@ add : _Int -> _Int -> _Int ;
add x y = x + y ; add x y = x + y ;
main : _Int -> _Int -> _Int ; main : _Int -> _Int -> _Int ;
main = (id add) 1 2 ; main = apply (id add) ;
idadd : _Int -> _Int -> _Int ;
idadd = id add ;

View file

@ -2,6 +2,44 @@
None known at this moment None known at this moment
main\_bug should not typecheck
```hs
apply : ('a -> 'b -> 'c) -> 'a -> 'b -> 'c ;
apply f x = \y. f x y ;
id : 'a -> 'a ;
id x = x ;
add : _Int -> _Int -> _Int ;
add x y = x + y ;
main_bug : _Int -> _Int -> _Int ;
main_bug= (apply id) add ;
idadd : _Int -> _Int -> _Int ;
idadd = id add ;
```
main\_bug should typecheck
```hs
apply : ('a -> 'b -> 'c) -> 'a -> 'b -> 'c ;
apply f x = \y. f x y ;
id : 'a -> 'a ;
id x = x ;
add : _Int -> _Int -> _Int ;
add x y = x + y ;
main_bug : _Int -> _Int -> _Int ;
main_bug = apply (id add) ;
idadd : _Int -> _Int -> _Int ;
idadd = id add ;
```
## Fixed bugs ## Fixed bugs
* 1 * 1

View file

@ -14,6 +14,7 @@ import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr ( import TypeChecker.TypeCheckerIr (
@ -300,38 +301,41 @@ algoW = \case
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (t0, t1) of unify t0 t1 = do
(TArr a b, TArr c d) -> do trace ("t0: " ++ show t0) return ()
s1 <- unify a c trace ("t1: " ++ show t1) return ()
s2 <- unify (apply s1 b) (apply s1 d) case (t0, t1) of
return $ s1 `compose` s2 (TArr a b, TArr c d) -> do
(TPol a, b) -> occurs a b s1 <- unify a c
(a, TPol b) -> occurs b a s2 <- unify (apply s1 b) (apply s1 d)
(TMono a, TMono b) -> return $ s1 `compose` s2
if a == b then return M.empty else throwError "Types do not unify" (TPol a, b) -> occurs a b
-- \| TODO: Figure out a cleaner way to express the same thing (a, TPol b) -> occurs b a
(TConstr (Constr name t), TConstr (Constr name' t')) -> (TMono a, TMono b) ->
if name == name' && length t == length t' if a == b then return M.empty else throwError "Types do not unify"
then do -- \| TODO: Figure out a cleaner way to express the same thing
xs <- zipWithM unify t t' (TConstr (Constr name t), TConstr (Constr name' t')) ->
return $ foldr compose nullSubst xs if name == name' && length t == length t'
else then do
throwError $ xs <- zipWithM unify t t'
unwords return $ foldr compose nullSubst xs
[ "Type constructor:" else
, printTree name throwError $
, "(" ++ printTree t ++ ")" unwords
, "does not match with:" [ "Type constructor:"
, printTree name' , printTree name
, "(" ++ printTree t' ++ ")" , "(" ++ printTree t ++ ")"
] , "does not match with:"
(a, b) -> , printTree name'
throwError . unwords $ , "(" ++ printTree t' ++ ")"
[ "Type:" ]
, printTree a (a, b) ->
, "can't be unified with:" throwError . unwords $
, printTree b [ "Type:"
] , printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -409,7 +413,7 @@ instance FreeVars (Map Ident Poly) where
-- | Apply substitutions to the environment. -- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st {vars = apply s (vars st)}) applySt s = local (\st -> st{vars = apply s (vars st)})
-- | Represents the empty substition set -- | Represents the empty substition set
nullSubst :: Subst nullSubst :: Subst
@ -419,21 +423,21 @@ nullSubst = M.empty
fresh :: Infer Type fresh :: Infer Type
fresh = do fresh = do
n <- gets count n <- gets count
modify (\st -> st {count = n + 1}) modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n return . TPol . Ident $ show n
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st {vars = M.insert i p (vars st)}) withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer () insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st {sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = insertConstr i t =
modify (\st -> st {constructors = M.insert i t (constructors st)}) modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
@ -441,7 +445,7 @@ insertConstr i t =
checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it (args, t') <- initType caseType it
(_, t, e') <- local (\st -> st {vars = args `M.union` vars st}) (algoW expr) (_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t) return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType :: Type -> Init -> Infer (Map Ident Poly, Type)

View file

@ -2,27 +2,33 @@
module TypeChecker.TypeCheckerIr where module TypeChecker.TypeCheckerIr where
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Functor.Identity (Identity) import Data.Functor.Identity (Identity)
import Data.Map (Map) import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..), import Grammar.Abs (
Literal (..), Type (..)) Data (..),
import Grammar.Print Ident (..),
import Prelude Init (..),
import qualified Prelude as C (Eq, Ord, Read, Show) Literal (..),
Type (..),
)
import Grammar.Print
import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
-- | A data type representing type variables -- | A data type representing type variables
data Poly = Forall [Ident] Type data Poly = Forall [Ident] Type
deriving Show deriving (Show)
newtype Ctx = Ctx { vars :: Map Ident Poly } newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env { count :: Int data Env = Env
, sigs :: Map Ident Type { count :: Int
, constructors :: Map Ident Type , sigs :: Map Ident Type
} , constructors :: Map Ident Type
}
type Error = String type Error = String
type Subst = Map Ident Type type Subst = Map Ident Type
@ -30,17 +36,17 @@ type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp data Exp
= EId Id = EId Id
| ELit Type Literal | ELit Type Literal
| ELet Bind Exp | ELet Bind Exp
| EApp Type Exp Exp | EApp Type Exp Exp
| EAdd Type Exp Exp | EAdd Type Exp Exp
| EAbs Type Id Exp | EAbs Type Id Exp
| ECase Type Exp [Inj] | ECase Type Exp [Inj]
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Inj = Inj (Init, Type) Exp data Inj = Inj (Init, Type) Exp
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
@ -54,90 +60,119 @@ data Bind = Bind Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where instance Print [Def] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x:xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
instance Print Def where instance Print Def where
prt i (DBind bind) = prt i bind prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d prt i (DData d) = prt i d
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind (t, name) rhs) = prPrec i 0 $ concatD prt i (Bind (t, name) rhs) =
[ prt 0 name prPrec i 0 $
, doc $ showString ":" concatD
, prt 1 t [ prt 0 name
, doc $ showString "=" , doc $ showString ":"
, prt 2 rhs , prt 0 t
] , doc $ showString "\n"
, prt 0 name
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i) prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD prtId i (name, t) =
[ prt 0 name prPrec i 0 $
, doc $ showString ":" concatD
, prt 0 t [ prt 0 name
] , doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD prtIdP i (name, t) =
[ doc $ showString "(" prPrec i 0 $
, prt 0 name concatD
, doc $ showString ":" [ doc $ showString "("
, prt 0 t , prt 0 name
, doc $ showString ")" , doc $ showString ":"
] , prt 0 t
, doc $ showString ")"
]
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n] EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1] ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
ELet bs e -> prPrec i 3 $ concatD ELet bs e ->
prPrec i 3 $
concatD
[ doc $ showString "let" [ doc $ showString "let"
, prt 0 bs , prt 0 bs
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
EApp _ e1 e2 -> prPrec i 2 $ concatD EApp _ e1 e2 ->
[ prt 2 e1 prPrec i 2 $
, prt 3 e2 concatD
] [ prt 2 e1
EAdd t e1 e2 -> prPrec i 1 $ concatD , prt 3 e2
[ doc $ showString "@" ]
, prt 0 t EAdd t e1 e2 ->
, prt 1 e1 prPrec i 1 $
, doc $ showString "+" concatD
, prt 2 e2 [ doc $ showString "@"
] , prt 0 t
EAbs t n e -> prPrec i 0 $ concatD , prt 1 e1
[ doc $ showString "@" , doc $ showString "+"
, prt 0 t , prt 2 e2
, doc $ showString "\\" , doc $ showString "\n"
, prtId 0 n ]
, doc $ showString "." EAbs t n e ->
, prt 0 e prPrec i 0 $
] concatD
ECase t exp injs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 injs, doc (showString "}"), doc (showString ":"), prt 0 t]) [ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtId 0 n
, doc $ showString "."
, prt 0 e
, doc $ showString "\n"
]
ECase t exp injs ->
prPrec
i
0
( concatD
[ doc (showString "case")
, prt 0 exp
, doc (showString "of")
, doc (showString "{")
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj (init,t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp]) Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where instance Print [Inj] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]

View file

@ -29,6 +29,7 @@ main = hspec $ do
infer_eid infer_eid
infer_eabs infer_eabs
infer_eapp infer_eapp
test_id_function
infer_elit = describe "algoW used on ELit" $ do infer_elit = describe "algoW used on ELit" $ do
it "infers the type mono Int" $ do it "infers the type mono Int" $ do
@ -86,7 +87,7 @@ infer_eapp = describe "algoW used on EApp" $ do
let env = Env 0 mempty mempty let env = Env 0 mempty mempty
let t = Forall [] (TPol "a") let t = Forall [] (TPol "a")
let ctx = Ctx (M.singleton (Ident (x :: String)) t) let ctx = Ctx (M.singleton (Ident (x :: String)) t)
getTypeC env ctx (EApp (EId (Ident x)) (EId (Ident x))) `shouldBe` Left "Occurs check failed" getTypeC env ctx (EApp (EId (Ident x)) (EId (Ident x))) `shouldSatisfy` isLeft
churf_id :: Bind churf_id :: Bind
churf_id = Bind "id" (TArr (TPol "a") (TPol "a")) "id" ["x"] (EId "x") churf_id = Bind "id" (TArr (TPol "a") (TPol "a")) "id" ["x"] (EId "x")
@ -95,10 +96,15 @@ churf_add :: Bind
churf_add = Bind "add" (TArr (TMono "Int") (TArr (TMono "Int") (TMono "Int"))) "add" ["x", "y"] (EAdd (EId "x") (EId "y")) churf_add = Bind "add" (TArr (TMono "Int") (TArr (TMono "Int") (TMono "Int"))) "add" ["x", "y"] (EAdd (EId "x") (EId "y"))
churf_main :: Bind churf_main :: Bind
churf_main = Bind "main" (TArr (TMono "Int") (TArr (TMono "Int") (TMono "Int"))) "main" [] (EApp (EId "id") (EId "add")) churf_main = Bind "main" (TArr (TMono "Int") (TMono "Int")) "main" [] (EApp (EApp (EId "id") (EId "add")) (ELit (LInt 0)))
test_bug :: IO () prg = Program [DBind churf_main, DBind churf_add, DBind churf_id]
test_bug = undefined
test_id_function :: SpecWith ()
test_id_function =
describe "typechecking a program with id, add and main, where id is applied to add in main" $ do
it "should succeed to find the correct type" $ do
typecheck prg `shouldSatisfy` isRight
isArrowPolyToMono :: Either Error Type -> Bool isArrowPolyToMono :: Either Error Type -> Bool
isArrowPolyToMono (Right (TArr (TPol _) (TMono _))) = True isArrowPolyToMono (Right (TArr (TPol _) (TMono _))) = True