Remade lets with bind & improvements

This commit is contained in:
sebastianselander 2023-03-24 11:21:25 +01:00
parent 30a79f34af
commit 3371c3a146
3 changed files with 43 additions and 25 deletions

View file

@ -51,7 +51,7 @@ ECons. Exp4 ::= UIdent ;
ELit. Exp4 ::= Lit ; ELit. Exp4 ::= Lit ;
EApp. Exp3 ::= Exp3 Exp4 ; EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ; EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ELet. Exp ::= "let" LIdent "=" Exp "in" Exp ; ELet. Exp ::= "let" Bind "in" Exp ;
EAbs. Exp ::= "\\" LIdent "." Exp ; EAbs. Exp ::= "\\" LIdent "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";

View file

@ -36,10 +36,7 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef
renameDef :: Def -> Rn Def renameDef :: Def -> Rn Def
renameDef = \case renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do DBind bind -> DBind . snd <$> renameBind initNames bind
(new_names, vars') <- newNames initNames (coerce vars)
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name (coerce vars') rhs'
DData (Data (Indexed cname types) constrs) -> do DData (Data (Indexed cname types) constrs) -> do
tvars_ <- tvars tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_ tvars' <- mapM nextNameTVar tvars_
@ -61,6 +58,12 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef
renameConstr new_types (Constructor name typ) = renameConstr new_types (Constructor name typ) =
Constructor name $ substituteTVar new_types typ Constructor name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of substituteTVar new_names typ = case typ of
TLit _ -> typ TLit _ -> typ
@ -110,11 +113,10 @@ renameExp old_names = \case
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing -- TODO fix shadowing
ELet name rhs e -> do ELet bind e -> do
(new_names, name') <- newName old_names (coerce name) (new_names, bind') <- renameBind old_names bind
(new_names', rhs') <- renameExp new_names rhs (new_names', e') <- renameExp new_names e
(new_names'', e') <- renameExp new_names' e pure (new_names', ELet bind' e')
pure (new_names'', ELet (coerce name') rhs' e')
EAbs par e -> do EAbs par e -> do
(new_names, par') <- newName old_names (coerce par) (new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e

View file

@ -78,8 +78,7 @@ checkPrg (Program bs) = do
preRun bs preRun bs
-- Type check the program twice to produce all top-level types in the first pass through -- Type check the program twice to produce all top-level types in the first pass through
bs' <- checkDef bs bs' <- checkDef bs
trace "\nFIRST ITERATION" return () trace ("FIRST ITERATION: " ++ printTree bs') pure ()
trace (printTree bs' ++ "\nSECOND ITERATION\n") return ()
bs'' <- checkDef bs bs'' <- checkDef bs
return $ T.Program bs'' return $ T.Program bs''
where where
@ -106,23 +105,35 @@ checkPrg (Program bs) = do
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
-- let lambda = makeLambda e (reverse $ coerce args) let lambda = makeLambda e (reverse (coerce args))
(_, lambdaT) <- inferExp lambda
args <- zip args <$> mapM (const fresh) args args <- zip args <$> mapM (const fresh) args
withBindings (map coerce args) $ do withBindings (map coerce args) $ do
e@(_, t') <- inferExp e e@(_, _) <- inferExp e
s <- gets sigs s <- gets sigs
-- let fs = map (second Just) (getFunctionTypes s e)
-- mapM_ (uncurry insertSig) fs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t) -> do Just (Just t) -> do
sub <- unify t t' sub <- unify t lambdaT
let newT = apply sub t let newT = apply sub t
insertSig (coerce name) (Just newT) insertSig (coerce name) (Just newT)
return $ T.Bind (coerce name, newT) (map coerce args) e return $ T.Bind (coerce name, newT) (map coerce args) e
_ -> do _ -> do
insertSig (coerce name) (Just t') insertSig (coerce name) (Just lambdaT)
return (T.Bind (coerce name, t') (map coerce args) e) -- (apply s e) return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e)
where -- where
makeLambda :: Exp -> [Ident] -> Exp -- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)]
makeLambda = foldl (flip (EAbs . coerce)) -- getFunctionTypes s = \case
-- (T.EId b, t) -> case M.lookup b s of
-- Just Nothing -> return (b, t)
-- _ -> []
-- (T.ELit _, _) -> []
-- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
-- (T.EApp e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
-- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
-- (T.EAbs _ e, _) -> getFunctionTypes s e
-- (T.ECase e injs, _) -> getFunctionTypes s e ++ concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq _ (T.TAll _ _) = True isMoreSpecificOrEq _ (T.TAll _ _) = True
@ -263,14 +274,15 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize" -- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do ELet b@(Bind name args e) e1 -> do
(s1, (e0', t1)) <- algoW e0 (s1, (_, t0)) <- algoW (makeLambda e (coerce args))
bind' <- checkBind b
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t1 let t' = generalize (apply s1 env) t0
withBinding (coerce name) t' $ do withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1 (s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2)) return (comp, apply comp (T.ELet bind' (e1', t2), t2))
-- \| TODO: Add judgement -- \| TODO: Add judgement
ECase caseExpr injs -> do ECase caseExpr injs -> do
@ -280,8 +292,12 @@ algoW = \case
let t' = apply comp ret_t let t' = apply comp ret_t
return (comp, (T.ECase (e', t) injs, t')) return (comp, (T.ECase (e', t) injs, t'))
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst unify :: T.Type -> T.Type -> Infer Subst
unify t0 t1 | trace ("T0: " ++ show t0 ++ "\nT1: " ++ show t1) False = undefined
unify t0 t1 = do unify t0 t1 = do
case (t0, t1) of case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do (T.TFun a b, T.TFun c d) -> do
@ -293,7 +309,7 @@ unify t0 t1 = do
(T.TAll _ t, b) -> unify t b (T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t (a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) -> (T.TLit a, T.TLit b) ->
if a == b then return M.empty else throwError "Types do not unify" if a == b then return M.empty else throwError . unwords $ ["Can not unify", "'" ++ printTree (T.TLit a) ++ "'", "with", "'" ++ printTree (T.TLit b) ++ "'"]
(T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) -> (T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do