Propagate type application, temporary remove nested pattern matching, fix void output

This commit is contained in:
Martin Fredin 2023-05-12 11:40:24 +02:00
parent 6260dc2c41
commit c3bcdfa81b
4 changed files with 175 additions and 140 deletions

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.LlvmIr (
LLVMType (..),
@ -57,6 +58,7 @@ instance ToIr LLVMType where
Ref ty -> toIr ty <> "*"
Function t xs -> toIr t <> " (" <> intercalate ", " (map toIr xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", toIr ty, "]"]
CustomType "void" -> "void"
CustomType (Ident ty) -> "%" <> ty
data LLVMComp

View file

@ -1,20 +1,21 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
-- {-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck) where
import Auxiliary (int, maybeToRightM, onM, snoc,
typeof)
import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Applicative (Applicative (liftA2), liftA3, (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
forM, runExceptT, unless, zipWithM,
zipWithM_)
MonadTrans (lift), forM, runExceptT,
unless, zipWithM, zipWithM_)
import Control.Monad.Extra (fromMaybeM, ifM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Control.Monad.State (MonadState (get), State, StateT,
evalState, evalStateT, gets, modify)
import Data.Coerce (coerce)
import Data.Foldable (foldlM)
import Data.Function (on)
@ -215,10 +216,7 @@ check (ECase scrut branches) c = do
subtype a c
apply (T.ECase (scrut', a) [], a)
_ -> do
branches' <- forM branches $ \(Branch p e) -> do
p' <- checkPattern p =<< apply a
e' <- check e c
pure (T.Branch p' e')
branches' <- checkBranches branches a c
apply (T.ECase (scrut', a) branches', c)
@ -232,56 +230,84 @@ check e b = do
apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
-- Γ ⊢ p₁⇒e₁ ‥ pₙ⇒eₙ :: A ↑ C ⊣ Δ
checkBranches :: [Branch] -> Type -> Type -> Tc [T.Branch' Type]
checkBranches branches a c = evalStateT go mempty
where
go = forM branches $ \(Branch p e) -> do
p' <- patternMatch p =<< lift (apply a)
lift $ do
e' <- check e c
apply (T.Branch p' e')
substituteAll :: Type -> StateT (Map TVar TEVar) Tc Type
substituteAll t = get >>= \subs -> case t of
TAll tvar t
| Just tevar <- Map.lookup tvar subs ->
lift . pure $ substitute tvar tevar t
| otherwise -> do
tevar <- lift fresh
modify (Map.insert tvar tevar)
substituteAll (substitute tvar tevar t)
TFun t1 t2 -> onM TFun substituteAll t1 t2
t -> pure t
patternMatch :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type)
-- ------------------- PVar
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
checkPattern (PVar x) a = do
-- Γ ⊢ x :: A ⊣ Γ,(x:A)
patternMatch (PVar x) a = lift $ do
insertEnv $ EnvVar x a
apply (T.PVar (coerce x), a)
-- ------------- PCatch
-- Γ ⊢ _ ↑ A ⊣ Γ
checkPattern PCatch a = apply (T.PCatch, a)
-- Γ ⊢ _ :: A ⊣ Γ
patternMatch PCatch a = lift $ apply (T.PCatch, a)
-- A = typeof(lit)
-- Γ ⊢ typeof(lit) <: A ⊣ Δ
-- ------------------------- PLit
-- Γ ⊢ lit ↑ A ⊣ Γ
checkPattern (PLit lit) a | a == typeof lit = apply (T.PLit lit, a)
checkPattern (PLit lit) a = error $ "\n -- MARTIN HJÄLP!! --\nUnimplemented match for: '" ++ printTree a ++ "' == '" ++ printTree (typeof lit) ++ "'"
-- Γ ⊢ lit :: A ⊣ Δ
patternMatch (PLit lit) a = lift $ do
subtype (typeof lit) a
apply (T.PLit lit, typeof lit)
-- Γ ∋ (K : T) Γ ⊢ A <: B ⊣ Δ
-- Γ ∋ (K : A) Γ ⊢ A <: C ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ T ⊣ Δ
checkPattern (PEnum k) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k
-- Γ ⊢ K :: C ⊣ Δ
patternMatch (PEnum k) b = do
a <- lift (maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k)
a <- substituteAll a
lift $ do
subtype a b
apply (T.PEnum (coerce k), a)
-- β α Γ Γ Δ
--
-- Γ ∋ (K : A) Θ₂ ⊢ p₁ ↑ [Θ₁]A₁ ⊣ Θ₂
-- Γ ∋ (K : A) Θ₂ ⊢ p₁ :: [Θ₁]A₁ ⊣ Θ₂
-- Γ ⊢ ∀ά₁‥άₘ A₁ → ‥ → Aₙ₊₁ = substituteAll(A) ⊣ Θ₁ ...
-- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Θₙ₊₁ ⊢ pₙ ↑ [Θₙ₊₁]Aₙ ⊣ Δ
-- -----------------------------------------------------------------------
-- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Β Θₙ₊₁ ⊢ pₙ :: [Θₙ₊₁]Aₙ ⊣ Δ
-- ----------------------------------------------------------------------------- PInj
-- Γ ⊢ K p₁‥pₙ ↑ B ⊣ Δ
{- checkPattern (PInj k ps) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k
patternMatch (PInj k ps) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lift (lookupInj k)
a <- substituteAll a
lift $ subtype (getDataId a) b
let as = getArgs a
unless (length as == length ps) $ throwError "Wrong number of arguments!"
ps' <- zipWithM (\p a -> checkPattern p =<< apply a) ps as
apply (T.PInj (coerce k) ps', a)
ps' <- zipWithM (\p a -> patternMatch p =<< lift (apply a)) ps as
lift $ apply (T.PInj (coerce k) ps', a)
where
substituteAll t = case t of
TAll tvar t -> do
tevar <- fresh
substituteAll (substitute tvar tevar t)
TFun t1 t2 -> onM TFun substituteAll t1 t2
t -> pure t
getArgs = \case
TAll _ t -> getArgs t
@ -289,12 +315,7 @@ checkPattern (PEnum k) b = do
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc -}
_ -> acc
-- Example
@ -304,32 +325,37 @@ checkPattern (PEnum k) b = do
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
checkPattern (PInj name ps) t_patt = do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getArgs t_inj
unless (length ts == length ps)
$ throwError "Wrong number of arguments!"
-- patternMatch (PInj name ps) t_patt = do
-- t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
-- let ts = getArgs t_inj
-- unless (length ts == length ps)
-- $ throwError "Wrong number of arguments!"
--
-- -- [ά/α]
-- sub <- substituteTVarsOf t_inj
-- subtype (sub $ getDataId t_inj) t_patt
-- let check p t = checkPattern p =<< apply (sub t)
-- ps' <- zipWithM check ps ts
-- apply (T.PInj (coerce name) ps', t_patt)
-- where
-- substituteTVarsOf = \case
-- TAll tvar t -> do
-- tevar <- fresh
-- (substitute tvar tevar .) <$> substituteTVarsOf t
-- _ -> pure id
--
-- getArgs = \case
-- TAll _ t -> getArgs t
-- t -> go [] t
-- where
-- go acc = \case
-- TFun t1 t2 -> go (snoc t1 acc) t2
-- _ -> acc
-- [ά/α]
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let check p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM check ps ts
apply (T.PInj (coerce name) ps', t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
@ -370,9 +396,9 @@ infer (EAnn e a) = do
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
infer (EApp e1 e2) = do
e1'@(_, a) <- infer e1
(e2', c) <- applyInfer a e2
apply (T.EApp e1' e2', c)
(e1', a) <- infer e1
(e2', a', c) <- applyInfer a e2
apply (T.EApp (e1', a') e2', c)
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
@ -418,25 +444,25 @@ infer (EAdd e1 e2) = do
-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↓ [Θₙ]C ⊣ Δ
-- --------------------------------------- Case↓
-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↓ C ⊣ Δ
infer (ECase scrut branches) = do
(scrut', a) <- infer scrut
case branches of
[] -> apply (T.ECase (scrut', a) [], a)
(Branch _ e):_ -> do
(_, b)<- infer e
(branches', b') <- foldlM go ([], b) branches
apply (T.ECase (scrut', a) branches', b')
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- infer (ECase scrut branches) = do
-- (scrut', a) <- infer scrut
-- case branches of
-- [] -> apply (T.ECase (scrut', a) [], a)
-- (Branch _ e):_ -> do
-- (_, b)<- infer e
-- (branches', b') <- foldlM go ([], b) branches
-- apply (T.ECase (scrut', a) branches', b')
-- where
-- go (pi, b) (Branch p e) = do
-- p' <- checkPattern p =<< apply a
-- e'@(_, b') <- infer e
-- subtype b' b
-- apply (T.Branch p' e' : pi, b')
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
@ -453,20 +479,21 @@ applyInfer (TEVar alpha) e = do
alpha1 <- fresh
alpha2 <- fresh
(env_l, env_r) <- gets (splitOn (EnvTEVar alpha) . env)
let alpha_solution = on TFun TEVar alpha1 alpha2
putEnv $ (env_l
:|> EnvTEVar alpha2
:|> EnvTEVar alpha1
:|> EnvTEVarSolved alpha (on TFun TEVar alpha1 alpha2)
:|> EnvTEVarSolved alpha alpha_solution
) <> env_r
e' <- check e $ TEVar alpha1
apply (e', TEVar alpha2)
apply (e', alpha_solution, TEVar alpha2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
applyInfer (TFun a c) e = do
exp' <- check e a
apply (exp', c)
e'@(_, a') <- check e a
apply (e',TFun a' c, c)
applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e)
@ -812,6 +839,7 @@ instance Apply (T.Branch' Type) where apply = applyBranch
instance Apply (T.Pattern' Type) where apply = applyPattern
instance Apply a => Apply [a] where apply = mapM apply
instance (Apply a, Apply b) => Apply (a, b) where apply = applyPair
instance (Apply a, Apply b, Apply c) => Apply (a, b, c) where apply = applyTuple
instance Apply T.Ident where apply = pure
applyType :: Type -> Tc Type
@ -834,7 +862,7 @@ applyType' cxt typ | typ == typ' = typ'
TFun t1 t2 -> on TFun (applyType' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyType' cxt t
TIdent t -> typ
TIdent _ -> typ
applyExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyExp exp = case exp of
@ -866,6 +894,9 @@ applyPattern = \case
applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b)
applyPair (x, y) = liftA2 (,) (apply x) (apply y)
applyTuple :: (Apply a, Apply b, Apply c) => (a, b, c) -> Tc (a, b, c)
applyTuple (x, y, z) = liftA3 (,,) (apply x) (apply y) (apply z)
---------------------------------------------------------------------------
-- * Debug
---------------------------------------------------------------------------

View file

@ -90,13 +90,15 @@ prtSig (x, t) =
]
instance (Print a, Print t) => Print (T a t) where
prt i (x, t) =
concatD
[ -- doc $ showString "("
{- , -} prt i x
-- , doc $ showString ":"
-- , prt 0 t
-- , doc $ showString ")"
prt i (x, t) = withT
where
noT = prt i x
withT = concatD
[ doc $ showString "("
, prt i x
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where

View file

@ -22,7 +22,7 @@ import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr qualified as T
import qualified TypeChecker.TypeCheckerIr as T
test = hspec testTypeCheckerBidir
@ -189,16 +189,16 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
run (fs ++ wrong2) `shouldNotSatisfy` ok
specify "Third wrong case expression rejected" $
run (fs ++ wrong3) `shouldNotSatisfy` ok
specify "Forth wrong case expression rejected" $
run (fs ++ wrong4) `shouldNotSatisfy` ok
specify "First correct case expression accepted" $
run (fs ++ correct1) `shouldSatisfy` ok
-- specify "Forth wrong case expression rejected" $
-- run (fs ++ wrong4) `shouldNotSatisfy` ok
-- specify "First correct case expression accepted" $
-- run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok
specify "Third correct case expression accepted" $
run (fs ++ correct3) `shouldSatisfy` ok
specify "Forth correct case expression accepted" $
run (fs ++ correct4) `shouldSatisfy` ok
-- specify "Third correct case expression accepted" $
-- run (fs ++ correct3) `shouldSatisfy` ok
-- specify "Forth correct case expression accepted" $
-- run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data List a where"
@ -254,9 +254,9 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
correct4 =
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
, " Cons Nil xs => elems xs"
--, " Nil => 0"
--, " Cons Nil Nil => 0"
--, " Cons Nil xs => elems xs"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
]