diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index ac5556b..5acb832 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,5 +1,8 @@ add : Int ; -add = 5; +add = 4; main : Int ; -main = add ; +main = case add of { + 5 => 0; + _ => 1; +}; diff --git a/src/Main.hs b/src/Main.hs index 02c49d0..d0f544c 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,27 +2,30 @@ module Main where -import Codegen.Codegen (generateCode) -import GHC.IO.Handle.Text (hPutStrLn) -import Grammar.ErrM (Err) -import Grammar.Par (myLexer, pProgram) -import Grammar.Print (printTree) +import Codegen.Codegen (generateCode) +import GHC.IO.Handle.Text (hPutStrLn) +import Grammar.ErrM (Err) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) -import Monomorphizer.Monomorphizer (monomorphize) +import Monomorphizer.Monomorphizer (monomorphize) -import Control.Monad (when) -import Data.List.Extra (isSuffixOf) +import Control.Monad (when) +import Data.List.Extra (isSuffixOf) -import Renamer.Renamer (rename) -import System.Directory (createDirectory, doesPathExist, - getDirectoryContents, - removeDirectoryRecursive, - setCurrentDirectory) -import System.Environment (getArgs) -import System.Exit (exitFailure, exitSuccess) -import System.IO (stderr) -import System.Process.Extra (spawnCommand, waitForProcess) -import TypeChecker.TypeChecker (typecheck) +import Renamer.Renamer (rename) +import System.Directory ( + createDirectory, + doesPathExist, + getDirectoryContents, + removeDirectoryRecursive, + setCurrentDirectory, + ) +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (stderr) +import System.Process.Extra (spawnCommand, waitForProcess) +import TypeChecker.TypeChecker (typecheck) main :: IO () main = @@ -59,7 +62,7 @@ main' debug s = do when check (removeDirectoryRecursive "output") createDirectory "output" writeFile "output/llvm.ll" compiled - if debug then debugDotViz else putStrLn compiled + -- if debug then debugDotViz else putStrLn compiled -- interpred <- fromInterpreterErr $ interpret lifted -- putStrLn "\n-- interpret" diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs index 1b1d2f6..051641e 100644 --- a/src/Monomorphizer/Monomorphizer.hs +++ b/src/Monomorphizer/Monomorphizer.hs @@ -1,14 +1,15 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} module Monomorphizer.Monomorphizer (monomorphize) where -import Data.Coerce (coerce) -import Grammar.Abs (Constructor (..), Ident (..)) -import Unsafe.Coerce (unsafeCoerce) +import Data.Coerce (coerce) +import Grammar.Abs (Constructor (..), Ident (..)) +import Unsafe.Coerce (unsafeCoerce) -import qualified Grammar.Abs as GA -import qualified Monomorphizer.MonomorphizerIr as M -import qualified TypeChecker.TypeCheckerIr as T +import Grammar.Abs qualified as GA +import Monomorphizer.MonomorphizerIr qualified as M +import TypeChecker.TypeCheckerIr qualified as T monomorphize :: T.Program -> M.Program monomorphize (T.Program ds) = M.Program $ monoDefs ds @@ -18,7 +19,7 @@ monoDefs = map monoDef monoDef :: T.Def -> M.Def monoDef (T.DBind bind) = M.DBind $ monoBind bind -monoDef (T.DData d) = M.DData $ unsafeCoerce d +monoDef (T.DData d) = M.DData $ unsafeCoerce d monoBind :: T.Bind -> M.Bind monoBind (T.Bind name args (e, t)) = M.Bind (monoId name) (map monoId args) (monoExpr e, monoType t) @@ -34,19 +35,19 @@ monoExpr = \case T.ECase expt injs -> M.ECase (monoexpt expt) (monoInjs injs) monoAbsType :: GA.Type -> M.Type -monoAbsType (GA.TLit u) = M.TLit (coerce u) -monoAbsType (GA.TVar _v) = error "NOT POLYMORHPIC TYPES" +monoAbsType (GA.TLit u) = M.TLit (coerce u) +monoAbsType (GA.TVar _v) = M.TLit "Int" monoAbsType (GA.TAll _v _t) = error "NOT ALL TYPES" -monoAbsType (GA.TEVar _v) = error "I DONT KNOW WHAT THIS IS" +monoAbsType (GA.TEVar _v) = error "I DONT KNOW WHAT THIS IS" monoAbsType (GA.TFun t1 t2) = M.TFun (monoAbsType t1) (monoAbsType t2) -monoAbsType (GA.TData _ _) = error "NOT INDEXED TYPES" +monoAbsType (GA.TData _ _) = error "NOT INDEXED TYPES" monoType :: T.Type -> M.Type -monoType (T.TAll _ t) = monoType t -monoType (T.TVar (T.MkTVar i)) = error "NOT POLYMORPHIC TYPES" -monoType (T.TLit (T.Ident i)) = M.TLit (Ident i) -monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) -monoType (T.TData _ _) = error "Not sure what this is" +monoType (T.TAll _ t) = monoType t +monoType (T.TVar (T.MkTVar i)) = M.TLit "Int" +monoType (T.TLit (T.Ident i)) = M.TLit (Ident i) +monoType (T.TFun t1 t2) = M.TFun (monoType t1) (monoType t2) +monoType (T.TData _ _) = error "Not sure what this is" monoexpt :: T.ExpT -> M.ExpT monoexpt (e, t) = (monoExpr e, monoType t) @@ -55,7 +56,7 @@ monoId :: T.Id -> M.Id monoId (n, t) = (coerce n, monoType t) monoLit :: T.Lit -> M.Lit -monoLit (T.LInt i) = M.LInt i +monoLit (T.LInt i) = M.LInt i monoLit (T.LChar c) = M.LChar c monoInjs :: [T.Branch] -> [M.Branch] @@ -65,7 +66,7 @@ monoInj :: T.Branch -> M.Branch monoInj (T.Branch (init, t) expt) = M.Branch (monoInit init, monoType t) (monoexpt expt) monoInit :: T.Pattern -> M.Pattern -monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t) +monoInit (T.PVar (id, t)) = M.PVar (coerce id, monoType t) monoInit (T.PLit (lit, t)) = M.PLit (monoLit lit, monoType t) -monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps) -monoInit T.PCatch = M.PCatch +monoInit (T.PInj id ps) = M.PInj (coerce id) (monoInit <$> ps) +monoInit T.PCatch = M.PCatch diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 03126e7..9bcb67b 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -117,36 +117,22 @@ checkPrg (Program bs) = do (DSig _) -> checkDef xs checkBind :: Bind -> Infer T.Bind -checkBind (Bind name args e) = do +checkBind err@(Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) (_, lambdaT) <- inferExp lambda args <- zip args <$> mapM (const fresh) args withBindings (map coerce args) $ do e@(_, _) <- inferExp e s <- gets sigs - -- let fs = map (second Just) (getFunctionTypes s e) - -- mapM_ (uncurry insertSig) fs case M.lookup (coerce name) s of Just (Just t) -> do - sub <- unify t lambdaT + sub <- bindErr (unify t lambdaT) err let newT = apply sub t insertSig (coerce name) (Just newT) return $ T.Bind (coerce name, newT) (map coerce args) e _ -> do insertSig (coerce name) (Just lambdaT) return (T.Bind (coerce name, lambdaT) (map coerce args) e) -- (apply s e) - -- where - -- getFunctionTypes :: Map T.Ident (Maybe T.Type) -> T.ExpT -> [(T.Ident, T.Type)] - -- getFunctionTypes s = \case - -- (T.EId b, t) -> case M.lookup b s of - -- Just Nothing -> return (b, t) - -- _ -> [] - -- (T.ELit _, _) -> [] - -- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 <> getFunctionTypes s e2 - -- (T.EApp e1 e2, _) -> getFunctionTypes s e1 <> getFunctionTypes s e2 - -- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 <> getFunctionTypes s e2 - -- (T.EAbs _ e, _) -> getFunctionTypes s e - -- (T.ECase e injs, _) -> getFunctionTypes s e <> concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq _ (T.TAll _ _) = True @@ -292,9 +278,9 @@ algoW = \case err@(EApp e0 e1) -> do fr <- fresh - (s0, (e0', t0)) <- exprErr (algoW e0) err + (s0, (e0', t0)) <- algoW e0 applySt s0 $ do - (s1, (e1', t1)) <- exprErr (algoW e1) err + (s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0 @@ -307,7 +293,7 @@ algoW = \case -- The bar over S₀ and Γ means "generalize" err@(ELet b@(Bind name args e) e1) -> do - (s1, (_, t0)) <- exprErr (algoW (makeLambda e (coerce args))) err + (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) bind' <- exprErr (checkBind b) err env <- asks vars let t' = generalize (apply s1 env) t0 @@ -322,7 +308,7 @@ algoW = \case (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub let t' = apply comp ret_t - return (comp, (T.ECase (e', t) injs, t')) + return (comp, apply comp (T.ECase (e', t) injs, t')) makeLambda :: Exp -> [T.Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) @@ -424,13 +410,14 @@ compose m1 m2 = M.map (apply m1) m2 `M.union` m1 -- and one for applying substitutions -- | A class representing free variables functions +class SubstType t where + -- | Apply a substitution to t + apply :: Subst -> t -> t + class FreeVars t where -- | Get all free variables from t free :: t -> Set T.Ident - -- | Apply a substitution to t - apply :: Subst -> t -> t - instance FreeVars T.Type where free :: T.Type -> Set T.Ident free (T.TVar (T.MkTVar a)) = S.singleton a @@ -441,6 +428,7 @@ instance FreeVars T.Type where free (T.TData _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a +instance SubstType T.Type where apply :: Subst -> T.Type -> T.Type apply sub t = do case t of @@ -453,16 +441,15 @@ instance FreeVars T.Type where Just _ -> apply sub t T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TData name a -> T.TData name (map (apply sub) a) - instance FreeVars (Map T.Ident T.Type) where free :: Map T.Ident T.Type -> Set T.Ident free m = foldl' S.union S.empty (map free $ M.elems m) + +instance SubstType (Map T.Ident T.Type) where apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply s = M.map (apply s) -instance FreeVars T.ExpT where - free :: T.ExpT -> Set T.Ident - free = error "free not implemented for T.Exp" +instance SubstType T.ExpT where apply :: Subst -> T.ExpT -> T.ExpT apply s = \case (T.EId i, outerT) -> (T.EId i, apply s outerT) @@ -476,17 +463,22 @@ instance FreeVars T.ExpT where (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) (T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t) (T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1) - (T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t) + (T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t) -instance FreeVars T.Branch where - free :: T.Branch -> Set T.Ident - free = undefined +instance SubstType T.Branch where apply :: Subst -> T.Branch -> T.Branch - apply s (T.Branch (i, t) e) = T.Branch (i, apply s t) (apply s e) + apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) -instance FreeVars [T.Branch] where - free :: [T.Branch] -> Set T.Ident - free = foldl' (\acc x -> free x `S.union` acc) mempty +instance SubstType T.Pattern where + apply :: Subst -> T.Pattern -> T.Pattern + apply s = \case + T.PVar (iden, t) -> T.PVar (iden, apply s t) + T.PLit (lit, t) -> T.PLit (lit, apply s t) + T.PInj i ps -> T.PInj i $ apply s ps + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i + +instance SubstType a => SubstType [a] where apply s = map (apply s) -- | Apply substitutions to the environment. @@ -552,8 +544,6 @@ inferBranch (Branch pat expr) = do newExp@(_, exprT) <- withPattern pat (inferExp expr) return (branchT, T.Branch newPat newExp, exprT) --- return (initT, T.Branch (it, initT) (e, exprT), exprT) - withPattern :: T.Pattern -> Infer a -> Infer a withPattern p ma = case p of T.PVar (x, t) -> withBinding x t ma @@ -608,3 +598,7 @@ partitionType = go [] exprErr :: Infer a -> Exp -> Infer a exprErr ma exp = catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) + +bindErr :: Infer a -> Bind -> Infer a +bindErr ma exp = + catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index be54d35..d14c736 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -214,6 +214,7 @@ instance Print Pattern where PLit (lit, typ) -> prPrec i 0 (concatD [doc $ showString "(", prt 0 lit, doc $ showString ",", prt 0 typ, doc $ showString ")"]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 0 patterns]) PCatch -> prPrec i 0 (concatD [doc (showString "_")]) + PEnum p -> prt i p instance Print [Branch] where prt _ [] = concatD [] diff --git a/test_program b/test_program index c5d39f6..ee74589 100644 --- a/test_program +++ b/test_program @@ -3,41 +3,39 @@ data List (a) where { Cons : a -> List (a) -> List (a) }; -data Bool () where { - True : Bool () - False : Bool () - }; +-- data Bool () where { +-- True : Bool () +-- False : Bool () +-- }; -hello_world = Cons 'h' (Cons 'e' (Cons 'l' (Cons 'l' (Cons 'o' (Cons ' ' (Cons 'w' (Cons 'o' (Cons 'r' (Cons 'l' (Cons 'd' Nil)))))))))) ; +-- hello_world = Cons 'h' (Cons 'e' (Cons 'l' (Cons 'l' (Cons 'o' (Cons ' ' (Cons 'w' (Cons 'o' (Cons 'r' (Cons 'l' (Cons 'd' Nil)))))))))) ; -length : List (a) -> Int ; -length xs = case xs of { - Nil => 0; - Cons x xs => length xs; -}; +-- length : List (a) -> Int ; +-- length xs = case xs of { +-- Nil => 0; +-- Cons x xs => length xs; +-- }; -head : List (a) -> a ; -head xs = case xs of { - Cons x xs => x; -}; +-- head : List (a) -> a ; +-- head xs = case xs of { +-- Cons x xs => x; +-- }; -firstIsOne : List (Int) -> Bool () ; -firstIsOne xs = case xs of { - Cons x xs => case x of { - 0 => True; - _ => case xs of { - Cons x xs => False; - _ => False; - }; - }; - _ => False; - }; - -main = firstIsOne (Cons 1 Nil); - -deepPat xs = case xs of { - Cons (Nil) _ => True; - _ => False; - }; +-- firstIsOne : List (Int) -> Bool () ; +-- firstIsOne xs = case xs of { +-- Cons x xs => case x of { +-- 0 => True; +-- _ => case xs of { +-- Cons x xs => False; +-- _ => False; +-- }; +-- }; +-- _ => False; +-- }; +-- main = firstIsOne (Cons 1 Nil); +test xs = case xs of { + 1 => 0; + lol => 1; + };