fixed bug where bound variable didn't exist in case

This commit is contained in:
sebastianselander 2023-03-06 11:27:17 +01:00
parent 778fec3dc4
commit 9c2f52f8bb
3 changed files with 79 additions and 60 deletions

View file

@ -1,9 +1,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
@ -73,9 +71,7 @@ freshenData (Data (Constr name ts) constrs) = do
checkData :: Data -> Infer ()
checkData d = do
trace ("OLD: " ++ show d) return ()
d' <- freshenData d
trace ("NEW: " ++ show d') return ()
case d' of
(Data typ@(Constr name ts) constrs) -> do
unless
@ -249,7 +245,11 @@ algoW = \case
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -280,7 +280,7 @@ algoW = \case
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do
(s0, t0, e0') <- algoW caseExpr
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- unzip <$> mapM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
@ -292,14 +292,15 @@ algoW = \case
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1) t1) of
unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
(TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t'
@ -316,10 +317,17 @@ unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1)
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
(a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
@ -339,7 +347,9 @@ occurs i t =
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones.
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
@ -364,7 +374,8 @@ instance FreeVars Type where
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a
free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply sub t = do
@ -413,7 +424,8 @@ insertSig i t = modify (\st -> st {sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors st)})
insertConstr i t =
modify (\st -> st {constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
@ -421,7 +433,7 @@ insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(s, t, e') <- local (\st -> st {vars = args}) (algoW expr)
(_, t, e') <- local (\st -> st {vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
@ -469,4 +481,4 @@ flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt i) = TMono "Int"
litType (LInt _) = TMono "Int"