adapted changes to work

This commit is contained in:
sebastianselander 2023-03-28 15:35:48 +02:00
parent 59d9be87cb
commit 7f0dab6dcb
3 changed files with 37 additions and 34 deletions

View file

@ -64,8 +64,8 @@ instance RemoveTEVar a b => RemoveTEVar [a] [b] where
instance RemoveTEVar Type T.Type where instance RemoveTEVar Type T.Type where
rmTEVar = \case rmTEVar = \case
TLit lit -> pure $ T.TLit (coerce lit) TLit lit -> pure $ T.TLit (coerce lit)
TVar tvar -> pure $ T.TVar tvar TVar tvar -> pure $ T.TVar (coerce tvar)
TData name typs -> T.TData (coerce name) <$> rmTEVar typs TData name typs -> T.TData (coerce name) <$> rmTEVar typs
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2) TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
TAll tvar t -> T.TAll tvar <$> rmTEVar t TAll tvar t -> T.TAll (coerce tvar) <$> rmTEVar t
TEVar _ -> throwError "NewType TEVar!" TEVar _ -> throwError "NewType TEVar!"

View file

@ -421,10 +421,10 @@ unify t0 t1 =
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t
(TVar (T.MkTVar a), t) -> occurs (coerce a) t (TVar (MkTVar a), t) -> occurs (coerce a) t
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t (t, TVar (MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t (a, TAll _ t) -> unify a t
(TLit a, TLit b) -> (TLit a, TLit b) ->
@ -478,7 +478,7 @@ occurs i t =
catchableErr catchableErr
( Aux.do ( Aux.do
"Occurs check failed, can't unify" "Occurs check failed, can't unify"
quote $ printTree (TVar $ T.MkTVar (coerce i)) quote $ printTree (TVar $ MkTVar (coerce i))
"with" "with"
quote $ printTree t quote $ printTree t
) )
@ -495,7 +495,7 @@ generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> Type -> Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t) go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
@ -506,7 +506,7 @@ with fresh ones.
-} -}
inst :: Type -> Infer Type inst :: Type -> Infer Type
inst = \case inst = \case
TAll (T.MkTVar bound) t -> do TAll (MkTVar bound) t -> do
fr <- fresh fr <- fresh
let s = M.singleton (coerce bound) fr let s = M.singleton (coerce bound) fr
apply s <$> inst t apply s <$> inst t
@ -528,8 +528,8 @@ fresh = do
fresh fresh
else else
if n == 0 if n == 0
then return . TVar . T.MkTVar $ LIdent [c] then return . TVar . MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ c : show n else return . TVar . MkTVar . LIdent $ c : show n
where where
next :: Char -> Char next :: Char -> Char
next 'z' = 'a' next 'z' = 'a'
@ -546,8 +546,8 @@ class FreeVars t where
instance FreeVars Type where instance FreeVars Type where
free :: Type -> Set T.Ident free :: Type -> Set T.Ident
free (TVar (T.MkTVar a)) = S.singleton (coerce a) free (TVar (MkTVar a)) = S.singleton (coerce a)
free (TAll (T.MkTVar bound) t) = free (TAll (MkTVar bound) t) =
S.singleton (coerce bound) `S.intersection` free t S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
@ -562,11 +562,11 @@ instance SubstType Type where
apply sub t = do apply sub t = do
case t of case t of
TLit a -> TLit a TLit a -> TLit a
TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (T.MkTVar $ coerce a) Nothing -> TVar (MkTVar $ coerce a)
Just t -> t Just t -> t
TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (T.MkTVar i) (apply sub t) Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a) TData name a -> TData name (apply sub a)
@ -683,7 +683,7 @@ int = TLit "Int"
char = TLit "Char" char = TLit "Char"
typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) () typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) ()
typeEq (TVar (T.MkTVar a)) t@(TVar _) = do typeEq (TVar (MkTVar a)) t@(TVar _) = do
st <- get st <- get
case M.lookup (coerce a) st of case M.lookup (coerce a) st of
Nothing -> put $ M.insert (coerce a) t st Nothing -> put $ M.insert (coerce a) t st

View file

@ -1,4 +1,4 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr ( module TypeChecker.TypeCheckerIr (
@ -6,11 +6,11 @@ module TypeChecker.TypeCheckerIr (
module TypeChecker.TypeCheckerIr, module TypeChecker.TypeCheckerIr,
) where ) where
import Data.String (IsString) import Data.String (IsString)
import Grammar.Abs (Lit (..)) import Grammar.Abs (Lit (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import Prelude qualified as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
@ -56,8 +56,8 @@ data Exp' t
| ECase (ExpT' t) [Branch' t] | ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data TVar = MkTVar Ident newtype TVar = MkTVar Ident
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id' t = (Ident, t) type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t) type ExpT' t = (Exp' t, t)
@ -105,8 +105,8 @@ instance Print t => Print (ExpT' t) where
] ]
instance Print t => Print [Bind' t] where instance Print t => Print [Bind' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Print t => Int -> [Id' t] -> Doc prtIdPs :: Print t => Int -> [Id' t] -> Doc
@ -171,13 +171,13 @@ instance Print t => Print (Branch' t) where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print t => Print [Branch' t] where instance Print t => Print [Branch' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print t => Print (Def' t) where instance Print t => Print (Def' t) where
prt i = \case prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_]) DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print t => Print (Data' t) where instance Print t => Print (Data' t) where
@ -197,12 +197,12 @@ instance Print t => Print (Pattern' t) where
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print t => Print [Def' t] where instance Print t => Print [Def' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where instance Print [Type] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where instance Print Type where
@ -213,6 +213,9 @@ instance Print Type where
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
instance Print TVar where
prt i (MkTVar ident) = prt i ident
type Program = Program' Type type Program = Program' Type
type Def = Def' Type type Def = Def' Type
type Data = Data' Type type Data = Data' Type