Propagate type application, temporary remove nested pattern matching, fix void output

This commit is contained in:
Martin Fredin 2023-05-12 11:40:24 +02:00
parent 6260dc2c41
commit c3bcdfa81b
4 changed files with 175 additions and 140 deletions

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.LlvmIr (
LLVMType (..),
@ -57,6 +58,7 @@ instance ToIr LLVMType where
Ref ty -> toIr ty <> "*"
Function t xs -> toIr t <> " (" <> intercalate ", " (map toIr xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", toIr ty, "]"]
CustomType "void" -> "void"
CustomType (Ident ty) -> "%" <> ty
data LLVMComp

View file

@ -1,20 +1,21 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
-- {-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck) where
import Auxiliary (int, maybeToRightM, onM, snoc,
typeof)
import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Applicative (Applicative (liftA2), liftA3, (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
forM, runExceptT, unless, zipWithM,
zipWithM_)
MonadTrans (lift), forM, runExceptT,
unless, zipWithM, zipWithM_)
import Control.Monad.Extra (fromMaybeM, ifM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Control.Monad.State (MonadState (get), State, StateT,
evalState, evalStateT, gets, modify)
import Data.Coerce (coerce)
import Data.Foldable (foldlM)
import Data.Function (on)
@ -215,10 +216,7 @@ check (ECase scrut branches) c = do
subtype a c
apply (T.ECase (scrut', a) [], a)
_ -> do
branches' <- forM branches $ \(Branch p e) -> do
p' <- checkPattern p =<< apply a
e' <- check e c
pure (T.Branch p' e')
branches' <- checkBranches branches a c
apply (T.ECase (scrut', a) branches', c)
@ -232,56 +230,84 @@ check e b = do
apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
-- Γ ⊢ p₁⇒e₁ ‥ pₙ⇒eₙ :: A ↑ C ⊣ Δ
checkBranches :: [Branch] -> Type -> Type -> Tc [T.Branch' Type]
checkBranches branches a c = evalStateT go mempty
where
go = forM branches $ \(Branch p e) -> do
p' <- patternMatch p =<< lift (apply a)
lift $ do
e' <- check e c
apply (T.Branch p' e')
substituteAll :: Type -> StateT (Map TVar TEVar) Tc Type
substituteAll t = get >>= \subs -> case t of
TAll tvar t
| Just tevar <- Map.lookup tvar subs ->
lift . pure $ substitute tvar tevar t
| otherwise -> do
tevar <- lift fresh
modify (Map.insert tvar tevar)
substituteAll (substitute tvar tevar t)
TFun t1 t2 -> onM TFun substituteAll t1 t2
t -> pure t
patternMatch :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type)
-- ------------------- PVar
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
checkPattern (PVar x) a = do
-- Γ ⊢ x :: A ⊣ Γ,(x:A)
patternMatch (PVar x) a = lift $ do
insertEnv $ EnvVar x a
apply (T.PVar (coerce x), a)
-- ------------- PCatch
-- Γ ⊢ _ ↑ A ⊣ Γ
checkPattern PCatch a = apply (T.PCatch, a)
-- Γ ⊢ _ :: A ⊣ Γ
patternMatch PCatch a = lift $ apply (T.PCatch, a)
-- A = typeof(lit)
-- Γ ⊢ typeof(lit) <: A ⊣ Δ
-- ------------------------- PLit
-- Γ ⊢ lit ↑ A ⊣ Γ
checkPattern (PLit lit) a | a == typeof lit = apply (T.PLit lit, a)
checkPattern (PLit lit) a = error $ "\n -- MARTIN HJÄLP!! --\nUnimplemented match for: '" ++ printTree a ++ "' == '" ++ printTree (typeof lit) ++ "'"
-- Γ ⊢ lit :: A ⊣ Δ
patternMatch (PLit lit) a = lift $ do
subtype (typeof lit) a
apply (T.PLit lit, typeof lit)
-- Γ ∋ (K : T) Γ ⊢ A <: B ⊣ Δ
-- Γ ∋ (K : A) Γ ⊢ A <: C ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ T ⊣ Δ
checkPattern (PEnum k) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k
-- Γ ⊢ K :: C ⊣ Δ
patternMatch (PEnum k) b = do
a <- lift (maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k)
a <- substituteAll a
lift $ do
subtype a b
apply (T.PEnum (coerce k), a)
-- β α Γ Γ Δ
--
-- Γ ∋ (K : A) Θ₂ ⊢ p₁ ↑ [Θ₁]A₁ ⊣ Θ₂
-- Γ ∋ (K : A) Θ₂ ⊢ p₁ :: [Θ₁]A₁ ⊣ Θ₂
-- Γ ⊢ ∀ά₁‥άₘ A₁ → ‥ → Aₙ₊₁ = substituteAll(A) ⊣ Θ₁ ...
-- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Θₙ₊₁ ⊢ pₙ ↑ [Θₙ₊₁]Aₙ ⊣ Δ
-- -----------------------------------------------------------------------
-- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Β Θₙ₊₁ ⊢ pₙ :: [Θₙ₊₁]Aₙ ⊣ Δ
-- ----------------------------------------------------------------------------- PInj
-- Γ ⊢ K p₁‥pₙ ↑ B ⊣ Δ
{- checkPattern (PInj k ps) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k
patternMatch (PInj k ps) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lift (lookupInj k)
a <- substituteAll a
lift $ subtype (getDataId a) b
let as = getArgs a
unless (length as == length ps) $ throwError "Wrong number of arguments!"
ps' <- zipWithM (\p a -> checkPattern p =<< apply a) ps as
apply (T.PInj (coerce k) ps', a)
ps' <- zipWithM (\p a -> patternMatch p =<< lift (apply a)) ps as
lift $ apply (T.PInj (coerce k) ps', a)
where
substituteAll t = case t of
TAll tvar t -> do
tevar <- fresh
substituteAll (substitute tvar tevar t)
TFun t1 t2 -> onM TFun substituteAll t1 t2
t -> pure t
getArgs = \case
TAll _ t -> getArgs t
@ -289,12 +315,7 @@ checkPattern (PEnum k) b = do
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc -}
_ -> acc
-- Example
@ -304,32 +325,37 @@ checkPattern (PEnum k) b = do
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
checkPattern (PInj name ps) t_patt = do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getArgs t_inj
unless (length ts == length ps)
$ throwError "Wrong number of arguments!"
-- patternMatch (PInj name ps) t_patt = do
-- t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
-- let ts = getArgs t_inj
-- unless (length ts == length ps)
-- $ throwError "Wrong number of arguments!"
--
-- -- [ά/α]
-- sub <- substituteTVarsOf t_inj
-- subtype (sub $ getDataId t_inj) t_patt
-- let check p t = checkPattern p =<< apply (sub t)
-- ps' <- zipWithM check ps ts
-- apply (T.PInj (coerce name) ps', t_patt)
-- where
-- substituteTVarsOf = \case
-- TAll tvar t -> do
-- tevar <- fresh
-- (substitute tvar tevar .) <$> substituteTVarsOf t
-- _ -> pure id
--
-- getArgs = \case
-- TAll _ t -> getArgs t
-- t -> go [] t
-- where
-- go acc = \case
-- TFun t1 t2 -> go (snoc t1 acc) t2
-- _ -> acc
-- [ά/α]
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let check p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM check ps ts
apply (T.PInj (coerce name) ps', t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
@ -370,9 +396,9 @@ infer (EAnn e a) = do
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
infer (EApp e1 e2) = do
e1'@(_, a) <- infer e1
(e2', c) <- applyInfer a e2
apply (T.EApp e1' e2', c)
(e1', a) <- infer e1
(e2', a', c) <- applyInfer a e2
apply (T.EApp (e1', a') e2', c)
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
@ -418,25 +444,25 @@ infer (EAdd e1 e2) = do
-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↓ [Θₙ]C ⊣ Δ
-- --------------------------------------- Case↓
-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↓ C ⊣ Δ
infer (ECase scrut branches) = do
(scrut', a) <- infer scrut
case branches of
[] -> apply (T.ECase (scrut', a) [], a)
(Branch _ e):_ -> do
(_, b)<- infer e
(branches', b') <- foldlM go ([], b) branches
apply (T.ECase (scrut', a) branches', b')
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- infer (ECase scrut branches) = do
-- (scrut', a) <- infer scrut
-- case branches of
-- [] -> apply (T.ECase (scrut', a) [], a)
-- (Branch _ e):_ -> do
-- (_, b)<- infer e
-- (branches', b') <- foldlM go ([], b) branches
-- apply (T.ECase (scrut', a) branches', b')
-- where
-- go (pi, b) (Branch p e) = do
-- p' <- checkPattern p =<< apply a
-- e'@(_, b') <- infer e
-- subtype b' b
-- apply (T.Branch p' e' : pi, b')
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
@ -453,20 +479,21 @@ applyInfer (TEVar alpha) e = do
alpha1 <- fresh
alpha2 <- fresh
(env_l, env_r) <- gets (splitOn (EnvTEVar alpha) . env)
let alpha_solution = on TFun TEVar alpha1 alpha2
putEnv $ (env_l
:|> EnvTEVar alpha2
:|> EnvTEVar alpha1
:|> EnvTEVarSolved alpha (on TFun TEVar alpha1 alpha2)
:|> EnvTEVarSolved alpha alpha_solution
) <> env_r
e' <- check e $ TEVar alpha1
apply (e', TEVar alpha2)
apply (e', alpha_solution, TEVar alpha2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
applyInfer (TFun a c) e = do
exp' <- check e a
apply (exp', c)
e'@(_, a') <- check e a
apply (e',TFun a' c, c)
applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e)
@ -806,13 +833,14 @@ pattern DSig' name typ = DSig (Sig name typ)
class Apply a where
apply :: a -> Tc a
instance Apply Type where apply = applyType
instance Apply (T.Exp' Type) where apply = applyExp
instance Apply (T.Branch' Type) where apply = applyBranch
instance Apply (T.Pattern' Type) where apply = applyPattern
instance Apply a => Apply [a] where apply = mapM apply
instance Apply Type where apply = applyType
instance Apply (T.Exp' Type) where apply = applyExp
instance Apply (T.Branch' Type) where apply = applyBranch
instance Apply (T.Pattern' Type) where apply = applyPattern
instance Apply a => Apply [a] where apply = mapM apply
instance (Apply a, Apply b) => Apply (a, b) where apply = applyPair
instance Apply T.Ident where apply = pure
instance (Apply a, Apply b, Apply c) => Apply (a, b, c) where apply = applyTuple
instance Apply T.Ident where apply = pure
applyType :: Type -> Tc Type
applyType t = gets $ (`applyType'` t) . env
@ -834,7 +862,7 @@ applyType' cxt typ | typ == typ' = typ'
TFun t1 t2 -> on TFun (applyType' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyType' cxt t
TIdent t -> typ
TIdent _ -> typ
applyExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyExp exp = case exp of
@ -866,6 +894,9 @@ applyPattern = \case
applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b)
applyPair (x, y) = liftA2 (,) (apply x) (apply y)
applyTuple :: (Apply a, Apply b, Apply c) => (a, b, c) -> Tc (a, b, c)
applyTuple (x, y, z) = liftA3 (,,) (apply x) (apply y) (apply z)
---------------------------------------------------------------------------
-- * Debug
---------------------------------------------------------------------------

View file

@ -90,13 +90,15 @@ prtSig (x, t) =
]
instance (Print a, Print t) => Print (T a t) where
prt i (x, t) =
concatD
[ -- doc $ showString "("
{- , -} prt i x
-- , doc $ showString ":"
-- , prt 0 t
-- , doc $ showString ")"
prt i (x, t) = withT
where
noT = prt i x
withT = concatD
[ doc $ showString "("
, prt i x
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where

View file

@ -1,28 +1,28 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Desugar.Desugar (desugar)
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr qualified as T
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Desugar.Desugar (desugar)
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
test = hspec testTypeCheckerBidir
@ -189,16 +189,16 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
run (fs ++ wrong2) `shouldNotSatisfy` ok
specify "Third wrong case expression rejected" $
run (fs ++ wrong3) `shouldNotSatisfy` ok
specify "Forth wrong case expression rejected" $
run (fs ++ wrong4) `shouldNotSatisfy` ok
specify "First correct case expression accepted" $
run (fs ++ correct1) `shouldSatisfy` ok
-- specify "Forth wrong case expression rejected" $
-- run (fs ++ wrong4) `shouldNotSatisfy` ok
-- specify "First correct case expression accepted" $
-- run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok
specify "Third correct case expression accepted" $
run (fs ++ correct3) `shouldSatisfy` ok
specify "Forth correct case expression accepted" $
run (fs ++ correct4) `shouldSatisfy` ok
-- specify "Third correct case expression accepted" $
-- run (fs ++ correct3) `shouldSatisfy` ok
-- specify "Forth correct case expression accepted" $
-- run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data List a where"
@ -254,9 +254,9 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
correct4 =
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
, " Cons Nil xs => elems xs"
--, " Nil => 0"
--, " Cons Nil Nil => 0"
--, " Cons Nil xs => elems xs"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
]
@ -329,5 +329,5 @@ runPrint =
["double x = x + x"]
ok = \case
Ok _ -> True
Ok _ -> True
Bad _ -> False