fixed a substitution bug where ap was incorrectly inferred.

also added cleaner fresh variables
This commit is contained in:
sebastian 2023-03-25 22:40:15 +01:00
parent 975dd34063
commit ac43af8110
6 changed files with 287 additions and 194 deletions

View file

@ -19,6 +19,7 @@ import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
@ -31,8 +32,7 @@ import TypeChecker.TypeCheckerIr (
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 mempty mempty
initEnv = Env 0 'a' mempty mempty mempty
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst) . run . inferExp
@ -82,39 +82,39 @@ retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
-- Type check the program twice to produce all top-level types in the first pass through
_ <- checkDef bs
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
bs' <- checkDef bs
return $ T.Program bs'
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DSig (Sig n t) -> do
collect (collectTypeVars t)
gets (M.member (coerce n) . sigs)
>>= flip
when
( throwError $
"Duplicate signatures for function '"
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind (Bind name args e) = do
@ -171,6 +171,23 @@ inferExp e = do
let subbed = apply s t
return $ second (const subbed) (e', t)
class CollectTVars a where
collectTypeVars :: a -> Set T.Ident
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
collectTypeVars (TAll _ t) = collectTypeVars t
collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2
collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts
collectTypeVars _ = S.empty
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where
toNew :: a -> b
@ -321,8 +338,9 @@ algoW = \case
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
let t' = apply comp ret_t
return (comp, apply comp (T.ECase (e', t) injs, t'))
trace ("EXPR: " ++ show (apply comp t)) pure ()
trace ("CASES: " ++ show (apply comp ret_t)) pure ()
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
@ -335,7 +353,7 @@ unify t0 t1 = do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
-------------------------------------------------------------------
@ -517,9 +535,24 @@ nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type
fresh = do
c <- gets nextChar
n <- gets count
modify (\st -> st{count = n + 1})
return . T.TVar . T.MkTVar . T.Ident $ show n
taken <- gets takenTypeVars
if c == 'z'
then do
modify (\st -> st{count = succ (count st), nextChar = 'a'})
else modify (\st -> st{nextChar = next (nextChar st)})
if coerce [c] `S.member` taken
then do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
@ -567,7 +600,7 @@ inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of
@ -586,15 +619,36 @@ inferPattern = \case
let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns
unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns))
unless
(length patterns == numArgs)
( throwError $
"The constructor '"
++ printTree constr
++ "'"
++ " should have "
++ show numArgs
++ " arguments but has been given "
++ show (length patterns)
)
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0")
return (T.PEnum $ coerce p, t)
unless
(typeLength t == 1)
( throwError $
"The constructor '"
++ printTree p
++ "'"
++ " should have "
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
@ -632,4 +686,9 @@ exprErr ma exp =
catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], [])
unzip4 =
foldl'
( \(as, bs, cs, ds) (a, b, c, d) ->
(as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])
)
([], [], [], [])