From b0ec5a2333275e50e993cb2efa5902d91f515417 Mon Sep 17 00:00:00 2001 From: Samuel Hammersberg Date: Fri, 31 Mar 2023 18:16:26 +0200 Subject: [PATCH] Started working on a Case Desugar phase. --- src/CaseDesugar/CaseDesugar.hs | 83 ++++++++++++ src/CaseDesugar/CaseDesugarIr.hs | 226 +++++++++++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 src/CaseDesugar/CaseDesugar.hs create mode 100644 src/CaseDesugar/CaseDesugarIr.hs diff --git a/src/CaseDesugar/CaseDesugar.hs b/src/CaseDesugar/CaseDesugar.hs new file mode 100644 index 0000000..e1db55e --- /dev/null +++ b/src/CaseDesugar/CaseDesugar.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE LambdaCase #-} + +module CaseDesugar.CaseDesugar (desuga) where + +import CaseDesugar.CaseDesugarIr qualified as CIR +import TypeChecker.TypeCheckerIr qualified as TIR + +desuga :: TIR.Program -> CIR.Program +desuga (TIR.Program x) = CIR.Program $ desugaDef <$> x + +desugaDef :: TIR.Def -> CIR.Def +desugaDef (TIR.DBind bin@TIR.Bind{}) = CIR.DBind $ desugaBind bin +desugaDef (TIR.DData dat@TIR.Data{}) = CIR.DData $ desugaData dat + +desugaData :: TIR.Data -> CIR.Data +desugaData (TIR.Data t injs) = CIR.Data (desugaType t) (desugaInj <$> injs) + +desugaType :: TIR.Type -> CIR.Type +desugaType (TIR.TLit (TIR.Ident s)) = CIR.TLit (CIR.Ident s) +desugaType (TIR.TVar tv) = CIR.TVar (desugaTVar tv) +desugaType (TIR.TData (TIR.Ident s) ts) = CIR.TData (CIR.Ident s) (desugaType <$> ts) +desugaType (TIR.TFun t1 t2) = CIR.TFun (desugaType t1) (desugaType t2) +desugaType (TIR.TAll _ t1) = desugaType t1 + +desugaTVar :: TIR.TVar -> CIR.TVar +desugaTVar (TIR.MkTVar (TIR.Ident s)) = CIR.MkTVar (CIR.Ident s) + +desugaInj :: TIR.Inj -> CIR.Inj +desugaInj (TIR.Inj (TIR.Ident s) t) = CIR.Inj (CIR.Ident s) (desugaType t) + +desugaId :: TIR.Id -> CIR.Id +desugaId (TIR.Ident s, t) = (CIR.Ident s, desugaType t) + +desugaBind :: TIR.Bind -> CIR.Bind +desugaBind (TIR.Bind id args exp) = + CIR.Bind (desugaId id) (desugaId <$> args) (desugaExpT exp) + +desugaExpT :: TIR.ExpT -> CIR.ExpT +desugaExpT (exp, t) = (desugaExp exp, desugaType t) + +desugaExp :: TIR.Exp -> CIR.Exp +desugaExp (TIR.EVar (TIR.Ident s)) = CIR.EVar (CIR.Ident s) +desugaExp (TIR.EInj (TIR.Ident s)) = CIR.EInj (CIR.Ident s) +desugaExp (TIR.ELit lit) = CIR.ELit lit +desugaExp (TIR.ELet b e) = CIR.ELet (desugaBind b) (desugaExpT e) +desugaExp (TIR.EApp e1 e2) = CIR.EApp (desugaExpT e1) (desugaExpT e2) +desugaExp (TIR.EAdd e1 e2) = CIR.EAdd (desugaExpT e1) (desugaExpT e2) +desugaExp (TIR.EAbs (TIR.Ident s) e) = CIR.EAbs (CIR.Ident s) (desugaExpT e) +desugaExp (TIR.ECase e branches) = CIR.ECase (desugaExpT e) (desugaBranches branches) + +desugaBranches :: [TIR.Branch] -> [CIR.Branch] +desugaBranches bs = do + let injections = filter (\case (TIR.Branch (TIR.PInj{}, _) _) -> True; _ -> False) bs + let patterns = filter (\case (TIR.Branch (TIR.PInj{}, _) _) -> True; _ -> False) bs + undefined + +desugaBranch :: TIR.Branch -> CIR.Branch +desugaBranch (TIR.Branch (TIR.PInj (TIR.Ident s) ps, pt) e) = do + undefined +desugaBranch (TIR.Branch (p, pt) e) = do + CIR.Branch + ( case p of + TIR.PVar id -> (CIR.PVar (desugaId id), desugaType pt) + TIR.PLit (lit, t) -> (CIR.PLit (lit, desugaType t), desugaType pt) + TIR.PCatch -> (CIR.PCatch, desugaType pt) + TIR.PEnum (TIR.Ident s) -> (CIR.PEnum (CIR.Ident s), desugaType pt) + ) + (desugaExpT e) + +{- +case (Tupli 5 5) of + Tupli 6 5 => 1 + Tupli _ x => 3 + x => 1 +=== +case (Tupli 5 5) of + Tupli x y => case x of + 6 => case y of + 5 => 1 + x => 3 + _ => case y of + x => 3 +-} \ No newline at end of file diff --git a/src/CaseDesugar/CaseDesugarIr.hs b/src/CaseDesugar/CaseDesugarIr.hs new file mode 100644 index 0000000..dd9864f --- /dev/null +++ b/src/CaseDesugar/CaseDesugarIr.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module CaseDesugar.CaseDesugarIr ( + module Grammar.Abs, + module CaseDesugar.CaseDesugarIr, +) where + +import Data.String (IsString) +import Grammar.Abs (Lit (..)) +import Grammar.Print +import Prelude +import Prelude qualified as C (Eq, Ord, Read, Show) + +newtype Program' t = Program [Def' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Def' t + = DBind (Bind' t) + | DData (Data' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Type + = TLit Ident + | TVar TVar + | TData Ident [Type] + | TFun Type Type + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Data' t = Data t [Inj' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Inj' t = Inj Ident t + deriving (C.Eq, C.Ord, C.Show, C.Read) + +newtype Ident = Ident String + deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) + +data Pattern' t + = PVar (Id' t) -- TODO should be Ident + | PLit (Lit, t) -- TODO should be Lit + | PCatch + | PEnum Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Exp' t + = EVar Ident + | EInj Ident + | ELit Lit + | ELet (Bind' t) (ExpT' t) + | EApp (ExpT' t) (ExpT' t) + | EAdd (ExpT' t) (ExpT' t) + | EAbs Ident (ExpT' t) + | ECase (ExpT' t) [Branch' t] + deriving (C.Eq, C.Ord, C.Show, C.Read) + +newtype TVar = MkTVar Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) + +type Id' t = (Ident, t) +type ExpT' t = (Exp' t, t) + +data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +data Branch' t = Branch (Pattern' t, t) (ExpT' t) + deriving (C.Eq, C.Ord, C.Show, C.Read) + +instance Print Ident where + prt _ (Ident s) = doc $ showString s + +instance Print t => Print (Program' t) where + prt i (Program sc) = prPrec i 0 $ prt 0 sc + +instance Print t => Print (Bind' t) where + prt i (Bind sig@(name, _) parms rhs) = + prPrec i 0 $ + concatD + [ prtSig sig + , prt 0 name + , prtIdPs 0 parms + , doc $ showString "=" + , prt 0 rhs + ] + +prtSig :: Print t => Id' t -> Doc +prtSig (name, t) = + concatD + [ prt 0 name + , doc $ showString ":" + , prt 0 t + , doc $ showString ";" + ] + +instance Print t => Print (ExpT' t) where + prt i (e, t) = + concatD + [ doc $ showString "(" + , prt i e + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print t => Print [Bind' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +prtIdPs :: Print t => Int -> [Id' t] -> Doc +prtIdPs i = prPrec i 0 . concatD . map (prt i) + +instance Print t => Print (Id' t) where + prt i (name, t) = + concatD + [ doc $ showString "(" + , prt i name + , doc $ showString "," + , prt i t + , doc $ showString ")" + ] + +instance Print t => Print (Exp' t) where + prt i = \case + EVar name -> prPrec i 3 $ prt 0 name + EInj name -> prPrec i 3 $ prt 0 name + ELit lit -> prPrec i 3 $ prt 0 lit + ELet b e -> + prPrec i 3 $ + concatD + [ doc $ showString "let" + , prt 0 b + , doc $ showString "in" + , prt 0 e + ] + EApp e1 e2 -> + prPrec i 2 $ + concatD + [ prt 2 e1 + , prt 3 e2 + ] + EAdd e1 e2 -> + prPrec i 1 $ + concatD + [ prt 1 e1 + , doc $ showString "+" + , prt 2 e2 + ] + EAbs v e -> + prPrec i 0 $ + concatD + [ doc $ showString "\\" + , prt 0 v + , doc $ showString "." + , prt 0 e + ] + ECase e branches -> + prPrec i 0 $ + concatD + [ doc $ showString "case" + , prt 0 e + , doc $ showString "of" + , doc $ showString "{" + , prt 0 branches + , doc $ showString "}" + ] + +instance Print t => Print (Branch' t) where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print t => Print [Branch' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print t => Print (Def' t) where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print t => Print (Data' t) where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print t => Print (Inj' t) where + prt i = \case + Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + +instance Print t => Print (Pattern' t) where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + +instance Print t => Print [Def' t] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print [Type] where + prt _ [] = concatD [] + prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] + +instance Print Type where + prt i = \case + TLit uident -> prPrec i 1 (concatD [prt 0 uident]) + TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) + TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) + TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) + +instance Print TVar where + prt i (MkTVar ident) = prt i ident + +type Program = Program' Type +type Def = Def' Type +type Data = Data' Type +type Bind = Bind' Type +type Branch = Branch' Type +type Pattern = Pattern' Type +type Inj = Inj' Type +type Exp = Exp' Type +type ExpT = ExpT' Type +type Id = Id' Type +pattern DBind' id vars expt = DBind (Bind id vars expt) +pattern DData' typ injs = DData (Data typ injs)