continued work pattern matching

This commit is contained in:
sebastianselander 2023-03-02 16:05:43 +01:00
parent 05313652f9
commit 2401b6437b
6 changed files with 79 additions and 58 deletions

View file

@ -8,18 +8,19 @@ separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";" Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ; Ident [Ident] "=" Exp ;
Data. Data ::= "data" Type "where" "{" Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
[Constructor] "}" ;
separator nonempty Constructor "" ;
Constructor. Constructor ::= Ident ":" Type ; Constructor. Constructor ::= Ident ":" Type ;
separator nonempty Constructor "" ;
TMono. Type1 ::= "_" Ident ; TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ; TPol. Type1 ::= "'" Ident ;
TConstr. Type1 ::= Ident "(" [Type] ")" ; TConstr. Type1 ::= Constr ;
TArr. Type ::= Type1 "->" Type ; TArr. Type ::= Type1 "->" Type ;
Constr. Constr ::= Ident "(" [Type] ")" ;
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ; EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ; EId. Exp4 ::= Ident ;
ELit. Exp4 ::= Literal ; ELit. Exp4 ::= Literal ;
@ -41,14 +42,9 @@ InitCatch. Init ::= "_" ;
separator Type " " ; separator Type " " ;
coercions Type 2 ; coercions Type 2 ;
-- This doesn't seem to work so we'll have to live with ugly keywords for now
-- token Poly upper (letter | digit | '_')* ;
-- token Mono lower (letter | digit | '_')* ;
separator Ident " "; separator Ident " ";
coercions Exp 5 ; coercions Exp 5 ;
comment "--" ; comment "--" ;
comment "{-" "-}" ; comment "{-" "-}" ;

17
Justfile Normal file
View file

@ -0,0 +1,17 @@
alias b := build
build:
bnfc -o src -d Grammar.cf
# clean the generated directories
clean:
rm -r src/Grammar
rm language
# run all tests
test:
cabal test
# compile a specific file
run FILE:
cabal run language {{FILE}}

View file

@ -1,27 +0,0 @@
{-# OPTIONS_GHC -Wno-unused-imports #-}
module TypeChecker.CheckInj where
import TypeChecker.TypeChecker
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Infer)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.Abs
import Grammar.Print (printTree)
checkInj :: Inj -> Infer T.Inj
checkInj (Inj it expr) = do
(_, e') <- inferExp expr
t' <- initType it
return $ T.Inj (it, t') e'
initType :: Init -> Infer Type
initType = undefined

View file

@ -7,7 +7,7 @@ module TypeChecker.TypeChecker where
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Functor.Identity (Identity, runIdentity) import Data.Functor.Identity (runIdentity)
import Data.List (foldl') import Data.List (foldl')
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import qualified Data.Map as M
@ -21,7 +21,6 @@ import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst) Poly (..), Subst)
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 mempty mempty initEnv = Env 0 mempty mempty
@ -39,9 +38,10 @@ typecheck = run . checkPrg
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = case d of checkData d = case d of
(Data typ@(TConstr name _) constrs) -> do (Data typ@(Constr name ts) constrs) -> do
unless (all isPoly ts) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ (\(Constructor name' t') traverse_ (\(Constructor name' t')
-> if typ == retType t' -> if TConstr typ == retType t'
then insertConstr name' t' else then insertConstr name' t' else
throwError $ throwError $
unwords unwords
@ -51,11 +51,9 @@ checkData d = case d of
, printTree (retType t') , printTree (retType t')
, "does not match data: " , "does not match data: "
, printTree typ]) constrs , printTree typ]) constrs
_ -> throwError "Data type incorrectly declared" retType :: Type -> Type
where retType (TArr _ t2) = retType t2
retType :: Type -> Type retType a = a
retType (TArr _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
@ -86,7 +84,7 @@ checkBind (Bind n t _ args e) = do
, "does not match body with inferred type:" , "does not match body with inferred type:"
, printTree t'' , printTree t''
]) ])
return $ T.Bind (n, t) [] e' return $ T.Bind (n, t) e'
where where
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs) makeLambda = foldl (flip EAbs)
@ -96,7 +94,7 @@ checkBind (Bind n t _ args e) = do
typeEq :: Type -> Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr name a) (TConstr name' b) = length a == length b typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = length a == length b
&& name == name' && name == name'
&& and (zipWith typeEq a b) && and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True typeEq (TPol _) (TPol _) = True
@ -107,6 +105,10 @@ isMoreGeneral _ (TPol _) = True
isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d
isMoreGeneral a b = a == b isMoreGeneral a b = a == b
isPoly :: Type -> Bool
isPoly (TPol _) = True
isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp) inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do inferExp e = do
(s, t, e') <- algoW e (s, t, e') <- algoW e
@ -120,7 +122,7 @@ replace t = \case
T.EAbs _ name e -> T.EAbs t name e T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2 T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2 T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
algoW :: Exp -> Infer (Subst, Type, T.Exp) algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case algoW = \case
@ -216,9 +218,9 @@ algoW = \case
let t' = generalize (apply s1 env) t1 let t' = generalize (apply s1 env) t1
withBinding name t' $ do withBinding name t' $ do
(s2, t2, e1') <- algoW e1 (s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' ) return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) e0') e1' )
ECase e0 injs -> undefined ECase _ _ -> undefined
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: Type -> Type -> Infer Subst
@ -231,7 +233,7 @@ unify t0 t1 = case (t0, t1) of
(a, TPol b) -> occurs b a (a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
-- | TODO: Figure out a cleaner way to express the same thing -- | TODO: Figure out a cleaner way to express the same thing
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t' (TConstr (Constr name t), TConstr (Constr name' t')) -> if name == name' && length t == length t'
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs return $ foldr compose nullSubst xs
@ -280,7 +282,7 @@ instance FreeVars Type where
free (TMono _) = mempty free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b free (TArr a b) = free a `S.union` free b
-- | Not guaranteed to be correct -- | Not guaranteed to be correct
free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a free (TConstr (Constr _ a)) = foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
case t of case t of
@ -289,7 +291,7 @@ instance FreeVars Type where
Nothing -> TPol a Nothing -> TPol a
Just t -> t Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b) TArr a b -> TArr (apply sub a) (apply sub b)
TConstr name a -> TConstr name (map (apply sub) a) TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
instance FreeVars Poly where instance FreeVars Poly where
free :: Poly -> Set Ident free :: Poly -> Set Ident
@ -329,3 +331,37 @@ insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) }) insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) })
-------- PATTERN MATCHING ---------
checkInj :: Inj -> Infer T.Inj
checkInj (Inj it expr) = do
(_, e') <- inferExp expr
t' <- initType it
return $ T.Inj (it, t') e'
initType :: Init -> Infer Type
initType = \case
InitLit lit -> return $ litType lit
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
Nothing -> throwError $ unwords ["Constructor:", printTree c, "does not exist"]
Just t -> do
let flat = flattenType t
let returnType = last flat
if length (init flat) == length args
then return returnType
else throwError $ "Can't partially match on the constructor: " ++ printTree c
-- Ignoring the variables for now, they can not be used in the expression to the
-- right of '=>'
InitCatch -> return $ TPol "catch"
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt i) = TMono "Int"

View file

@ -49,7 +49,7 @@ data Def = DBind Bind | DData Data
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp data Bind = Bind Id Exp
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where instance Print [Def] where
@ -64,11 +64,10 @@ instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind (t, name) parms rhs) = prPrec i 0 $ concatD prt i (Bind (t, name) rhs) = prPrec i 0 $ concatD
[ prt 0 name [ prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 1 t , prt 1 t
, prtIdPs 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 2 rhs , prt 2 rhs
] ]

View file

@ -8,5 +8,5 @@ data Bool () where {
False : Bool () False : Bool ()
}; };
main : List ('a) ; main : List (_Int) ;
main = Cons 1 (Cons 0 Nil) ; main = Cons 1 (Cons 0 Nil) ;