Implemented potential fix for one of the bugs

This commit is contained in:
sebastianselander 2023-03-05 14:34:39 +01:00
parent fe63fa6215
commit 778fec3dc4
3 changed files with 262 additions and 195 deletions

View file

@ -29,6 +29,7 @@ id x = x ;
main : Maybe ('a -> 'a) ; main : Maybe ('a -> 'a) ;
main = Just id; main = Just id;
``` ```
UPDATE: Might have found a fix. Need to be tested.
### The function f is not carried into the case-expression ### The function f is not carried into the case-expression

View file

@ -1,40 +1,42 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-} {-# HLINT ignore "Use mapAndUnzipM" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Functor.Identity (runIdentity) import Data.Foldable (traverse_)
import Data.List (foldl') import Data.Functor.Identity (runIdentity)
import Data.Map (Map) import Data.List (foldl')
import qualified Data.Map as M import Data.Map (Map)
import Data.Set (Set) import Data.Map qualified as M
import qualified Data.Set as S import Data.Set (Set)
import Data.Set qualified as S
import Data.Foldable (traverse_) import Debug.Trace (trace)
import Debug.Trace (trace) import Grammar.Abs
import Grammar.Abs import Grammar.Print (printTree)
import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr (
import qualified TypeChecker.TypeCheckerIr as T Ctx (..),
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, Env (..),
Poly (..), Subst) Error,
Infer,
{- BUGS TODO: Poly (..),
Occurs fails on data types, e.g declared Maybe a, used in fn as Maybe (a -> a) Subst,
-} )
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 mempty mempty initEnv = Env 0 mempty mempty
runPretty :: Exp -> Either Error String runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst). run . inferExp runPretty = fmap (printTree . fst) . run . inferExp
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runC initEnv initCtx run = runC initEnv initCtx
@ -45,24 +47,60 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg typecheck = run . checkPrg
{- | Start by freshening the type variable of data types to avoid clash with
other user defined polymorphic types
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ -> error "Bug: implementation of fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
where
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol a) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) -> TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) = Constructor name (freshenType iden t)
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = case d of checkData d = do
(Data typ@(Constr name ts) constrs) -> do trace ("OLD: " ++ show d) return ()
unless (all isPoly ts) (throwError $ unwords ["Data type incorrectly declared"]) d' <- freshenData d
traverse_ (\(Constructor name' t') trace ("NEW: " ++ show d') return ()
-> if TConstr typ == retType t' case d' of
then insertConstr name' t' else (Data typ@(Constr name ts) constrs) -> do
throwError $ unless
unwords (all isPoly ts)
[ "return type of constructor:" (throwError $ unwords ["Data type incorrectly declared"])
, printTree name traverse_
, "with type:" ( \(Constructor name' t') ->
, printTree (retType t') if TConstr typ == retType t'
, "does not match data: " then insertConstr name' t'
, printTree typ]) constrs else
throwError $
unwords
[ "return type of constructor:"
, printTree name
, "with type:"
, printTree (retType t')
, "does not match data: "
, printTree typ
]
)
constrs
retType :: Type -> Type retType :: Type -> Type
retType (TArr _ t2) = retType t2 retType (TArr _ t2) = retType t2
retType a = a retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
@ -71,54 +109,62 @@ checkPrg (Program bs) = do
where where
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x:xs) = case x of preRun (x : xs) = case x of
DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return [] checkDef [] = return []
checkDef (x:xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs) (DData d) -> fmap (T.DData d :) (checkDef xs)
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args) (t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t' s <- unify t t'
let t'' = apply s t let t'' = apply s t
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature" unless
, printTree t (t `typeEq` t'')
, "does not match body with inferred type:" ( throwError $
, printTree t'' unwords
]) [ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e' return $ T.Bind (n, t) e'
where where
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs) makeLambda = foldl (flip EAbs)
-- | Check if two types are considered equal {- | Check if two types are considered equal
-- For the purpose of the algorithm two polymorphic types are always considered equal For the purpose of the algorithm two polymorphic types are always considered
equal
-}
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = length a == length b typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
&& name == name' length a == length b
&& and (zipWith typeEq a b) && name == name'
typeEq (TPol _) (TPol _) = True && and (zipWith typeEq a b)
typeEq _ _ = False typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: Type -> Type -> Bool isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq (TArr a b) (TArr c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
= n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2) n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
isPoly (TPol _) = True isPoly (TPol _) = True
isPoly _ = False isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp) inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do inferExp e = do
@ -128,57 +174,59 @@ inferExp e = do
replace :: Type -> T.Exp -> T.Exp replace :: Type -> T.Exp -> T.Exp
replace t = \case replace t = \case
T.ELit _ e -> T.ELit t e T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t) T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2 T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2 T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs T.ECase _ expr injs -> T.ECase t expr injs
algoW :: Exp -> Infer (Subst, Type, T.Exp) algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
-- | TODO: Reason more about this one. Could be wrong
EAnn e t -> do EAnn e t -> do
(s1, t', e') <- algoW e (s1, t', e') <- algoW e
unless (t `isMoreSpecificOrEq` t') (throwError $ unwords unless
["Annotated type:" (t `isMoreSpecificOrEq` t')
, printTree t ( throwError $
, "does not match inferred type:" unwords
, printTree t' ]) [ "Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t'
]
)
applySt s1 $ do applySt s1 $ do
s2 <- unify t t' s2 <- unify t t'
return (s2 `compose` s1, t, e') return (s2 `compose` s1, t, e')
-- | ------------------ -- \| ------------------
-- | Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- \| x : σ ∈ Γ τ = inst(σ)
-- | x : σ ∈ Γ τ = inst(σ) -- \| ----------------------
-- | ---------------------- -- \| Γ ⊢ x : τ, ∅
-- | Γ ⊢ x : τ, ∅
EId i -> do EId i -> do
var <- asks vars var <- asks vars
case M.lookup i var of case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x)) Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup i sig of case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t)) Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do Nothing -> do
constr <- gets constructors constr <- gets constructors
case M.lookup i constr of case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t)) Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i Nothing -> throwError $ "Unbound variable: " ++ show i
-- | τ = newvar Γ, x : τ ⊢ e : τ', S -- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- | --------------------------------- -- \| ---------------------------------
-- | Γ ⊢ w λx. e : Sτ → τ', S -- \| Γ ⊢ w λx. e : Sτ → τ', S
EAbs name e -> do EAbs name e -> do
fr <- fresh fr <- fresh
@ -188,11 +236,11 @@ algoW = \case
let newArr = TArr varType t' let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e') return (s1, newArr, T.EAbs newArr (name, varType) e')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
-- | ------------------------------------------ -- \| ------------------------------------------
-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
-- This might be wrong -- This might be wrong
EAdd e0 e1 -> do EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0 (s1, t0, e0') <- algoW e0
@ -203,10 +251,10 @@ algoW = \case
s4 <- unify (apply s3 t1) (TMono "Int") s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1') return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
-- | -------------------------------------- -- \| --------------------------------------
-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
EApp e0 e1 -> do EApp e0 e1 -> do
fr <- fresh fr <- fresh
@ -218,11 +266,11 @@ algoW = \case
let t = apply s2 fr let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- | ---------------------------------------------- -- \| ----------------------------------------------
-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ -- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
-- The bar over S₀ and Γ means "generalize" -- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0 (s1, t1, e0') <- algoW e0
@ -230,18 +278,17 @@ algoW = \case
let t' = generalize (apply s1 env) t1 let t' = generalize (apply s1 env) t1
withBinding name t' $ do withBinding name t' $ do
(s2, t2, e1') <- algoW e1 (s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) e0') e1') return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
ECase caseExpr injs -> do ECase caseExpr injs -> do
(s0, t0, e0') <- algoW caseExpr (s0, t0, e0') <- algoW caseExpr
(injs', ts) <- unzip <$> mapM (checkInj t0) injs (injs', ts) <- unzip <$> mapM (checkInj t0) injs
case ts of case ts of
[] -> throwError "Case expression missing any matches" [] -> throwError "Case expression missing any matches"
ts -> do ts -> do
unified <- zipWithM unify ts (tail ts) unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts) let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs') return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
@ -253,27 +300,40 @@ unify t0 t1 = case (trace ("LEFT: " ++ show t0) t0, trace ("RIGHT: " ++ show t1)
(TPol a, b) -> occurs a b (TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a (a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
-- | TODO: Figure out a cleaner way to express the same thing -- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) -> if name == name' && length t == length t' (TConstr (Constr name t), TConstr (Constr name' t')) ->
then do if name == name' && length t == length t'
xs <- zipWithM unify t t' then do
return $ foldr compose nullSubst xs xs <- zipWithM unify t t'
else throwError $ unwords return $ foldr compose nullSubst xs
["Type constructor:" else
, printTree name throwError $
, "(" ++ printTree t ++ ")" unwords
, "does not match with:" [ "Type constructor:"
, printTree name' , printTree name
, "(" ++ printTree t' ++ ")"] , "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b] (a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
-- I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal
-}
occurs :: Ident -> Type -> Infer Subst occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst occurs _ (TPol _) = return nullSubst
occurs i t = if S.member i (free t) occurs i t =
then throwError $ unwords ["Occurs check failed, can't unify", printTree (TPol i), "with", printTree t] if S.member i (free t)
else return $ M.singleton i t then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly generalize :: Map Ident Poly -> Type -> Poly
@ -292,44 +352,45 @@ compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- | A class representing free variables functions -- | A class representing free variables functions
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
free :: t -> Set Ident free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t -- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where instance FreeVars Type where
free :: Type -> Set Ident free :: Type -> Set Ident
free (TPol a) = S.singleton a free (TPol a) = S.singleton a
free (TMono _) = mempty free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b free (TArr a b) = free a `S.union` free b
-- | Not guaranteed to be correct -- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
case t of case t of
TMono a -> TMono a TMono a -> TMono a
TPol a -> case M.lookup a sub of TPol a -> case M.lookup a sub of
Nothing -> TPol a Nothing -> TPol a
Just t -> t Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b) TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a)) TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
instance FreeVars Poly where instance FreeVars Poly where
free :: Poly -> Set Ident free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s) apply s = M.map (apply s)
-- | Apply substitutions to the environment. -- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st { vars = apply s (vars st) }) applySt s = local (\st -> st {vars = apply s (vars st)})
-- | Represents the empty substition set -- | Represents the empty substition set
nullSubst :: Subst nullSubst :: Subst
@ -339,68 +400,73 @@ nullSubst = M.empty
fresh :: Infer Type fresh :: Infer Type
fresh = do fresh = do
n <- gets count n <- gets count
modify (\st -> st { count = n + 1 }) modify (\st -> st {count = n + 1})
return . TPol . Ident $ show n return . TPol . Ident $ show n
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st { vars = M.insert i p (vars st) }) withBinding i p = local (\st -> st {vars = M.insert i p (vars st)})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer () insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) }) insertSig i t = modify (\st -> st {sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) }) insertConstr i t = modify (\st -> st {constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType -- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type); checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it (args, t') <- initType caseType it
(s, t, e') <- local (\st -> st { vars = args }) (algoW expr) (s, t, e') <- local (\st -> st {vars = args}) (algoW expr)
return (T.Inj (it, t') e', t) return (T.Inj (it, t') e', t)
initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case initType expected = \case
InitLit lit ->
InitLit lit -> let returnType = litType lit let returnType = litType lit
in if expected == returnType in if expected == returnType
then return (mempty,expected) then return (mempty, expected)
else throwError $ unwords [ "Inferred type" else
, printTree returnType throwError $
, "does not match expected type:" unwords
, printTree expected [ "Inferred type"
] , printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do InitConstr c args -> do
st <- gets constructors st <- gets constructors
case M.lookup c st of case M.lookup c st of
Nothing -> throwError $ unwords ["Constructor:" Nothing ->
, printTree c throwError $
, "does not exist" unwords
] [ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do Just t -> do
let flat = flattenType t let flat = flattenType t
let returnType = last flat let returnType = last flat
case (length (init flat) == length args, returnType `isMoreSpecificOrEq` expected) of case (length (init flat) == length args, returnType `isMoreSpecificOrEq` expected) of
(True, True) -> return (M.fromList $ zip args (map (Forall []) flat), expected) (True, True) -> return (M.fromList $ zip args (map (Forall []) flat), expected)
(False, _) -> throwError $ "Can't partially match on the constructor: " ++ printTree c (False, _) -> throwError $ "Can't partially match on the constructor: " ++ printTree c
(_, False) -> throwError $ unwords [ "Inferred type" (_, False) ->
, printTree returnType throwError $
, "does not match expected type:" unwords
, printTree expected [ "Inferred type"
] , printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected) InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a] flattenType a = [a]
litType :: Literal -> Type litType :: Literal -> Type
litType (LInt i) = TMono "Int" litType (LInt i) = TMono "Int"

View file

@ -16,7 +16,7 @@ data Maybe ('a) where {
id : 'a -> 'a ; id : 'a -> 'a ;
id x = x ; id x = x ;
main : Maybe ('a -> 'a) ; main : Maybe ('a -> 'a) ;
main = Just id; main = Just id;
-- data Either ('a 'b) where { -- data Either ('a 'b) where {