Fixed errors in tc hm

This commit is contained in:
sebastianselander 2023-03-27 16:48:23 +02:00
parent 847ec37117
commit 6e54378327
4 changed files with 346 additions and 345 deletions

View file

@ -8,19 +8,18 @@ module Codegen.LlvmIr (
LLVMComp (..),
Visibility (..),
CallingConvention (..),
ToIr(..)
ToIr (..),
) where
import Data.List (intercalate)
import Grammar.Abs (Character)
import TypeChecker.TypeCheckerIr (Ident (..))
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show)
instance ToIr CallingConvention where
toIr :: CallingConvention -> String
toIr TailCC = "tailcc"
toIr FastCC = "fastcc"
toIr CCC = "ccc"
toIr CCC = "ccc"
toIr ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types
@ -34,7 +33,7 @@ data LLVMType
| Function LLVMType [LLVMType]
| Array Integer LLVMType
| CustomType Ident
deriving Show
deriving (Show)
class ToIr a where
toIr :: a -> String
@ -63,12 +62,12 @@ data LLVMComp
| LLSge
| LLSlt
| LLSle
deriving Show
deriving (Show)
instance ToIr LLVMComp where
toIr :: LLVMComp -> String
toIr = \case
LLEq -> "eq"
LLNe -> "ne"
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
@ -78,30 +77,31 @@ instance ToIr LLVMComp where
LLSlt -> "slt"
LLSle -> "sle"
data Visibility = Local | Global deriving Show
data Visibility = Local | Global deriving (Show)
instance ToIr Visibility where
toIr :: Visibility -> String
toIr Local = "%"
toIr Local = "%"
toIr Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant
{- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant
-}
data LLVMValue
= VInteger Integer
| VChar Character
| VChar Char
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
deriving Show
deriving (Show)
instance ToIr LLVMValue where
toIr :: LLVMValue -> String
toIr v = case v of
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
@ -114,8 +114,8 @@ data LLVMIr
| Declare LLVMType Ident Params
| SetVariable Ident LLVMIr
| Variable Ident
-- extractvalue <aggregate type> <val>, <idx>{, <idx>}*
| ExtractValue LLVMType LLVMValue Integer
| -- extractvalue <aggregate type> <val>, <idx>{, <idx>}*
ExtractValue LLVMType LLVMValue Integer
| GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Add LLVMType LLVMValue LLVMValue
@ -136,7 +136,7 @@ data LLVMIr
| Comment String
| UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place
deriving Show
deriving (Show)
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
@ -146,14 +146,15 @@ llvmIrToString = go 0
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
_ -> (i, i)
insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation
-}
{- FOURMOLU_DISABLE -}
-- \| Converts a LLVM inststruction to a String, allowing for printing etc.
-- The integer represents the indentation
--
{- FOURMOLU_DISABLE -}
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of
@ -261,4 +262,3 @@ llvmIrToString = go 0
lblPfx :: String
lblPfx = "lbl_"

View file

@ -1,7 +1,6 @@
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module GA) where
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
import qualified Grammar.Abs as GA (Ident (..))
import qualified TypeChecker.TypeCheckerIr as TIR (Ident (..))
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
type Id = (TIR.Ident, Type)
@ -26,8 +25,12 @@ data Exp
| ECase ExpT [Branch]
deriving (Show, Ord, Eq)
data Pattern = PVar Id | PLit (Lit, Type) | PInj TIR.Ident [Pattern]
| PCatch | PEnum TIR.Ident
data Pattern
= PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
| PCatch
| PEnum TIR.Ident
deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
@ -48,4 +51,4 @@ data Type = TLit TIR.Ident | TFun Type Type
flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x]
flattenType x = [x]

View file

@ -1,31 +1,29 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Subst)
import Auxiliary
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Function (on)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Data.String
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty
@ -39,7 +37,7 @@ run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg
checkData :: Data -> Infer ()
@ -50,9 +48,9 @@ checkData d = do
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
( \(Inj name' t') ->
if typ == retType t'
then insertConstr (coerce name') (toNew t')
then insertConstr (coerce name') (t')
else
throwError $
unwords
@ -73,9 +71,9 @@ checkData d = do
retType :: Type -> Type
retType (TFun _ t2) = retType t2
retType a = a
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do
preRun bs
bs' <- checkDef bs
@ -94,25 +92,27 @@ preRun (x : xs) = case x of
<> printTree n
<> "'"
)
insertSig (coerce n) (Just $ toNew t) >> preRun xs
insertSig (coerce n) (Just $ t) >> preRun xs
DBind (Bind n _ e) -> do
collect (collectTypeVars e)
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return []
checkDef (x : xs) = case x of
(DBind b) -> do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs)
(DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs)
(DSig _) -> checkDef xs
where
coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer T.Bind
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda
@ -133,41 +133,41 @@ checkBind (Bind name args e) = do
insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e)
typeEq :: T.Type -> T.Type -> Bool
typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r'
typeEq (T.TLit a) (T.TLit b) = a == b
typeEq (T.TData name a) (T.TData name' b) =
typeEq :: Type -> Type -> Bool
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (TLit a) (TLit b) = a == b
typeEq (TData name a) (TData name' b) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True
typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (TVar _) (TVar _) = True
typeEq _ _ = False
skolem :: T.Type -> T.Type
skolem (T.TVar (T.MkTVar a)) = T.TLit a
skolem (T.TAll x t) = T.TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2
skolem t = t
skolem :: Type -> Type
skolem (TVar (T.MkTVar a)) = TLit (coerce a)
skolem (TAll x t) = TAll x (skolem t)
skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2
skolem t = t
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True
isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer T.ExpT
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
@ -178,7 +178,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty
collectTypeVars _ = S.empty
instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -190,43 +190,12 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
algoW = \case
err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err
unless
(toNew t `isMoreSpecificOrEq` t')
(t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
@ -236,34 +205,34 @@ algoW = \case
]
)
applySt s1 $ do
s2 <- exprErr (unify (toNew t) t') err
s2 <- exprErr (unify (t) t') err
let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
return (comp, apply comp (e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit))
ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EVar i -> do
var <- asks vars
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do
fr <- fresh
insertSig (coerce i) (Just fr)
return (nullSubst, (T.EId $ coerce i, fr))
return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do
constr <- gets constructors
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t))
Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing ->
throwError $
"Constructor: '"
@ -280,7 +249,7 @@ algoW = \case
( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr
let newArr = T.TFun varType t'
let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
)
err
@ -314,7 +283,7 @@ algoW = \case
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err
s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
@ -346,33 +315,33 @@ makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst
unify :: Type -> Type -> Infer Subst
unify t0 t1 = do
case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do
(TFun a b, TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t
(TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
-------------------------------------------------------------------
(T.TVar (T.MkTVar a), t) -> occurs a t
(t, T.TVar (T.MkTVar b)) -> occurs b t
(T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
(TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(TAll _ t, b) -> unify t b
(a, TAll _ t) -> unify a t
(TLit a, TLit b) ->
if a == b
then return M.empty
else
throwError
. unwords
$ [ "Can not unify"
, "'" <> printTree (T.TLit a) <> "'"
, "'" <> printTree (TLit a) <> "'"
, "with"
, "'" <> printTree (T.TLit b) <> "'"
, "'" <> printTree (TLit b) <> "'"
]
(T.TData name t, T.TData name' t') ->
(TData name t, TData name' t') ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
@ -380,7 +349,7 @@ unify t0 t1 = do
else
throwError $
unwords
[ "T.Type constructor:"
[ "Type constructor:"
, printTree name
, "(" <> printTree t <> ")"
, "does not match with:"
@ -398,42 +367,42 @@ unify t0 t1 = do
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal
-}
occurs :: T.Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t)
occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (T.TVar $ T.MkTVar i)
, printTree (TVar $ T.MkTVar (coerce i))
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map T.Ident T.Type -> T.Type -> T.Type
generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
go :: [T.Ident] -> Type -> Type
go [] t = t
go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: T.Type -> Infer T.Type
inst :: Type -> Infer Type
inst = \case
T.TAll (T.MkTVar bound) t -> do
TAll (T.MkTVar bound) t -> do
fr <- fresh
let s = M.singleton bound fr
let s = M.singleton (coerce bound) fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
@ -455,41 +424,40 @@ class FreeVars t where
-- | Get all free variables from t
free :: t -> Set T.Ident
instance FreeVars T.Type where
free :: T.Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
instance FreeVars Type where
free :: Type -> Set T.Ident
free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (TLit _) = mempty
free (TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (T.TData _ a) =
free (TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
instance SubstType T.Type where
apply :: Subst -> T.Type -> T.Type
instance SubstType Type where
apply :: Subst -> Type -> Type
apply sub t = do
case t of
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where
free :: Map T.Ident T.Type -> Set T.Ident
TLit a -> TLit a
TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (T.MkTVar $ coerce a)
Just t -> t
TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (map (apply sub) a)
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
instance SubstType (Map T.Ident T.Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply s = M.map (apply s)
instance SubstType T.Exp where
apply :: Subst -> T.Exp -> T.Exp
instance SubstType (T.Exp' Type) where
apply s = \case
T.EId i -> T.EId i
T.EVar i -> T.EVar i
T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet
@ -499,19 +467,18 @@ instance SubstType T.Exp where
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj{} -> error "implement"
instance SubstType T.Branch where
apply :: Subst -> T.Branch -> T.Branch
instance SubstType (T.Branch' Type) where
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType T.Pattern where
apply :: Subst -> T.Pattern -> T.Pattern
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType a => SubstType [a] where
apply s = map (apply s)
@ -519,7 +486,7 @@ instance SubstType a => SubstType [a] where
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where
instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment.
@ -531,7 +498,7 @@ nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type
fresh :: Infer Type
fresh = do
c <- gets nextChar
n <- gets count
@ -545,34 +512,34 @@ fresh = do
fresh
else
if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n
then return . TVar . T.MkTVar $ LIdent [c]
else return . TVar . T.MkTVar . LIdent $ [c] ++ show n
next :: Char -> Char
next 'z' = 'a'
next a = succ a
next a = succ a
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe T.Type -> Infer ()
insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: T.Ident -> T.Type -> Infer ()
insertConstr :: T.Ident -> Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = throwError "Atleast one case required"
checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
@ -594,23 +561,23 @@ checkCase expT brnchs = do
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a
withPattern :: T.Pattern' Type -> Infer a -> Infer a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type)
inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt)
PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
@ -644,28 +611,28 @@ inferPattern = \case
++ show (typeLength t - 1)
++ " arguments but has been given 0"
)
let (T.TData _data _ts) = t -- nasty nasty
let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs)
return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do
fr <- fresh
let pvar = T.PVar (coerce x, fr)
return (pvar, fr)
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
typeLength :: T.Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b
typeLength _ = 1
typeLength :: Type -> Int
typeLength (TFun a b) = typeLength a + typeLength b
typeLength _ = 1
litType :: Lit -> T.Type
litType (LInt _) = int
litType :: Lit -> Type
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"
int = TLit "Int"
char = TLit "Char"
partitionType ::
Int -> -- Number of parameters to apply
@ -676,8 +643,8 @@ partitionType = go []
go acc 0 t = (acc, t)
go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp =
@ -691,3 +658,19 @@ unzip4 =
)
([], [], [], [])
newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
}
deriving (Show)
type Error = String
type Subst = Map T.Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))

View file

@ -1,24 +1,24 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr (
module Grammar.Abs,
module TypeChecker.TypeCheckerIr,
) where
module TypeChecker.TypeCheckerIr
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
import Data.String (IsString)
import Grammar.Abs (Lit (..), TVar (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
import Data.String (IsString)
import Grammar.Abs (Lit (..), TVar (..))
import Grammar.Print
import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def' t = DBind (Bind' t)
| DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def' t
= DBind (Bind' t)
| DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Type
= TLit Ident
@ -26,24 +26,24 @@ data Type
| TData Ident [Type]
| TFun Type Type
| TAll TVar Type
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
data Pattern' t
= PVar (Id' t) -- TODO should be Ident
| PLit (Lit, t) -- TODO should be Lit
= PVar (Id' t) -- TODO should be Ident
| PLit (Lit, t) -- TODO should be Lit
| PCatch
| PEnum Ident
| PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp' t
= EVar Ident
@ -52,18 +52,18 @@ data Exp' t
| ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t)
| EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id' t = (Ident, t)
type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Ident where
prt i (Ident s) = prt i s
@ -72,127 +72,143 @@ instance Print t => Print (Program' t) where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print t => Print (Bind' t) where
prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD
[ prtSig sig
, prt 0 name
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig
, prt 0 name
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
prtSig :: Print t => Id' t -> Doc
prtSig (name, t) = concatD [ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
instance Print t => Print (ExpT' t) where
prt i (e, t) = concatD [ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Print t => Int -> [Id' t] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print t => Print (Id' t) where
prt i (name, t) = concatD [ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EInj name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 b
, doc $ showString "in"
, prt 0 e
]
EApp e1 e2 -> prPrec i 2 $ concatD
[ prt 2 e1
, prt 3 e2
]
EAdd e1 e2 -> prPrec i 1 $ concatD
[ prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs v e -> prPrec i 0 $ concatD
[ doc $ showString "\\"
, prt 0 v
, doc $ showString "."
, prt 0 e
]
ECase e branches -> prPrec i 0 $ concatD
[ doc $ showString "case"
, prt 0 e
, doc $ showString "of"
, doc $ showString "{"
, prt 0 branches
, doc $ showString "}"
]
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EInj name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e ->
prPrec i 3 $
concatD
[ doc $ showString "let"
, prt 0 b
, doc $ showString "in"
, prt 0 e
]
EApp e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1
, prt 3 e2
]
EAdd e1 e2 ->
prPrec i 1 $
concatD
[ prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs v e ->
prPrec i 0 $
concatD
[ doc $ showString "\\"
, prt 0 v
, doc $ showString "."
, prt 0 e
]
ECase e branches ->
prPrec i 0 $
concatD
[ doc $ showString "case"
, prt 0 e
, doc $ showString "of"
, doc $ showString "{"
, prt 0 branches
, doc $ showString "}"
]
instance Print t => Print (Branch' t) where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print t => Print [Branch' t] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print t => Print (Def' t) where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print t => Print (Data' t) where
prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
instance Print t => Print (Inj' t) where
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
instance Print t => Print (Pattern' t) where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print t => Print [Def' t] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where
prt _ [] = concatD []
prt _ (x:xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where
prt i = \case
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TVar tvar -> prPrec i 1 (concatD [prt 0 tvar])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
prt i = \case
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TVar tvar -> prPrec i 1 (concatD [prt 0 tvar])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
type Program = Program' Type
type Def = Def' Type
@ -201,9 +217,8 @@ type Bind = Bind' Type
type Branch = Branch' Type
type Pattern = Pattern' Type
type Inj = Inj' Type
type Exp = Exp' Type
type Exp = Exp' Type
type ExpT = ExpT' Type
type Id = Id' Type
type Id = Id' Type
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)
pattern DData' typ injs = DData (Data typ injs)