Unification part works (probably). Have a hard time understanding it.

This commit is contained in:
sebastianselander 2023-02-17 18:42:50 +01:00
parent 764faa582b
commit f188cffb8d
7 changed files with 167 additions and 197 deletions

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings common warnings
ghc-options: -W ghc-options: -Wdefault
executable language executable language
import: warnings import: warnings

View file

@ -46,5 +46,5 @@ main =
putStrLn "" putStrLn ""
putStrLn " ----- TYPECHECKER ----- " putStrLn " ----- TYPECHECKER ----- "
putStrLn "" putStrLn ""
putStrLn . printTree $ prg print prg
exitSuccess exitSuccess

View file

@ -1,12 +1,13 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Renamer.RenamerIr ( RExp (..) module Renamer.RenamerIr (
, RBind (..) RExp (..),
, RProgram (..) RBind (..),
, Const (..) RProgram (..),
, Ident (..) Const (..),
, Type (..) Ident (..),
) where Type (..),
) where
import Grammar.Abs ( import Grammar.Abs (
Bind (..), Bind (..),
@ -51,9 +52,9 @@ instance Print RBind where
instance Print RExp where instance Print RExp where
prt i = \case prt i = \case
RAnn e t -> prPrec i 2 (concatD [prt 0 e, doc (showString ":"), prt 1 t]) RAnn e t -> prPrec i 2 (concatD [prt 0 e, doc (showString ":"), prt 1 t])
RBound n _ -> prPrec i 3 (concatD [prt 0 ("var" ++ show n)]) RBound n _ -> prPrec i 3 (concatD [prt 0 n])
RFree id -> prPrec i 3 (concatD [prt 0 id]) RFree id -> prPrec i 3 (concatD [prt 0 id])
RConst n -> prPrec i 3 (concatD [prt 0 n]) RConst n -> prPrec i 3 (concatD [prt 0 n])
RApp e e1 -> prPrec i 2 (concatD [prt 2 e, prt 3 e1]) RApp e e1 -> prPrec i 2 (concatD [prt 2 e, prt 3 e1])
RAdd e e1 -> prPrec i 1 (concatD [prt 1 e, doc (showString "+"), prt 2 e1]) RAdd e e1 -> prPrec i 1 (concatD [prt 1 e, doc (showString "+"), prt 2 e1])
RAbs u id e -> prPrec i 0 (concatD [doc (showString "λ"), prt 0 ("var" ++ show u), doc (showString "."), prt 0 e]) RAbs u _ e -> prPrec i 0 (concatD [doc (showString "λ"), prt 0 u, doc (showString "."), prt 0 e])

View file

@ -7,15 +7,16 @@ module TypeChecker.TypeChecker where
import Control.Monad (void) import Control.Monad (void)
import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT) import Control.Monad.State (StateT)
import Control.Monad.State qualified as St import qualified Control.Monad.State as St
import Data.Functor.Identity (Identity, runIdentity) import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr
data Ctx = Ctx data Ctx = Ctx
{ vars :: Map Integer Type { vars :: Map Integer Type
, sigs :: Map Ident (RBind, Maybe Type) , sigs :: Map Ident Type
, nextFresh :: Int , nextFresh :: Int
} }
deriving (Show) deriving (Show)
@ -38,70 +39,54 @@ TODOs:
-} -}
typecheck :: RProgram -> Either Error TProgram
typecheck = todo
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0) run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do
xs' <- mapM inferBind xs
return $ TProgram xs'
-- Binds are not correctly added to the context.
-- Can't type check programs with more than one function currently
inferBind :: RBind -> Infer TBind
inferBind b@(RBind name e) = do
insertSigs name b Nothing
(t, e') <- inferExp e
return $ TBind name t e'
-- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary -- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- { \x. \y. x + y } will have the type { a -> b -> Int } -- { \x. \y. x + y } will have the type { a -> b -> Int }
inferExp :: RExp -> Infer (Type, TExp) inferExp :: RExp -> Infer Type
inferExp = \case inferExp = \case
RAnn expr typ -> do RAnn expr typ -> do
(t, expr') <- inferExp expr t <- inferExp expr
void $ t =:= typ void $ t =:= typ
return (typ, expr') return t
RBound num name -> do
t <- lookupVars num RBound num name -> lookupVars num
return (t, TBound num name t)
RFree name -> do RFree name -> lookupSigs name
(b@(RBind name _), t) <- lookupSigs name
t' <- case t of RConst (CInt i) -> return $ TMono "Int"
Nothing -> do
(TBind _ a _) <- inferBind b RConst (CStr str) -> return $ TMono "Str"
insertSigs name b (Just a)
return a
Just a -> return a
return (t', TFree name t')
RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RAdd expr1 expr2 -> do RAdd expr1 expr2 -> do
(typ1, expr1') <- check expr1 (TMono "Int") let int = TMono "Int"
(_, expr2') <- check expr2 (TMono "Int") typ1 <- check expr1 int
return (typ1, TAdd expr1' expr2' typ1) typ2 <- check expr2 int
return int
RApp expr1 expr2 -> do RApp expr1 expr2 -> do
(fn_t, expr1') <- inferExp expr1 fn_t <- inferExp expr1
(arg_t, expr2') <- inferExp expr2 arg_t <- inferExp expr2
res <- fresh res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t) return res
RAbs num name expr -> do RAbs num name expr -> do
arg <- fresh arg <- fresh
insertVars num arg insertVars num arg
(typ, expr') <- inferExp expr typ <- inferExp expr
return (TArrow arg typ, TAbs num name expr' typ) return $ TArrow arg typ
check :: RExp -> Type -> Infer (Type, TExp) check :: RExp -> Type -> Infer ()
check e t = do check e t = do
(t', e') <- inferExp e t' <- inferExp e
t'' <- t' =:= t t =:= t'
return (t'', e') return ()
fresh :: Infer Type fresh :: Infer Type
fresh = do fresh = do
@ -120,7 +105,6 @@ fresh = do
return $ TArrow t1 t2 return $ TArrow t1 t2
(=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b]) (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type lookupVars :: Integer -> Infer Type
lookupVars i = do lookupVars i = do
st <- St.gets vars st <- St.gets vars
@ -133,17 +117,17 @@ insertVars i t = do
st <- St.get st <- St.get
St.put (st {vars = M.insert i t st.vars}) St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer (RBind, Maybe Type) lookupSigs :: Ident -> Infer Type
lookupSigs i = do lookupSigs i = do
st <- St.gets sigs st <- St.gets sigs
case M.lookup i st of case M.lookup i st of
Just t -> return t Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs" Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> RBind -> Maybe Type -> Infer () insertSigs :: Ident -> Type -> Infer ()
insertSigs i b t = do insertSigs i t = do
st <- St.get st <- St.get
St.put (st {sigs = M.insert i (b, t) st.sigs}) St.put (st {sigs = M.insert i t st.sigs})
{-# WARNING todo "TODO IN CODE" #-} {-# WARNING todo "TODO IN CODE" #-}
todo :: a todo :: a
@ -158,3 +142,12 @@ data Error
| AnnotatedMismatch String | AnnotatedMismatch String
| Default String | Default String
deriving (Show) deriving (Show)
{-
The procedure inst(σ) specializes the polytype
σ by copying the term and replacing the bound type variables
consistently by new monotype variables.
-}

View file

@ -1,21 +1,21 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr ( TProgram(..) module TypeChecker.TypeCheckerIr (
, TBind(..) TProgram (..),
, TExp(..) TBind (..),
, RProgram(..) TExp (..),
, RBind(..) RProgram (..),
, RExp(..) RBind (..),
, Type(..) RExp (..),
, Const(..) Type (..),
, Ident(..) Const (..),
) Ident (..),
where ) where
import Renamer.RenamerIr
import Grammar.Print import Grammar.Print
import Renamer.RenamerIr
data TProgram = TProgram [TBind] newtype TProgram = TProgram [TBind]
deriving (Eq, Show, Read, Ord) deriving (Eq, Show, Read, Ord)
data TBind = TBind Ident Type TExp data TBind = TBind Ident Type TExp
@ -50,17 +50,21 @@ instance Print TBind where
instance Print TExp where instance Print TExp where
prt i = \case prt i = \case
TAnn e t -> prPrec i 2 $ concatD TAnn e t ->
prPrec i 2 $
concatD
[ prt 0 e [ prt 0 e
, doc (showString ":") , doc (showString ":")
, prt 1 t , prt 1 t
] ]
TBound _ u t -> prPrec i 3 $ concatD [ prt 0 u ] TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
TFree u t -> prPrec i 3 $ concatD [ prt 0 u ] TFree u t -> prPrec i 3 $ concatD [prt 0 u]
TConst c _ -> prPrec i 3 (concatD [prt 0 c]) TConst c _ -> prPrec i 3 (concatD [prt 0 c])
TApp e e1 t -> prPrec i 2 $ concatD [ prt 2 e , prt 3 e1 ] TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
TAdd e e1 t -> prPrec i 1 $ concatD [ prt 1 e , doc (showString "+") , prt 2 e1 ] TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
TAbs _ u e t -> prPrec i 0 $ concatD TAbs _ u e t ->
prPrec i 0 $
concatD
[ doc (showString "(") [ doc (showString "(")
, doc (showString "λ") , doc (showString "λ")
, prt 0 u , prt 0 u

View file

@ -1,5 +1,4 @@
{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
@ -11,21 +10,22 @@ import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Unification hiding (applyBindings, (=:=)) import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U import qualified Control.Unification as U
import Control.Unification.IntVar import Control.Unification.IntVar
import Data.Foldable (fold) import Data.Foldable (fold)
import Data.Functor.Fixedpoint import Data.Functor.Fixedpoint
import Data.Functor.Identity import Data.Functor.Identity
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import Data.Maybe (fromJust, fromMaybe) import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\)) import Data.Set (Set, (\\))
import Data.Set qualified as S import qualified Data.Set as S
import Debug.Trace (trace) import Debug.Trace (trace)
import GHC.Generics (Generic1) import GHC.Generics (Generic1)
import Renamer.Renamer import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..)) import qualified Renamer.RenamerIr as R
import Renamer.RenamerIr qualified as R import Renamer.RenamerIr (Const (..), Ident (..), RBind (..),
RExp (..), RProgram (..))
type Ctx = Map Ident UPolytype type Ctx = Map Ident UPolytype
@ -46,7 +46,10 @@ type Type = Fix TypeT
type UType = UTerm TypeT IntVar type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t data Poly t = Forall [Ident] t
deriving (Eq, Show, Functor) deriving (Eq, Functor)
instance Show t => Show (Poly t) where
show (Forall is t) = unwords (map (\(Ident x) -> "forall " ++ x ++ ".") is) ++ " " ++ show t
type Polytype = Poly Type type Polytype = Poly Type
@ -101,56 +104,13 @@ data TExp
---------------------------------------------------------- ----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program typecheck :: RProgram -> Either TypeError Program
typecheck = run . inferProgram typecheck = undefined
inferProgram :: RProgram -> Infer Program run :: Infer (UType, TExp) -> Either TypeError Polytype
inferProgram (RProgram binds) = do run = fmap fst
binds' <- mapM inferBind binds >>> (>>= applyBindings)
return $ Program binds' >>> (>>= (generalize >>> fmap fromUPolytype))
>>> flip evalStateT mempty
inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do
(t, e') <- infer e
e'' <- convert fromUType e'
t' <- fromUType t
insertSigs i (Forall [] t)
return $ Bind i e'' t'
fromUType :: UType -> Infer Polytype
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case
(TAnn e t) -> do
e' <- convert f e
EAnn e' <$> f t
(TFree i t) -> do
t' <- f t
return $ EFree i t'
(TBound i t) -> do
t' <- f t
return $ EBound i t'
(TConst c t) -> do
t' <- f t
return $ EConst c t'
(TApp e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EApp e1' e2' t'
(TAdd e1 e2 t) -> do
e1' <- convert f e1
e2' <- convert f e2
t' <- f t
return $ EAdd e1' e2' t'
(TAbs i e t) -> do
e' <- convert f e
t' <- f t
return $ EAbs i e' t'
run :: Infer a -> Either TypeError a
run =
flip evalStateT mempty
>>> flip runReaderT mempty >>> flip runReaderT mempty
>>> runExceptT >>> runExceptT
>>> evalIntBindingT >>> evalIntBindingT
@ -166,6 +126,7 @@ infer = \case
t1 =:= UTMono "Int" t1 =:= UTMono "Int"
t2 =:= UTMono "Int" t2 =:= UTMono "Int"
return (UTMono "Int", TAdd e1' e2' (UTMono "Int")) return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
-- type is not used, probably wrong
(RAnn e t) -> do (RAnn e t) -> do
(t', e') <- infer e (t', e') <- infer e
check e t' check e t'
@ -180,7 +141,7 @@ infer = \case
arg <- fresh arg <- fresh
withBinding i (Forall [] arg) $ do withBinding i (Forall [] arg) $ do
(res, e') <- infer e (res, e') <- infer e
return $ (UTArrow arg res, TAbs i e' (UTArrow arg res)) return (UTArrow arg res, TAbs i e' (UTArrow arg res))
(RFree i) -> do (RFree i) -> do
t <- lookupSigsT i t <- lookupSigsT i
return (t, TFree i t) return (t, TFree i t)
@ -277,6 +238,7 @@ skolemize (Forall xs uty) = do
mkVarName :: String -> IntVar -> Ident mkVarName :: String -> IntVar -> Ident
mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1) mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
-- | Used in let bindings to generalize functions declared there
generalize :: UType -> Infer UPolytype generalize :: UType -> Infer UPolytype
generalize uty = do generalize uty = do
uty' <- applyBindings uty uty' <- applyBindings uty
@ -289,3 +251,11 @@ generalize uty = do
fromUPolytype :: UPolytype -> Polytype fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze) fromUPolytype = fmap (fromJust . freeze)
inf = RAbs 0 "x" (RApp (RBound 0 "x") (RBound 0 "x"))
one = RConst (CInt 1)
lambda = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))
fn = RAbs 0 "x" (RBound 0 "x")

View file

@ -1 +1,3 @@
apply = \x. \y. (x : Mono Int) test = \x. (x : Mono String) ;
apply x y = x + y ;