diff --git a/sample-programs/mono-1.crf b/sample-programs/mono-1.crf index 9c0a08f..568c674 100644 --- a/sample-programs/mono-1.crf +++ b/sample-programs/mono-1.crf @@ -1,6 +1,6 @@ -const x y = x +const2 x y = x -f x = (const x 'c') +f x = (const2 x 'c') main = f 5 diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs index c50a7cc..6c851f9 100644 --- a/src/Monomorphizer/Monomorphizer.hs +++ b/src/Monomorphizer/Monomorphizer.hs @@ -91,7 +91,10 @@ isBindMarked ident = gets (Map.member ident) -- | Finds main bind. getMain :: EnvM T.Bind -getMain = asks (\env -> fromJust $ Map.lookup (T.Ident "main") (input env)) +getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of + Just mainBind -> mainBind + Nothing -> error "main not found in monomorphizer!" + ) -- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime -- error when encountering different structures between the two arguments. @@ -219,30 +222,24 @@ morphBranch (T.Branch (p, pt) (e, et)) = do pt' <- getMonoFromPoly pt et' <- getMonoFromPoly et env <- ask - (p', newLocals) <- morphPattern pt' (locals env) (p, pt) - local (const env { locals = newLocals }) $ do + (p', newLocals) <- morphPattern p pt' + local (const env { locals = Set.union (locals env) newLocals }) $ do e' <- morphExp et' e return $ M.Branch (p', pt') (e', et') --- | Morphs pattern (pattern => expression), gives the newly bound local variables. -morphPattern :: M.Type -> Set.Set Ident -> (T.Pattern, T.Type) -> EnvM (M.Pattern, Set.Set Ident) -morphPattern expectedType ls (p, t) = case p of - T.PVar ident -> do t' <- getMonoFromPoly t - return (M.PVar (ident, t'), Set.insert ident ls) - T.PLit lit -> do t' <- getMonoFromPoly t - return (M.PLit (convertLit lit, t'), ls) - T.PCatch -> return (M.PCatch, ls) - -- Constructor ident - T.PEnum ident -> do morphCons expectedType ident - return (M.PEnum ident, ls) - T.PInj ident ps -> do morphCons expectedType ident - let (M.TData tIdent ts) = expectedType - -- TODO: this is wrong! - pairs <- mapM (\(pat, patT) -> morphPattern patT ls pat) (zip ps ts) - if length ts == length ps then - return (M.PCatch, Set.singleton $ Ident "$1y") - else return (M.PInj ident (map fst pairs), Set.unions (map snd pairs)) - +morphPattern :: T.Pattern -> M.Type -> EnvM (M.Pattern, Set.Set Ident) +morphPattern p expectedType = case p of + T.PVar ident -> return (M.PVar (ident, expectedType), Set.singleton ident) + T.PLit lit -> return (M.PLit (convertLit lit, expectedType), Set.empty) + T.PCatch -> return (M.PCatch, Set.empty) + T.PEnum ident -> do morphCons expectedType ident + return (M.PEnum ident, Set.empty) + T.PInj ident pts -> do morphCons expectedType ident + ts' <- mapM (getMonoFromPoly . snd) pts + let pts' = zip (map fst pts) ts' + psSets <- mapM (uncurry morphPattern) pts' + return (M.PInj ident (map fst psSets), Set.unions $ map snd psSets) + -- | Creates a new identifier for a function with an assigned type. newFuncName :: M.Type -> T.Bind -> Ident newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =