Fixed errors in tc hm

This commit is contained in:
sebastianselander 2023-03-27 16:48:23 +02:00
parent 847ec37117
commit 6e54378327
4 changed files with 346 additions and 345 deletions

View file

@ -8,14 +8,13 @@ module Codegen.LlvmIr (
LLVMComp (..), LLVMComp (..),
Visibility (..), Visibility (..),
CallingConvention (..), CallingConvention (..),
ToIr(..) ToIr (..),
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import Grammar.Abs (Character)
import TypeChecker.TypeCheckerIr (Ident (..)) import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show)
instance ToIr CallingConvention where instance ToIr CallingConvention where
toIr :: CallingConvention -> String toIr :: CallingConvention -> String
toIr TailCC = "tailcc" toIr TailCC = "tailcc"
@ -34,7 +33,7 @@ data LLVMType
| Function LLVMType [LLVMType] | Function LLVMType [LLVMType]
| Array Integer LLVMType | Array Integer LLVMType
| CustomType Ident | CustomType Ident
deriving Show deriving (Show)
class ToIr a where class ToIr a where
toIr :: a -> String toIr :: a -> String
@ -63,7 +62,7 @@ data LLVMComp
| LLSge | LLSge
| LLSlt | LLSlt
| LLSle | LLSle
deriving Show deriving (Show)
instance ToIr LLVMComp where instance ToIr LLVMComp where
toIr :: LLVMComp -> String toIr :: LLVMComp -> String
toIr = \case toIr = \case
@ -78,21 +77,22 @@ instance ToIr LLVMComp where
LLSlt -> "slt" LLSlt -> "slt"
LLSle -> "sle" LLSle -> "sle"
data Visibility = Local | Global deriving Show data Visibility = Local | Global deriving (Show)
instance ToIr Visibility where instance ToIr Visibility where
toIr :: Visibility -> String toIr :: Visibility -> String
toIr Local = "%" toIr Local = "%"
toIr Global = "@" toIr Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant or a string contstant
-}
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VChar Character | VChar Char
| VIdent Ident LLVMType | VIdent Ident LLVMType
| VConstant String | VConstant String
| VFunction Ident Visibility LLVMType | VFunction Ident Visibility LLVMType
deriving Show deriving (Show)
instance ToIr LLVMValue where instance ToIr LLVMValue where
toIr :: LLVMValue -> String toIr :: LLVMValue -> String
@ -114,8 +114,8 @@ data LLVMIr
| Declare LLVMType Ident Params | Declare LLVMType Ident Params
| SetVariable Ident LLVMIr | SetVariable Ident LLVMIr
| Variable Ident | Variable Ident
-- extractvalue <aggregate type> <val>, <idx>{, <idx>}* | -- extractvalue <aggregate type> <val>, <idx>{, <idx>}*
| ExtractValue LLVMType LLVMValue Integer ExtractValue LLVMType LLVMValue Integer
| GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Add LLVMType LLVMValue LLVMValue | Add LLVMType LLVMValue LLVMValue
@ -136,7 +136,7 @@ data LLVMIr
| Comment String | Comment String
| UnsafeRaw String -- This should generally be avoided, and proper | UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place -- instructions should be used in its place
deriving Show deriving (Show)
-- | Converts a list of LLVMIr instructions to a string -- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String llvmIrToString :: [LLVMIr] -> String
@ -150,9 +150,10 @@ llvmIrToString = go 0
DefineEnd -> (i - 1, 0) DefineEnd -> (i - 1, 0)
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation -- \| Converts a LLVM inststruction to a String, allowing for printing etc.
-} -- The integer represents the indentation
--
{- FOURMOLU_DISABLE -} {- FOURMOLU_DISABLE -}
insToString :: Int -> LLVMIr -> String insToString :: Int -> LLVMIr -> String
insToString i l = insToString i l =
@ -261,4 +262,3 @@ llvmIrToString = go 0
lblPfx :: String lblPfx :: String
lblPfx = "lbl_" lblPfx = "lbl_"

View file

@ -1,7 +1,6 @@
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module GA) where module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
import qualified Grammar.Abs as GA (Ident (..)) import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
import qualified TypeChecker.TypeCheckerIr as TIR (Ident (..))
type Id = (TIR.Ident, Type) type Id = (TIR.Ident, Type)
@ -26,8 +25,12 @@ data Exp
| ECase ExpT [Branch] | ECase ExpT [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern = PVar Id | PLit (Lit, Type) | PInj TIR.Ident [Pattern] data Pattern
| PCatch | PEnum TIR.Ident = PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
| PCatch
| PEnum TIR.Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT data Branch = Branch (Pattern, Type) ExpT

View file

@ -6,7 +6,7 @@ module TypeChecker.TypeCheckerHm where
import Auxiliary import Auxiliary
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second) import Data.Bifunctor (second)
@ -16,16 +16,14 @@ import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace) import Data.String
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr qualified as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Subst)
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty
@ -39,7 +37,7 @@ run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg typecheck = run . checkPrg
checkData :: Data -> Infer () checkData :: Data -> Infer ()
@ -50,9 +48,9 @@ checkData d = do
(all isPoly ts) (all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"]) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ traverse_
( \(Constructor name' t') -> ( \(Inj name' t') ->
if typ == retType t' if typ == retType t'
then insertConstr (coerce name') (toNew t') then insertConstr (coerce name') (t')
else else
throwError $ throwError $
unwords unwords
@ -75,7 +73,7 @@ retType :: Type -> Type
retType (TFun _ t2) = retType t2 retType (TFun _ t2) = retType t2
retType a = a retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs' <- checkDef bs bs' <- checkDef bs
@ -94,7 +92,7 @@ preRun (x : xs) = case x of
<> printTree n <> printTree n
<> "'" <> "'"
) )
insertSig (coerce n) (Just $ toNew t) >> preRun xs insertSig (coerce n) (Just $ t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
collect (collectTypeVars e) collect (collectTypeVars e)
s <- gets sigs s <- gets sigs
@ -103,16 +101,18 @@ preRun (x : xs) = case x of
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return [] checkDef [] = return []
checkDef (x : xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) (DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs)
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
where
coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda e@(_, args_t) <- inferExp lambda
@ -133,33 +133,33 @@ checkBind (Bind name args e) = do
insertSig (coerce name) (Just args_t) insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e) return (T.Bind (coerce name, args_t) [] e)
typeEq :: T.Type -> T.Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r' typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (T.TLit a) (T.TLit b) = a == b typeEq (TLit a) (TLit b) = a == b
typeEq (T.TData name a) (T.TData name' b) = typeEq (TData name a) (TData name' b) =
length a == length b length a == length b
&& name == name' && name == name'
&& and (zipWith typeEq a b) && and (zipWith typeEq a b)
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2 typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2 typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True typeEq (TVar _) (TVar _) = True
typeEq _ _ = False typeEq _ _ = False
skolem :: T.Type -> T.Type skolem :: Type -> Type
skolem (T.TVar (T.MkTVar a)) = T.TLit a skolem (TVar (T.MkTVar a)) = TLit (coerce a)
skolem (T.TAll x t) = T.TAll x (skolem t) skolem (TAll x t) = TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2
skolem t = t skolem t = t
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2) && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
@ -167,7 +167,7 @@ isPoly (TAll _ _) = True
isPoly (TVar _) = True isPoly (TVar _) = True
isPoly _ = False isPoly _ = False
inferExp :: Exp -> Infer T.ExpT inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
@ -190,43 +190,12 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
unless unless
(toNew t `isMoreSpecificOrEq` t') (t `isMoreSpecificOrEq` t')
( throwError $ ( throwError $
unwords unwords
[ "Annotated type:" [ "Annotated type:"
@ -236,34 +205,34 @@ algoW = \case
] ]
) )
applySt s1 $ do applySt s1 $ do
s2 <- exprErr (unify (toNew t) t') err s2 <- exprErr (unify (t) t') err
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t)) return (comp, apply comp (e', t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit)) ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
EVar i -> do EVar i -> do
var <- asks vars var <- asks vars
case M.lookup (coerce i) var of case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x)) Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do Just Nothing -> do
fr <- fresh fr <- fresh
insertSig (coerce i) (Just fr) insertSig (coerce i) (Just fr)
return (nullSubst, (T.EId $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do EInj i -> do
constr <- gets constructors constr <- gets constructors
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t)) Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing -> Nothing ->
throwError $ throwError $
"Constructor: '" "Constructor: '"
@ -280,7 +249,7 @@ algoW = \case
( withBinding (coerce name) fr $ do ( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr let varType = apply s1 fr
let newArr = T.TFun varType t' let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
) )
err err
@ -314,7 +283,7 @@ algoW = \case
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do applySt s0 $ do
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0 let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
@ -346,33 +315,33 @@ makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = do unify t0 t1 = do
case (t0, t1) of case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do (TFun a b, TFun c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
------------------------------------------------------------------- -------------------------------------------------------------------
(T.TVar (T.MkTVar a), t) -> occurs a t (TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, T.TVar (T.MkTVar b)) -> occurs b t (t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(T.TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t (a, TAll _ t) -> unify a t
(T.TLit a, T.TLit b) -> (TLit a, TLit b) ->
if a == b if a == b
then return M.empty then return M.empty
else else
throwError throwError
. unwords . unwords
$ [ "Can not unify" $ [ "Can not unify"
, "'" <> printTree (T.TLit a) <> "'" , "'" <> printTree (TLit a) <> "'"
, "with" , "with"
, "'" <> printTree (T.TLit b) <> "'" , "'" <> printTree (TLit b) <> "'"
] ]
(T.TData name t, T.TData name' t') -> (TData name t, TData name' t') ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
@ -380,7 +349,7 @@ unify t0 t1 = do
else else
throwError $ throwError $
unwords unwords
[ "T.Type constructor:" [ "Type constructor:"
, printTree name , printTree name
, "(" <> printTree t <> ")" , "(" <> printTree t <> ")"
, "does not match with:" , "does not match with:"
@ -398,42 +367,42 @@ unify t0 t1 = do
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal where these are equal
-} -}
occurs :: T.Ident -> T.Type -> Infer Subst occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then then
throwError $ throwError $
unwords unwords
[ "Occurs check failed, can't unify" [ "Occurs check failed, can't unify"
, printTree (T.TVar $ T.MkTVar i) , printTree (TVar $ T.MkTVar (coerce i))
, "with" , "with"
, printTree t , printTree t
] ]
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
generalize :: Map T.Ident T.Type -> T.Type -> T.Type generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> T.Type -> T.Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t)
removeForalls :: T.Type -> T.Type removeForalls :: Type -> Type
removeForalls (T.TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
-} -}
inst :: T.Type -> Infer T.Type inst :: Type -> Infer Type
inst = \case inst = \case
T.TAll (T.MkTVar bound) t -> do TAll (T.MkTVar bound) t -> do
fr <- fresh fr <- fresh
let s = M.singleton bound fr let s = M.singleton (coerce bound) fr
apply s <$> inst t apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
-- | Compose two substitution sets -- | Compose two substitution sets
@ -455,41 +424,40 @@ class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
free :: t -> Set T.Ident free :: t -> Set T.Ident
instance FreeVars T.Type where instance FreeVars Type where
free :: T.Type -> Set T.Ident free :: Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (T.TLit _) = mempty free (TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (T.TData _ a) = free (TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a foldl' (\acc x -> free x `S.union` acc) S.empty a
instance SubstType T.Type where instance SubstType Type where
apply :: Subst -> T.Type -> T.Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
case t of case t of
T.TLit a -> T.TLit a TLit a -> TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> T.TVar (T.MkTVar $ coerce a) Nothing -> TVar (T.MkTVar $ coerce a)
Just t -> t Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t) Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a) TData name a -> TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident T.Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free m = foldl' S.union S.empty (map free $ M.elems m)
instance SubstType (Map T.Ident T.Type) where instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply s = M.map (apply s) apply s = M.map (apply s)
instance SubstType T.Exp where instance SubstType (T.Exp' Type) where
apply :: Subst -> T.Exp -> T.Exp
apply s = \case apply s = \case
T.EId i -> T.EId i T.EVar i -> T.EVar i
T.ELit lit -> T.ELit lit T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 -> T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet T.ELet
@ -499,13 +467,12 @@ instance SubstType T.Exp where
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e) T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj{} -> error "implement"
instance SubstType T.Branch where instance SubstType (T.Branch' Type) where
apply :: Subst -> T.Branch -> T.Branch
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType T.Pattern where instance SubstType (T.Pattern' Type) where
apply :: Subst -> T.Pattern -> T.Pattern
apply s = \case apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t) T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t) T.PLit (lit, t) -> T.PLit (lit, apply s t)
@ -519,7 +486,7 @@ instance SubstType a => SubstType [a] where
instance (SubstType a, SubstType b) => SubstType (a, b) where instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b) apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment. -- | Apply substitutions to the environment.
@ -531,7 +498,7 @@ nullSubst :: Subst
nullSubst = M.empty nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type fresh :: Infer Type
fresh = do fresh = do
c <- gets nextChar c <- gets nextChar
n <- gets count n <- gets count
@ -545,34 +512,34 @@ fresh = do
fresh fresh
else else
if n == 0 if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c] then return . TVar . T.MkTVar $ LIdent [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n else return . TVar . T.MkTVar . LIdent $ [c] ++ show n
next :: Char -> Char next :: Char -> Char
next 'z' = 'a' next 'z' = 'a'
next a = succ a next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings -- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs = withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe T.Type -> Infer () insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: T.Ident -> T.Type -> Infer () insertConstr :: T.Ident -> Type -> Infer ()
insertConstr i t = insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)}) modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = throwError "Atleast one case required" checkCase _ [] = throwError "Atleast one case required"
checkCase expT brnchs = do checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
@ -594,13 +561,13 @@ checkCase expT brnchs = do
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type) return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type) inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a withPattern :: T.Pattern' Type -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
@ -608,9 +575,9 @@ withPattern p ma = case p of
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type) inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt) PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors) t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
@ -644,28 +611,28 @@ inferPattern = \case
++ show (typeLength t - 1) ++ show (typeLength t - 1)
++ " arguments but has been given 0" ++ " arguments but has been given 0"
) )
let (T.TData _data _ts) = t -- nasty nasty let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs) return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do PVar x -> do
fr <- fresh fr <- fresh
let pvar = T.PVar (coerce x, fr) let pvar = T.PVar (coerce x, fr)
return (pvar, fr) return (pvar, fr)
flattenType :: T.Type -> [T.Type] flattenType :: Type -> [Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: T.Type -> Int typeLength :: Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b typeLength (TFun a b) = typeLength a + typeLength b
typeLength _ = 1 typeLength _ = 1
litType :: Lit -> T.Type litType :: Lit -> Type
litType (LInt _) = int litType (LInt _) = int
litType (LChar _) = char litType (LChar _) = char
int = T.TLit "Int" int = TLit "Int"
char = T.TLit "Char" char = TLit "Char"
partitionType :: partitionType ::
Int -> -- Number of parameters to apply Int -> -- Number of parameters to apply
@ -691,3 +658,19 @@ unzip4 =
) )
([], [], [], []) ([], [], [], [])
newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
}
deriving (Show)
type Error = String
type Subst = Map T.Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))

View file

@ -1,22 +1,22 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr (
module TypeChecker.TypeCheckerIr module Grammar.Abs,
( module Grammar.Abs module TypeChecker.TypeCheckerIr,
, module TypeChecker.TypeCheckerIr
) where ) where
import Data.String (IsString) import Data.String (IsString)
import Grammar.Abs (Lit (..), TVar (..)) import Grammar.Abs (Lit (..), TVar (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import Prelude qualified as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def' t = DBind (Bind' t) data Def' t
= DBind (Bind' t)
| DData (Data' t) | DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
@ -72,7 +72,9 @@ instance Print t => Print (Program' t) where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print t => Print (Bind' t) where instance Print t => Print (Bind' t) where
prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig [ prtSig sig
, prt 0 name , prt 0 name
, prtIdPs 0 parms , prtIdPs 0 parms
@ -81,14 +83,18 @@ instance Print t => Print (Bind' t) where
] ]
prtSig :: Print t => Id' t -> Doc prtSig :: Print t => Id' t -> Doc
prtSig (name, t) = concatD [ prt 0 name prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
, doc $ showString ";" , doc $ showString ";"
] ]
instance Print t => Print (ExpT' t) where instance Print t => Print (ExpT' t) where
prt i (e, t) = concatD [ doc $ showString "(" prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e , prt i e
, doc $ showString "," , doc $ showString ","
, prt i t , prt i t
@ -104,7 +110,9 @@ prtIdPs :: Print t => Int -> [Id' t] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i) prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print t => Print (Id' t) where instance Print t => Print (Id' t) where
prt i (name, t) = concatD [ doc $ showString "(" prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name , prt i name
, doc $ showString "," , doc $ showString ","
, prt i t , prt i t
@ -116,29 +124,38 @@ instance Print t => Print (Exp' t) where
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EInj name -> prPrec i 3 $ prt 0 name EInj name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> prPrec i 3 $ concatD ELet b e ->
prPrec i 3 $
concatD
[ doc $ showString "let" [ doc $ showString "let"
, prt 0 b , prt 0 b
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
] ]
EApp e1 e2 -> prPrec i 2 $ concatD EApp e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1 [ prt 2 e1
, prt 3 e2 , prt 3 e2
] ]
EAdd e1 e2 -> prPrec i 1 $ concatD EAdd e1 e2 ->
prPrec i 1 $
concatD
[ prt 1 e1 [ prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
] ]
EAbs v e -> prPrec i 0 $ concatD EAbs v e ->
prPrec i 0 $
concatD
[ doc $ showString "\\" [ doc $ showString "\\"
, prt 0 v , prt 0 v
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
] ]
ECase e branches ->
ECase e branches -> prPrec i 0 $ concatD prPrec i 0 $
concatD
[ doc $ showString "case" [ doc $ showString "case"
, prt 0 e , prt 0 e
, doc $ showString "of" , doc $ showString "of"
@ -147,7 +164,6 @@ instance Print t => Print (Exp' t) where
, doc $ showString "}" , doc $ showString "}"
] ]
instance Print t => Print (Branch' t) where instance Print t => Print (Branch' t) where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
@ -206,4 +222,3 @@ type ExpT = ExpT' Type
type Id = Id' Type type Id = Id' Type
pattern DBind' id vars expt = DBind (Bind id vars expt) pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs) pattern DData' typ injs = DData (Data typ injs)