Simplified quite a bit. Made a unify function. Still bugs left

This commit is contained in:
sebastianselander 2023-02-17 11:09:48 +01:00
parent eafe0fea0b
commit a9f54dbca1
4 changed files with 192 additions and 209 deletions

View file

@ -3,12 +3,12 @@ function-arrows: trailing
comma-style: leading comma-style: leading
import-export-style: diff-friendly import-export-style: diff-friendly
indent-wheres: false indent-wheres: false
record-brace-space: false record-brace-space: true
newlines-between-decls: 1 newlines-between-decls: 1
haddock-style: multi-line haddock-style: multi-line
haddock-style-module: haddock-style-module:
let-style: auto let-style: auto
in-style: right-align in-style: right-align
respectful: true respectful: false
fixities: [] fixities: []
unicode: never unicode: never

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings common warnings
ghc-options: -Wdefault ghc-options: -W
executable language executable language
import: warnings import: warnings

View file

@ -1,26 +1,30 @@
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
import Control.Monad (when, void) import Control.Monad (void)
import Control.Monad.Except (ExceptT, throwError, runExceptT) import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT) import Control.Monad.State (StateT)
import qualified Control.Monad.State as St import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity) import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Grammar.ErrM (Err)
import Grammar.Print
import Data.List (findIndex)
import Debug.Trace (trace)
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr
data Ctx = Ctx { vars :: Map Integer Type data Ctx = Ctx
, sigs :: Map Ident Type { vars :: Map Integer Type
, nextFresh :: Ident , sigs :: Map Ident Type
} , nextFresh :: Int
deriving Show }
deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
{- {-
@ -28,18 +32,20 @@ The type checker will assume we first rename all variables to unique name, as to
have to care about scoping. It significantly improves the quality of life of the have to care about scoping. It significantly improves the quality of life of the
programmer. programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
Generalize. Not really sure what that means though
-} -}
type Infer = StateT Ctx (ExceptT Error Identity)
initEnv :: Ctx
initEnv = Ctx mempty mempty "a"
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT initEnv run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg typecheck = run . inferPrg
inferPrg :: RProgram -> Infer TProgram inferPrg :: RProgram -> Infer TProgram
inferPrg (RProgram xs) = do inferPrg (RProgram xs) = do
@ -54,122 +60,83 @@ inferBind (RBind name e) = do
inferExp :: RExp -> Infer (Type, TExp) inferExp :: RExp -> Infer (Type, TExp)
inferExp = \case inferExp = \case
RAnn expr typ -> do RAnn expr typ -> do
(t,expr') <- inferExp expr (t, expr') <- inferExp expr
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn") void $ t =:= typ
return (typ,expr') return (typ, expr')
RBound num name -> do
-- Name is only here for proper error messages t <- lookupVars num
RBound num name -> return (t, TBound num name t)
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return (t, TBound num name t)
RFree name -> do RFree name -> do
M.lookup name <$> St.gets sigs >>= \case t <- lookupSigs name
Nothing -> throwError $ UnboundVar "RFree" return (t, TFree name t)
Just t -> return (t, TFree name t) RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
fvar <- fresh
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
specifyType expr1 newType
typ1' <- apply newType typ1
return $ (typ1', TApp expr1' expr2' typ1')
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
RAdd expr1 expr2 -> do RAdd expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1 (typ1, expr1') <- check expr1 (TMono "Int")
(typ2, expr2') <- inferExp expr2 (_, expr2') <- check expr2 (TMono "Int")
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd") return (typ1, TAdd expr1' expr2' typ1)
specifyType expr1 (TMono "Int") RApp expr1 expr2 -> do
specifyType expr2 (TMono "Int") (fn_t, expr1') <- inferExp expr1
return (TMono "Int", TAdd expr1' expr2' (TMono "Int")) (arg_t, expr2') <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
RAbs num name expr -> do RAbs num name expr -> do
insertVars num (TPoly "a") arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr (typ, expr') <- inferExp expr
newTyp <- lookupVars num return (TArrow arg typ, TAbs num name expr' typ)
return $ (TArrow newTyp typ, TAbs num name expr' typ)
-- Aux check :: RExp -> Type -> Infer (Type, TExp)
isInt :: Type -> Bool check e t = do
isInt (TMono "Int") = True (t', e') <- inferExp e
isInt _ = False t'' <- t' =:= t
return (t'', e')
isArrow :: Type -> Bool fresh :: Infer Type
isArrow (TArrow _ _) = True
isArrow _ = False
isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fresh :: Infer Ident
fresh = do fresh = do
(Ident var) <- St.gets nextFresh var <- St.gets nextFresh
when (length var == 0) (throwError $ Default "fresh") St.modify (\st -> st {nextFresh = succ var})
index <- case findIndex (== (head var)) alphabet of return (TPoly $ Ident (show var))
Nothing -> throwError $ Default "fresh"
Just i -> return i
let nextIndex = (index + 1) `mod` 26
let newVar = Ident $ [alphabet !! nextIndex]
St.modify (\st -> st { nextFresh = newVar })
return newVar
where
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
unify :: Type -> Type -> Infer Type -- | Unify two types.
unify = todo (=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
-- | Specify the type of a bound variable (=:=) a (TPoly _) = return a
-- Because in lambdas we have to assume a general type and update it (=:=) (TMono a) (TMono b) | a == b = return (TMono a)
specifyType :: RExp -> Type -> Infer () (=:=) (TArrow a b) (TArrow c d) = do
specifyType (RBound num name) typ = do t1 <- a =:= c
insertVars num typ t2 <- b =:= d
return () return $ TArrow t1 t2
specifyType _ _ = return () (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type lookupVars :: Integer -> Infer Type
lookupVars i = do lookupVars i = do
st <- St.gets vars st <- St.gets vars
case M.lookup i st of case M.lookup i st of
Just t -> return t Just t -> return t
Nothing -> throwError $ UnboundVar "lookupVars" Nothing -> throwError $ UnboundVar "lookupVars"
insertVars :: Integer -> Type -> Infer () insertVars :: Integer -> Type -> Infer ()
insertVars i t = do insertVars i t = do
st <- St.get st <- St.get
St.put ( st { vars = M.insert i t st.vars } ) St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type lookupSigs :: Ident -> Infer Type
lookupSigs i = do lookupSigs i = do
st <- St.gets sigs st <- St.gets sigs
case M.lookup i st of case M.lookup i st of
Just t -> return t Just t -> return t
Nothing -> throwError $ UnboundVar "lookupSigs" Nothing -> throwError $ UnboundVar "lookupSigs"
insertSigs :: Ident -> Type -> Infer () insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do insertSigs i t = do
st <- St.get st <- St.get
St.put ( st { sigs = M.insert i t st.sigs } ) St.put (st {sigs = M.insert i t st.sigs})
-- Have to figure out the equivalence classes for types.
-- Currently this does not support more than exact matches.
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
apply t1 t2 = throwError $ TypeMismatch "apply"
{-# WARNING todo "TODO IN CODE" #-} {-# WARNING todo "TODO IN CODE" #-}
todo :: a todo :: a
@ -183,21 +150,30 @@ data Error
| UnboundVar String | UnboundVar String
| AnnotatedMismatch String | AnnotatedMismatch String
| Default String | Default String
deriving Show deriving (Show)
-- Tests -- Tests
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x")) -- (\x. x + 1) 1
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String"))) app_lambda :: RExp
app_lambda = app lambda one
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x")))) lambda :: RExp
lambda = RAbs 0 "x" $ add bound one
add :: RExp -> RExp -> RExp
add = RAdd
--add x = \y. x+y; bound = RBound 0 "x"
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
-- main = (\z. z+z) ((add 4) 6); app :: RExp -> RExp -> RExp
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd app = RApp
four = RConst (CInt 4)
six = RConst (CInt 6) one :: RExp
applyAdd = (RApp (RApp add four) six) one = RConst (CInt 1)
partialAdd = RApp add four
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
arr_t = TArrow (TMono "Int") (TPoly "1")
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))

View file

@ -1,28 +1,31 @@
{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-} {-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where module TypeChecker.Unification where
import Renamer.Renamer import Control.Arrow ((>>>))
import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..)) import Control.Monad.Except
import qualified Renamer.RenamerIr as R import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Reader import Control.Unification hiding (applyBindings, (=:=))
import Control.Monad.State import Control.Unification qualified as U
import Control.Monad.Except import Control.Unification.IntVar
import Data.Functor.Identity import Data.Foldable (fold)
import Control.Arrow ((>>>)) import Data.Functor.Fixedpoint
import Control.Unification hiding ((=:=), applyBindings) import Data.Functor.Identity
import qualified Control.Unification as U import Data.Map (Map)
import Control.Unification.IntVar import Data.Map qualified as M
import Data.Functor.Fixedpoint import Data.Maybe (fromJust, fromMaybe)
import GHC.Generics (Generic1) import Data.Set (Set, (\\))
import Data.Foldable (fold) import Data.Set qualified as S
import Data.Map (Map) import Debug.Trace (trace)
import qualified Data.Map as M import GHC.Generics (Generic1)
import Data.Maybe (fromMaybe, fromJust) import Renamer.Renamer
import Data.Set (Set, (\\)) import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import qualified Data.Set as S import Renamer.RenamerIr qualified as R
import Debug.Trace (trace)
type Ctx = Map Ident UPolytype type Ctx = Map Ident UPolytype
@ -39,12 +42,13 @@ instance Show a => Show (TypeT a) where
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity))) type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
type Type = Fix TypeT type Type = Fix TypeT
type UType = UTerm TypeT IntVar type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t data Poly t = Forall [Ident] t
deriving (Eq, Show, Functor) deriving (Eq, Show, Functor)
type Polytype = Poly Type type Polytype = Poly Type
type UPolytype = Poly UType type UPolytype = Poly UType
@ -67,23 +71,23 @@ pattern UTPoly :: Ident -> UType
pattern UTPoly v = UTerm (TPolyT v) pattern UTPoly v = UTerm (TPolyT v)
data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType
deriving Show deriving (Show)
data Program = Program [Bind] newtype Program = Program [Bind]
deriving Show deriving (Show)
data Bind = Bind Ident Exp Polytype data Bind = Bind Ident Exp Polytype
deriving Show deriving (Show)
data Exp data Exp
= EAnn Exp Polytype = EAnn Exp Polytype
| EBound Ident Polytype | EBound Ident Polytype
| EFree Ident Polytype | EFree Ident Polytype
| EConst Const Polytype | EConst Const Polytype
| EApp Exp Exp Polytype | EApp Exp Exp Polytype
| EAdd Exp Exp Polytype | EAdd Exp Exp Polytype
| EAbs Ident Exp Polytype | EAbs Ident Exp Polytype
deriving Show deriving (Show)
data TExp data TExp
= TAnn TExp UType = TAnn TExp UType
@ -93,11 +97,11 @@ data TExp
| TApp TExp TExp UType | TApp TExp TExp UType
| TAdd TExp TExp UType | TAdd TExp TExp UType
| TAbs Ident TExp UType | TAbs Ident TExp UType
deriving Show deriving (Show)
---------------------------------------------------------- ----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program typecheck :: RProgram -> Either TypeError Program
typecheck = run . inferProgram typecheck = run . inferProgram
inferProgram :: RProgram -> Infer Program inferProgram :: RProgram -> Infer Program
inferProgram (RProgram binds) = do inferProgram (RProgram binds) = do
@ -106,7 +110,7 @@ inferProgram (RProgram binds) = do
inferBind :: RBind -> Infer Bind inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do inferBind (RBind i e) = do
(t,e') <- infer e (t, e') <- infer e
e'' <- convert fromUType e' e'' <- convert fromUType e'
t' <- fromUType t t' <- fromUType t
insertSigs i (Forall [] t) insertSigs i (Forall [] t)
@ -114,20 +118,19 @@ inferBind (RBind i e) = do
fromUType :: UType -> Infer Polytype fromUType :: UType -> Infer Polytype
fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype)) fromUType = applyBindings >>> (>>= (generalize >>> fmap fromUPolytype))
convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case convert f = \case
(TAnn e t) -> do (TAnn e t) -> do
e' <- convert f e e' <- convert f e
t' <- (f t) EAnn e' <$> f t
return $ EAnn e' t'
(TFree i t) -> do (TFree i t) -> do
t' <- f t t' <- f t
return $ EFree i t' return $ EFree i t'
(TBound i t) -> do (TBound i t) -> do
t' <- f t t' <- f t
return $ EBound i t' return $ EBound i t'
(TConst c t) -> do (TConst c t) -> do
t' <- f t t' <- f t
return $ EConst c t' return $ EConst c t'
(TApp e1 e2 t) -> do (TApp e1 e2 t) -> do
@ -135,42 +138,43 @@ convert f = \case
e2' <- convert f e2 e2' <- convert f e2
t' <- f t t' <- f t
return $ EApp e1' e2' t' return $ EApp e1' e2' t'
(TAdd e1 e2 t) -> do (TAdd e1 e2 t) -> do
e1' <- convert f e1 e1' <- convert f e1
e2' <- convert f e2 e2' <- convert f e2
t' <- f t t' <- f t
return $ EAdd e1' e2' t' return $ EAdd e1' e2' t'
(TAbs i e t) -> do (TAbs i e t) -> do
e' <- convert f e e' <- convert f e
t' <- f t t' <- f t
return $ EAbs i e' t' return $ EAbs i e' t'
run :: Infer a -> Either TypeError a run :: Infer a -> Either TypeError a
run = flip evalStateT mempty run =
>>> flip runReaderT mempty flip evalStateT mempty
>>> runExceptT >>> flip runReaderT mempty
>>> evalIntBindingT >>> runExceptT
>>> runIdentity >>> evalIntBindingT
>>> runIdentity
infer :: RExp -> Infer (UType, TExp) infer :: RExp -> Infer (UType, TExp)
infer = \case infer = \case
(RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int")) (RConst (CInt i)) -> return (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String")) (RConst (CStr str)) -> return (UTMono "String", TConst (CStr str) (UTMono "String"))
(RAdd e1 e2) -> do (RAdd e1 e2) -> do
(t1,e1') <- infer e2 (t1, e1') <- infer e2
(t2,e2') <- infer e1 (t2, e2') <- infer e1
t1 =:= (UTMono "Int") t1 =:= UTMono "Int"
t2 =:= (UTMono "Int") t2 =:= UTMono "Int"
return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int")) return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(RAnn e t) -> do (RAnn e t) -> do
(t',e') <- infer e (t', e') <- infer e
check e t' check e t'
return (t', TAnn e' t') return (t', TAnn e' t')
(RApp e1 e2) -> do (RApp e1 e2) -> do
(f,e1') <- infer e1 (f, e1') <- infer e1
(arg,e2') <- infer e2 (arg, e2') <- infer e2
res <- fresh res <- fresh
f =:= UTArrow f arg f =:= UTArrow arg res
return (res, TApp e1' e2' res) return (res, TApp e1' e2' res)
(RAbs _ i e) -> do (RAbs _ i e) -> do
arg <- fresh arg <- fresh
@ -199,8 +203,8 @@ lookupSigsT :: Ident -> Infer UType
lookupSigsT x@(Ident i) = do lookupSigsT x@(Ident i) = do
ctx <- ask ctx <- ask
case M.lookup x ctx of case M.lookup x ctx of
Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i) Nothing -> trace (show ctx) (throwError $ "Sigs - Unbound variable: " <> i)
Just ut -> return $ fromPolytype ut Just ut -> return $ fromPolytype ut
insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m () insertSigs :: MonadState (Map Ident UPolytype) m => Ident -> UPolytype -> m ()
insertSigs x ty = modify (M.insert x ty) insertSigs x ty = modify (M.insert x ty)
@ -218,21 +222,23 @@ withBinding x ty = local (M.insert x ty)
deriving instance Ord IntVar deriving instance Ord IntVar
class FreeVars a where class FreeVars a where
freeVars :: a -> Infer (Set (Either Ident IntVar)) freeVars :: a -> Infer (Set (Either Ident IntVar))
instance FreeVars UType where instance FreeVars UType where
freeVars ut = do freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
let ftvs = ucata (const S.empty) let ftvs =
(\case {TMonoT x -> S.singleton (Left x); f -> fold f}) ucata
ut (const S.empty)
return $ fuvs `S.union` ftvs (\case TMonoT x -> S.singleton (Left x); f -> fold f)
ut
return $ fuvs `S.union` ftvs
instance FreeVars UPolytype where instance FreeVars UPolytype where
freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut
instance FreeVars Ctx where instance FreeVars Ctx where
freeVars = fmap S.unions . mapM freeVars . M.elems freeVars = fmap S.unions . mapM freeVars . M.elems
fresh :: Infer UType fresh :: Infer UType
fresh = UVar <$> lift (lift (lift freeVar)) fresh = UVar <$> lift (lift (lift freeVar))
@ -249,21 +255,22 @@ applyBindings = lift . lift . U.applyBindings
instantiate :: UPolytype -> Infer UType instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty return $ substU (M.fromList (zip (map Left xs) xs')) uty
substU :: Map (Either Ident IntVar) UType -> UType -> UType substU :: Map (Either Ident IntVar) UType -> UType -> UType
substU m = ucata substU m =
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) ucata
(\case (\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m) ( \case
f -> UTerm f TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
) f -> UTerm f
)
skolemize :: UPolytype -> Infer UType skolemize :: UPolytype -> Infer UType
skolemize (Forall xs uty) = do skolemize (Forall xs uty) = do
xs' <- mapM (const fresh) xs xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
where where
toSkolem (UVar v) = UTPoly (mkVarName "s" v) toSkolem (UVar v) = UTPoly (mkVarName "s" v)
@ -272,13 +279,13 @@ mkVarName nm (IntVar v) = Ident $ nm ++ show (v + (maxBound :: Int) + 1)
generalize :: UType -> Infer UPolytype generalize :: UType -> Infer UPolytype
generalize uty = do generalize uty = do
uty' <- applyBindings uty uty' <- applyBindings uty
ctx <- ask ctx <- ask
tmfvs <- freeVars uty' tmfvs <- freeVars uty'
ctxfvs <- freeVars ctx ctxfvs <- freeVars ctx
let fvs = S.toList $ tmfvs \\ ctxfvs let fvs = S.toList $ tmfvs \\ ctxfvs
xs = map (either id (mkVarName "a")) fvs xs = map (either id (mkVarName "a")) fvs
return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty') return $ Forall xs (substU (M.fromList (zip fvs (map UTPoly xs))) uty')
fromUPolytype :: UPolytype -> Polytype fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze) fromUPolytype = fmap (fromJust . freeze)