Remade the algorithm myself. Still some bugs.

This commit is contained in:
sebastianselander 2023-02-18 23:08:27 +01:00
parent f188cffb8d
commit 8b5cd3cf9a
12 changed files with 584 additions and 257 deletions

155
src/TypeChecker/HM.hs Normal file
View file

@ -0,0 +1,155 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
module TypeChecker.HM (typecheck) where
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.Abs
import Grammar.Print
import qualified TypeChecker.HMIr as T
type Infer = StateT Ctx (ExceptT String Identity)
type Error = String
data Ctx = Ctx { constr :: Map Type Type
, vars :: Map Ident Type
, sigs :: Map Ident Type
, frsh :: Char }
deriving Show
run :: Infer a -> Either String a
run = runIdentity . runExceptT . flip evalStateT initC
int = TMono "Int"
initC :: Ctx
initC = Ctx M.empty M.empty M.empty 'a'
typecheck :: Program -> Either Error T.Program
typecheck = run . inferPrg
inferPrg :: Program -> Infer T.Program
inferPrg (Program bs) = do
traverse (\(Bind n t _ _ _) -> insertSig n t) bs
bs' <- mapM inferBind bs
return $ T.Program bs'
inferBind :: Bind -> Infer T.Bind
inferBind (Bind i t _ params rhs) = do
(t',e') <- inferExp (makeLambda (reverse params) rhs)
addConstraint t t'
-- when (t /= t') (throwError $ "Signature of function" ++ printTree i ++ "does not match inferred type of expression: " ++ printTree e')
return $ T.Bind (t,i) [] e'
makeLambda :: [Ident] -> Exp -> Exp
makeLambda xs e = foldl (flip EAbs) e xs
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp = \case
EAnn e t -> do
(t',e') <- inferExp e
when (t' /= t) (throwError "Annotated type and inferred type don't match")
return (t', e')
EInt i -> return (int, T.EInt int i)
EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i
EAdd e1 e2 -> do
(t1, e1') <- inferExp e1
(t2, e2') <- inferExp e2
unless (isInt t1 && isInt t2) (throwError "Can not add non-ints")
return (int,T.EAdd int e1' e2')
EApp e1 e2 -> do
(t1, e1') <- inferExp e1
(t2, e2') <- inferExp e2
fr <- fresh
addConstraint t1 (TArr t2 fr)
return (fr, T.EApp fr e1' e2')
EAbs name e -> do
fr <- fresh
insertVar name fr
(ret_t,e') <- inferExp e
t <- solveConstraints (TArr fr ret_t)
return (t, T.EAbs t name e')
ELet name e1 e2 -> do
fr <- fresh
insertVar name fr
(t1, e1') <- inferExp e1
(t2, e2') <- inferExp e2
ret_t <- solveConstraints t1
return (ret_t, T.ELet ret_t name e1' e2')
isInt :: Type -> Bool
isInt (TMono "Int") = True
isInt _ = False
lookupVar :: Ident -> Infer Type
lookupVar i = do
st <- get
case M.lookup i (vars st) of
Just t -> return t
Nothing -> case M.lookup i (sigs st) of
Just t -> return t
Nothing -> throwError $ "Unbound variable or function" ++ printTree i
insertVar :: Ident -> Type -> Infer ()
insertVar s t = modify ( \st -> st { vars = M.insert s t (vars st) } )
insertSig :: Ident -> Type -> Infer ()
insertSig s t = modify ( \st -> st { sigs = M.insert s t (sigs st) } )
fresh :: Infer Type
fresh = do
chr <- gets frsh
modify (\st -> st { frsh = succ chr })
return $ TPol (Ident [chr])
addConstraint :: Type -> Type -> Infer ()
addConstraint t1 t2 = do
when (t2 `contains` t1) (throwError $ "Can't match type " ++ printTree t1 ++ " with " ++ printTree t2)
modify (\st -> st { constr = M.insert t1 t2 (constr st) })
contains :: Type -> Type -> Bool
contains (TArr t1 t2) b = t1 `contains` b || t2 `contains` b
contains (TMono a) (TMono b) = False
contains a b = a == b
solveConstraints :: Type -> Infer Type
solveConstraints t = do
c <- gets constr
v <- gets vars
subst t <$> solveAll (M.toList c)
subst :: Type -> [(Type, Type)] -> Type
subst t [] = t
subst (TArr t1 t2) (x:xs) = subst (TArr (replace x t1) (replace x t2)) xs
subst t (x:xs) = subst (replace x t) xs
solveAll :: [(Type, Type)] -> Infer [(Type, Type)]
solveAll [] = return []
solveAll (x:xs) = case x of
(TArr t1 t2, TArr t3 t4) -> solveAll $ (t1,t3) : (t2,t4) : xs
(TArr t1 t2, b) -> fmap ((b, TArr t1 t2) :) $ solveAll $ solve (b, TArr t1 t2) xs
(a, TArr t1 t2) -> fmap ((a, TArr t1 t2) :) $ solveAll $ solve (a, TArr t1 t2) xs
(TMono a, TPol b) -> fmap ((TPol b, TMono a) :) $ solveAll $ solve (TPol b, TMono a) xs
(TPol a, TMono b) -> fmap ((TPol a, TMono a) :) $ solveAll $ solve (TPol a, TMono b) xs
(TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types"
(TPol a, TPol b) -> fmap ((TPol a, TPol b) :) $ solveAll $ solve (TPol a, TPol b) xs
solve :: (Type, Type) -> [(Type, Type)] -> [(Type, Type)]
solve x = map (second (replace x))
replace :: (Type, Type) -> Type -> Type
replace a (TArr t1 t2) = TArr (replace a t1) (replace a t2)
replace (a,b) c = if a==c then b else c
-- Known bugs
-- (x : a) + 3 type checks

102
src/TypeChecker/HMIr.hs Normal file
View file

@ -0,0 +1,102 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.HMIr
( module Grammar.Abs
, module TypeChecker.HMIr
) where
import Grammar.Abs (Ident (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Type Ident
| EInt Type Integer
| ELet Type Ident Exp Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Ident Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id = (Type, Ident)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name
, doc $ showString ";"
, prt 0 n
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId _ n -> prPrec i 3 $ concatD [prt 0 n]
EInt _ i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet _ name e1 e2 -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 name
, prt 0 e1
, doc $ showString "in"
, prt 0 e2
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prt 0 n
, doc $ showString "."
, prt 0 e
]

View file

@ -1,153 +1,153 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
-- {-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE OverloadedRecordDot #-}
-- {-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker where
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
-- import Control.Monad (void)
-- import Control.Monad.Except (ExceptT, runExceptT, throwError)
-- import Control.Monad.State (StateT)
-- import qualified Control.Monad.State as St
-- import Data.Functor.Identity (Identity, runIdentity)
-- import Data.Map (Map)
-- import qualified Data.Map as M
import TypeChecker.TypeCheckerIr
-- import TypeChecker.TypeCheckerIr
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Int
}
deriving (Show)
-- data Ctx = Ctx
-- { vars :: Map Integer Type
-- , sigs :: Map Ident Type
-- , nextFresh :: Int
-- }
-- deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
-- -- Perhaps swap over to reader monad instead for vars and sigs.
-- type Infer = StateT Ctx (ExceptT Error Identity)
{-
-- {-
The type checker will assume we first rename all variables to unique name, as to not
have to care about scoping. It significantly improves the quality of life of the
programmer.
-- The type checker will assume we first rename all variables to unique name, as to not
-- have to care about scoping. It significantly improves the quality of life of the
-- programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
-- TODOs:
-- Add skolemization variables. i.e
-- { \x. 3 : forall a. a -> a }
-- should not type check
Generalize. Not really sure what that means though
-- Generalize. Not really sure what that means though
-}
-- -}
typecheck :: RProgram -> Either Error TProgram
typecheck = todo
-- typecheck :: RProgram -> Either Error TProgram
-- typecheck = todo
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
-- run :: Infer a -> Either Error a
-- run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
-- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- { \x. \y. x + y } will have the type { a -> b -> Int }
inferExp :: RExp -> Infer Type
inferExp = \case
-- -- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- -- { \x. \y. x + y } will have the type { a -> b -> Int }
-- inferExp :: RExp -> Infer Type
-- inferExp = \case
RAnn expr typ -> do
t <- inferExp expr
void $ t =:= typ
return t
-- RAnn expr typ -> do
-- t <- inferExp expr
-- void $ t =:= typ
-- return t
RBound num name -> lookupVars num
-- RBound num name -> lookupVars num
RFree name -> lookupSigs name
-- RFree name -> lookupSigs name
RConst (CInt i) -> return $ TMono "Int"
-- RConst (CInt i) -> return $ TMono "Int"
RConst (CStr str) -> return $ TMono "Str"
-- RConst (CStr str) -> return $ TMono "Str"
RAdd expr1 expr2 -> do
let int = TMono "Int"
typ1 <- check expr1 int
typ2 <- check expr2 int
return int
-- RAdd expr1 expr2 -> do
-- let int = TMono "Int"
-- typ1 <- check expr1 int
-- typ2 <- check expr2 int
-- return int
RApp expr1 expr2 -> do
fn_t <- inferExp expr1
arg_t <- inferExp expr2
res <- fresh
new_t <- fn_t =:= TArrow arg_t res
return res
-- RApp expr1 expr2 -> do
-- fn_t <- inferExp expr1
-- arg_t <- inferExp expr2
-- res <- fresh
-- new_t <- fn_t =:= TArrow arg_t res
-- return res
RAbs num name expr -> do
arg <- fresh
insertVars num arg
typ <- inferExp expr
return $ TArrow arg typ
-- RAbs num name expr -> do
-- arg <- fresh
-- insertVars num arg
-- typ <- inferExp expr
-- return $ TArrow arg typ
check :: RExp -> Type -> Infer ()
check e t = do
t' <- inferExp e
t =:= t'
return ()
-- check :: RExp -> Type -> Infer ()
-- check e t = do
-- t' <- inferExp e
-- t =:= t'
-- return ()
fresh :: Infer Type
fresh = do
var <- St.gets nextFresh
St.modify (\st -> st {nextFresh = succ var})
return (TPoly $ Ident (show var))
-- fresh :: Infer Type
-- fresh = do
-- var <- St.gets nextFresh
-- St.modify (\st -> st {nextFresh = succ var})
-- return (TPoly $ Ident (show var))
-- | Unify two types.
(=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
(=:=) a (TPoly _) = return a
(=:=) (TMono a) (TMono b) | a == b = return (TMono a)
(=:=) (TArrow a b) (TArrow c d) = do
t1 <- a =:= c
t2 <- b =:= d
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- -- | Unify two types.
-- (=:=) :: Type -> Type -> Infer Type
-- (=:=) (TPoly _) b = return b
-- (=:=) a (TPoly _) = return a
-- (=:=) (TMono a) (TMono b) | a == b = return (TMono a)
-- (=:=) (TArrow a b) (TArrow c d) = do
-- t1 <- a =:= c
-- t2 <- b =:= d
-- return $ TArrow t1 t2
-- (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
-- lookupVars :: Integer -> Infer Type
-- lookupVars i = do
-- st <- St.gets vars
-- case M.lookup i st of
-- Just t -> return t
-- Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put (st {vars = M.insert i t st.vars})
-- insertVars :: Integer -> Type -> Infer ()
-- insertVars i t = do
-- st <- St.get
-- St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
-- lookupSigs :: Ident -> Infer Type
-- lookupSigs i = do
-- st <- St.gets sigs
-- case M.lookup i st of
-- Just t -> return t
-- Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put (st {sigs = M.insert i t st.sigs})
-- insertSigs :: Ident -> Type -> Infer ()
-- insertSigs i t = do
-- st <- St.get
-- St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
todo = error "TODO in code"
-- {-# WARNING todo "TODO IN CODE" #-}
-- todo :: a
-- todo = error "TODO in code"
data Error
= TypeMismatch String
| NotNumber String
| FunctionTypeMismatch String
| NotFunction String
| UnboundVar String
| AnnotatedMismatch String
| Default String
deriving (Show)
-- data Error
-- = TypeMismatch String
-- | NotNumber String
-- | FunctionTypeMismatch String
-- | NotFunction String
-- | UnboundVar String
-- | AnnotatedMismatch String
-- | Default String
-- deriving (Show)
{-
-- {-
The procedure inst(σ) specializes the polytype
σ by copying the term and replacing the bound type variables
consistently by new monotype variables.
-- The procedure inst(σ) specializes the polytype
-- σ by copying the term and replacing the bound type variables
-- consistently by new monotype variables.
-}
-- -}

View file

@ -1,74 +1,74 @@
{-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr (
TProgram (..),
TBind (..),
TExp (..),
RProgram (..),
RBind (..),
RExp (..),
Type (..),
Const (..),
Ident (..),
) where
module TypeChecker.TypeCheckerIr --(
-- TProgram (..),
-- TBind (..),
-- TExp (..),
-- RProgram (..),
-- RBind (..),
-- RExp (..),
-- Type (..),
-- Const (..),
-- Ident (..),
-- ) where
import Grammar.Print
import Renamer.RenamerIr
-- import Grammar.Print
-- import Renamer.RenamerIr
newtype TProgram = TProgram [TBind]
deriving (Eq, Show, Read, Ord)
-- newtype TProgram = TProgram [TBind]
-- deriving (Eq, Show, Read, Ord)
data TBind = TBind Ident Type TExp
deriving (Eq, Show, Read, Ord)
-- data TBind = TBind Ident Type TExp
-- deriving (Eq, Show, Read, Ord)
data TExp
= TAnn TExp Type
| TBound Integer Ident Type
| TFree Ident Type
| TConst Const Type
| TApp TExp TExp Type
| TAdd TExp TExp Type
| TAbs Integer Ident TExp Type
deriving (Eq, Ord, Show, Read)
-- data TExp
-- = TAnn TExp Type
-- | TBound Integer Ident Type
-- | TFree Ident Type
-- | TConst Const Type
-- | TApp TExp TExp Type
-- | TAdd TExp TExp Type
-- | TAbs Integer Ident TExp Type
-- deriving (Eq, Ord, Show, Read)
instance Print TProgram where
prt i = \case
TProgram defs -> prPrec i 0 (concatD [prt 0 defs])
-- instance Print TProgram where
-- prt i = \case
-- TProgram defs -> prPrec i 0 (concatD [prt 0 defs])
instance Print TBind where
prt i = \case
TBind x t e ->
prPrec i 0 $
concatD
[ prt 0 x
, doc (showString ":")
, prt 0 t
, doc (showString "=")
, prt 0 e
, doc (showString "\n")
]
-- instance Print TBind where
-- prt i = \case
-- TBind x t e ->
-- prPrec i 0 $
-- concatD
-- [ prt 0 x
-- , doc (showString ":")
-- , prt 0 t
-- , doc (showString "=")
-- , prt 0 e
-- , doc (showString "\n")
-- ]
instance Print TExp where
prt i = \case
TAnn e t ->
prPrec i 2 $
concatD
[ prt 0 e
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
TFree u t -> prPrec i 3 $ concatD [prt 0 u]
TConst c _ -> prPrec i 3 (concatD [prt 0 c])
TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
TAbs _ u e t ->
prPrec i 0 $
concatD
[ doc (showString "(")
, doc (showString "λ")
, prt 0 u
, doc (showString ".")
, prt 0 e
, doc (showString ")")
]
-- instance Print TExp where
-- prt i = \case
-- TAnn e t ->
-- prPrec i 2 $
-- concatD
-- [ prt 0 e
-- , doc (showString ":")
-- , prt 1 t
-- ]
-- TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
-- TFree u t -> prPrec i 3 $ concatD [prt 0 u]
-- TConst c _ -> prPrec i 3 (concatD [prt 0 c])
-- TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
-- TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
-- TAbs _ u e t ->
-- prPrec i 0 $
-- concatD
-- [ doc (showString "(")
-- , doc (showString "λ")
-- , prt 0 u
-- , doc (showString ".")
-- , prt 0 e
-- , doc (showString ")")
-- ]