progress on fixing bugs

This commit is contained in:
sebastianselander 2023-03-21 17:09:03 +01:00
parent 3026a96eb7
commit 509de4415e
2 changed files with 49 additions and 29 deletions

View file

@ -7,17 +7,16 @@ module TypeChecker.TypeChecker where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.List.Extra (allSame)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import Data.Set qualified as S
import Data.Tree (flatten)
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
@ -296,10 +295,13 @@ algoW = \case
-- probably by returning substitutions in the functions used in this body
ECase caseExpr injs -> do
(sub, t, e') <- algoW caseExpr
(subst, t) <- checkCase t injs
(subst, inj_t, ret_t) <- checkCase t injs
let composition = subst `compose` sub
let t' = apply composition t
return (composition, t', T.ECase t' e' (map (\(Inj i _) -> T.Inj (i, t') e') injs))
trace ("COMPOSITION: " ++ show composition) return ()
trace ("T: " ++ show t) return ()
trace ("T': " ++ show t') return ()
return (composition, t', T.ECase t' e' (map (\(Inj i _) -> T.Inj (i, inj_t) e') injs))
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
@ -462,6 +464,10 @@ fresh = do
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, Poly)] -> m a -> m a
withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
@ -473,34 +479,39 @@ insertConstr i t =
-------- PATTERN MATCHING ---------
checkCase :: Type -> [Inj] -> Infer (Subst, Type)
checkCase :: Type -> [Inj] -> Infer (Subst, Type, Type)
checkCase expT injs = do
(injs, returns) <- mapAndUnzipM checkInj injs
(sub, _) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, (a `apply` acc))) <$> unify x acc) (nullSubst, expT) injs
t <- foldM (\acc x -> (`apply` acc) <$> unify x acc) (head returns) (tail returns)
return (sub, t)
(sub, injs_type) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc) (nullSubst, expT) injs
(_, returns_type) <- foldM (\(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc) (nullSubst, head returns) (tail returns)
return (sub, injs_type, returns_type)
{- | fst = type of init
| snd = type of expr
-}
checkInj :: Inj -> Infer (Type, Type)
checkInj (Inj it expr) = do
initT <- inferInit it
(exprT, _) <- inferExp expr
(initT, vars) <- inferInit it
let converted = map (second (Forall [])) vars
(exprT, _) <- withBindings converted (inferExp expr)
return (initT, exprT)
inferInit :: Init -> Infer Type
inferInit :: Init -> Infer (Type, [T.Id])
inferInit = \case
InitLit lit -> return $ litType lit
InitLit lit -> return (litType lit, mempty)
InitConstr fn vars -> do
gets (M.lookup fn . constructors) >>= \case
Nothing -> throwError $ "Constructor: " ++ printTree fn ++ " does not exist"
Just a -> do
let ft = init $ flattenType a
case compare (length vars) (length ft) of
EQ -> return . last $ flattenType a
case unsnoc $ flattenType a of
Nothing -> throwError "Partial pattern match not allowed"
Just (vs, ret) ->
case length vars `compare` length vs of
EQ -> do
trace ("IDS AND TYPES: " ++ show (zip vars vs)) return ()
return (ret, zip vars vs)
_ -> throwError "Partial pattern match not allowed"
InitCatch -> fresh
InitCatch -> (,mempty) <$> fresh
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b

View file

@ -1,15 +1,24 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
-- data Bool () where {
-- True : Bool ()
-- False : Bool ()
-- };
data Maybe ('a) where {
Nothing : Maybe ('a)
Just : 'a -> Maybe ('a)
};
main : Bool () -> Maybe (Bool ()) ;
main x = case x of {
True => Nothing;
False => Just 0
}
-- main : Bool () -> Maybe (Bool ()) ;
-- main x =
-- case x of {
-- True => Nothing;
-- False => Just True
-- };
fun : Maybe (_Int) -> _Int ;
fun a =
case a of {
Just b => b;
Nothing => 0
};