Incorporated most of main, as well as started on quickcheck

This commit is contained in:
sebastianselander 2023-02-27 11:12:05 +01:00
parent 06e65de235
commit 2f45f39435
19 changed files with 1252 additions and 1090 deletions

View file

@ -1,27 +1,54 @@
Program. Program ::= [Bind] ;
Program. Program ::= [Def] ;
DBind. Def ::= Bind ;
DData. Def ::= Data ;
terminator Def ";" ;
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ;
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ;
EInt. Exp4 ::= Integer ;
ELit. Exp4 ::= Literal ;
EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
TMono. Type1 ::= "Mono" Ident ;
TPol. Type1 ::= "Poly" Ident ;
LInt. Literal ::= Integer ;
Inj. Inj ::= Init "=>" Exp ;
terminator Inj ";" ;
InitLit. Init ::= Literal ;
InitConstr. Init ::= Ident [Match] ;
InitCatch. Init ::= "_" ;
LMatch. Match ::= Literal ;
IMatch. Match ::= Ident ;
InitMatch. Match ::= Ident Match ;
separator Match " " ;
TMono. Type1 ::= "_" Ident ;
TPol. Type1 ::= "'" Ident ;
TArr. Type ::= Type1 "->" Type ;
separator Type " " ;
-- shift/reduce problem here
Data. Data ::= "data" Ident [Type] "where" ";"
[Constructor];
terminator Constructor ";" ;
Constructor. Constructor ::= Ident ":" Type ;
-- This doesn't seem to work so we'll have to live with ugly keywords for now
-- token Upper (upper (letter | digit | '_')*) ;
-- token Lower (lower (letter | digit | '_')*) ;
-- token Poly upper (letter | digit | '_')* ;
-- token Mono lower (letter | digit | '_')* ;
separator Bind ";" ;
terminator Bind ";" ;
separator Ident " ";
coercions Type 1 ;
@ -29,3 +56,4 @@ coercions Exp 5 ;
comment "--" ;
comment "{-" "-}" ;

View file

@ -16,7 +16,7 @@ extra-source-files:
Grammar.cf
common warnings
ghc-options: -Wdefault
ghc-options: -W
executable language
import: warnings
@ -31,15 +31,12 @@ executable language
Grammar.Skel
Grammar.ErrM
Auxiliary
-- TypeChecker.TypeChecker
-- TypeChecker.TypeCheckerIr
-- TypeChecker.Unification
TypeChecker.HM
TypeChecker.AlgoW
TypeChecker.HMIr
Renamer.RenamerM
-- Renamer.Renamer
-- Renamer.RenamerIr
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
LambdaLifter.LambdaLifter
Codegen.Codegen
Codegen.LlvmIr
hs-source-dirs: src
@ -50,34 +47,35 @@ executable language
, either
, extra
, array
, QuickCheck
default-language: GHC2021
test-suite test
hs-source-dirs: tests, src
main-is: Main.hs
type: exitcode-stdio-1.0
Test-suite language-testsuite
type: exitcode-stdio-1.0
main-is: Tests.hs
other-modules:
Grammar.Abs
Grammar.Lex
Grammar.Par
Grammar.Print
Grammar.Skel
Grammar.ErrM
Auxiliary
Renamer.RenamerM
TypeChecker.AlgoW
TypeChecker.HM
TypeChecker.HMIr
other-modules:
Grammar.Abs
Grammar.Lex
Grammar.Par
Grammar.Print
Grammar.Skel
Grammar.ErrM
Auxiliary
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
build-depends:
base >=4.16
, mtl
, containers
, either
, array
, extra
, hspec
hs-source-dirs: src, tests
default-language: GHC2021
build-depends:
base >=4.16
, mtl
, containers
, either
, extra
, array
, QuickCheck
default-language: GHC2021

277
src/Codegen/Codegen.hs Normal file
View file

@ -0,0 +1,277 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (compile) where
import Auxiliary (snoc)
import Codegen.LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..),
llvmIrToString)
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import TypeChecker.TypeChecker
import TypeChecker.TypeCheckerIr
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, variableCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
-- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation.
getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ map go bs
where
go (Bind id args _) =
(id, FunctionInfo { numArgs=length args, arguments=args })
initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart
, functions = getFunctions scs
, variableCount = 0
}
-- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to
-- Simply pipe it to lli
compile :: Program -> Err String
compile (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where
t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
compileExp :: Exp -> CompilerState ()
compileExp = \case
ELit _ (LInt i) -> emitInt i
EAdd t e1 e2 -> emitAdd t e1 e2
EId (name, _) -> emitIdent name
EApp t e1 e2 -> emitApp t e1 e2
EAbs t ti e -> emitAbs t ti e
ELet bind e -> emitLet bind e
--- aux functions ---
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet b e = emit . Comment $ concat [ "ELet ("
, show b
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 []
where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do
let newStack = e2 : stack
case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState ()
emitInt i = do
-- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
ELit _ (LInt i) -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions
case Map.lookup id funcs of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $ SetVariable (Ident $ show vc)
(Call (type2LlvmType t) Global name [])
pure $ VIdent (Ident $ show vc) (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType
type2LlvmType = \case
(TMono "Int") -> I64
TArr t xs -> do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
t -> I64 --CustomType $ Ident ("\"" ++ show t ++ "\"")
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TArr t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (ELit _ (LInt _)) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t
-- | Partion type into types of parameters and return type.
partitionType :: Int -- Number of parameters to apply
-> Type
-> ([Type], Type)
partitionType = go []
where
go acc 0 t = (acc, t)
go acc i t = case t of
TArr t1 t2 -> go (snoc t1 acc) (i - 1) t2
_ -> error "Number of parameters and type doesn't match"

204
src/Codegen/LlvmIr.hs Normal file
View file

@ -0,0 +1,204 @@
{-# LANGUAGE LambdaCase #-}
module Codegen.LlvmIr (
LLVMType (..),
LLVMIr (..),
llvmIrToString,
LLVMValue (..),
LLVMComp (..),
Visibility (..),
) where
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr
-- | A datatype which represents some basic LLVM types
data LLVMType
= I1
| I8
| I32
| I64
| Ptr
| Ref LLVMType
| Function LLVMType [LLVMType]
| Array Int LLVMType
| CustomType Ident
instance Show LLVMType where
show :: LLVMType -> String
show = \case
I1 -> "i1"
I8 -> "i8"
I32 -> "i32"
I64 -> "i64"
Ptr -> "ptr"
Ref ty -> show ty <> "*"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"]
CustomType (Ident ty) -> ty
data LLVMComp
= LLEq
| LLNe
| LLUgt
| LLUge
| LLUlt
| LLUle
| LLSgt
| LLSge
| LLSlt
| LLSle
instance Show LLVMComp where
show :: LLVMComp -> String
show = \case
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
LLUle -> "ule"
LLSgt -> "sgt"
LLSge -> "sge"
LLSlt -> "slt"
LLSle -> "sle"
data Visibility = Local | Global
instance Show Visibility where
show :: Visibility -> String
show Local = "%"
show Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant
data LLVMValue
= VInteger Integer
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
instance Show LLVMValue where
show :: LLVMValue -> String
show v = case v of
VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
-- | A datatype which represents different instructions in LLVM
data LLVMIr
= Define LLVMType Ident Params
| DefineEnd
| Declare LLVMType Ident Params
| SetVariable Ident LLVMIr
| Variable Ident
| Add LLVMType LLVMValue LLVMValue
| Sub LLVMType LLVMValue LLVMValue
| Div LLVMType LLVMValue LLVMValue
| Mul LLVMType LLVMValue LLVMValue
| Srem LLVMType LLVMValue LLVMValue
| Icmp LLVMComp LLVMType LLVMValue LLVMValue
| Br Ident
| BrCond LLVMValue Ident Ident
| Label Ident
| Call LLVMType Visibility Ident Args
| Alloca LLVMType
| Store LLVMType Ident LLVMType Ident
| Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue
| Comment String
| UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place
deriving (Show)
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
where
go :: Int -> [LLVMIr] -> String
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
insToString n x <> go i' xs
-- | Converts a LLVM inststruction to a String, allowing for printing etc.
-- The integer represents the indentation
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of
(Define t (Ident i) params) ->
concat
[ "define ", show t, " @", i
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
, ") {\n"
]
DefineEnd -> "}\n"
(Declare _t (Ident _i) _params) -> undefined
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
(Add t v1 v2) ->
concat
[ "add ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Sub t v1 v2) ->
concat
[ "sub ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Div t v1 v2) ->
concat
[ "sdiv ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Mul t v1 v2) ->
concat
[ "mul ", show t, " ", show v1
, ", ", show v2, "\n"
]
(Srem t v1 v2) ->
concat
[ "srem ", show t, " ", show v1, ", "
, show v2, "\n"
]
(Call t vis (Ident i) arg) ->
concat
[ "call ", show t, " ", show vis, i, "("
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
, ")\n"
]
(Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) ->
concat
[ "store ", show t1, " %", id1
, ", ", show t2 , " %", id2, "\n"
]
(Bitcast t1 (Ident i) t2) ->
concat
[ "bitcast ", show t1, " %"
, i, " to ", show t2, "\n"
]
(Icmp comp t v1 v2) ->
concat
[ "icmp ", show comp, " ", show t
, " ", show v1, ", ", show v2, "\n"
]
(Ret t v) ->
concat
[ "ret ", show t, " "
, show v, "\n"
]
(UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) ->
concat
[ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n"
]
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id

View file

View file

@ -1,78 +0,0 @@
{-# LANGUAGE LambdaCase #-}
module Interpreter where
import Control.Applicative (Applicative)
import Control.Monad.Except (Except, MonadError (throwError),
liftEither)
import Data.Either.Combinators (maybeToRight)
import Data.Map (Map)
import qualified Data.Map as Map
import Grammar.Abs
import Grammar.Print (printTree)
interpret :: Program -> Except String Integer
interpret (Program e) =
eval mempty e >>= \case
VClosure {} -> throwError "main evaluated to a function"
VInt i -> pure i
data Val = VInt Integer
| VClosure Cxt Ident Exp
type Cxt = Map Ident Val
eval :: Cxt -> Exp -> Except String Val
eval cxt = \case
-- ------------ x ∈ γ
-- γ ⊢ x ⇓ γ(x)
EId x ->
maybeToRightM
("Unbound variable:" ++ printTree x)
$ Map.lookup x cxt
-- ---------
-- γ ⊢ i ⇓ i
EInt i -> pure $ VInt i
-- γ ⊢ e ⇓ let δ in λx. f
-- γ ⊢ e₁ ⇓ v
-- δ,x=v ⊢ f ⇓ v₁
-- ------------------------------
-- γ ⊢ e e₁ ⇓ v₁
EApp e e1 ->
eval cxt e >>= \case
VInt _ -> throwError "Not a function"
VClosure delta x f -> do
v <- eval cxt e1
eval (Map.insert x v delta) f
--
-- -----------------------------
-- γ ⊢ λx. f ⇓ let γ in λx. f
EAbs x e -> pure $ VClosure cxt x e
-- γ ⊢ e ⇓ v
-- γ ⊢ e₁ ⇓ v₁
-- ------------------
-- γ ⊢ e e₁ ⇓ v + v₁
EAdd e e1 -> do
v <- eval cxt e
v1 <- eval cxt e1
case (v, v1) of
(VInt i, VInt i1) -> pure $ VInt (i + i1)
_ -> throwError "Can't add a function"
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err

View file

@ -0,0 +1,192 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import Renamer.Renamer
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
ELit _ (LInt i) -> (mempty, AInt i)
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
type AnnExp = (Set Id, AnnExp')
data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) =
case e of
AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of
AId n -> pure $ EId n
AInt i -> pure $ ELit (TMono "Int") (LInt i)
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae)
-- Lift lambda into let and bind free variables
AAbs t parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp $ TMono "Int") sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
ELit _ (LInt i) -> ([], ELit (TMono "Int") (LInt i))
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, [])
where
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)

View file

@ -2,42 +2,81 @@
module Main where
import Grammar.Par (myLexer, pProgram)
-- import TypeChecker.TypeChecker (typecheck)
import Codegen.Codegen (compile)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Grammar.Print (printTree)
import Renamer.RenamerM (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import TypeChecker.AlgoW (typecheck)
import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (stderr)
import TypeChecker.TypeChecker (typecheck)
main :: IO ()
main = getArgs >>= \case
main =
getArgs >>= \case
[] -> print "Required file path missing"
(x : _) -> do
file <- readFile x
case pProgram (myLexer file) of
Left err -> do
putStrLn "SYNTAX ERROR"
putStrLn err
exitFailure
Right prg -> do
putStrLn ""
putStrLn " ----- PARSER ----- "
putStrLn ""
putStrLn . printTree $ prg
case typecheck (rename prg) of
Left err -> do
putStrLn "TYPECHECK ERROR"
print err
exitFailure
Right prg -> do
putStrLn ""
putStrLn " ----- RAW ----- "
putStrLn ""
print prg
putStrLn ""
putStrLn " ----- TYPECHECKER ----- "
putStrLn ""
putStrLn $ printTree prg
exitSuccess
(s : _) -> main' s
main' :: String -> IO ()
main' s = do
file <- readFile s
printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram $ myLexer file
printToErr $ printTree parsed
printToErr "\n-- Renamer --"
let renamed = rename parsed
printToErr $ printTree renamed
printToErr "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck renamed
printToErr $ printTree typechecked
printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked
printToErr $ printTree lifted
printToErr "\n -- Printing compiler output to stdout --"
compiled <- fromCompilerErr $ compile lifted
putStrLn compiled
writeFile "llvm.ll" compiled
exitSuccess
printToErr :: String -> IO ()
printToErr = hPutStrLn stderr
fromCompilerErr :: Err a -> IO a
fromCompilerErr =
either
( \err -> do
putStrLn "\nCOMPILER ERROR"
putStrLn err
exitFailure
)
pure
fromSyntaxErr :: Err a -> IO a
fromSyntaxErr =
either
( \err -> do
putStrLn "\nSYNTAX ERROR"
putStrLn err
exitFailure
)
pure
fromTypeCheckerErr :: Err a -> IO a
fromTypeCheckerErr =
either
( \err -> do
putStrLn "\nTYPECHECKER ERROR"
putStrLn err
exitFailure
)
pure

View file

@ -1,101 +1,91 @@
{-# LANGUAGE LambdaCase, OverloadedRecordDot, OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
module Renamer.Renamer (rename) where
module Renamer.Renamer where
import Renamer.RenamerIr
import Control.Monad.State
import Control.Monad.Except
import Control.Monad.Reader
import Data.Functor.Identity (Identity, runIdentity)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Map (Map)
import qualified Data.Map as M
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
import Renamer.RenamerIr
import qualified Grammar.Abs as Old
type Rename = StateT Ctx (ExceptT Error Identity)
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
data Ctx = Ctx { count :: Integer
, sig :: Set Ident
, env :: Map Ident Integer}
--
run :: Rename a -> Either Error a
run = runIdentity . runExceptT . flip evalStateT initCtx
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }
deriving (Functor, Applicative, Monad, MonadState Int)
initCtx :: Ctx
initCtx = Ctx { count = 0
, sig = mempty
, env = mempty }
-- | Maps old to new name
type Names = Map Ident Ident
rename :: Old.Program -> Either Error RProgram
rename = run . renamePrg
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
renamePrg :: Old.Program -> Rename RProgram
renamePrg (Old.Program xs) = do
xs' <- mapM renameBind xs
return $ RProgram xs'
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
renameBind :: Old.Bind -> Rename RBind
renameBind (Old.Bind n t i args e) = do
insertSig i
e' <- renameExp (makeLambda (reverse args) e)
return $ RBind i e'
where
makeLambda :: [Ident] -> Old.Exp -> Old.Exp
makeLambda [] e = e
makeLambda (x:xs) e = makeLambda xs (Old.EAbs x e)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
renameExp :: Old.Exp -> Rename RExp
renameExp = \case
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
Old.EId i -> do
st <- get
case M.lookup i st.env of
Just n -> return $ RId i
Nothing -> case S.member i st.sig of
True -> return $ RId i
False -> throwError $ UnboundVar (show i)
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
Old.EInt c -> return $ RInt c
ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2')
Old.EAnn e t -> flip RAnn t <$> renameExp e
EAbs par e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
Old.EApp e1 e2 -> RApp <$> renameExp e1 <*> renameExp e2
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
Old.EAdd e1 e2 -> RAdd <$> renameExp e1 <*> renameExp e2
ECase _ _ -> error "ECase NOT IMPLEMENTED YET"
-- Convert let-expressions to lambdas
Old.ELet i e1 e2 -> renameExp (Old.EApp (Old.EAbs i e2) e1)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
Old.EAbs i e -> do
n <- cnt
ctx <- get
insertEnv i n
re <- renameExp e
return $ RAbs n i re
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Get current count and increase it by one
cnt :: Rename Integer
cnt = do
st <- get
put (Ctx { count = succ st.count
, sig = st.sig
, env = st.env })
return st.count
insertEnv :: Ident -> Integer -> Rename ()
insertEnv i n = do
c <- get
put ( Ctx { env = M.insert i n c.env , sig = c.sig , count = c.count} )
insertSig :: Ident -> Rename ()
insertSig i = do
c <- get
put ( Ctx { sig = S.insert i c.sig , env = c.env , count = c.count } )
data Error = UnboundVar String
instance Show Error where
show (UnboundVar str) = "Unbound variable: " <> str
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -1,32 +0,0 @@
{-# LANGUAGE LambdaCase #-}
module Renamer.RenamerIr (
RExp (..),
RBind (..),
RProgram (..),
Ident (..),
Type (..),
) where
import Grammar.Abs (
Bind (..),
Ident (..),
Program (..),
Type (..),
)
import Grammar.Print
data RProgram = RProgram [RBind]
deriving (Eq, Show, Read, Ord)
data RBind = RBind Ident RExp
deriving (Eq, Show, Read, Ord)
data RExp
= RAnn RExp Type
| RId Ident
| RInt Integer
| RApp RExp RExp
| RAdd RExp RExp
| RAbs Integer Ident RExp
deriving (Eq, Ord, Show, Read)

View file

@ -1,83 +0,0 @@
{-# LANGUAGE LambdaCase #-}
module Renamer.RenamerM where
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
where
initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
renameSc :: Names -> Bind -> Rn Bind
renameSc old_names (Bind name t _ parms rhs) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure $ Bind name t name parms' rhs'
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn { runRn :: State Int a }
deriving (Functor, Applicative, Monad, MonadState Int)
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
EInt i1 -> pure (old_names, EInt i1)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2')
EAbs par e -> do
(new_names, par') <- newName old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ

View file

@ -1,238 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
module TypeChecker.AlgoW where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (bimap, second)
import Data.Functor.Identity (Identity, runIdentity)
import Data.List (foldl', intersect)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Grammar.Abs
import Grammar.Print (Print, printTree)
import qualified TypeChecker.HMIr as T
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving Show
newtype Ctx = Ctx { vars :: Map Ident Poly }
data Env = Env { count :: Int
, sigs :: Map Ident Type
}
type Error = String
type Subst = Map Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
initCtx = Ctx mempty
initEnv = Env 0 mempty
runPretty :: Print a => Infer a -> Either Error String
runPretty = fmap printTree . run
run :: Infer a -> Either Error a
run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
traverse (\(Bind n t _ _ _) -> insertSig n t) bs
bs' <- mapM checkBind bs
return $ T.Program bs'
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
return $ T.Bind (t'',n) [] e'
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- w e
let subbed = apply s t
return (subbed, replace subbed e')
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.EInt t' e -> T.EInt t e
T.EId t' i -> T.EId t i
T.EAbs t' name e -> T.EAbs t name e
T.EApp t' e1 e2 -> T.EApp t e1 e2
T.EAdd t' e1 e2 -> T.EAdd t e1 e2
T.ELet t' name e1 e2 -> T.ELet t name e1 e2
w :: Exp -> Infer (Subst, Type, T.Exp)
w = \case
EAnn e t -> do
(s1, t', e') <- w e
applySt s1 $ do
s2 <- unify (apply s1 t) t'
return (s2 `compose` s1, t, e')
EInt n -> return (nullSubst, TMono "Int", T.EInt (TMono "Int") n)
EId i -> do
var <- asks vars
case M.lookup i var of
Nothing -> throwError $ "Unbound variable: " ++ show i
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId x i)
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- w e
let newArr = TArr (apply s1 fr) t'
return (s1, newArr, T.EAbs newArr name e')
EAdd e0 e1 -> do
(s1, t0, e0') <- w e0
applySt s1 $ do
(s2, t1, e1') <- w e1
applySt s2 $ do
s3 <- unify (subst s2 t0) (TMono "Int")
s4 <- unify (subst s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
EApp e0 e1 -> do
fr <- fresh
(s1, t0, e0') <- w e0
applySt s1 $ do
(s2, t1, e1') <- w e1
applySt s2 $ do
s3 <- unify (subst s2 t0) (TArr t1 fr)
let t = apply s3 fr
return (s3 `compose` s2 `compose` s1, t, T.EApp t e0' e1')
ELet name e0 e1 -> do
(s1, t1, e0') <- w e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- w e1
return (s2 `compose` s1, t2, T.ELet t2 name e0' e1' )
-- | Unify two types producing a new substitution (constraint)
unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (subst s1 b) (subst s1 c)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- | Check if a type is contained in another type.
-- I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal
occurs :: Ident -> Type -> Infer Subst
occurs i (TPol a) = return nullSubst
occurs i t = if S.member i (free t)
then throwError "Occurs check failed"
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones.
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (subst m1) m2 `M.union` m1
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s)
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st { vars = apply s (vars st) })
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Substitute type variables with their mappings from the substitution set.
subst :: Subst -> Type -> Type
subst m t = do
case t of
TPol a -> fromMaybe t (M.lookup a m)
TMono a -> TMono a
TArr a b -> TArr (subst m a) (subst m b)
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st { count = n + 1 })
return . TPol . Ident $ "t" ++ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
-- | Lookup a variable in the context
lookupVar :: Ident -> Infer Poly
lookupVar i = do
m <- asks vars
case M.lookup i m of
Just t -> return t
Nothing -> throwError $ "Unbound variable: " ++ show i
lett = let (Right (t,e)) = run $ inferExp $ ELet "x" (EAdd (EInt 5) (EInt 5)) (EAdd (EId "x") (EId "x"))
in t == TMono "Int"
letty = let (Right (t,e)) = run $ inferExp $ ELet "f" (EAbs "x" (EId "x")) (EApp (EId "f") (EInt 3))
in e

View file

@ -1,181 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
{-# LANGUAGE FlexibleInstances #-}
module TypeChecker.HM where
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor (bimap, second)
import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map)
import qualified Data.Map as M
import Grammar.Abs
import Grammar.Print
import qualified TypeChecker.HMIr as T
type Infer = StateT Ctx (ExceptT String Identity)
type Error = String
data Ctx = Ctx { constr :: Map Type Type
, vars :: Map Ident Type
, sigs :: Map Ident Type
, frsh :: Char }
deriving Show
runC :: Ctx -> Infer a -> Either String (a, Ctx)
runC c = runIdentity . runExceptT . flip runStateT c
run :: Infer a -> Either String a
run = runIdentity . runExceptT . flip evalStateT initC
initC :: Ctx
initC = Ctx M.empty M.empty M.empty 'a'
typecheck :: Program -> Either Error T.Program
typecheck = run . inferPrg
inferPrg :: Program -> Infer T.Program
inferPrg (Program bs) = do
traverse (\(Bind n t _ _ _) -> insertSig n t) bs
bs' <- mapM inferBind bs
return $ T.Program bs'
inferBind :: Bind -> Infer T.Bind
inferBind (Bind i t _ params rhs) = do
(t',e') <- inferExp (makeLambda rhs (reverse params))
when (t /= t') (throwError . unwords $ [ "Signature of function"
, show i
, "with type:"
, show t
, "does not match inferred type"
, show t'
, "of expression:"
, show e'])
return $ T.Bind (t,i) [] e'
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(t, e') <- inferExp' e
t'' <- solveConstraints t
return (t'', replaceType t'' e')
where
inferExp' :: Exp -> Infer (Type, T.Exp)
inferExp' = \case
EAnn e t -> do
(t',e') <- inferExp' e
t'' <- solveConstraints t'
when (t'' /= t) (throwError "Annotated type and inferred type don't match")
return (t', e')
EInt i -> return (int, T.EInt int i)
EId i -> (\t -> (t, T.EId t i)) <$> lookupVar i
EAdd e1 e2 -> do
insertSig "+" (TArr int (TArr int int))
inferExp' (EApp (EApp (EId "+") e1) e2)
EApp e1 e2 -> do
(t1, e1') <- inferExp' e1
(t2, e2') <- inferExp' e2
fr <- fresh
addConstraint t1 (TArr t2 fr)
return (fr, T.EApp fr e1' e2')
EAbs name e -> do
fr <- fresh
insertVar name fr
(ret_t,e') <- inferExp' e
t <- solveConstraints (TArr fr ret_t)
return (t, T.EAbs t name e')
ELet name e1 e2 -> error "Let expression not implemented yet"
replaceType :: Type -> T.Exp -> T.Exp
replaceType t = \case
T.EInt _ i -> T.EInt t i
T.EId _ i -> T.EId t i
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAbs _ name e -> T.EAbs t name e
T.ELet _ name e1 e2 -> T.ELet t name e1 e2
isInt :: Type -> Bool
isInt (TMono "Int") = True
isInt _ = False
lookupVar :: Ident -> Infer Type
lookupVar i = do
st <- get
case M.lookup i (vars st) of
Just t -> return t
Nothing -> case M.lookup i (sigs st) of
Just t -> return t
Nothing -> throwError $ "Unbound variable or function" ++ printTree i
insertVar :: Ident -> Type -> Infer ()
insertVar s t = modify ( \st -> st { vars = M.insert s t (vars st) } )
insertSig :: Ident -> Type -> Infer ()
insertSig s t = modify ( \st -> st { sigs = M.insert s t (sigs st) } )
-- | Generate a new fresh variable and increment the state
fresh :: Infer Type
fresh = do
chr <- gets frsh
modify (\st -> st { frsh = succ chr })
return $ TPol (Ident [chr])
-- | Adds a constraint to the constraint set.
-- i.e: a = int -> b
-- b = int
-- thus when solving constraints it must be the case that
-- a = int -> int
addConstraint :: Type -> Type -> Infer ()
addConstraint t1 t2 = do
modify (\st -> st { constr = M.insert t1 t2 (constr st) })
-- | Given a type, solve the constraints and figure out the type that should be assigned to it.
solveConstraints :: Type -> Infer Type
solveConstraints t = do
c <- gets constr
v <- gets vars
xs <- solveAll (M.toList c)
modify (\st -> st { constr = M.fromList xs })
return $ subst t xs
-- | Substitute
subst :: Type -> [(Type, Type)] -> Type
subst t [] = t
subst (TArr t1 t2) (x:xs) = subst (TArr (replace x t1) (replace x t2)) xs
subst t (x:xs) = subst (replace x t) xs
-- | Given a set of constraints run the replacement on all of them, producing a new set of
-- replacements.
-- https://youtu.be/trmq3wYcUxU - good video for explanation
solveAll :: [(Type, Type)] -> Infer [(Type, Type)]
solveAll [] = return []
solveAll (x:xs) = case x of
(TArr t1 t2, TArr t3 t4) -> solveAll $ (t1,t3) : (t2,t4) : xs
(TArr t1 t2, b) -> fmap ((b, TArr t1 t2) :) $ solveAll $ solve (b, TArr t1 t2) xs
(a, TArr t1 t2) -> fmap ((a, TArr t1 t2) :) $ solveAll $ solve (a, TArr t1 t2) xs
(TMono a, TPol b) -> fmap ((TPol b, TMono a) :) $ solveAll $ solve (TPol b, TMono a) xs
(TPol a, TMono b) -> fmap ((TPol a, TMono b) :) $ solveAll $ solve (TPol a, TMono b) xs
(TPol a, TPol b) -> fmap ((TPol a, TPol b) :) $ solveAll $ solve (TPol a, TPol b) xs
(TMono a, TMono b) -> if a == b then solveAll xs else throwError "Can't unify types"
solve :: (Type, Type) -> [(Type, Type)] -> [(Type, Type)]
solve x = map (both (replace x))
-- | Given a constraint (type, type) and a type, if the constraint matches the input
-- replace with the constrained type
replace :: (Type, Type) -> Type -> Type
replace a (TArr t1 t2) = TArr (replace a t1) (replace a t2)
replace (a,b) c = if a==c then b else c
both :: (a -> b) -> (a,a) -> (b,b)
both f = bimap f f
int = TMono "Int"

View file

@ -1,110 +0,0 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.HMIr
( module Grammar.Abs
, module TypeChecker.HMIr
) where
import Grammar.Abs (Ident (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp
= EId Type Ident
| EInt Type Integer
| ELet Type Ident Exp Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Ident Exp
deriving (C.Eq, C.Ord, C.Read)
instance Show Exp where
show (EId t (Ident i)) = i ++ " : " ++ show t
show (EInt _ i) = show i
show (ELet t i e1 e2) = "let " ++ show t ++ " = " ++ show e1 ++ " in " ++ show e2
show (EApp t e1 e2) = show e1 ++ " " ++ show e2 ++ " : " ++ show t
show (EAdd _ e1 e2) = show e1 ++ " + " ++ show e2
show (EAbs t (Ident i) e) = "\\ " ++ i ++ ". " ++ show e ++ " : " ++ show t
type Id = (Type, Ident)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind name@(n, _) parms rhs) = prPrec i 0 $ concatD
[ prtId 0 name
, doc $ showString ";"
, prt 0 n
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId _ n -> prPrec i 3 $ concatD [prt 0 n]
EInt _ i1 -> prPrec i 3 $ concatD [prt 0 i1]
ELet _ name e1 e2 -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 name
, prt 0 e1
, doc $ showString "in"
, prt 0 e2
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prt 0 n
, doc $ showString "."
, prt 0 e
]

View file

@ -1,153 +1,250 @@
-- {-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE OverloadedRecordDot #-}
-- {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use traverse_" #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
module TypeChecker.TypeChecker where
-- import Control.Monad (void)
-- import Control.Monad.Except (ExceptT, runExceptT, throwError)
-- import Control.Monad.State (StateT)
-- import qualified Control.Monad.State as St
-- import Data.Functor.Identity (Identity, runIdentity)
-- import Data.Map (Map)
-- import qualified Data.Map as M
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity, runIdentity)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
-- import TypeChecker.TypeCheckerIr
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
-- data Ctx = Ctx
-- { vars :: Map Integer Type
-- , sigs :: Map Ident Type
-- , nextFresh :: Int
-- }
-- deriving (Show)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving Show
-- -- Perhaps swap over to reader monad instead for vars and sigs.
-- type Infer = StateT Ctx (ExceptT Error Identity)
newtype Ctx = Ctx { vars :: Map Ident Poly }
-- {-
data Env = Env { count :: Int
, sigs :: Map Ident Type
}
-- The type checker will assume we first rename all variables to unique name, as to not
-- have to care about scoping. It significantly improves the quality of life of the
-- programmer.
type Error = String
type Subst = Map Ident Type
-- TODOs:
-- Add skolemization variables. i.e
-- { \x. 3 : forall a. a -> a }
-- should not type check
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
-- Generalize. Not really sure what that means though
initCtx = Ctx mempty
initEnv = Env 0 mempty
-- -}
runPretty :: Exp -> Either Error String
runPretty = fmap (printTree . fst). run . inferExp
-- typecheck :: RProgram -> Either Error TProgram
-- typecheck = todo
run :: Infer a -> Either Error a
run = runC initEnv initCtx
-- run :: Infer a -> Either Error a
-- run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
-- -- Have to figure out a way to coerce polymorphic types to monomorphic ones where necessary
-- -- { \x. \y. x + y } will have the type { a -> b -> Int }
-- inferExp :: RExp -> Infer Type
-- inferExp = \case
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
-- RAnn expr typ -> do
-- t <- inferExp expr
-- void $ t =:= typ
-- return t
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
let bs' = getBinds bs
traverse (\(Bind n t _ _ _) -> insertSig n t) bs'
bs' <- mapM checkBind bs'
return $ T.Program bs'
where
getBinds :: [Def] -> [Bind]
getBinds = map toBind . filter isBind
isBind :: Def -> Bool
isBind (DBind _) = True
isBind _ = True
toBind :: Def -> Bind
toBind (DBind bind) = bind
toBind _ = error "Can't convert DData to Bind"
-- RBound num name -> lookupVars num
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless (t `typeEq` t'') (throwError $ unwords ["Top level signature", printTree t, "does not match body with type:", printTree t''])
return $ T.Bind (n, t) [] e'
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
-- RFree name -> lookupSigs name
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
-- RConst (CInt i) -> return $ TMono "Int"
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp e = do
(s, t, e') <- w e
let subbed = apply s t
return (subbed, replace subbed e')
-- RConst (CStr str) -> return $ TMono "Str"
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
-- RAdd expr1 expr2 -> do
-- let int = TMono "Int"
-- typ1 <- check expr1 int
-- typ2 <- check expr2 int
-- return int
w :: Exp -> Infer (Subst, Type, T.Exp)
w = \case
-- RApp expr1 expr2 -> do
-- fn_t <- inferExp expr1
-- arg_t <- inferExp expr2
-- res <- fresh
-- new_t <- fn_t =:= TArrow arg_t res
-- return res
EAnn e t -> do
(s1, t', e') <- w e
applySt s1 $ do
s2 <- unify (apply s1 t) t'
return (s2 `compose` s1, t, e')
-- RAbs num name expr -> do
-- arg <- fresh
-- insertVars num arg
-- typ <- inferExp expr
-- return $ TArrow arg typ
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
-- check :: RExp -> Type -> Infer ()
-- check e t = do
-- t' <- inferExp e
-- t =:= t'
-- return ()
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
-- fresh :: Infer Type
-- fresh = do
-- var <- St.gets nextFresh
-- St.modify (\st -> st {nextFresh = succ var})
-- return (TPoly $ Ident (show var))
EId i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Nothing -> throwError $ "Unbound variable: " ++ show i
Just t -> return (nullSubst, t, T.EId (i, t))
-- -- | Unify two types.
-- (=:=) :: Type -> Type -> Infer Type
-- (=:=) (TPoly _) b = return b
-- (=:=) a (TPoly _) = return a
-- (=:=) (TMono a) (TMono b) | a == b = return (TMono a)
-- (=:=) (TArrow a b) (TArrow c d) = do
-- t1 <- a =:= c
-- t2 <- b =:= d
-- return $ TArrow t1 t2
-- (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- w e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
-- lookupVars :: Integer -> Infer Type
-- lookupVars i = do
-- st <- St.gets vars
-- case M.lookup i st of
-- Just t -> return t
-- Nothing -> throwError $ UnboundVar "lookupVars"
EAdd e0 e1 -> do
(s1, t0, e0') <- w e0
applySt s1 $ do
(s2, t1, e1') <- w e1
applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
-- insertVars :: Integer -> Type -> Infer ()
-- insertVars i t = do
-- st <- St.get
-- St.put (st {vars = M.insert i t st.vars})
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- w e0
applySt s0 $ do
(s1, t1, e1') <- w e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
-- lookupSigs :: Ident -> Infer Type
-- lookupSigs i = do
-- st <- St.gets sigs
-- case M.lookup i st of
-- Just t -> return t
-- Nothing -> throwError $ UnboundVar "lookupSigs"
ELet name e0 e1 -> do
(s1, t1, e0') <- w e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- w e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
-- insertSigs :: Ident -> Type -> Infer ()
-- insertSigs i t = do
-- st <- St.get
-- St.put (st {sigs = M.insert i t st.sigs})
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
-- {-# WARNING todo "TODO IN CODE" #-}
-- todo :: a
-- todo = error "TODO in code"
-- | Unify two types producing a new substitution (constraint)
unify :: Type -> Type -> Infer Subst
unify t0 t1 = case (t0, t1) of
(TArr a b, TArr c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
-- data Error
-- = TypeMismatch String
-- | NotNumber String
-- | FunctionTypeMismatch String
-- | NotFunction String
-- | UnboundVar String
-- | AnnotatedMismatch String
-- | Default String
-- deriving (Show)
-- | Check if a type is contained in another type.
-- I.E. { a = a -> b } is an unsolvable constraint since there is no substitution such that these are equal
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs i t = if S.member i (free t)
then throwError "Occurs check failed"
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
-- {-
-- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones.
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
-- The procedure inst(σ) specializes the polytype
-- σ by copying the term and replacing the bound type variables
-- consistently by new monotype variables.
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- -}
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
free :: t -> Set Ident
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
apply :: Subst -> Type -> Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply s = M.map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st { vars = apply s (vars st) })
-- | Represents the empty substition set
nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st { count = n + 1 })
return . TPol . Ident $ "t" ++ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })

View file

@ -1,74 +1,99 @@
-- {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
module TypeChecker.TypeCheckerIr --(
-- TProgram (..),
-- TBind (..),
-- TExp (..),
-- RProgram (..),
-- RBind (..),
-- RExp (..),
-- Type (..),
-- Const (..),
-- Ident (..),
-- ) where
module TypeChecker.TypeCheckerIr
( module Grammar.Abs
, module TypeChecker.TypeCheckerIr
) where
-- import Grammar.Print
-- import Renamer.RenamerIr
import Grammar.Abs (Ident (..), Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
-- newtype TProgram = TProgram [TBind]
-- deriving (Eq, Show, Read, Ord)
newtype Program = Program [Bind]
deriving (C.Eq, C.Ord, C.Show, C.Read)
-- data TBind = TBind Ident Type TExp
-- deriving (Eq, Show, Read, Ord)
data Exp
= EId Id
| ELit Type Literal
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| EAbs Type Id Exp
deriving (C.Eq, C.Ord, C.Read, C.Show)
-- data TExp
-- = TAnn TExp Type
-- | TBound Integer Ident Type
-- | TFree Ident Type
-- | TConst Const Type
-- | TApp TExp TExp Type
-- | TAdd TExp TExp Type
-- | TAbs Integer Ident TExp Type
-- deriving (Eq, Ord, Show, Read)
type Id = (Ident, Type)
data Bind = Bind Id [Id] Exp
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind (t, name) parms rhs) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, prtIdPs 0 parms
, doc $ showString "="
, prt 0 rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
prtId :: Int -> Id -> Doc
prtId i (name, t) = prPrec i 0 $ concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
]
prtIdP :: Int -> Id -> Doc
prtIdP i (name, t) = prPrec i 0 $ concatD
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1]
ELet bs e -> prPrec i 3 $ concatD
[ doc $ showString "let"
, prt 0 bs
, doc $ showString "in"
, prt 0 e
]
EApp t e1 e2 -> prPrec i 2 $ concatD
[ prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 -> prPrec i 1 $ concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
]
EAbs t n e -> prPrec i 0 $ concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtId 0 n
, doc $ showString "."
, prt 0 e
]
-- instance Print TProgram where
-- prt i = \case
-- TProgram defs -> prPrec i 0 (concatD [prt 0 defs])
-- instance Print TBind where
-- prt i = \case
-- TBind x t e ->
-- prPrec i 0 $
-- concatD
-- [ prt 0 x
-- , doc (showString ":")
-- , prt 0 t
-- , doc (showString "=")
-- , prt 0 e
-- , doc (showString "\n")
-- ]
-- instance Print TExp where
-- prt i = \case
-- TAnn e t ->
-- prPrec i 2 $
-- concatD
-- [ prt 0 e
-- , doc (showString ":")
-- , prt 1 t
-- ]
-- TBound _ u t -> prPrec i 3 $ concatD [prt 0 u]
-- TFree u t -> prPrec i 3 $ concatD [prt 0 u]
-- TConst c _ -> prPrec i 3 (concatD [prt 0 c])
-- TApp e e1 t -> prPrec i 2 $ concatD [prt 2 e, prt 3 e1]
-- TAdd e e1 t -> prPrec i 1 $ concatD [prt 1 e, doc (showString "+"), prt 2 e1]
-- TAbs _ u e t ->
-- prPrec i 0 $
-- concatD
-- [ doc (showString "(")
-- , doc (showString "λ")
-- , prt 0 u
-- , doc (showString ".")
-- , prt 0 e
-- , doc (showString ")")
-- ]

View file

@ -1,3 +1,2 @@
fun : Mono Int -> Mono Int ;
fun = let f = \x. x in f 3 ;
main : _Int ;
main = 3 + 3 ;

View file

@ -1,21 +0,0 @@
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Grammar.Abs
import System.Exit (exitFailure)
import Test.Hspec
import TypeChecker.AlgoW
main :: IO ()
main = do
print "RUNNING TESTS BROTHER"
exitFailure
-- hspec $ do
-- describe "the algorithm W" $ do
-- it "infers EInt as type Int" $ do
-- fmap fst (run (inferExp (EInt 1))) `shouldBe` Right (TMono "Int")
-- it "throws an exception if a variable is inferred with an empty env" $ do
-- run (inferExp (EId "x")) `shouldBe` Left "Unbound variable: x"
-- it "throws an exception if the annotated type does not match the inferred type" $ do
-- fmap fst (run (inferExp (EAnn (EInt 3) (TPol "a")))) `shouldBe` Right (TMono "bad")

56
tests/Tests.hs Normal file
View file

@ -0,0 +1,56 @@
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use <$>" #-}
module Main where
import Control.Monad.Except
import Grammar.Abs
import Test.QuickCheck
import TypeChecker.TypeChecker
import qualified TypeChecker.TypeCheckerIr as T
main :: IO ()
main = do
quickCheck prop_isInt
quickCheck prop_idAbs_generic
newtype AbsExp = AE Exp deriving Show
newtype EIntExp = EI Exp deriving Show
instance Arbitrary EIntExp where
arbitrary = genInt
instance Arbitrary AbsExp where
arbitrary = genLambda
getType :: Infer (Type, T.Exp) -> Either Error Type
getType ie = case run ie of
Left err -> Left err
Right (t,e) -> return t
genInt :: Gen EIntExp
genInt = EI . ELit . LInt <$> arbitrary
genLambda :: Gen AbsExp
genLambda = do
str <- arbitrary @String
let str' = Ident str
return $ AE $ EAbs str' (EId str')
prop_idAbs_generic :: AbsExp -> Bool
prop_idAbs_generic (AE e) = case getType (inferExp e) of
Left _ -> False
Right t -> isGenericArr t
prop_isInt :: EIntExp -> Bool
prop_isInt (EI e) = case getType (inferExp e) of
Left _ -> False
Right t -> t == int
int :: Type
int = TMono "Int"
isGenericArr :: Type -> Bool
isGenericArr (TArr (TPol a) (TPol b)) = a == b
isGenericArr _ = False