fixed bug where bound variable didn't exist in case

This commit is contained in:
sebastianselander 2023-03-06 11:27:17 +01:00
parent 778fec3dc4
commit 9c2f52f8bb
3 changed files with 79 additions and 60 deletions

View file

@ -3,16 +3,20 @@
module Renamer.Renamer where module Renamer.Renamer where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (
modify) MonadState,
State,
evalState,
gets,
modify,
)
import Data.List (foldl') import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import Data.Map qualified as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe) import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
@ -28,9 +32,8 @@ rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs)
pure . DBind $ Bind name t name parms' rhs' pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def renameSc _ def = pure def
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a } newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int) deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name -- | Maps old to new name
@ -46,36 +49,40 @@ renameLocalBind old_names (Bind name t _ parms rhs) = do
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1)) ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2') pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
ELet i e1 e2 -> do ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1 (new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2 (new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2') pure (new_names', ELet i e1' e2')
EAbs par e -> do EAbs par e -> do
(new_names, par') <- newName old_names par (new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e') pure (new_names', EAbs par' e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) pure (new_names, EAnn e' t)
ECase e injs -> do ECase e injs -> do
(new_names, e') <- renameExp old_names e (_, e') <- renameExp old_names e
pure (new_names, ECase e' injs) (new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do
(new_names, xs') <- unzip <$> mapM (renameInj ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> Ident -> Rn (Names, Ident)

View file

@ -1,9 +1,7 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# HLINT ignore "Use mapAndUnzipM" #-} {-# HLINT ignore "Use mapAndUnzipM" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
@ -73,9 +71,7 @@ freshenData (Data (Constr name ts) constrs) = do
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = do checkData d = do
trace ("OLD: " ++ show d) return ()
d' <- freshenData d d' <- freshenData d
trace ("NEW: " ++ show d') return ()
case d' of case d' of
(Data typ@(Constr name ts) constrs) -> do (Data typ@(Constr name ts) constrs) -> do
unless unless
@ -249,7 +245,11 @@ algoW = \case
-- applySt s2 $ do -- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int") s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int") s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
@ -280,7 +280,7 @@ algoW = \case
(s2, t2, e1') <- algoW e1 (s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do ECase caseExpr injs -> do
(s0, t0, e0') <- algoW caseExpr (_, t0, e0') <- algoW caseExpr
(injs', ts) <- unzip <$> mapM (checkInj t0) injs (injs', ts) <- unzip <$> mapM (checkInj t0) injs
case ts of case ts of
[] -> throwError "Case expression missing any matches" [] -> throwError "Case expression missing any matches"
@ -292,14 +292,15 @@ algoW = \case
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1) t1) of unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do (TArr a b, TArr c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
(TPol a, b) -> occurs a b (TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a (a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" (TMono a, TMono b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing -- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) -> (TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t' if name == name' && length t == length t'
@ -316,10 +317,17 @@ unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1)
, printTree name' , printTree name'
, "(" ++ printTree t' ++ ")" , "(" ++ printTree t' ++ ")"
] ]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] (a, b) ->
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
-} -}
occurs :: Ident -> Type -> Infer Subst occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst occurs _ (TPol _) = return nullSubst
@ -339,7 +347,9 @@ occurs i t =
generalize :: Map Ident Poly -> Type -> Poly generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t generalize env t = Forall (S.toList $ free t S.\\ free env) t
-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type inst :: Poly -> Infer Type
inst (Forall xs t) = do inst (Forall xs t) = do
xs' <- mapM (const fresh) xs xs' <- mapM (const fresh) xs
@ -364,7 +374,8 @@ instance FreeVars Type where
free (TMono _) = mempty free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b free (TArr a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a free (TConstr (Constr _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
@ -413,7 +424,8 @@ insertSig i t = modify (\st -> st {sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors st)}) insertConstr i t =
modify (\st -> st {constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
@ -421,7 +433,7 @@ insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors
checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it (args, t') <- initType caseType it
(s, t, e') <- local (\st -> st {vars = args}) (algoW expr) (_, t, e') <- local (\st -> st {vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t) return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType :: Type -> Init -> Infer (Map Ident Poly, Type)
@ -469,4 +481,4 @@ flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a] flattenType a = [a]
litType :: Literal -> Type litType :: Literal -> Type
litType (LInt i) = TMono "Int" litType (LInt _) = TMono "Int"

View file

@ -13,11 +13,11 @@ data Maybe ('a) where {
Just : 'a -> Maybe ('a) Just : 'a -> Maybe ('a)
}; };
id : 'a -> 'a ; -- id : 'a -> 'a ;
id x = x ; -- id x = x ;
main : Maybe ('a -> 'a) ; -- main : Maybe ('a -> 'a) ;
main = Just id; -- main = Just id;
-- data Either ('a 'b) where { -- data Either ('a 'b) where {
-- Left : 'a -> Either ('a 'b) -- Left : 'a -> Either ('a 'b)
@ -40,11 +40,11 @@ main = Just id;
-- Left y => Nothing ; -- Left y => Nothing ;
-- Right x => Just x -- Right x => Just x
-- }; -- };
--
-- -- Bug. f not included in the case-expression context -- Bug. f not included in the case-expression context
-- fmap : ('a -> 'b) -> Maybe ('a) -> Maybe ('b) ; fmap : ('a -> 'b) -> Maybe ('a) -> Maybe ('b) ;
-- fmap f x = fmap f x =
-- case x of { case x of {
-- Just x => Just (f x) ; Just x => Just (f x) ;
-- Nothing => Nothing Nothing => Nothing
-- } }