From 3371c3a146b2626ffd803d757f9a39ba2af5a018 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Fri, 24 Mar 2023 11:21:25 +0100 Subject: [PATCH] Remade lets with bind & improvements --- Grammar.cf | 2 +- src/Renamer/Renamer.hs | 20 ++++++++------- src/TypeChecker/TypeChecker.hs | 46 +++++++++++++++++++++++----------- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/Grammar.cf b/Grammar.cf index 3bb15bd..65d5782 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -51,7 +51,7 @@ ECons. Exp4 ::= UIdent ; ELit. Exp4 ::= Lit ; EApp. Exp3 ::= Exp3 Exp4 ; EAdd. Exp1 ::= Exp1 "+" Exp2 ; -ELet. Exp ::= "let" LIdent "=" Exp "in" Exp ; +ELet. Exp ::= "let" Bind "in" Exp ; EAbs. Exp ::= "\\" LIdent "." Exp ; ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index 3fa1afc..e60310e 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -36,10 +36,7 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef renameDef :: Def -> Rn Def renameDef = \case DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ - DBind (Bind name vars rhs) -> do - (new_names, vars') <- newNames initNames (coerce vars) - rhs' <- snd <$> renameExp new_names rhs - pure . DBind $ Bind name (coerce vars') rhs' + DBind bind -> DBind . snd <$> renameBind initNames bind DData (Data (Indexed cname types) constrs) -> do tvars_ <- tvars tvars' <- mapM nextNameTVar tvars_ @@ -61,6 +58,12 @@ renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef renameConstr new_types (Constructor name typ) = Constructor name $ substituteTVar new_types typ +renameBind :: Names -> Bind -> Rn (Names, Bind) +renameBind old_names (Bind name vars rhs) = do + (new_names, vars') <- newNames old_names (coerce vars) + (newer_names, rhs') <- renameExp new_names rhs + pure (newer_names, Bind name (coerce vars') rhs') + substituteTVar :: [(TVar, TVar)] -> Type -> Type substituteTVar new_names typ = case typ of TLit _ -> typ @@ -110,11 +113,10 @@ renameExp old_names = \case pure (Map.union env1 env2, EAdd e1' e2') -- TODO fix shadowing - ELet name rhs e -> do - (new_names, name') <- newName old_names (coerce name) - (new_names', rhs') <- renameExp new_names rhs - (new_names'', e') <- renameExp new_names' e - pure (new_names'', ELet (coerce name') rhs' e') + ELet bind e -> do + (new_names, bind') <- renameBind old_names bind + (new_names', e') <- renameExp new_names e + pure (new_names', ELet bind' e') EAbs par e -> do (new_names, par') <- newName old_names (coerce par) (new_names', e') <- renameExp new_names e diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index a2b4308..712c1cd 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -78,8 +78,7 @@ checkPrg (Program bs) = do preRun bs -- Type check the program twice to produce all top-level types in the first pass through bs' <- checkDef bs - trace "\nFIRST ITERATION" return () - trace (printTree bs' ++ "\nSECOND ITERATION\n") return () + trace ("FIRST ITERATION: " ++ printTree bs') pure () bs'' <- checkDef bs return $ T.Program bs'' where @@ -106,23 +105,35 @@ checkPrg (Program bs) = do checkBind :: Bind -> Infer T.Bind checkBind (Bind name args e) = do - -- let lambda = makeLambda e (reverse $ coerce args) + let lambda = makeLambda e (reverse (coerce args)) + (_, lambdaT) <- inferExp lambda args <- zip args <$> mapM (const fresh) args withBindings (map coerce args) $ do - e@(_, t') <- inferExp e + e@(_, _) <- inferExp e s <- gets sigs + -- let fs = map (second Just) (getFunctionTypes s e) + -- mapM_ (uncurry insertSig) fs case M.lookup (coerce name) s of Just (Just t) -> do - sub <- unify t t' + sub <- unify t lambdaT let newT = apply sub t insertSig (coerce name) (Just newT) return $ T.Bind (coerce name, newT) (map coerce args) e _ -> do - insertSig (coerce name) (Just t') - return (T.Bind (coerce name, t') (map coerce args) e) -- (apply s e) - where - makeLambda :: Exp -> [Ident] -> Exp - makeLambda = foldl (flip (EAbs . coerce)) + insertSig (coerce name) (Just lambdaT) + return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e) + -- where + -- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)] + -- getFunctionTypes s = \case + -- (T.EId b, t) -> case M.lookup b s of + -- Just Nothing -> return (b, t) + -- _ -> [] + -- (T.ELit _, _) -> [] + -- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EApp e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EAbs _ e, _) -> getFunctionTypes s e + -- (T.ECase e injs, _) -> getFunctionTypes s e ++ concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq _ (T.TAll _ _) = True @@ -263,14 +274,15 @@ algoW = \case -- The bar over S₀ and Γ means "generalize" - ELet name e0 e1 -> do - (s1, (e0', t1)) <- algoW e0 + ELet b@(Bind name args e) e1 -> do + (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) + bind' <- checkBind b env <- asks vars - let t' = generalize (apply s1 env) t1 + let t' = generalize (apply s1 env) t0 withBinding (coerce name) t' $ do (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 - return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2)) + return (comp, apply comp (T.ELet bind' (e1', t2), t2)) -- \| TODO: Add judgement ECase caseExpr injs -> do @@ -280,8 +292,12 @@ algoW = \case let t' = apply comp ret_t return (comp, (T.ECase (e', t) injs, t')) +makeLambda :: Exp -> [Ident] -> Exp +makeLambda = foldl (flip (EAbs . coerce)) + -- | Unify two types producing a new substitution unify :: T.Type -> T.Type -> Infer Subst +unify t0 t1 | trace ("T0: " ++ show t0 ++ "\nT1: " ++ show t1) False = undefined unify t0 t1 = do case (t0, t1) of (T.TFun a b, T.TFun c d) -> do @@ -293,7 +309,7 @@ unify t0 t1 = do (T.TAll _ t, b) -> unify t b (a, T.TAll _ t) -> unify a t (T.TLit a, T.TLit b) -> - if a == b then return M.empty else throwError "Types do not unify" + if a == b then return M.empty else throwError . unwords $ ["Can not unify", "'" ++ printTree (T.TLit a) ++ "'", "with", "'" ++ printTree (T.TLit b) ++ "'"] (T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) -> if name == name' && length t == length t' then do