unit tests, started on pattern matching

This commit is contained in:
sebastianselander 2023-02-28 17:15:48 +01:00
parent d23d417ff3
commit 05313652f9
9 changed files with 212 additions and 133 deletions

View file

@ -3,7 +3,6 @@
module Main where
-- import Codegen.Codegen (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
@ -12,7 +11,6 @@ import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
@ -25,32 +23,28 @@ main' :: String -> IO ()
main' s = do
file <- readFile s
printToErr "-- Parse Tree -- "
putStrLn "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file
printToErr $ printTree parsed
putStrLn $ printTree parsed
printToErr "\n-- Renamer --"
putStrLn "\n-- Renamer --"
let renamed = rename parsed
printToErr $ printTree renamed
putStrLn $ printTree renamed
printToErr "\n-- TypeChecker --"
putStrLn "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked
putStrLn $ printTree typechecked
-- printToErr "\n-- Lambda Lifter --"
-- putStrLn "\n-- Lambda Lifter --"
-- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted
-- putStrLn $ printTree lifted
-- printToErr "\n -- Printing compiler output to stdout --"
-- putStrLn "\n -- Printing compiler output to stdout --"
-- compiled <- fromCompilerErr $ compile lifted
-- putStrLn compiled
-- writeFile "llvm.ll" compiled
exitSuccess
printToErr :: String -> IO ()
printToErr = hPutStrLn stderr
fromCompilerErr :: Err a -> IO a
fromCompilerErr =
either

View file

@ -28,7 +28,6 @@ rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs)
pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
--
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }

View file

@ -0,0 +1,27 @@
{-# OPTIONS_GHC -Wno-unused-imports #-}
module TypeChecker.CheckInj where
import TypeChecker.TypeChecker
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Infer)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.Abs
import Grammar.Print (printTree)
checkInj :: Inj -> Infer T.Inj
checkInj (Inj it expr) = do
(_, e') <- inferExp expr
t' <- initType it
return $ T.Inj (it, t') e'
initType :: Init -> Infer Type
initType = undefined

View file

@ -1,10 +1,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
{-# HLINT ignore "Use zipWithM" #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Control.Monad.Except
@ -21,23 +18,9 @@ import Data.Foldable (traverse_)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly
}
data Env = Env { count :: Int
, sigs :: Map Ident Type
, dtypes :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
initCtx = Ctx mempty
initEnv = Env 0 mempty mempty
@ -98,7 +81,11 @@ checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature", printTree t, "does not match body with type:", printTree t''])
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
])
return $ T.Bind (n, t) [] e'
where
makeLambda :: Exp -> [Ident] -> Exp
@ -109,12 +96,17 @@ checkBind (Bind n t _ args e) = do
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr name a) (TConstr name' b) = if length a == length b
then name == name' && and (zipWith typeEq a b)
else False
typeEq (TConstr name a) (TConstr name' b) = length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreGeneral :: Type -> Type -> Bool
isMoreGeneral _ (TPol _) = True
isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d
isMoreGeneral a b = a == b
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- algoW e
@ -123,24 +115,30 @@ inferExp e = do
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
algoW :: Exp -> Infer (Subst, Type, T.Exp)
algoW = \case
-- | TODO: Reason more about this one. Could be wrong
EAnn e t -> do
(s1, t', e') <- algoW e
unless (t `isMoreGeneral` t') (throwError $ unwords
["Annotated type:"
, printTree t
, "does not match inferred type:"
, printTree t' ])
applySt s1 $ do
s2 <- unify (apply s1 t) t'
s2 <- unify t t'
return (s2 `compose` s1, t, e')
-- | ------------------
-- | Γ ⊢ e₀ : Int, ∅
-- | Γ ⊢ i : Int, ∅
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
@ -159,7 +157,7 @@ algoW = \case
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets dtypes
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i
@ -220,9 +218,9 @@ algoW = \case
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
ECase e0 injs -> undefined
-- | Unify two types producing a new substitution (constraint)
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do
@ -235,9 +233,15 @@ unify t0 t1 = case (t0, t1) of
-- | TODO: Figure out a cleaner way to express the same thing
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
then do
xs <- sequence $ zipWith unify t t'
xs <- zipWithM unify t t'
return $ foldr compose nullSubst xs
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
else throwError $ unwords
["Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"]
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type.
@ -324,4 +328,4 @@ insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })
insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors st) })

View file

@ -1,14 +1,33 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
module TypeChecker.TypeCheckerIr where
import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
import qualified Prelude as C (Eq, Ord, Read, Show)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly }
data Env = Env { count :: Int
, sigs :: Map Ident Type
, constructors :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
@ -22,6 +41,9 @@ data Exp
| EAbs Type Id Exp
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Inj = Inj (Init, Type) Exp
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
deriving (C.Eq, C.Ord, C.Read, C.Show)
@ -30,6 +52,10 @@ type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x:xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
@ -41,16 +67,16 @@ instance Print Bind where
prt i (Bind (t, name) parms rhs) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, prt 1 t
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
, prt 2 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)