Fixed bug in HM, fixed and reimported tests.

This commit is contained in:
sebastian 2023-05-10 23:54:31 +02:00
parent c5fbd70756
commit 49ef3f9f7c
4 changed files with 186 additions and 227 deletions

View file

@ -35,7 +35,13 @@ bidm FILE:
cabal run language -- -d -t bi -m {{FILE}} cabal run language -- -d -t bi -m {{FILE}}
hmp FILE: hmp FILE:
cabal run language -- -t hm -d -p {{FILE}} cabal run language -- -t hm -p {{FILE}}
bip FILE: bip FILE:
cabal run language -- -t bi -p {{FILE}} cabal run language -- -t bi -p {{FILE}}
hmdp FILE:
cabal run language -- -t hm -d -p {{FILE}}
bidp FILE:
cabal run language -- -t bi -d -p {{FILE}}

View file

@ -1,9 +1,9 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Desugar.Desugar (desugar) where module Desugar.Desugar (desugar) where
import Grammar.Abs import Grammar.Abs
{- {-
@ -17,21 +17,21 @@ desugar (Program defs) = Program (map desugarDef defs)
desugarVarName :: VarName -> LIdent desugarVarName :: VarName -> LIdent
desugarVarName (VSymbol (Symbol i)) = LIdent $ fixName i desugarVarName (VSymbol (Symbol i)) = LIdent $ fixName i
desugarVarName (VIdent i) = i desugarVarName (VIdent i) = i
desugarDef :: Def -> Def desugarDef :: Def -> Def
desugarDef = \case desugarDef = \case
DBind b -> DBind (desugarBind b) DBind b -> DBind (desugarBind b)
DSig sig -> DSig (desugarSig sig) DSig sig -> DSig (desugarSig sig)
DData d -> DData (desugarData d) DData d -> DData (desugarData d)
desugarBind :: Bind -> Bind desugarBind :: Bind -> Bind
desugarBind (BindS name args e) = Bind (desugarVarName name) args (desugarExp e) desugarBind (BindS name args e) = Bind (desugarVarName name) args (desugarExp e)
desugarBind (Bind name args e) = Bind name args (desugarExp e) desugarBind (Bind name args e) = Bind name args (desugarExp e)
desugarSig :: Sig -> Sig desugarSig :: Sig -> Sig
desugarSig (SigS ident typ) = Sig (desugarVarName ident) (desugarType typ) desugarSig (SigS ident typ) = Sig (desugarVarName ident) (desugarType typ)
desugarSig (Sig ident typ) = Sig ident (desugarType typ) desugarSig (Sig ident typ) = Sig ident (desugarType typ)
desugarData :: Data -> Data desugarData :: Data -> Data
desugarData (Data typ injs) = Data (desugarType typ) (map desugarInj injs) desugarData (Data typ injs) = Data (desugarType typ) (map desugarInj injs)
@ -45,7 +45,7 @@ desugarType = \case
let (name : tvars) = flatten t1 ++ [t2] let (name : tvars) = flatten t1 ++ [t2]
in case name of in case name of
TIdent ident -> TData ident (map desugarType tvars) TIdent ident -> TData ident (map desugarType tvars)
_ -> error "desugarType is not implemented correctly" _ -> error "desugarType is not implemented correctly, or the user made a mistake"
TLit l -> TLit l TLit l -> TLit l
TVar v -> TVar v TVar v -> TVar v
(TAll i t) -> TAll i (desugarType t) (TAll i t) -> TAll i (desugarType t)
@ -55,26 +55,26 @@ desugarType = \case
where where
flatten :: Type -> [Type] flatten :: Type -> [Type]
flatten (TApp a b) = flatten a <> flatten b flatten (TApp a b) = flatten a <> flatten b
flatten a = [a] flatten a = [a]
desugarInj :: Inj -> Inj desugarInj :: Inj -> Inj
desugarInj (Inj ident typ) = Inj ident (desugarType typ) desugarInj (Inj ident typ) = Inj ident (desugarType typ)
desugarExp :: Exp -> Exp desugarExp :: Exp -> Exp
desugarExp = \case desugarExp = \case
EApp e1 e2 -> EApp (desugarExp e1) (desugarExp e2) EApp e1 e2 -> EApp (desugarExp e1) (desugarExp e2)
EAdd e1 e2 -> EAdd (desugarExp e1) (desugarExp e2) EAdd e1 e2 -> EAdd (desugarExp e1) (desugarExp e2)
EAbs i e -> EAbs i (desugarExp e) EAbs i e -> EAbs i (desugarExp e)
-- EAbsS pat e -> EAbs (LIdent "$zz$") (ECase (EVar "$zz$") [Branch (desugarPattern pat) (desugarExp e)]) -- EAbsS pat e -> EAbs (LIdent "$zz$") (ECase (EVar "$zz$") [Branch (desugarPattern pat) (desugarExp e)])
ELet b e -> ELet (desugarBind b) (desugarExp e) ELet b e -> ELet (desugarBind b) (desugarExp e)
ECase e br -> ECase (desugarExp e) (map desugarBranch br) ECase e br -> ECase (desugarExp e) (map desugarBranch br)
EAnn e t -> EAnn (desugarExp e) (desugarType t) EAnn e t -> EAnn (desugarExp e) (desugarType t)
EVarS (VSymbol (Symbol symb)) -> EVar (LIdent $ fixName symb) EVarS (VSymbol (Symbol symb)) -> EVar (LIdent $ fixName symb)
EVarS (VIdent (LIdent ident)) -> EVar $ LIdent $ fixName ident EVarS (VIdent (LIdent ident)) -> EVar $ LIdent $ fixName ident
EVar (LIdent i) -> EVar (LIdent $ fixName i) EVar (LIdent i) -> EVar (LIdent $ fixName i)
ELit (LString str) -> toList str ELit (LString str) -> toList str
ELit l -> ELit l ELit l -> ELit l
EInj i -> EInj i EInj i -> EInj i
toList :: String -> Exp toList :: String -> Exp
toList = foldr (EApp . EApp (EInj (UIdent "Cons")) . ELit . LChar) (EInj (UIdent "Nil")) toList = foldr (EApp . EApp (EInj (UIdent "Cons")) . ELit . LChar) (EInj (UIdent "Nil"))
@ -84,14 +84,14 @@ desugarBranch (Branch p e) = Branch (desugarPattern p) (desugarExp e)
desugarPattern :: Pattern -> Pattern desugarPattern :: Pattern -> Pattern
desugarPattern = \case desugarPattern = \case
PVar ident -> PVar ident PVar ident -> PVar ident
PLit lit -> PLit (desugarLit lit) PLit lit -> PLit (desugarLit lit)
PCatch -> PCatch PCatch -> PCatch
PEnum ident -> PEnum ident PEnum ident -> PEnum ident
PInj ident patterns -> PInj ident (map desugarPattern patterns) PInj ident patterns -> PInj ident (map desugarPattern patterns)
desugarLit :: Lit -> Lit desugarLit :: Lit -> Lit
desugarLit (LInt i) = LInt i desugarLit (LInt i) = LInt i
desugarLit (LChar c) = LChar c desugarLit (LChar c) = LChar c
desugarLit (LString c) = LString c desugarLit (LString c) = LString c
@ -120,4 +120,4 @@ fixName = concatMap mapSymbols
':' -> "$semicolon$" ':' -> "$semicolon$"
'[' -> "$lbracket$" '[' -> "$lbracket$"
']' -> "$rbracket$" ']' -> "$rbracket$"
c -> c : "" c -> c : ""

View file

@ -1,33 +1,32 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-} {-# LANGUAGE QualifiedDo #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, maybeToRightM, typeof, unzip4) import Auxiliary (int, maybeToRightM, typeof, unzip4)
import qualified Auxiliary as Aux import Auxiliary qualified as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer import Control.Monad.Writer
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace, traceShow) import Grammar.Abs
import Grammar.Abs import Grammar.Print (printTree)
import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr (T, T')
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr qualified as T
import TypeChecker.TypeCheckerIr (T, T')
{- {-
TODO TODO
@ -42,7 +41,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
where where
onLeft :: (Error -> String) -> Either Error a -> Either String a onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
@ -69,13 +68,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs
replace :: Map T.Ident T.Ident -> Type -> Type replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts) replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t) Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def Nothing -> def
replace _ t = t replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)] bindCount :: [Def] -> Infer [(Int, Def)]
@ -129,7 +128,7 @@ preRun (x : xs) = case x of
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where where
-- Check if function body / signature has been declared already -- Check if function body / signature has been declared already
@ -151,11 +150,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident] freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty freeOrdered _ = mempty
-- Much cleaner implementation, unfortunately one minor bug -- Much cleaner implementation, unfortunately one minor bug
-- checkBind :: Bind -> Infer (T.Bind' Type) -- checkBind :: Bind -> Infer (T.Bind' Type)
@ -193,14 +192,11 @@ checkBind (Bind name args e) = do
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just typSig) -> do Just (Just typSig) -> do
env <- asks vars
let genInfSig = generalize mempty infSig let genInfSig = generalize mempty infSig
trace "\n\n" pure ()
trace ("genInfSig: " ++ printTree genInfSig) pure ()
trace ("typSig: " ++ printTree typSig ++ "\n\n") pure ()
sub <- genInfSig `unify` typSig sub <- genInfSig `unify` typSig
--b <- (genInfSig <<= typSig) b <- genInfSig <<= typSig
unless True unless
b
( throwError $ ( throwError $
Error Error
( Aux.do ( Aux.do
@ -231,7 +227,7 @@ checkData err@(Data typ injs) = do
pure (name, tvars') pure (name, tvars')
_ -> _ ->
uncatchableErr $ uncatchableErr $
unwords ["Bad data type definition: ", printTree typ] unwords ["Bad data type definition: ", show typ]
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m () checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
checkInj (Inj c inj_typ) name tvars checkInj (Inj c inj_typ) name tvars
@ -262,11 +258,11 @@ checkInj (Inj c inj_typ) name tvars
toTVar :: Type -> Either Error TVar toTVar :: Type -> Either Error TVar
toTVar = \case toTVar = \case
TVar tvar -> pure tvar TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable" _ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T' T.Exp' Type) inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do inferExp e = do
@ -279,7 +275,7 @@ class CollectTVars a where
instance CollectTVars Exp where instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty collectTVars _ = S.empty
instance CollectTVars Type where instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -296,17 +292,17 @@ algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
sub1 <- unify t t' sub1 <- unify t' t
sub2 <- unify t' t b <- t' <<= t
b <- (apply sub1 t <<= apply sub2 t') unless
unless b b
( uncatchableErr $ Aux.do ( uncatchableErr $ Aux.do
"Annotated type" "Annotated type"
quote $ printTree t quote $ printTree t
"does not match inferred type" "does not match inferred type"
quote $ printTree t' quote $ printTree t'
) )
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub1 `compose` sub0
return (comp, (apply comp e', t)) return (comp, (apply comp e', t))
-- \| ------------------ -- \| ------------------
@ -605,12 +601,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> Type -> Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
@ -640,11 +636,10 @@ fresh = do
modify (\st -> st{count = succ (count st)}) modify (\st -> st{count = succ (count st)})
return $ TVar $ MkTVar $ LIdent $ show n return $ TVar $ MkTVar $ LIdent $ show n
-- Is the left a subtype of the right -- Is the left more general than the right
(<<=) :: Type -> Type -> Infer Bool (<<=) :: Type -> Type -> Infer Bool
(<<=) a b = case (a, b) of (<<=) a b = case (a, b) of
(TVar a, TVar b) -> return $ a == b (TVar _, _) -> return True
(TVar a, _) -> return True
(TFun a b, TFun c d) -> do (TFun a b, TFun c d) -> do
bfirst <- a <<= c bfirst <- a <<= c
bsecond <- b <<= d bsecond <- b <<= d
@ -652,37 +647,43 @@ fresh = do
(TData n1 ts1, TData n2 ts2) -> do (TData n1 ts1, TData n2 ts2) -> do
b <- and <$> zipWithM (<<=) ts1 ts2 b <- and <$> zipWithM (<<=) ts1 ts2
return (b && n1 == n2 && length ts1 == length ts2) return (b && n1 == n2 && length ts1 == length ts2)
(t1@(TAll _ _ ), t2) -> let (tvars1, t1') = gatherTVars [] t1 (t1@(TAll _ _), t2) ->
(tvars2, t2') = gatherTVars [] t2 let (tvars1, t1') = gatherTVars [] t1
in go (tvars1 ++ tvars2) t1 t2 (tvars2, t2') = gatherTVars [] t2
(t1, t2@(TAll _ _)) -> let (tvars1, t1') = gatherTVars [] t1 in go (tvars1 ++ tvars2) t1' t2'
(tvars2, t2') = gatherTVars [] t2 (t1, t2@(TAll _ _)) ->
in go (tvars1 ++ tvars2) t1' t2' let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2'
(t1, t2) -> return $ t1 == t2 (t1, t2) -> return $ t1 == t2
where where
go :: [TVar] -> Type -> Type -> Infer Bool go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do go tvars t1 t2 = do
-- probably not necessary
freshies <- mapM (const fresh) tvars freshies <- mapM (const fresh) tvars
let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies let sub = M.fromList $ zip [coerce x | (MkTVar x) <- tvars] freshies
let t1' = apply sub t1 let t1' = apply sub t1
let t2' = apply sub t2 let t2' = apply sub t2
trace ("t1': " ++ printTree t1') pure () let alph = execState (alpha t1' t2') mempty
trace ("t2': " ++ printTree t2') pure () return $ apply alph t1' == t2'
t1' <<= t2'
{-
Renaming: a -> b -> a and c -> d -> c
gives 0 -> 1 -> 0 and -> 2 -> 3 -> 2
They have to be given the same name. Alpha-renaming in the subtype check is done incorrectly
-}
-- Pre-condition: All TAlls are outermost -- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type) gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) = gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
let (tvars', t') = gatherTVars (tvar : tvars) t
in (tvars', t')
gatherTVars tvars t = (tvars, t) gatherTVars tvars t = (tvars, t)
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map T.Ident Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar (MkTVar (LIdent i)), t2) -> do
m <- get
put (M.insert (coerce i) t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
_ -> return ()
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
@ -716,15 +717,15 @@ instance SubstType Type where
TLit _ -> t TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a) Nothing -> TVar (MkTVar $ coerce a)
Just t -> t Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t) Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a) TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar $ coerce a) Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t Just t -> t
instance FreeVars (Map T.Ident Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
@ -766,10 +767,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where instance SubstType (T.Pattern' Type) where
apply s = \case apply s = \case
T.PVar iden -> T.PVar iden T.PVar iden -> T.PVar iden
T.PLit lit -> T.PLit lit T.PLit lit -> T.PLit lit
T.PInj i ps -> T.PInj i $ apply s ps T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t) apply s (p, t) = (apply s p, apply s t)
@ -810,11 +811,11 @@ withBindings xs =
-- | Run the monadic action with a pattern -- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
withPattern (p, t) ma = case p of withPattern (p, t) ma = case p of
T.PVar x -> withBinding x t ma T.PVar x -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma T.PLit _ -> ma
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer () insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -839,11 +840,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: Type -> Int typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1 typeLength _ = 1
{- | Catch an error if possible and add the given {- | Catch an error if possible and add the given
expression as addition to the error message expression as addition to the error message
@ -926,11 +927,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show) deriving (Show)
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident , declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)
@ -956,12 +957,3 @@ quote s = "'" ++ s ++ "'"
letters :: [T.Ident] letters :: [T.Ident]
letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z'] letters = map T.Ident $ [1 ..] >>= flip replicateM ['a' .. 'z']
{-
first = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "b"))))))
second = TAll (MkTVar (LIdent "a")) (TAll (MkTVar (LIdent "b")) (TFun (TVar (MkTVar (LIdent "a"))) (TFun (TVar (MkTVar (LIdent "b"))) (TVar (MkTVar (LIdent "a"))))))
-}

View file

@ -20,7 +20,6 @@ import TypeChecker.TypeCheckerIr (Program)
testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do
sequence_ goods sequence_ goods
sequence_ bads sequence_ bads
sequence_ bes
goods = goods =
[ testSatisfy [ testSatisfy
@ -55,6 +54,35 @@ goods =
"};" "};"
) )
ok ok
, testSatisfy
"A basic arithmetic function should be able to be inferred"
( D.do
"plusOne x = x + 1 ;"
"main x = plusOne x ;"
)
ok
, testSatisfy
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
ok
, testSatisfy
"length function on int list infers correct signature"
( D.do
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
)
ok
] ]
bads = bads =
@ -121,97 +149,38 @@ bads =
" };" " };"
) )
bad bad
-- FIXME FAILING TEST , -- FIXME FAILING TEST
-- , testSatisfy testSatisfy
-- "id with incorrect signature" "id with incorrect signature"
-- ( D.do
-- "id : a -> b;"
-- "id x = x;"
-- )
-- bad
-- FIXME FAILING TEST
-- , testSatisfy
-- "incorrect signature on const"
-- ( D.do
-- "const : a -> b -> b;"
-- "const x y = x"
-- )
-- bad
-- FIXME FAILING TEST
-- , testSatisfy
-- "incorrect type signature on id lambda"
-- ( D.do
-- "id = ((\\x. x) : a -> b);"
-- )
-- bad
]
bes =
[ testBe
"A basic arithmetic function should be able to be inferred"
( D.do ( D.do
"plusOne x = x + 1 ;" "id : a -> b;"
"main x = plusOne x ;" "id x = x;"
) )
bad
, -- FIXME FAILING TEST
testSatisfy
"incorrect signature on const"
( D.do ( D.do
"plusOne : Int -> Int ;" "const : a -> b -> b;"
"plusOne x = x + 1 ;" "const x y = x"
"main : Int -> Int ;"
"main x = plusOne x ;"
) )
, testBe bad
"A basic arithmetic function should be able to be inferred" , -- FIXME FAILING TEST
testSatisfy
"incorrect type signature on id lambda"
( D.do ( D.do
"plusOne x = x + 1 ;" "id = ((\\x. x) : a -> b);"
)
( D.do
"plusOne : Int -> Int ;"
"plusOne x = x + 1 ;"
)
, testBe
"List of function Int -> Int functions should be inferred corretly"
( D.do
_List
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
( D.do
_List
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
, testBe
"length function on int list infers correct signature"
( D.do
"data List where "
" Nil : List"
" Cons : Int -> List -> List"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
)
( D.do
"data List where"
" Nil : List"
" Cons : Int -> List -> List"
"length : List -> Int"
"length xs = case xs of"
" Nil => 0"
" Cons _ xs => 1 + length xs"
) )
bad
] ]
testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction
testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = fmap (printTree . fst) . typecheck <=< fmap desugar . pProgram . myLexer run s = do
p <- (fmap desugar . pProgram . resolveLayout True . myLexer) s
reportForall Hm p
(printTree . fst) <$> (typecheck <=< rename <=< annotateForall) p
ok (Right _) = True ok (Right _) = True
ok (Left _) = False ok (Left _) = False
@ -221,45 +190,37 @@ bad = not . ok
-- FUNCTIONS -- FUNCTIONS
_const = D.do _const = D.do
"const : a -> b -> a ;" "const : a -> b -> a"
"const x y = x ;" "const x y = x"
_List = D.do _List = D.do
"data List a where {" "data List a where { Nil : List a; Cons : a -> List a -> List a; }"
" Nil : List a;"
" Cons : a -> List a -> List a;"
"};"
_headSig = D.do _headSig = D.do
"head : List a -> a ;" "head : List a -> a"
_head = D.do _head = D.do
"head xs = " "head xs = case xs of"
" case xs of {" " Cons x xs => x"
" Cons x xs => x ;"
" };"
_Bool = D.do _Bool = D.do
"data Bool where {" "data Bool where"
" True : Bool" " True : Bool"
" False : Bool" " False : Bool"
"};"
_not = D.do _not = D.do
"not : Bool -> Bool ;" "not : Bool -> Bool ;"
"not x = case x of {" "not x = case x of"
" True => False ;" " True => False"
" False => True ;" " False => True"
"};"
_id = "id x = x ;" _id = "id x = x ;"
_Maybe = D.do _Maybe = D.do
"data Maybe a where {" "data Maybe a where"
" Nothing : Maybe a" " Nothing : Maybe a"
" Just : a -> Maybe a" " Just : a -> Maybe a"
" };"
_fmap = D.do _fmap = D.do
"fmap f ma = case ma of {" "fmap f ma = case ma of"
" Nothing => Nothing ;" " Nothing => Nothing"
" Just a => Just (f a) ;" " Just a => Just (f a)"
"};"