fixed bug where bound variable didn't exist in case

This commit is contained in:
sebastianselander 2023-03-06 11:27:17 +01:00
parent 778fec3dc4
commit 9c2f52f8bb
3 changed files with 79 additions and 60 deletions

View file

@ -2,16 +2,20 @@
module Renamer.Renamer where module Renamer.Renamer where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (
modify) MonadState,
import Data.List (foldl') State,
import Data.Map (Map) evalState,
import qualified Data.Map as Map gets,
import Data.Maybe (fromMaybe) modify,
import Data.Tuple.Extra (dupe) )
import Grammar.Abs import Data.List (foldl')
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
@ -20,62 +24,65 @@ rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs)
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs -- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
initNames = Map.fromList $ foldl' saveIfBind [] bs initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms (new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs' pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def renameSc _ def = pure def
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a } newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int) deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name -- | Maps old to new name
type Names = Map Ident Ident type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind) renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name (new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms (new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs (new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs') pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
ELet i e1 e2 -> do
ELet i e1 e2 -> do (new_names, e1') <- renameExp old_names e1
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2 (new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2') pure (new_names', ELet i e1' e2')
EAbs par e -> do
EAbs par e -> do
(new_names, par') <- newName old_names par (new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase e injs -> do ECase e injs -> do
(new_names, e') <- renameExp old_names e (_, e') <- renameExp old_names e
pure (new_names, ECase e' injs) (new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do
(new_names, xs') <- unzip <$> mapM (renameInj ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)

View file

@ -1,9 +1,7 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# HLINT ignore "Use mapAndUnzipM" #-} {-# HLINT ignore "Use mapAndUnzipM" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
@ -73,9 +71,7 @@ freshenData (Data (Constr name ts) constrs) = do
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = do checkData d = do
trace ("OLD: " ++ show d) return ()
d' <- freshenData d d' <- freshenData d
trace ("NEW: " ++ show d') return ()
case d' of case d' of
(Data typ@(Constr name ts) constrs) -> do (Data typ@(Constr name ts) constrs) -> do
unless unless
@ -249,7 +245,11 @@ algoW = \case
-- applySt s2 $ do -- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int") s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int") s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -280,7 +280,7 @@ algoW = \case
(s2, t2, e1') <- algoW e1 (s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do ECase caseExpr injs -> do
(s0, t0, e0') <- algoW caseExpr (_, t0, e0') <- algoW caseExpr
(injs', ts) <- unzip <$> mapM (checkInj t0) injs (injs', ts) <- unzip <$> mapM (checkInj t0) injs
case ts of case ts of
[] -> throwError "Case expression missing any matches" [] -> throwError "Case expression missing any matches"
@ -292,14 +292,15 @@ algoW = \case
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1) t1) of unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do (TArr a b, TArr c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
(TPol a, b) -> occurs a b (TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a (a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" (TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing -- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) -> (TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t' if name == name' && length t == length t'
@ -316,10 +317,17 @@ unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1)
, printTree name' , printTree name'
, "(" ++ printTree t' ++ ")" , "(" ++ printTree t' ++ ")"
] ]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] (a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-} -}
occurs :: Ident -> Type -> Infer Subst occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst occurs _ (TPol _) = return nullSubst
@ -339,7 +347,9 @@ occurs i t =
generalize :: Map Ident Poly -> Type -> Poly generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t generalize env t = Forall (S.toList $ free t S.\\ free env) t
-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type inst :: Poly -> Infer Type
inst (Forall xs t) = do inst (Forall xs t) = do
xs' <- mapM (const fresh) xs xs' <- mapM (const fresh) xs
@ -364,7 +374,8 @@ instance FreeVars Type where
free (TMono _) = mempty free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
@ -413,7 +424,8 @@ insertSig i t = modify (\st -> st {sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors st)}) insertConstr i t =
modify (\st -> st {constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
@ -421,7 +433,7 @@ insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors
checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it (args, t') <- initType caseType it
(s, t, e') <- local (\st -> st {vars = args}) (algoW expr) (_, t, e') <- local (\st -> st {vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t) return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType :: Type -> Init -> Infer (Map Ident Poly, Type)
@ -469,4 +481,4 @@ flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a] flattenType a = [a]
litType :: Literal -> Type litType :: Literal -> Type
litType (LInt i) = TMono "Int" litType (LInt _) = TMono "Int"

View file

@ -13,11 +13,11 @@ data Maybe ('a) where {
Just : 'a -> Maybe ('a) Just : 'a -> Maybe ('a)
}; };
id : 'a -> 'a ; -- id : 'a -> 'a ;
id x = x ; -- id x = x ;
main : Maybe ('a -> 'a) ; -- main : Maybe ('a -> 'a) ;
main = Just id; -- main = Just id;
-- data Either ('a 'b) where { -- data Either ('a 'b) where {
-- Left : 'a -> Either ('a 'b) -- Left : 'a -> Either ('a 'b)
@ -40,11 +40,11 @@ main = Just id;
-- Left y => Nothing ; -- Left y => Nothing ;
-- Right x => Just x -- Right x => Just x
-- }; -- };
--
-- -- Bug. f not included in the case-expression context -- Bug. f not included in the case-expression context
-- fmap : ('a -> 'b) -> Maybe ('a) -> Maybe ('b) ; fmap : ('a -> 'b) -> Maybe ('a) -> Maybe ('b) ;
-- fmap f x = fmap f x =
-- case x of { case x of {
-- Just x => Just (f x) ; Just x => Just (f x) ;
-- Nothing => Nothing Nothing => Nothing
-- } }