Remade lets with bind & improvements
This commit is contained in:
parent
30a79f34af
commit
3371c3a146
3 changed files with 43 additions and 25 deletions
|
|
@ -36,10 +36,7 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef
|
|||
renameDef :: Def -> Rn Def
|
||||
renameDef = \case
|
||||
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
|
||||
DBind (Bind name vars rhs) -> do
|
||||
(new_names, vars') <- newNames initNames (coerce vars)
|
||||
rhs' <- snd <$> renameExp new_names rhs
|
||||
pure . DBind $ Bind name (coerce vars') rhs'
|
||||
DBind bind -> DBind . snd <$> renameBind initNames bind
|
||||
DData (Data (Indexed cname types) constrs) -> do
|
||||
tvars_ <- tvars
|
||||
tvars' <- mapM nextNameTVar tvars_
|
||||
|
|
@ -61,6 +58,12 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef
|
|||
renameConstr new_types (Constructor name typ) =
|
||||
Constructor name $ substituteTVar new_types typ
|
||||
|
||||
renameBind :: Names -> Bind -> Rn (Names, Bind)
|
||||
renameBind old_names (Bind name vars rhs) = do
|
||||
(new_names, vars') <- newNames old_names (coerce vars)
|
||||
(newer_names, rhs') <- renameExp new_names rhs
|
||||
pure (newer_names, Bind name (coerce vars') rhs')
|
||||
|
||||
substituteTVar :: [(TVar, TVar)] -> Type -> Type
|
||||
substituteTVar new_names typ = case typ of
|
||||
TLit _ -> typ
|
||||
|
|
@ -110,11 +113,10 @@ renameExp old_names = \case
|
|||
pure (Map.union env1 env2, EAdd e1' e2')
|
||||
|
||||
-- TODO fix shadowing
|
||||
ELet name rhs e -> do
|
||||
(new_names, name') <- newName old_names (coerce name)
|
||||
(new_names', rhs') <- renameExp new_names rhs
|
||||
(new_names'', e') <- renameExp new_names' e
|
||||
pure (new_names'', ELet (coerce name') rhs' e')
|
||||
ELet bind e -> do
|
||||
(new_names, bind') <- renameBind old_names bind
|
||||
(new_names', e') <- renameExp new_names e
|
||||
pure (new_names', ELet bind' e')
|
||||
EAbs par e -> do
|
||||
(new_names, par') <- newName old_names (coerce par)
|
||||
(new_names', e') <- renameExp new_names e
|
||||
|
|
|
|||
|
|
@ -78,8 +78,7 @@ checkPrg (Program bs) = do
|
|||
preRun bs
|
||||
-- Type check the program twice to produce all top-level types in the first pass through
|
||||
bs' <- checkDef bs
|
||||
trace "\nFIRST ITERATION" return ()
|
||||
trace (printTree bs' ++ "\nSECOND ITERATION\n") return ()
|
||||
trace ("FIRST ITERATION: " ++ printTree bs') pure ()
|
||||
bs'' <- checkDef bs
|
||||
return $ T.Program bs''
|
||||
where
|
||||
|
|
@ -106,23 +105,35 @@ checkPrg (Program bs) = do
|
|||
|
||||
checkBind :: Bind -> Infer T.Bind
|
||||
checkBind (Bind name args e) = do
|
||||
-- let lambda = makeLambda e (reverse $ coerce args)
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(_, lambdaT) <- inferExp lambda
|
||||
args <- zip args <$> mapM (const fresh) args
|
||||
withBindings (map coerce args) $ do
|
||||
e@(_, t') <- inferExp e
|
||||
e@(_, _) <- inferExp e
|
||||
s <- gets sigs
|
||||
-- let fs = map (second Just) (getFunctionTypes s e)
|
||||
-- mapM_ (uncurry insertSig) fs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t) -> do
|
||||
sub <- unify t t'
|
||||
sub <- unify t lambdaT
|
||||
let newT = apply sub t
|
||||
insertSig (coerce name) (Just newT)
|
||||
return $ T.Bind (coerce name, newT) (map coerce args) e
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just t')
|
||||
return (T.Bind (coerce name, t') (map coerce args) e) -- (apply s e)
|
||||
where
|
||||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
makeLambda = foldl (flip (EAbs . coerce))
|
||||
insertSig (coerce name) (Just lambdaT)
|
||||
return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e)
|
||||
-- where
|
||||
-- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)]
|
||||
-- getFunctionTypes s = \case
|
||||
-- (T.EId b, t) -> case M.lookup b s of
|
||||
-- Just Nothing -> return (b, t)
|
||||
-- _ -> []
|
||||
-- (T.ELit _, _) -> []
|
||||
-- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
|
||||
-- (T.EApp e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
|
||||
-- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
|
||||
-- (T.EAbs _ e, _) -> getFunctionTypes s e
|
||||
-- (T.ECase e injs, _) -> getFunctionTypes s e ++ concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs
|
||||
|
||||
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
|
||||
isMoreSpecificOrEq _ (T.TAll _ _) = True
|
||||
|
|
@ -263,14 +274,15 @@ algoW = \case
|
|||
|
||||
-- The bar over S₀ and Γ means "generalize"
|
||||
|
||||
ELet name e0 e1 -> do
|
||||
(s1, (e0', t1)) <- algoW e0
|
||||
ELet b@(Bind name args e) e1 -> do
|
||||
(s1, (_, t0)) <- algoW (makeLambda e (coerce args))
|
||||
bind' <- checkBind b
|
||||
env <- asks vars
|
||||
let t' = generalize (apply s1 env) t1
|
||||
let t' = generalize (apply s1 env) t0
|
||||
withBinding (coerce name) t' $ do
|
||||
(s2, (e1', t2)) <- algoW e1
|
||||
let comp = s2 `compose` s1
|
||||
return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2))
|
||||
return (comp, apply comp (T.ELet bind' (e1', t2), t2))
|
||||
|
||||
-- \| TODO: Add judgement
|
||||
ECase caseExpr injs -> do
|
||||
|
|
@ -280,8 +292,12 @@ algoW = \case
|
|||
let t' = apply comp ret_t
|
||||
return (comp, (T.ECase (e', t) injs, t'))
|
||||
|
||||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
makeLambda = foldl (flip (EAbs . coerce))
|
||||
|
||||
-- | Unify two types producing a new substitution
|
||||
unify :: T.Type -> T.Type -> Infer Subst
|
||||
unify t0 t1 | trace ("T0: " ++ show t0 ++ "\nT1: " ++ show t1) False = undefined
|
||||
unify t0 t1 = do
|
||||
case (t0, t1) of
|
||||
(T.TFun a b, T.TFun c d) -> do
|
||||
|
|
@ -293,7 +309,7 @@ unify t0 t1 = do
|
|||
(T.TAll _ t, b) -> unify t b
|
||||
(a, T.TAll _ t) -> unify a t
|
||||
(T.TLit a, T.TLit b) ->
|
||||
if a == b then return M.empty else throwError "Types do not unify"
|
||||
if a == b then return M.empty else throwError . unwords $ ["Can not unify", "'" ++ printTree (T.TLit a) ++ "'", "with", "'" ++ printTree (T.TLit b) ++ "'"]
|
||||
(T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) ->
|
||||
if name == name' && length t == length t'
|
||||
then do
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue