Unification part works (probably). Have a hard time understanding it.

This commit is contained in:
sebastianselander 2023-02-17 18:42:50 +01:00
parent 764faa582b
commit f188cffb8d
7 changed files with 167 additions and 197 deletions

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings
ghc-options: -W
ghc-options: -Wdefault
executable language
import: warnings

View file

@ -2,14 +2,14 @@
module Main where
import Grammar.Par (myLexer, pProgram)
import Grammar.Par (myLexer, pProgram)
-- import TypeChecker.TypeChecker (typecheck)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import TypeChecker.TypeChecker (typecheck)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main =
@ -46,5 +46,5 @@ main =
putStrLn ""
putStrLn " ----- TYPECHECKER ----- "
putStrLn ""
putStrLn . printTree $ prg
print prg
exitSuccess

View file

@ -1,12 +1,13 @@
{-# LANGUAGE LambdaCase #-}
module Renamer.RenamerIr ( RExp (..)
, RBind (..)
, RProgram (..)
, Const (..)
, Ident (..)
, Type (..)
) where
module Renamer.RenamerIr (
RExp (..),
RBind (..),
RProgram (..),
Const (..),
Ident (..),
Type (..),
) where
import Grammar.Abs (
Bind (..),
@ -51,9 +52,9 @@ instance Print RBind where
instance Print RExp where
prt i = \case
RAnn e t -> prPrec i 2 (concatD [prt 0 e, doc (showString ":"), prt 1 t])
RBound n _ -> prPrec i 3 (concatD [prt 0 ("var" ++ show n)])
RBound n _ -> prPrec i 3 (concatD [prt 0 n])
RFree id -> prPrec i 3 (concatD [prt 0 id])
RConst n -> prPrec i 3 (concatD [prt 0 n])
RApp e e1 -> prPrec i 2 (concatD [prt 2 e, prt 3 e1])
RAdd e e1 -> prPrec i 1 (concatD [prt 1 e, doc (showString "+"), prt 2 e1])
RAbs u id e -> prPrec i 0 (concatD [doc (showString "λ"), prt 0 ("var" ++ show u), doc (showString "."), prt 0 e])
RAbs u _ e -> prPrec i 0 (concatD [doc (showString "λ"), prt 0 u, doc (showString "."), prt 0 e])

View file

@ -1,21 +1,22 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}
module TypeChecker.TypeChecker where
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import Data.Map qualified as M
import TypeChecker.TypeCheckerIr
import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT)
import qualified Control.Monad.State as St
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import TypeChecker.TypeCheckerIr
data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident (RBind, Maybe Type)
{ vars :: Map Integer Type
, sigs :: Map Ident Type
, nextFresh :: Int
}
deriving (Show)
@ -38,70 +39,54 @@ TODOs:
-}
typecheck :: RProgram -> Either Error TProgram
typecheck = todo
run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do
xs' <- mapM inferBind xs
return $ TProgram xs'
-- Binds are not correctly added to the context.
-- Can't type check programs with more than one function currently
inferBind :: RBind -> Infer TBind
inferBind b@(RBind name e) = do
insertSigs name b Nothing
(t, e') <- inferExp e
return $ TBind name t e'
-- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- { \x. \y. x + y } will have the type { a -> b -> Int }
inferExp :: RExp -> Infer (Type, TExp)
inferExp :: RExp -> Infer Type
inferExp = \case
RAnn expr typ -> do
(t, expr') <- inferExp expr
t <- inferExp expr
void $ t =:= typ
return (typ, expr')
RBound num name -> do
t <- lookupVars num
return (t, TBound num name t)
RFree name -> do
(b@(RBind name _), t) <- lookupSigs name
t' <- case t of
Nothing -> do
(TBind _ a _) <- inferBind b
insertSigs name b (Just a)
return a
Just a -> return a
return (t', TFree name t')
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
return t
RBound num name -> lookupVars num
RFree name -> lookupSigs name
RConst (CInt i) -> return $ TMono "Int"
RConst (CStr str) -> return $ TMono "Str"
RAdd expr1 expr2 -> do
(typ1, expr1') <- check expr1 (TMono "Int")
(_, expr2') <- check expr2 (TMono "Int")
return (typ1, TAdd expr1' expr2' typ1)
let int = TMono "Int"
typ1 <- check expr1 int
typ2 <- check expr2 int
return int
RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1
(arg_t, expr2') <- inferExp expr2
fn_t <- inferExp expr1
arg_t <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
return res
RAbs num name expr -> do
arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr
return (TArrow arg typ, TAbs num name expr' typ)
typ <- inferExp expr
return $ TArrow arg typ
check :: RExp -> Type -> Infer (Type, TExp)
check :: RExp -> Type -> Infer ()
check e t = do
(t', e') <- inferExp e
t'' <- t' =:= t
return (t'', e')
t' <- inferExp e
t =:= t'
return ()
fresh :: Infer Type
fresh = do
@ -120,30 +105,29 @@ fresh = do
return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type
lookupVars i = do
st <- St.gets vars
case M.lookup i st of
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer ()
insertVars i t = do
st <- St.get
St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer (RBind, Maybe Type)
lookupSigs :: Ident -> Infer Type
lookupSigs i = do
st <- St.gets sigs
case M.lookup i st of
Just t -> return t
Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> RBind -> Maybe Type -> Infer ()
insertSigs i b t = do
insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do
st <- St.get
St.put (st {sigs = M.insert i (b, t) st.sigs})
St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-}
todo :: a
@ -158,3 +142,12 @@ data Error
| AnnotatedMismatch String
| Default String
deriving (Show)
{-
The procedure inst(σ) specializes the polytype
σ by copying the term and replacing the bound type variables
consistently by new monotype variables.
-}

View file

@ -1,21 +1,21 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr ( TProgram(..)
, TBind(..)
, TExp(..)
, RProgram(..)
, RBind(..)
, RExp(..)
, Type(..)
, Const(..)
, Ident(..)
)
where
module TypeChecker.TypeCheckerIr (
TProgram (..),
TBind (..),
TExp (..),
RProgram (..),
RBind (..),
RExp (..),
Type (..),
Const (..),
Ident (..),
) where
import Renamer.RenamerIr
import Grammar.Print
import Grammar.Print
import Renamer.RenamerIr
data TProgram = TProgram [TBind]
newtype TProgram = TProgram [TBind]
deriving (Eq, Show, Read, Ord)
data TBind = TBind Ident Type TExp
@ -50,21 +50,25 @@ instance Print TBind where
instance Print TExp where
prt i = \case
TAnn e t -> prPrec i 2 $ concatD
[ prt 0 e
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD [ prt 0 u ]
TFree u t -> prPrec i 3 $ concatD [ prt 0 u ]
TAnn e t ->
prPrec i 2 $
concatD
[ prt 0 e
, doc (showString ":")
, prt 1 t
]
TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
TFree u t -> prPrec i 3 $ concatD [prt 0 u]
TConst c _ -> prPrec i 3 (concatD [prt 0 c])
TApp e e1 t -> prPrec i 2 $ concatD [ prt 2 e , prt 3 e1 ]
TAdd e e1 t -> prPrec i 1 $ concatD [ prt 1 e , doc (showString "+") , prt 2 e1 ]
TAbs _ u e t -> prPrec i 0 $ concatD
[ doc (showString "(")
, doc (showString "λ")
, prt 0 u
, doc (showString ".")
, prt 0 e
, doc (showString ")")
]
TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
TAbs _ u e t ->
prPrec i 0 $
concatD
[ doc (showString "(")
, doc (showString "λ")
, prt 0 u
, doc (showString ".")
, prt 0 e
, doc (showString ")")
]

View file

@ -1,31 +1,31 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import Renamer.RenamerIr qualified as R
import Control.Arrow ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=))
import qualified Control.Unification as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import qualified Renamer.RenamerIr as R
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..),
RExp (..), RProgram (..))
type Ctx = Map Ident UPolytype
@ -37,7 +37,7 @@ data TypeT a = TPolyT Ident | TMonoT Ident | TArrowT a a
instance Show a => Show (TypeT a) where
show (TPolyT (Ident i)) = i
show (TMonoT (Ident i)) = i
show (TArrowT a b) = show a ++ " -> " ++ show b
show (TArrowT a b) = show a ++ " -> " ++ show b
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
@ -46,7 +46,10 @@ type Type = Fix TypeT
type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t
deriving (Eq, Show, Functor)
deriving (Eq, Functor)
instance Show t => Show (Poly t) where
show (Forall is t) = unwords (map (\(Ident x) -> "forall " ++ x ++ ".") is) ++ " " ++ show t
type Polytype = Poly Type
@ -101,60 +104,17 @@ data TExp
----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program
typecheck = run . inferProgram
typecheck = undefined
inferProgram :: RProgram -> Infer Program
inferProgram (RProgram binds) = do
binds' <- mapM inferBind binds
return $ Program binds'
inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do
(t, e') <- infer e
e'' <- convert fromUType e'
t' <- fromUType t
insertSigs i (Forall [] t)
return $ Bind i e'' t'
fromUType :: UType -> Infer Polytype
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case
(TAnn e t) -> do
e' <- convert f e
EAnn e' <$> f t
(TFree i t) -> do
t' <- f t
return $ EFree i t'
(TBound i t) -> do
t' <- f t
return $ EBound i t'
(TConst c t) -> do
t' <- f t
return $ EConst c t'
(TApp e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EApp e1' e2' t'
(TAdd e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EAdd e1' e2' t'
(TAbs i e t) -> do
e' <- convert f e
t' <- f t
return $ EAbs i e' t'
run :: Infer a -> Either TypeError a
run =
flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
run :: Infer (UType, TExp) -> Either TypeError Polytype
run = fmap fst
>>> (>>= applyBindings)
>>> (>>= (generalize >>> fmap fromUPolytype))
>>> flip evalStateT mempty
>>> flip runReaderT mempty
>>> runExceptT
>>> evalIntBindingT
>>> runIdentity
infer :: RExp -> Infer (UType, TExp)
infer = \case
@ -166,6 +126,7 @@ infer = \case
t1 =:= UTMono "Int"
t2 =:= UTMono "Int"
return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
-- type is not used, probably wrong
(RAnn e t) -> do
(t', e') <- infer e
check e t'
@ -180,7 +141,7 @@ infer = \case
arg <- fresh
withBinding i (Forall [] arg) $ do
(res, e') <- infer e
return $ (UTArrow arg res, TAbs i e' (UTArrow arg res))
return (UTArrow arg res, TAbs i e' (UTArrow arg res))
(RFree i) -> do
t <- lookupSigsT i
return (t, TFree i t)
@ -213,7 +174,7 @@ fromPolytype :: UPolytype -> UType
fromPolytype (Forall ids ut) = ut
ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a
ucata f _ (UVar v) = f v
ucata f _ (UVar v) = f v
ucata f g (UTerm t) = g (fmap (ucata f g) t)
withBinding :: MonadReader Ctx m => Ident -> UPolytype -> m a -> m a
@ -277,6 +238,7 @@ skolemize (Forall xs uty) = do
mkVarName :: String -> IntVar -> Ident
mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
-- | Used in let bindings to generalize functions declared there
generalize :: UType -> Infer UPolytype
generalize uty = do
uty' <- applyBindings uty
@ -289,3 +251,11 @@ generalize uty = do
fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze)
inf = RAbs 0 "x" (RApp (RBound 0 "x") (RBound 0 "x"))
one = RConst (CInt 1)
lambda = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))
fn = RAbs 0 "x" (RBound 0 "x")

View file

@ -1 +1,3 @@
apply = \x. \y. (x : Mono Int)
test = \x. (x : Mono String) ;
apply x y = x + y ;