unit tests, started on pattern matching
This commit is contained in:
parent
d23d417ff3
commit
05313652f9
9 changed files with 212 additions and 133 deletions
27
src/TypeChecker/CheckInj.hs
Normal file
27
src/TypeChecker/CheckInj.hs
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{-# OPTIONS_GHC -Wno-unused-imports #-}
|
||||
module TypeChecker.CheckInj where
|
||||
|
||||
import TypeChecker.TypeChecker
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (Infer)
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Functor.Identity (Identity, runIdentity)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
|
||||
|
||||
checkInj :: Inj -> Infer T.Inj
|
||||
checkInj (Inj it expr) = do
|
||||
(_, e') <- inferExp expr
|
||||
t' <- initType it
|
||||
return $ T.Inj (it, t') e'
|
||||
|
||||
initType :: Init -> Infer Type
|
||||
initType = undefined
|
||||
|
||||
|
|
@ -1,10 +1,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# HLINT ignore "Use traverse_" #-}
|
||||
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
|
||||
{-# HLINT ignore "Use zipWithM" #-}
|
||||
|
||||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
import Control.Monad.Except
|
||||
|
|
@ -21,23 +18,9 @@ import Data.Foldable (traverse_)
|
|||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
|
||||
Poly (..), Subst)
|
||||
|
||||
-- | A data type representing type variables
|
||||
data Poly = Forall [Ident] Type
|
||||
deriving Show
|
||||
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly
|
||||
}
|
||||
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
, dtypes :: Map Ident Type
|
||||
}
|
||||
|
||||
type Error = String
|
||||
type Subst = Map Ident Type
|
||||
|
||||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 mempty mempty
|
||||
|
|
@ -98,7 +81,11 @@ checkBind (Bind n t _ args e) = do
|
|||
(t', e') <- inferExp $ makeLambda e (reverse args)
|
||||
s <- unify t t'
|
||||
let t'' = apply s t
|
||||
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature", printTree t, "does not match body with type:", printTree t''])
|
||||
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature"
|
||||
, printTree t
|
||||
, "does not match body with inferred type:"
|
||||
, printTree t''
|
||||
])
|
||||
return $ T.Bind (n, t) [] e'
|
||||
where
|
||||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
|
|
@ -109,12 +96,17 @@ checkBind (Bind n t _ args e) = do
|
|||
typeEq :: Type -> Type -> Bool
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TConstr name a) (TConstr name' b) = if length a == length b
|
||||
then name == name' && and (zipWith typeEq a b)
|
||||
else False
|
||||
typeEq (TConstr name a) (TConstr name' b) = length a == length b
|
||||
&& name == name'
|
||||
&& and (zipWith typeEq a b)
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
isMoreGeneral :: Type -> Type -> Bool
|
||||
isMoreGeneral _ (TPol _) = True
|
||||
isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d
|
||||
isMoreGeneral a b = a == b
|
||||
|
||||
inferExp :: Exp -> Infer (Type, T.Exp)
|
||||
inferExp e = do
|
||||
(s, t, e') <- algoW e
|
||||
|
|
@ -123,24 +115,30 @@ inferExp e = do
|
|||
|
||||
replace :: Type -> T.Exp -> T.Exp
|
||||
replace t = \case
|
||||
T.ELit _ e -> T.ELit t e
|
||||
T.EId (n, _) -> T.EId (n, t)
|
||||
T.EAbs _ name e -> T.EAbs t name e
|
||||
T.EApp _ e1 e2 -> T.EApp t e1 e2
|
||||
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||
T.ELit _ e -> T.ELit t e
|
||||
T.EId (n, _) -> T.EId (n, t)
|
||||
T.EAbs _ name e -> T.EAbs t name e
|
||||
T.EApp _ e1 e2 -> T.EApp t e1 e2
|
||||
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
|
||||
|
||||
algoW :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
algoW = \case
|
||||
|
||||
-- | TODO: Reason more about this one. Could be wrong
|
||||
EAnn e t -> do
|
||||
(s1, t', e') <- algoW e
|
||||
unless (t `isMoreGeneral` t') (throwError $ unwords
|
||||
["Annotated type:"
|
||||
, printTree t
|
||||
, "does not match inferred type:"
|
||||
, printTree t' ])
|
||||
applySt s1 $ do
|
||||
s2 <- unify (apply s1 t) t'
|
||||
s2 <- unify t t'
|
||||
return (s2 `compose` s1, t, e')
|
||||
|
||||
-- | ------------------
|
||||
-- | Γ ⊢ e₀ : Int, ∅
|
||||
-- | Γ ⊢ i : Int, ∅
|
||||
|
||||
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||
|
||||
|
|
@ -159,7 +157,7 @@ algoW = \case
|
|||
case M.lookup i sig of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> do
|
||||
constr <- gets dtypes
|
||||
constr <- gets constructors
|
||||
case M.lookup i constr of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> throwError $ "Unbound variable: " ++ show i
|
||||
|
|
@ -220,9 +218,9 @@ algoW = \case
|
|||
(s2, t2, e1') <- algoW e1
|
||||
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
|
||||
|
||||
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
|
||||
ECase e0 injs -> undefined
|
||||
|
||||
-- | Unify two types producing a new substitution (constraint)
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: Type -> Type -> Infer Subst
|
||||
unify t0 t1 = case (t0, t1) of
|
||||
(TArr a b, TArr c d) -> do
|
||||
|
|
@ -235,9 +233,15 @@ unify t0 t1 = case (t0, t1) of
|
|||
-- | TODO: Figure out a cleaner way to express the same thing
|
||||
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
|
||||
then do
|
||||
xs <- sequence $ zipWith unify t t'
|
||||
xs <- zipWithM unify t t'
|
||||
return $ foldr compose nullSubst xs
|
||||
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
|
||||
else throwError $ unwords
|
||||
["Type constructor:"
|
||||
, printTree name
|
||||
, "(" ++ printTree t ++ ")"
|
||||
, "does not match with:"
|
||||
, printTree name'
|
||||
, "(" ++ printTree t' ++ ")"]
|
||||
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
|
||||
|
||||
-- | Check if a type is contained in another type.
|
||||
|
|
@ -324,4 +328,4 @@ insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
|
|||
|
||||
-- | Insert a constructor with its data type
|
||||
insertConstr :: Ident -> Type -> Infer ()
|
||||
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })
|
||||
insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) })
|
||||
|
|
|
|||
|
|
@ -1,14 +1,33 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.TypeCheckerIr
|
||||
( module Grammar.Abs
|
||||
, module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
module TypeChecker.TypeCheckerIr where
|
||||
|
||||
import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map (Map)
|
||||
import Grammar.Abs (Data (..), Ident (..), Init (..),
|
||||
Literal (..), Type (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
-- | A data type representing type variables
|
||||
data Poly = Forall [Ident] Type
|
||||
deriving Show
|
||||
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly }
|
||||
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
, constructors :: Map Ident Type
|
||||
}
|
||||
|
||||
type Error = String
|
||||
type Subst = Map Ident Type
|
||||
|
||||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
|
@ -22,6 +41,9 @@ data Exp
|
|||
| EAbs Type Id Exp
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Inj = Inj (Init, Type) Exp
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Def = DBind Bind | DData Data
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
|
|
@ -30,6 +52,10 @@ type Id = (Ident, Type)
|
|||
data Bind = Bind Id [Id] Exp
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i (DBind bind) = prt i bind
|
||||
prt i (DData d) = prt i d
|
||||
|
|
@ -41,16 +67,16 @@ instance Print Bind where
|
|||
prt i (Bind (t, name) parms rhs) = prPrec i 0 $ concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, prt 1 t
|
||||
, prtIdPs 0 parms
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
, prt 2 rhs
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
|
||||
|
||||
prtIdPs :: Int -> [Id] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue