Add implicit foralls for bidir, update and unify pipeline

This commit is contained in:
Martin Fredin 2023-04-03 17:34:33 +02:00
parent 12bca1c32d
commit 9870802371
33 changed files with 1010 additions and 1055 deletions

View file

@ -75,8 +75,7 @@ PInj. Pattern ::= UIdent [Pattern1];
-- * AUX -- * AUX
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
layout "of", "where", "let"; layout "of", "where";
layout stop "in";
layout toplevel; layout toplevel;
separator Def ";"; separator Def ";";

View file

@ -1,219 +0,0 @@
let SessionLoad = 1
let s:so_save = &g:so | let s:siso_save = &g:siso | setg so=0 siso=0 | setl so=-1 siso=-1
let v:this_session=expand("<sfile>:p")
silent only
silent tabonly
cd ~/Documents/bachelor_thesis/language
if expand('%') == '' && !&modified && line('$') <= 1 && getline(1) == ''
let s:wipebuf = bufnr('%')
endif
let s:shortmess_save = &shortmess
if &shortmess =~ 'A'
set shortmess=aoOA
else
set shortmess=aoO
endif
badd +1 ~/Documents/bachelor_thesis/language
badd +298 src/TypeChecker/TypeChecker.hs
badd +7 test_program
badd +46 src/TypeChecker/TypeCheckerIr.hs
badd +6 Grammar.cf
badd +1 src/Grammar/Abs.hs
argglobal
%argdel
$argadd ~/Documents/bachelor_thesis/language
set stal=2
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabrewind
edit src/TypeChecker/TypeChecker.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
argglobal
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 298 - ((18 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 298
normal! 029|
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/Grammar.cf", ":p")) | buffer ~/Documents/bachelor_thesis/language/Grammar.cf | else | edit ~/Documents/bachelor_thesis/language/Grammar.cf | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/Grammar.cf
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
argglobal
balt ~/Documents/bachelor_thesis/language/test_program
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs", ":p")) | buffer ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | else | edit ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/Grammar.cf
argglobal
balt ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 40 - ((12 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 40
normal! 0
lcd ~/Documents/bachelor_thesis/language
tabnext
edit ~/Documents/bachelor_thesis/language/test_program
argglobal
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 010|
lcd ~/Documents/bachelor_thesis/language
tabnext 1
set stal=1
if exists('s:wipebuf') && len(win_findbuf(s:wipebuf)) == 0 && getbufvar(s:wipebuf, '&buftype') isnot# 'terminal'
silent exe 'bwipe ' . s:wipebuf
endif
unlet! s:wipebuf
set winheight=1 winwidth=20
let &shortmess = s:shortmess_save
let s:sx = expand("<sfile>:p:r")."x.vim"
if filereadable(s:sx)
exe "source " . fnameescape(s:sx)
endif
let &g:so = s:so_save | let &g:siso = s:siso_save
set hlsearch
nohlsearch
doautoall SessionLoadPost
unlet SessionLoad
" vim: set ft=vim :

View file

@ -35,10 +35,12 @@ executable language
Auxiliary Auxiliary
Renamer.Renamer Renamer.Renamer
TypeChecker.TypeChecker TypeChecker.TypeChecker
AnnForall
TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerBidir
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
TypeChecker.RemoveTEVar TypeChecker.ReportTEVar
TypeChecker.RemoveForall
LambdaLifter LambdaLifter
Monomorphizer.Monomorphizer Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr
@ -72,11 +74,14 @@ executable language
Test-suite language-testsuite Test-suite language-testsuite
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
main-is: Tests.hs main-is: Main.hs
other-modules: other-modules:
TestTypeCheckerBidir TestTypeCheckerBidir
TestTypeCheckerHm TestTypeCheckerHm
TestAnnForall
TestReportForall
TestRenamer
Grammar.Abs Grammar.Abs
Grammar.Lex Grammar.Lex
@ -90,13 +95,16 @@ Test-suite language-testsuite
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr
Renamer.Renamer Renamer.Renamer
TypeChecker.TypeChecker TypeChecker.TypeChecker
AnnForall
ReportForall
TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerBidir
TypeChecker.RemoveTEVar TypeChecker.ReportTEVar
TypeChecker.RemoveForall
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
Compiler Compiler
hs-source-dirs: src, tests, tests/TypecheckingHM hs-source-dirs: src, tests
build-depends: build-depends:
base >=4.16 base >=4.16
@ -110,6 +118,7 @@ Test-suite language-testsuite
, process , process
, bytestring , bytestring
, hspec , hspec
, directory
default-language: GHC2021 default-language: GHC2021

27
pipeline.txt Normal file
View file

@ -0,0 +1,27 @@
Parser
|
ReportForall Report unnecessary foralls. Hm: report rank>2 foralls
|
AnnotateForall Annotate all unbound type variables with foralls
|
Renamer Rename type variables and term variables
|
/ \
/ \
TypeCheckHm TypeCheckBi
\ /
\ /
|
ReportTEVar Report type existential variables and change type AST
|
RemoveForall RemoveForall and change type AST
|
Monomorpher
|
Desugar
|
CodeGen

View file

@ -10,6 +10,10 @@ even : Int -> Bool ()
even x = not (odd x) even x = not (odd x)
odd x = not (even x) odd x = not (even x)
main = case even 64 of
True => 1
False => 0

View file

@ -1,9 +1,13 @@
data Bool () where { data Bool () where
True : Bool () True : Bool ()
False : Bool () False : Bool ()
};
toBool = case 0 of { toBool x = case x of
0 => False; 0 => False
_ => True; _ => True
};
fromBool b = case b of
False => 0
True => 1
main = fromBool (toBool 10)

View file

@ -0,0 +1,10 @@
applyId : (forall a. a -> a) -> a -> a
applyId f x = f x
id : a -> a
id x = x
main = applyId id 4

View file

@ -1,10 +1,8 @@
data Bool () where { data Bool () where
True : Bool () True : Bool ()
False : Bool () False : Bool ()
};
main : Bool () -> a -> Int ; main : Bool () -> a -> Int
main b = case b of { main b = case b of
False => (\x. 1); False => (\x. 1)
True => \x. 0; True => (\x. 0)
};

View file

@ -1,10 +1,8 @@
data Bool () where { data Bool () where
True : Bool () True : Bool ()
False : Bool () False : Bool ()
};
ifThenElse : forall a. Bool () -> a -> a -> a; ifThenElse : forall a. Bool () -> a -> a -> a
ifThenElse b if else = case b of { ifThenElse b if else = case b of
True => if; True => if
False => else False => else
}

View file

@ -1,24 +1,20 @@
data Maybe (a) where { data Maybe (a) where
Nothing : Maybe (a) Nothing : Maybe (a)
Just : a -> Maybe (a) Just : a -> Maybe (a)
};
fromJust : Maybe (a) -> a ; fromJust : Maybe (a) -> a
fromJust a = fromJust a =
case a of { case a of
Just a => a Just a => a
};
fromMaybe : a -> Maybe (a) -> a ; fromMaybe : a -> Maybe (a) -> a
fromMaybe a b = fromMaybe a b =
case b of { case b of
Just a => a; Just a => a
Nothing => a Nothing => a
};
maybe : b -> (a -> b) -> Maybe (a) -> b; maybe : b -> (a -> b) -> Maybe (a) -> b
maybe b f ma = maybe b f ma =
case ma of { case ma of
Just a => f a; Just a => f a
Nothing => b Nothing => b
}

View file

@ -1,13 +1,9 @@
data List (a) where { data List (a) where
Nil : List (a) Nil : List (a)
Cons : a -> List (a) -> List (a) Cons : a -> List (a) -> List (a)
};
test xs = case xs of {
Cons Nil _ => 0 ;
};
test xs = case xs of
Cons Nil _ => 0
List a /= List (List a) List a /= List (List a)

View file

@ -11,7 +11,8 @@ pkgs.haskellPackages.developPackage {
ghc ghc
jasmin jasmin
llvmPackages_15.libllvm llvmPackages_15.libllvm
texlive.combined.scheme-full clang
# texlive.combined.scheme-full
]) ])
++ ++
(with pkgs.haskellPackages; [ cabal-install (with pkgs.haskellPackages; [ cabal-install

100
src/AnnForall.hs Normal file
View file

@ -0,0 +1,100 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module AnnForall (annotateForall) where
import Auxiliary (partitionDefs)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (throwError)
import Data.Function (on)
import Data.Set (Set)
import qualified Data.Set as Set
import Grammar.Abs
import Grammar.ErrM (Err)
annotateForall :: Program -> Err Program
annotateForall (Program defs) = do
ds' <- mapM (fmap DData . annData) ds
bs' <- mapM (fmap DBind . annBind) bs
pure $ Program (ds' ++ ss' ++ bs')
where
ss' = map (DSig . annSig) ss
(ds, ss, bs) = partitionDefs defs
annData :: Data -> Err Data
annData (Data typ injs) = do
(typ', tvars) <- annTyp typ
pure (Data typ' $ map (annInj tvars) injs)
where
annTyp typ = do
(bounded, ts) <- boundedTVars mempty typ
unbounded <- Set.fromList <$> mapM assertTVar ts
let diff = unbounded Set.\\ bounded
typ' = foldr TAll typ diff
(typ', ) . fst <$> boundedTVars mempty typ'
where
boundedTVars tvars typ = case typ of
TAll tvar t -> boundedTVars (Set.insert tvar tvars) t
TData _ ts -> pure (tvars, ts)
_ -> throwError "Misformed data declaration"
assertTVar typ = case typ of
TVar tvar -> pure tvar
_ -> throwError $ unwords [ "Misformed data declaration:"
, "Non type variable argument"
]
annInj tvars (Inj n t) =
Inj n $ foldr TAll t (unboundedTVars t Set.\\ tvars)
annSig :: Sig -> Sig
annSig (Sig name typ) = Sig name $ annType typ
annBind :: Bind -> Err Bind
annBind (Bind name vars exp) = Bind name vars <$> annExp exp
where
annExp = \case
EAnn e t -> flip EAnn (annType t) <$> annExp e
EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2)
EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2)
ELet bind e -> liftA2 ELet (annBind bind) (annExp e)
EAbs x e -> EAbs x <$> annExp e
ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs)
e -> pure e
annBranch (Branch p e) = Branch p <$> annExp e
annType :: Type -> Type
annType typ = go $ unboundedTVars typ
where
go us
| null us = typ
| otherwise = foldr TAll typ us
unboundedTVars :: Type -> Set TVar
unboundedTVars = unboundedTVars' mempty
unboundedTVars' :: Set TVar -> Type -> Set TVar
unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded
where
tvars = gatherTVars typ
gatherTVars = \case
TAll tvar t -> TVars { bounded = Set.singleton tvar
, unbounded = unboundedTVars' (Set.insert tvar bs) t
}
TVar tvar -> uTVars $ Set.singleton tvar
TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2
TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs
_ -> TVars { bounded = mempty, unbounded = mempty }
data TVars = TVars
{ bounded :: Set TVar
, unbounded :: Set TVar
} deriving (Eq, Show, Ord)
uTVars :: Set TVar -> TVars
uTVars us = TVars
{ bounded = mempty
, unbounded = us
}

View file

@ -1,14 +1,16 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
module Auxiliary (module Auxiliary) where module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither) import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (MonadError) import Control.Monad.Error.Class (liftEither)
import Data.Either.Combinators (maybeToRight) import Control.Monad.Except (MonadError)
import Data.List (foldl') import Data.Either.Combinators (maybeToRight)
import Grammar.Abs import Data.List (foldl')
import Prelude hiding ((>>), (>>=)) import Grammar.Abs
import Prelude hiding ((>>), (>>=))
(>>) a b = a ++ " " ++ b (>>) a b = a ++ " " ++ b
(>>=) a f = f a (>>=) a f = f a
@ -29,6 +31,9 @@ mapAccumM f = go
(acc'', xs') <- go acc' xs (acc'', xs') <- go acc' xs
pure (acc'', x' : xs') pure (acc'', x' : xs')
onM :: Monad m => (b -> b -> c) -> (a -> m b) -> a -> a -> m c
onM f g x y = liftA2 f (g x) (g y)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = unzip4 =
foldl' foldl'
@ -38,7 +43,7 @@ unzip4 =
([], [], [], []) ([], [], [], [])
litType :: Lit -> Type litType :: Lit -> Type
litType (LInt _) = int litType (LInt _) = int
litType (LChar _) = char litType (LChar _) = char
int = TLit "Int" int = TLit "Int"
@ -53,3 +58,10 @@ trd_ :: (a, b, c) -> c
snd_ (_, a, _) = a snd_ (_, a, _) = a
fst_ (a, _, _) = a fst_ (a, _, _) = a
trd_ (_, _, a) = a trd_ (_, _, a) = a
partitionDefs :: [Def] -> ([Data], [Sig], [Bind])
partitionDefs defs = (datas, sigs, binds)
where
datas = [ d | DData d <- defs ]
sigs = [ s | DSig s <- defs ]
binds = [ b | DBind b <- defs ]

View file

@ -178,27 +178,14 @@ abstractExp (free, (exp, typ)) = case exp of
names = snoc parm freeList names = snoc parm freeList
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return) applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
where where
(t_var, t_return) = applyVarType t (t_var, t_return) = case t of
TFun t1 t2 -> (t1, t2)
abstractBranch :: AnnBranch -> State Int Branch abstractBranch :: AnnBranch -> State Int Branch
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
applyVarType :: Type -> (Type, Type)
applyVarType typ = (t1, foldr ($) t2 foralls)
where
(t1, t2) = case typ' of
TFun t1 t2 -> (t1, t2)
_ -> error "Not a function!"
(foralls, typ') = skipForalls [] typ
skipForalls acc = \case
TAll tvar t -> skipForalls (snoc (TAll tvar) acc) t
t -> (acc, t)
nextNumber :: State Int Int nextNumber :: State Int Int
nextNumber = do nextNumber = do
i <- get i <- get
@ -270,20 +257,9 @@ getVars :: Type -> [Type]
getVars = fst . partitionType getVars = fst . partitionType
partitionType :: Type -> ([Type], Type) partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls' partitionType = go []
where where
go acc t = case t of go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2 TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t) _ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)

View file

@ -1,11 +1,12 @@
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
module Main where module Main where
import AnnForall (annotateForall)
import Codegen.Codegen (generateCode) import Codegen.Codegen (generateCode)
import Compiler (compile) import Compiler (compile)
import Control.Monad (when) import Control.Monad (when, (<=<))
import Data.Bool (bool)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
import Data.Maybe (fromJust, isNothing) import Data.Maybe (fromJust, isNothing)
import Desugar.Desugar (desugar) import Desugar.Desugar (desugar)
@ -13,10 +14,11 @@ import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Layout (resolveLayout) import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (Print, printTree)
import LambdaLifter (lambdaLift) import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize) import Monomorphizer.Monomorphizer (monomorphize)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import ReportForall (reportForall)
import System.Console.GetOpt (ArgDescr (NoArg, ReqArg), import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder), ArgOrder (RequireOrder),
OptDescr (Option), getOpt, OptDescr (Option), getOpt,
@ -87,35 +89,40 @@ data Options = Options
} }
main' :: Options -> String -> IO () main' :: Options -> String -> IO ()
main' opts s = do main' opts s =
let
log :: (Print a, Show a) => a -> IO ()
log = printToErr . if opts.debug then show else printTree
in do
file <- readFile s file <- readFile s
printToErr "-- Parse Tree -- " printToErr "-- Parse Tree -- "
parsed <- fromSyntaxErr . pProgram . resolveLayout True $ myLexer file parsed <- fromErr . pProgram . resolveLayout True $ myLexer file
bool (printToErr $ printTree parsed) (printToErr $ show parsed) opts.debug log parsed
printToErr "-- Desugar --" printToErr "-- Desugar --"
let desugared = desugar parsed let desugared = desugar parsed
bool (printToErr $ printTree desugared) (printToErr $ show desugared) opts.debug log desugared
printToErr "\n-- Renamer --" printToErr "\n-- Renamer --"
renamed <- fromRenamerErr . rename $ desugared _ <- fromErr $ reportForall (fromJust opts.typechecker) desugared
bool (printToErr $ printTree renamed) (printToErr $ show renamed) opts.debug renamed <- fromErr $ (rename <=< annotateForall) desugared
log renamed
printToErr "\n-- TypeChecker --" printToErr "\n-- TypeChecker --"
typechecked <- fromTypeCheckerErr $ typecheck (fromJust opts.typechecker) renamed typechecked <- fromErr $ typecheck (fromJust opts.typechecker) renamed
bool (printToErr $ printTree typechecked) (printToErr $ show typechecked) opts.debug log typechecked
printToErr "\n-- Lambda Lifter --" printToErr "\n-- Lambda Lifter --"
let lifted = lambdaLift typechecked let lifted = lambdaLift typechecked
bool (printToErr $ printTree lifted) (printToErr $ show lifted) opts.debug log lifted
printToErr "\n -- Monomorphizer --" printToErr "\n -- Monomorphizer --"
let monomorphized = monomorphize lifted let monomorphized = monomorphize lifted
bool (printToErr $ printTree monomorphized) (printToErr $ show monomorphized) opts.debug log lifted
printToErr "\n -- Compiler --" printToErr "\n -- Compiler --"
generatedCode <- fromCompilerErr $ generateCode monomorphized generatedCode <- fromErr $ generateCode monomorphized
check <- doesPathExist "output" check <- doesPathExist "output"
when check (removeDirectoryRecursive "output") when check (removeDirectoryRecursive "output")
@ -143,55 +150,9 @@ debugDotViz = do
spawnWait :: String -> IO ExitCode spawnWait :: String -> IO ExitCode
spawnWait s = spawnCommand s >>= waitForProcess spawnWait s = spawnCommand s >>= waitForProcess
printToErr :: String -> IO () printToErr :: String -> IO ()
printToErr = hPutStrLn stderr printToErr = hPutStrLn stderr
fromCompilerErr :: Err a -> IO a fromErr :: Err a -> IO a
fromCompilerErr = fromErr = either (\s -> printToErr s >> exitFailure) pure
either
( \err -> do
putStrLn "\nCOMPILER ERROR"
putStrLn err
exitFailure
)
pure
fromSyntaxErr :: Err a -> IO a
fromSyntaxErr =
either
( \err -> do
putStrLn "\nSYNTAX ERROR"
putStrLn err
exitFailure
)
pure
fromTypeCheckerErr :: Err a -> IO a
fromTypeCheckerErr =
either
( \err -> do
putStrLn "\nTYPECHECKER ERROR"
putStrLn err
exitFailure
)
pure
fromRenamerErr :: Err a -> IO a
fromRenamerErr =
either
( \err -> do
putStrLn "\nRENAMER ERROR"
putStrLn err
exitFailure
)
pure
fromInterpreterErr :: Err a -> IO a
fromInterpreterErr =
either
( \err -> do
putStrLn "\nINTERPRETER ERROR"
putStrLn err
exitFailure
)
pure

View file

@ -7,37 +7,40 @@
-- monomorphic bindings will be part of this compilation step. -- monomorphic bindings will be part of this compilation step.
-- Apply the following monomorphization function on all monomorphic binds, with -- Apply the following monomorphization function on all monomorphic binds, with
-- their type as an additional argument. -- their type as an additional argument.
-- --
-- The function that transforms Binds operates on both monomorphic and -- The function that transforms Binds operates on both monomorphic and
-- polymorphic functions, creates a context in which all possible polymorphic types -- polymorphic functions, creates a context in which all possible polymorphic types
-- are mapped to concrete types, created using the additional argument. -- are mapped to concrete types, created using the additional argument.
-- Expressions are then recursively processed. The type of these expressions -- Expressions are then recursively processed. The type of these expressions
-- are changed to using the mapped generic types. The expected type provided -- are changed to using the mapped generic types. The expected type provided
-- in the recursion is changed depending on the different nodes. -- in the recursion is changed depending on the different nodes.
-- --
-- When an external bind is encountered (with EId), it is checked whether it -- When an external bind is encountered (with EId), it is checked whether it
-- exists in outputed binds or not. If it does, nothing further is evaluated. -- exists in outputed binds or not. If it does, nothing further is evaluated.
-- If not, the bind transformer function is called on it with the -- If not, the bind transformer function is called on it with the
-- expected type in this context. The result of this computation (a monomorphic -- expected type in this context. The result of this computation (a monomorphic
-- bind) is added to the resulting set of binds. -- bind) is added to the resulting set of binds.
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import qualified TypeChecker.TypeCheckerIr as T import Monomorphizer.DataTypeRemover (removeDataTypes)
import TypeChecker.TypeCheckerIr (Ident (Ident))
import qualified Monomorphizer.MorbIr as M
import qualified Monomorphizer.MonomorphizerIr as O import qualified Monomorphizer.MonomorphizerIr as O
import Monomorphizer.DataTypeRemover (removeDataTypes) import qualified Monomorphizer.MorbIr as M
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ident (Ident))
import Debug.Trace import Control.Monad.Reader (MonadReader (ask, local),
import Control.Monad.State (MonadState (get), gets, modify, StateT (runStateT)) Reader, asks, runReader)
import qualified Data.Map as Map import Control.Monad.State (MonadState (get),
import qualified Data.Set as Set StateT (runStateT), gets,
import Data.Maybe (fromJust) modify)
import Control.Monad.Reader (Reader, MonadReader (local, ask), asks, runReader) import Data.Coerce (coerce)
import Data.Coerce (coerce) import qualified Data.Map as Map
import Grammar.Print (printTree) import Data.Maybe (fromJust)
import qualified Data.Set as Set
import Debug.Trace
import Grammar.Print (printTree)
-- | State Monad wrapper for "Env". -- | State Monad wrapper for "Env".
newtype EnvM a = EnvM (StateT Output (Reader Env) a) newtype EnvM a = EnvM (StateT Output (Reader Env) a)
@ -90,9 +93,9 @@ getMain = asks (\env -> fromJust $ Map.lookup (T.Ident "main") (input env))
mapTypes :: T.Type -> M.Type -> [(Ident, M.Type)] mapTypes :: T.Type -> M.Type -> [(Ident, M.Type)]
mapTypes (T.TLit _) (M.TLit _) = [] mapTypes (T.TLit _) (M.TLit _) = []
mapTypes (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] mapTypes (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
mapTypes (T.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes pt1 mt1 ++ mapTypes (T.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes pt1 mt1 ++
mapTypes pt2 mt2 mapTypes pt2 mt2
mapTypes (T.TData tIdent pTs) (M.TData mIdent mTs) = if tIdent /= mIdent mapTypes (T.TData tIdent pTs) (M.TData mIdent mTs) = if tIdent /= mIdent
then error "nuh uh" then error "nuh uh"
else foldl (\xs (p, m) -> mapTypes p m ++ xs) [] (zip pTs mTs) else foldl (\xs (p, m) -> mapTypes p m ++ xs) [] (zip pTs mTs)
mapTypes t1 t2 = error $ "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" mapTypes t1 t2 = error $ "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
@ -111,8 +114,6 @@ getMonoFromPoly t = do env <- ask
Nothing -> M.TLit (Ident "void") Nothing -> M.TLit (Ident "void")
--error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" --error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args) (T.TData ident args) -> M.TData ident (map (getMono polys) args)
-- TODO: TAll should work different/should not exist in this tree
(T.TAll _ t) -> getMono polys t
-- | If ident not already in env's output, morphed bind to output -- | If ident not already in env's output, morphed bind to output
-- (and all referenced binds within this bind). -- (and all referenced binds within this bind).
@ -128,14 +129,14 @@ morphBind expectedType b@(T.Bind (Ident _, btype) args (exp, expt)) =
bindMarked <- isBindMarked (coerce name') bindMarked <- isBindMarked (coerce name')
-- Return with right name if already marked -- Return with right name if already marked
if bindMarked then return name' else do if bindMarked then return name' else do
-- Mark so that this bind will not be processed in recursive or cyclic -- Mark so that this bind will not be processed in recursive or cyclic
-- function calls -- function calls
markBind (coerce name') markBind (coerce name')
expt' <- getMonoFromPoly expt expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp exp' <- morphExp expt' exp
-- Get monomorphic type sof args -- Get monomorphic type sof args
args' <- mapM convertArg args args' <- mapM convertArg args
addOutputBind $ M.Bind (coerce name', expectedType) addOutputBind $ M.Bind (coerce name', expectedType)
args' (exp', expectedType) args' (exp', expectedType)
return name' return name'
@ -162,7 +163,7 @@ getInputData ident = do env <- ask
-- | Expects polymorphic types in data definition to be mapped -- | Expects polymorphic types in data definition to be mapped
-- in environment. -- in environment.
--morphData :: T.Data -> EnvM () --morphData :: T.Data -> EnvM ()
--morphData (T.Data t cs) = do --morphData (T.Data t cs) = do
-- t' <- getMonoFromPoly t -- t' <- getMonoFromPoly t
-- output <- get -- output <- get
-- cs' <- mapM (\(T.Inj ident t) -> do t' <- getMonoFromPoly t -- cs' <- mapM (\(T.Inj ident t) -> do t' <- getMonoFromPoly t
@ -170,7 +171,7 @@ getInputData ident = do env <- ask
-- addOutputData $ M.Data t' cs' -- addOutputData $ M.Data t' cs'
morphCons :: M.Type -> Ident -> EnvM () morphCons :: M.Type -> Ident -> EnvM ()
morphCons expectedType ident = do morphCons expectedType ident = do
maybeD <- getInputData ident maybeD <- getInputData ident
case maybeD of case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found" Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
@ -191,7 +192,7 @@ morphCons expectedType ident = do
-- TODO: Change in tree so that these are the same. -- TODO: Change in tree so that these are the same.
-- Converts Lit -- Converts Lit
convertLit :: T.Lit -> M.Lit convertLit :: T.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v convertLit (T.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v convertLit (T.LChar v) = M.LChar v
morphExp :: M.Type -> T.Exp -> EnvM M.Exp morphExp :: M.Type -> T.Exp -> EnvM M.Exp
@ -204,7 +205,7 @@ morphExp expectedType exp = case exp of
morphApp M.EApp expectedType e1 e2 morphApp M.EApp expectedType e1 e2
T.EAdd e1 e2 -> do T.EAdd e1 e2 -> do
morphApp M.EAdd expectedType e1 e2 morphApp M.EAdd expectedType e1 e2
T.EAbs ident (exp, t) -> local (\env -> env { locals = Set.insert ident (locals env) }) $ do T.EAbs ident (exp, t) -> local (\env -> env { locals = Set.insert ident (locals env) }) $ do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
morphExp t' exp morphExp t' exp
T.ECase (exp, t) bs -> do T.ECase (exp, t) bs -> do
@ -256,7 +257,7 @@ morphPattern ls = \case
-- | Creates a new identifier for a function with an assigned type -- | Creates a new identifier for a function with an assigned type
newFuncName :: M.Type -> T.Bind -> Ident newFuncName :: M.Type -> T.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main" if bindName == "main"
then Ident bindName then Ident bindName
else newName t ident else newName t ident
@ -286,7 +287,7 @@ runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds. -- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env createEnv :: [T.Def] -> Env
createEnv defs = Env { input = Map.fromList bindPairs, createEnv defs = Env { input = Map.fromList bindPairs,
dataDefs = Map.fromList dataPairs, dataDefs = Map.fromList dataPairs,
polys = Map.empty, polys = Map.empty,
locals = Set.empty } locals = Set.empty }
@ -312,7 +313,7 @@ getBindsFromDefs = foldl (\bs -> \case
getDefsFromOutput :: Output -> [M.Def] getDefsFromOutput :: Output -> [M.Def]
getDefsFromOutput o = getDefsFromOutput o =
map M.DBind binds ++ map M.DBind binds ++
(map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty) (map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty)
where where
(binds, dataInput) = splitBindsAndData o (binds, dataInput) = splitBindsAndData o
@ -323,7 +324,7 @@ splitBindsAndData output = foldl
(\(oBinds, oData) (ident, o) -> case o of (\(oBinds, oData) (ident, o) -> case o of
Incomplete -> error "internal bug in monomorphizer" Incomplete -> error "internal bug in monomorphizer"
Complete b -> (b:oBinds, oData) Complete b -> (b:oBinds, oData)
Data t d -> (oBinds, (ident, t, d):oData)) Data t d -> (oBinds, (ident, t, d):oData))
([], []) ([], [])
(Map.toList output) (Map.toList output)
@ -339,7 +340,7 @@ createNewData ((consIdent, consType, polyData):input) o =
newDataType = getDataType consType newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType newCons = M.Inj consIdent consType
getDataType :: M.Type -> M.Type getDataType :: M.Type -> M.Type
getDataType (M.TFun t1 t2) = getDataType t2 getDataType (M.TFun t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData getDataType tData@(M.TData _ _) = tData
@ -356,7 +357,7 @@ getDataType _ = error "???"
-- Nothing -> do -- Nothing -> do
-- createNewData cs $ Map.insert ident (M.Data (M.TLit $ Ident "void") [newCons]) o -- createNewData cs $ Map.insert ident (M.Data (M.TLit $ Ident "void") [newCons]) o
-- Just _ -> do -- Just _ -> do
-- createNewData cs $ Map.adjust (\(M.Data _ pcs') -> -- createNewData cs $ Map.adjust (\(M.Data _ pcs') ->
-- M.Data expectedType (newCons : pcs')) ident o -- M.Data expectedType (newCons : pcs')) ident o
-- _ -> error "internal bug in monomorphizer" -- _ -> error "internal bug in monomorphizer"

View file

@ -1,224 +1,112 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
module Renamer.Renamer (rename) where module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM) import Auxiliary (maybeToRightM, onM, partitionDefs)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (liftA2)
import Control.Monad (when) import Control.Monad.Except (ExceptT, MonadError, runExceptT)
import Control.Monad.Except ( import Control.Monad.State (MonadState, State, evalState, gets,
ExceptT, modify)
MonadError (catchError, throwError), import Data.Map (Map)
runExceptT, import qualified Data.Map as Map
) import Data.Tuple.Extra (dupe)
import Control.Monad.State ( import Grammar.Abs
MonadState, import Grammar.ErrM (Err)
State, import Grammar.Print (printTree)
StateT,
evalState,
evalStateT,
get,
gets,
lift,
mapAndUnzipM,
modify,
put,
)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Err Program rename :: Program -> Err Program
rename (Program defs) = Program <$> renameDefs defs rename (Program defs) = rename' $ do
ds' <- mapM (fmap DData . rnData) ds
ss' <- mapM (fmap DSig . rnSig) ss
bs' <- mapM (fmap DBind . rnTopBind) bs
pure $ Program (ds' ++ ss' ++ bs')
where
(ds, ss, bs) = partitionDefs defs
rename' = flip evalState initCxt
. runExceptT
. runRn
initCxt = Cxt
{ counter = 0
, names = Map.fromList $ [ dupe n | Sig n _ <- ss ]
++ [ dupe n | Bind n _ _ <- bs ]
}
rnData :: Data -> Rn Data
rnData (Data typ injs) = liftA2 Data (rnType typ) (mapM rnInj injs)
where
rnInj (Inj name t) = Inj name <$> rnType t
initCxt :: Cxt rnSig :: Sig -> Rn Sig
initCxt = Cxt 0 0 rnSig (Sig name typ) = liftA2 Sig (getName name) (rnType typ)
rnType :: Type -> Rn Type
rnType = \case
TVar (MkTVar name) -> TVar . MkTVar <$> getName name
TData name ts -> TData name <$> localNames (mapM rnType ts)
TFun t1 t2 -> onM TFun (localNames . rnType) t1 t2
TAll (MkTVar name) t -> liftA2 (TAll . MkTVar) (newName name) (rnType t)
typ -> pure typ
rnTopBind :: Bind -> Rn Bind
rnTopBind = rnBind' False
rnLocalBind :: Bind -> Rn Bind
rnLocalBind = rnBind' True
rnBind' :: Bool -> Bind -> Rn Bind
rnBind' isLocal (Bind name vars rhs) = do
name' <- if isLocal then newName name else getName name
(vars', rhs') <- localNames $ liftA2 (,) (mapM newName vars) (rnExp rhs)
pure (Bind name' vars' rhs')
rnExp :: Exp -> Rn Exp
rnExp = \case
EVar x -> EVar <$> getName x
EInj x -> pure (EInj x)
ELit lit -> pure (ELit lit)
EApp e1 e2 -> onM EApp (localNames . rnExp) e1 e2
EAdd e1 e2 -> onM EAdd (localNames . rnExp) e1 e2
ELet bind e -> liftA2 ELet (rnLocalBind bind) (rnExp e)
EAbs x e -> liftA2 EAbs (newName x) (rnExp e)
EAnn e t -> liftA2 EAnn (rnExp e) (rnType t)
ECase e bs -> liftA2 ECase (rnExp e) (mapM (localNames . rnBranch) bs)
rnBranch :: Branch -> Rn Branch
rnBranch (Branch p e) = liftA2 Branch (rnPattern p) (rnExp e)
rnPattern :: Pattern -> Rn Pattern
rnPattern = \case
PVar x -> PVar <$> newName x
PLit lit -> pure (PLit lit)
PCatch -> pure PCatch
PEnum name -> pure (PEnum name)
PInj name ps -> PInj name <$> mapM rnPattern ps
data Cxt = Cxt data Cxt = Cxt
{ var_counter :: Int { counter :: Int
, tvar_counter :: Int , names :: Map LIdent LIdent
} }
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a} newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name getName :: LIdent -> Rn LIdent
type Names = Map String String getName name = maybeToRightM err =<< gets (Map.lookup name . names)
where err = "Can't find new name " ++ printTree name
renameDefs :: [Def] -> Err [Def] newName :: LIdent -> Rn LIdent
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt newName name = do
name' <- gets (mk name . counter)
modify $ \cxt -> cxt { counter = succ cxt.counter
, names = Map.insert name name' cxt.names
}
pure name'
where where
initNames = Map.fromList [dupe s | DBind (Bind (LIdent s) _ _) <- defs] mk (LIdent name) i = LIdent ("#" ++ show i ++ name)
renameDef :: Def -> Rn Def localNames :: MonadState Cxt m => m b -> m b
renameDef = \case localNames m = do
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ old_names <- gets names
DBind (Bind name vars rhs) -> do m <* modify ( \cxt' -> cxt' { names = old_names })
(new_names, vars') <- newNamesL initNames vars
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name vars' rhs'
DData (Data typ injs) -> do
tvars <- collectTVars [] typ
tvars' <- mapM nextNameTVar tvars
let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs'
where
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet (Bind name vars rhs) e -> do
(new_names, name') <- newNameL old_names name
(new_names', vars') <- newNamesL new_names vars
(new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
(new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- mapAndUnzipM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns b@(Branch patt e) = do
(new_names, patt') <- catchError (evalStateT (renamePattern ns patt) mempty) (\x -> throwError $ x ++ " in pattern '" ++ printTree b ++ "'")
(new_names', e') <- renameExp new_names e
return (new_names', Branch patt' e')
renamePattern :: Names -> Pattern -> StateT (Set LIdent) Rn (Names, Pattern)
renamePattern ns p = case p of
PInj cs ps -> do
(ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps')
PVar name -> do
vs <- get
when (name `Set.member` vs) (throwError $ "Conflicting definitions of '" ++ printTree name ++ "'")
put (Set.insert name vs)
nn <- lift $ newNameL ns name
return $ second PVar nn
_ -> return (ns, p)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar
| tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL
-- | Create a new name and add it to name environment.
newNameL :: Names -> LIdent -> Rn (Names, LIdent)
newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
-- | Create a new name and add it to name environment.
newNameU :: Names -> UIdent -> Rn (Names, UIdent)
newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String
makeName prefix = do
i <- gets var_counter
let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -1,206 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM)
import Control.Monad.Except (ExceptT, MonadError, runExceptT,
throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (MonadState, StateT, evalStateT, gets,
modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (TData cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Inj -> Inj
renameConstr new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet bind e -> do
(new_names, bind') <- renameBind old_names bind
(new_names', e') <- renameExp new_names e
pure (new_names', ELet bind' e')
EAbs par e -> do
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do
(new_names, init') <- renamePattern ns init
(new_names', e') <- renameExp new_names e
return (new_names', Branch init' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of
PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps
return (ns_new, PInj cs ps)
rest -> return (ns, rest)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

70
src/ReportForall.hs Normal file
View file

@ -0,0 +1,70 @@
{-# LANGUAGE LambdaCase #-}
module ReportForall (reportForall) where
import Auxiliary (partitionDefs)
import Control.Monad (unless, void, when)
import Control.Monad.Except (MonadError (throwError))
import Data.Either.Combinators (mapRight)
import Data.Foldable (foldlM)
import Data.Function (on)
import Data.List (delete)
import Grammar.Abs
import Grammar.ErrM (Err)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
reportForall :: TypeChecker -> Program -> Err ()
reportForall tc p = do
when (tc == Hm) $ rpProgram rpaType p
rpProgram rpuType p
rpuType :: Type -> Err ()
rpuType typ = do
tvars <- go [] typ
unless (null tvars) $ throwError "Unused forall"
where
go tvars = \case
TAll tvar t
| tvar `elem` tvars -> throwError "Duplicate forall"
| otherwise -> go (tvar : tvars) t
TVar tvar -> pure (delete tvar tvars)
TFun t1 t2 -> go tvars t1 >>= (`go` t2)
TData _ typs -> foldlM go tvars typs
_ -> pure tvars
rpaType :: Type -> Err ()
rpaType = rpForall . skipForall
where
skipForall = \case
TAll _ t -> skipForall t
t -> t
rpForall = \case
TAll {} -> throwError "Higher rank forall not allowed"
TFun t1 t2 -> on (>>) rpForall t1 t2
TData _ typs -> mapM_ rpForall typs
_ -> pure ()
rpProgram :: (Type -> Err ()) -> Program -> Err ()
rpProgram rf (Program defs) = do
mapM_ rpuBind bs
mapM_ rpuData ds
mapM_ rpuSig ss
where
(ds, ss, bs) = partitionDefs defs
rpuSig (Sig _ typ) = rf typ
rpuData (Data typ injs) = rf typ >> mapM rpuInj injs
rpuInj (Inj _ typ) = rf typ
rpuBind (Bind _ _ rhs) = rpuExp rhs
rpuBranch (Branch _ e) = rpuExp e
rpuExp = \case
EAnn e t -> rpuExp e >> rf t
EApp e1 e2 -> on (>>) rpuExp e1 e2
EAdd e1 e2 -> on (>>) rpuExp e1 e2
ELet bind e -> rpuBind bind >> rpuExp e
EAbs _ e -> rpuExp e
ECase e bs -> rpuExp e >> mapM_ rpuBranch bs
_ -> pure ()
reportAnyForall :: Program -> Err ()
reportAnyForall = undefined

View file

@ -0,0 +1,48 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.RemoveForall (removeForall) where
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2))
import Data.Function (on)
import Data.List (partition)
import Data.Tuple.Extra (second)
import Grammar.ErrM (Err)
import qualified TypeChecker.ReportTEVar as R
import TypeChecker.TypeCheckerIr
removeForall :: Program' R.Type -> Program
removeForall (Program defs) = Program $ map (DData . rfData) ds
++ map (DBind . rfBind) bs
where
(ds, bs) = ([d | DData d <- defs ], [ b | DBind b <- defs ])
rfData (Data typ injs) = Data (rfType typ) (map rfInj injs)
rfInj (Inj name typ) = Inj name (rfType typ)
rfBind (Bind name vars rhs) = Bind (rfId name) (map rfId vars) (rfExpT rhs)
rfId = second rfType
rfExpT (e, t) = (rfExp e, rfType t)
rfExp = \case
EApp e1 e2 -> on EApp rfExpT e1 e2
EAdd e1 e2 -> on EAdd rfExpT e1 e2
ELet bind e -> ELet (rfBind bind) (rfExpT e)
EAbs name e -> EAbs name (rfExpT e)
ECase e bs -> ECase (rfExpT e) (map rfBranch bs)
ELit lit -> ELit lit
EVar name -> EVar name
EInj name -> EInj name
rfBranch (Branch (p, t) e) = Branch (rfPattern p, rfType t) (rfExpT e)
rfPattern = \case
PVar id -> PVar (rfId id)
PLit (lit, t) -> PLit (lit, rfType t)
PCatch -> PCatch
PEnum name -> PEnum name
PInj name ps -> PInj name (map rfPattern ps)
rfType :: R.Type -> Type
rfType = \case
R.TAll _ t -> rfType t
R.TFun t1 t2 -> on TFun rfType t1 t2
R.TData name ts -> TData name (map rfType ts)
R.TLit lit -> TLit lit
R.TVar tvar -> TVar tvar

View file

@ -1,71 +0,0 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.RemoveTEVar where
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import Grammar.Abs
import Grammar.ErrM (Err)
import TypeChecker.TypeCheckerIr qualified as T
class RemoveTEVar a b where
rmTEVar :: a -> Err b
instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where
rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs
instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where
rmTEVar = \case
T.DBind bind -> T.DBind <$> rmTEVar bind
T.DData dat -> T.DData <$> rmTEVar dat
instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where
rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs)
instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where
rmTEVar exp = case exp of
T.EVar name -> pure $ T.EVar name
T.EInj name -> pure $ T.EInj name
T.ELit lit -> pure $ T.ELit lit
T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e)
T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
T.EAdd e1 e2 -> liftA2 T.EAdd (rmTEVar e1) (rmTEVar e2)
T.EAbs name e -> T.EAbs name <$> rmTEVar e
T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches)
instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where
rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e)
instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where
rmTEVar = \case
T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t
T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t
T.PCatch -> pure T.PCatch
T.PEnum name -> pure $ T.PEnum name
T.PInj name ps -> T.PInj name <$> rmTEVar ps
instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where
rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs)
instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where
rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ
instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where
rmTEVar = secondM rmTEVar
instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where
rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ)
instance RemoveTEVar a b => RemoveTEVar [a] [b] where
rmTEVar = mapM rmTEVar
instance RemoveTEVar Type T.Type where
rmTEVar = \case
TLit lit -> pure $ T.TLit (coerce lit)
TVar (MkTVar i) -> pure $ T.TVar (T.MkTVar $ coerce i)
TData name typs -> T.TData (coerce name) <$> rmTEVar typs
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
TAll (MkTVar i) t -> T.TAll (T.MkTVar $ coerce i) <$> rmTEVar t
TEVar _ -> throwError "NewType TEVar!"

View file

@ -0,0 +1,81 @@
{-# LANGUAGE LambdaCase #-}
module TypeChecker.ReportTEVar where
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import qualified Grammar.Abs as G
import Grammar.ErrM (Err)
import TypeChecker.TypeCheckerIr hiding (Type (..))
data Type
= TLit Ident
| TVar TVar
| TData Ident [Type]
| TFun Type Type
| TAll TVar Type
deriving (Eq, Ord, Show, Read)
class ReportTEVar a b where
reportTEVar :: a -> Err b
instance ReportTEVar (Program' G.Type) (Program' Type) where
reportTEVar (Program defs) = Program <$> reportTEVar defs
instance ReportTEVar (Def' G.Type) (Def' Type) where
reportTEVar = \case
DBind bind -> DBind <$> reportTEVar bind
DData dat -> DData <$> reportTEVar dat
instance ReportTEVar (Bind' G.Type) (Bind' Type) where
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
instance ReportTEVar (Exp' G.Type) (Exp' Type) where
reportTEVar exp = case exp of
EVar name -> pure $ EVar name
EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
instance ReportTEVar (Branch' G.Type) (Branch' Type) where
reportTEVar (Branch (patt, t_patt) e) = liftA2 Branch (liftA2 (,) (reportTEVar patt) (reportTEVar t_patt)) (reportTEVar e)
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
reportTEVar = \case
PVar (name, t) -> PVar . (name,) <$> reportTEVar t
PLit (lit, t) -> PLit . (lit,) <$> reportTEVar t
PCatch -> pure PCatch
PEnum name -> pure $ PEnum name
PInj name ps -> PInj name <$> reportTEVar ps
instance ReportTEVar (Data' G.Type) (Data' Type) where
reportTEVar (Data typ injs) = liftA2 Data (reportTEVar typ) (reportTEVar injs)
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
instance ReportTEVar (Id' G.Type) (Id' Type) where
reportTEVar = secondM reportTEVar
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
instance ReportTEVar a b => ReportTEVar [a] [b] where
reportTEVar = mapM reportTEVar
instance ReportTEVar G.Type Type where
reportTEVar = \case
G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar _ -> throwError "NewType TEVar!"

View file

@ -1,17 +1,19 @@
module TypeChecker.TypeChecker (typecheck, TypeChecker (..)) where module TypeChecker.TypeChecker (typecheck, TypeChecker (..)) where
import Control.Monad ((<=<)) import Control.Monad ((<=<))
import Grammar.Abs import qualified Grammar.Abs as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar)) import TypeChecker.RemoveForall (removeForall)
import TypeChecker.TypeCheckerBidir qualified as Bi import qualified TypeChecker.ReportTEVar as R
import TypeChecker.TypeCheckerHm qualified as Hm import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeCheckerIr qualified as T import qualified TypeChecker.TypeCheckerBidir as Bi
import qualified TypeChecker.TypeCheckerHm as Hm
import TypeChecker.TypeCheckerIr
data TypeChecker = Bi | Hm data TypeChecker = Bi | Hm deriving Eq
typecheck :: TypeChecker -> Program -> Err T.Program typecheck :: TypeChecker -> G.Program -> Err Program
typecheck tc = rmTEVar <=< f typecheck tc = fmap removeForall . (reportTEVar <=< f)
where where
f = case tc of f = case tc of
Bi -> Bi.typecheck Bi -> Bi.typecheck

View file

@ -121,6 +121,7 @@ typecheckBind (Bind name vars rhs) = do
, "Did you forget to add type annotation to a polymorphic function?" , "Did you forget to add type annotation to a polymorphic function?"
] ]
-- TODO remove some checks
typecheckDataType :: Data -> Err (T.Data' Type) typecheckDataType :: Data -> Err (T.Data' Type)
typecheckDataType (Data typ injs) = do typecheckDataType (Data typ injs) = do
(name, tvars) <- go [] typ (name, tvars) <- go [] typ
@ -135,6 +136,7 @@ typecheckDataType (Data typ injs) = do
-> pure (name, tvars') -> pure (name, tvars')
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ] _ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
-- TODO remove some checks
typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type) typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
typecheckInj (Inj inj_name inj_typ) name tvars typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ | not $ boundTVars tvars inj_typ
@ -878,18 +880,18 @@ traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure
ppT = \case ppT = \case
TLit (UIdent s) -> s TLit (UIdent s) -> s
TVar (MkTVar (LIdent s)) -> "α_" ++ s TVar (MkTVar (LIdent s)) -> "a_" ++ s
TFun t1 t2 -> ppT t1 ++ "" ++ ppT t2 TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s TEVar (MkTEVar (LIdent s)) -> "a^_" ++ s
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs) TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
++ " )" ++ " )"
ppEnvElem = \case ppEnvElem = \case
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s EnvTVar (MkTVar (LIdent s)) -> "a_" ++ s
EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s EnvTEVar (MkTEVar (LIdent s)) -> "a^_" ++ s
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t EnvTEVarSolved (MkTEVar (LIdent s)) t -> "_" ++ s ++ "=" ++ ppT t
EnvMark (MkTEVar (LIdent s)) -> "" ++ "ά_" ++ s EnvMark (MkTEVar (LIdent s)) -> "" ++ "a^_" ++ s
ppEnv = \case ppEnv = \case
Empty -> "·" Empty -> "·"

View file

@ -1,31 +1,31 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-} {-# LANGUAGE QualifiedDo #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux import qualified Auxiliary as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer import Control.Monad.Writer
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import qualified Data.Set as S
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import qualified TypeChecker.TypeCheckerIr as T
-- TODO: Disallow mutual recursion -- TODO: Disallow mutual recursion
@ -34,7 +34,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
where where
onLeft :: (Error -> String) -> Either Error a -> Either String a onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
@ -118,7 +118,7 @@ preRun (x : xs) = case x of
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where where
-- Check if function body / signature has been declared already -- Check if function body / signature has been declared already
@ -140,11 +140,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident] freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty freeOrdered _ = mempty
checkBind :: Bind -> Infer (T.Bind' Type) checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
@ -178,22 +178,19 @@ checkBind (Bind name args e) = do
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do checkData err@(Data typ injs) = do
(name, tvars) <- go typ (name, tvars) <- go (skipForalls typ)
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
where where
go = \case go = \case
TData name typs TData name typs
| Right tvars' <- mapM toTVar typs -> | Right tvars' <- mapM toTVar typs ->
pure (name, tvars') pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
_ -> _ ->
uncatchableErr $ uncatchableErr $
unwords ["Bad data type definition: ", printTree typ] unwords ["Bad data type definition: ", printTree typ]
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m () checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
checkInj (Inj c inj_typ) name tvars checkInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ =
catchableErr "Unbound type variables"
| TData name' typs <- returnType inj_typ | TData name' typs <- returnType inj_typ
, Right tvars' <- mapM toTVar typs , Right tvars' <- mapM toTVar typs
, name' == name , name' == name
@ -217,27 +214,15 @@ checkInj (Inj c inj_typ) name tvars
, "\nActual: " , "\nActual: "
, printTree $ returnType inj_typ , printTree $ returnType inj_typ
] ]
where
boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case
TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
TFun t1 t2 -> do
t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2
return $ t1' && t2'
TVar tvar -> return $ tvar `elem` tvars'
TData _ typs -> and <$> mapM (boundTVars tvars) typs
TLit _ -> return True
TEVar _ -> error "TEVar in data type declaration"
toTVar :: Type -> Either Error TVar toTVar :: Type -> Either Error TVar
toTVar = \case toTVar = \case
TVar tvar -> pure tvar TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable" _ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type) inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do inferExp e = do
@ -250,7 +235,7 @@ class CollectTVars a where
instance CollectTVars Exp where instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty collectTVars _ = S.empty
instance CollectTVars Type where instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -569,12 +554,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> Type -> Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
@ -611,42 +596,39 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
-- Is the left a subtype of the right -- Is the left a subtype of the right
(<<=) :: Type -> Type -> Bool (<<=) :: Type -> Type -> Bool
(<<=) (TVar _) _ = True (<<=) (TVar _) _ = True
(<<=) (TAll _ t1) (TAll _ t2) = t1 <<= t2 (<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d (<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
(<<=) (TData n1 ts1) (TData n2 ts2) = (<<=) (TData n1 ts1) (TData n2 ts2) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith (<<=) ts1 ts2) && and (zipWith (<<=) ts1 ts2)
(<<=) t0 t@(TAll _ _) = go t0 t
where
go t0 t@(TAll _ t1) = S.toList (free t0) == foralls t && go' t0 t1
go _ _ = undefined
go' (TEVar (MkTEVar a)) (TVar (MkTVar b)) = a == b
go' (TEVar (MkTEVar a)) (TEVar (MkTEVar b)) = a == b
go' (TFun a b) (TFun c d) = a `go'` c && b `go'` d
go' _ _ = False
(<<=) a b = a == b (<<=) a b = a == b
skipForalls :: Type -> Type
skipForalls = \case
TAll _ t -> t
t -> t
foralls :: Type -> [T.Ident] foralls :: Type -> [T.Ident]
foralls (TAll (MkTVar a) t) = coerce a : foralls t foralls (TAll (MkTVar a) t) = coerce a : foralls t
foralls _ = [] foralls _ = []
mkForall :: Type -> Type mkForall :: Type -> Type
mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of
[] -> t [] -> t
(x : xs) -> (x : xs) ->
let f acc [] = acc let f acc [] = acc
f acc (x : xs) = f (x acc) xs f acc (x : xs) = f (x acc) xs
(y : ys) = reverse $ x : xs (y : ys) = reverse $ x : xs
in f (y t) ys in f (y t) ys
skolemize :: Type -> Type skolemize :: Type -> Type
skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a
skolemize (TAll x t) = TAll x (skolemize t) skolemize (TAll x t) = TAll x (skolemize t)
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolemize (TData n ts) = TData n (map skolemize ts) skolemize (TData n ts) = TData n (map skolemize ts)
skolemize t = t skolemize t = t
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
@ -680,10 +662,10 @@ instance SubstType Type where
TLit _ -> t TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a) Nothing -> TVar (MkTVar $ coerce a)
Just t -> t Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t) Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a) TData name a -> TData name (apply sub a)
TEVar (MkTEVar _) -> t TEVar (MkTEVar _) -> t
@ -728,10 +710,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where instance SubstType (T.Pattern' Type) where
apply s = \case apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t) T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t) T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t) apply s (p, t) = (apply s p, apply s t)
@ -773,10 +755,10 @@ withBindings xs =
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
withPattern p ma = case p of withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma T.PLit _ -> ma
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer () insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -801,11 +783,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: Type -> Int typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1 typeLength _ = 1
{- | Catch an error if possible and add the given {- | Catch an error if possible and add the given
expression as addition to the error message expression as addition to the error message
@ -888,11 +870,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show) deriving (Show)
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident , declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)

View file

@ -1,4 +1,4 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr ( module TypeChecker.TypeCheckerIr (
@ -6,11 +6,11 @@ module TypeChecker.TypeCheckerIr (
module TypeChecker.TypeCheckerIr, module TypeChecker.TypeCheckerIr,
) where ) where
import Data.String (IsString) import Data.String (IsString)
import Grammar.Abs (Lit (..)) import Grammar.Abs (Lit (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show) import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
@ -25,8 +25,7 @@ data Type
| TVar TVar | TVar TVar
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
| TAll TVar Type deriving (Eq, Ord, Show, Read)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Data' t = Data t [Inj' t] data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
@ -105,8 +104,8 @@ instance Print t => Print (ExpT' t) where
] ]
instance Print t => Print [Bind' t] where instance Print t => Print [Bind' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Print t => Int -> [Id' t] -> Doc prtIdPs :: Print t => Int -> [Id' t] -> Doc
@ -171,13 +170,13 @@ instance Print t => Print (Branch' t) where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print t => Print [Branch' t] where instance Print t => Print [Branch' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print t => Print (Def' t) where instance Print t => Print (Def' t) where
prt i = \case prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_]) DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print t => Print (Data' t) where instance Print t => Print (Data' t) where
@ -202,12 +201,12 @@ instance Print t => Print (Pattern' t) where
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print t => Print [Def' t] where instance Print t => Print [Def' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where instance Print [Type] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where instance Print Type where
@ -216,7 +215,6 @@ instance Print Type where
TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) TVar tvar -> prPrec i 1 (concatD [prt 0 tvar])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
instance Print TVar where instance Print TVar where
prt i (MkTVar ident) = prt i ident prt i (MkTVar ident) = prt i ident

View file

@ -1,10 +1,16 @@
module Main where module Main where
import Test.Hspec import Test.Hspec
import TestAnnForall (testAnnForall)
import TestRenamer (testRenamer)
import TestReportForall (testReportForall)
import TestTypeCheckerBidir (testTypeCheckerBidir) import TestTypeCheckerBidir (testTypeCheckerBidir)
import TestTypeCheckerHm (testTypeCheckerHm) import TestTypeCheckerHm (testTypeCheckerHm)
main = hspec $ do main = hspec $ do
testReportForall
testAnnForall
testRenamer
testTypeCheckerBidir testTypeCheckerBidir
testTypeCheckerHm testTypeCheckerHm

113
tests/TestAnnForall.hs Normal file
View file

@ -0,0 +1,113 @@
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE QualifiedDo #-}
module TestAnnForall (testAnnForall, test) where
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import qualified DoStrings as D
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import Test.Hspec (describe, hspec, shouldBe,
shouldNotSatisfy, shouldSatisfy,
shouldThrow, specify)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
test = hspec testAnnForall
testAnnForall = describe "Test AnnForall" $ do
ann_data1
ann_data2
ann_bad_data1
ann_bad_data2
ann_bad_data3
ann_sig1
ann_sig2
ann_bind
ann_data1 = specify "Annotate data type" $
D.do "data Either (a b) where"
" Left : a -> Either (a b)"
" Right : b -> Either (a b)"
`shouldBePrg`
D.do "data forall a. forall b. Either (a b) where"
" Left : a -> Either (a b)"
" Right : b -> Either (a b)"
ann_data2 = specify "Annotate constructor with additional type variable" $
D.do "data forall a. forall b. Either (a b) where"
" Left : c -> a -> Either (a b)"
" Right : b -> Either (a b)"
`shouldBePrg`
D.do "data forall a. forall b. Either (a b) where"
" Left : forall c. c -> a -> Either (a b)"
" Right : b -> Either (a b)"
ann_bad_data1 = specify "Bad data type variables" $
D.do "data Either (Int b) where"
" Left : a -> Either (a b)"
" Right : b -> Either (a b)"
`shouldBeErr`
"Misformed data declaration: Non type variable argument"
ann_bad_data2 = specify "Bad data identifer" $
D.do "data Int -> Either (a b) where"
" Left : a -> Either (a b)"
" Right : b -> Either (a b)"
`shouldBeErr`
"Misformed data declaration"
ann_bad_data3 = specify "Constructor forall duplicate" $
D.do "data Int -> Either (a b) where"
" Left : forall a. a -> Either (a b)"
" Right : b -> Either (a b)"
`shouldBeErr`
"Misformed data declaration"
ann_sig1 = specify "Annotate signature" $
"f : a -> b -> (forall a. a -> a) -> a"
`shouldBePrg`
"f : forall a. forall b. a -> b -> (forall a. a -> a) -> a"
ann_sig2 = specify "Annotate signature 2" $
D.do "const : forall a. forall b. a -> b -> a"
"const x y = x"
"main = const 'a' 65"
`shouldBePrg`
D.do "const : forall a. forall b. a -> b -> a"
"const x y = x"
"main = const 'a' 65"
ann_bind = specify "Annotate bind" $
"f = (\\x.\\y. x : a -> b -> a) 4"
`shouldBePrg`
"f = (\\x.\\y. x : forall a. forall b. a -> b -> a) 4"
shouldBeErr s err = run s `shouldBe` Bad err
shouldBePrg s1 s2
| Ok p2 <- run' s2 = run s1 `shouldBe` Ok p2
| otherwise = error ("Faulty expectation \n" ++ show (run' s2))
run = annotateForall <=< run'
run' s = do
p <- run'' s
reportForall Bi p
pure p
run'' = pProgram . resolveLayout True . myLexer
runPrint = (putStrLn . either show printTree . run) $
D.do "data forall a. forall b. Either (a b) where"
" Left : c -> a -> Either (a b)"
" Right : b -> Either (a b)"

96
tests/TestRenamer.hs Normal file
View file

@ -0,0 +1,96 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE QualifiedDo #-}
module TestRenamer (testRenamer, test, runPrint) where
import AnnForall (annotateForall)
import Control.Exception (ErrorCall (ErrorCall),
Exception (displayException),
SomeException (SomeException),
evaluate, try)
import Control.Exception.Extra (try_)
import Control.Monad (unless, (<=<))
import Control.Monad.Except (throwError)
import Data.Either.Extra (fromEither)
import qualified DoStrings as D
import GHC.Generics (Generic, Generic1)
import Grammar.Abs (Program (Program))
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import System.IO.Error (catchIOError, tryIOError)
import Test.Hspec (anyErrorCall, anyException,
describe, hspec, shouldBe,
shouldNotSatisfy, shouldReturn,
shouldSatisfy, shouldThrow,
specify)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
-- FIXME tests sucks
test = hspec testRenamer
testRenamer = describe "Test Renamer" $ do
rn_data1
rn_data2
rn_sig
rn_bind1
rn_bind2
rn_data1 = specify "Rename data type" . shouldSatisfyOk $
D.do "data forall a. forall b. Either (a b) where"
" Left : a -> Either (a b)"
" Right : b -> Either (a b)"
rn_data2 = specify "Rename data type forall in constructor " . shouldSatisfyOk $
D.do "data forall a. forall b. Either (a b) where"
" Left : forall c. c -> a -> Either (a b)"
" Right : b -> Either (a b)"
rn_sig = specify "Rename signature" $ shouldSatisfyOk
"f : forall a. forall b. a -> b -> (forall a. a -> a) -> a"
rn_bind1 = specify "Rename simple bind" $ shouldSatisfyOk
"f x = (\\y. let y2 = y + 1 in y2) (x + 1)"
rn_bind2 = specify "Rename bind with case" . shouldSatisfyOk $
D.do "data forall a. List (a) where"
" Nil : List (a) "
" Cons : a -> List (a) -> List (a)"
"length : forall a. List (a) -> Int"
"length list = case list of"
" Nil => 0"
" Cons x Nil => 1"
" Cons x (Cons y ys) => 2 + length ys"
runPrint = putStrLn . either show printTree . run $
D.do "data forall a. List (a) where"
" Nil : List (a) "
" Cons : a -> List (a) -> List (a)"
"length : forall a. List (a) -> Int"
"length list = case list of"
" Nil => 0"
" Cons x Nil => 1"
" Cons x (Cons y ys) => 2 + length ys"
shouldSatisfyOk s = run s `shouldSatisfy` ok
ok = \case
Ok !_ -> True
Bad !_ -> False
shouldBeErr s err = run s `shouldBe` Bad err
run = rename <=< run'
run' = pProgram . resolveLayout True . myLexer

47
tests/TestReportForall.hs Normal file
View file

@ -0,0 +1,47 @@
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestReportForall (testReportForall, test) where
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import qualified DoStrings as D
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import Test.Hspec (describe, hspec, shouldBe,
shouldNotSatisfy, shouldSatisfy,
shouldThrow, specify)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
testReportForall = describe "Test ReportForall" $ do
rp_unused1
rp_unused2
rp_forall
test = hspec testReportForall
rp_unused1 = specify "Unused forall 1" $
"g : forall a. forall a. a -> (forall a. a -> a) -> a"
`shouldBeErrBi`
"Duplicate forall"
rp_unused2 = specify "Unused forall 2" $
"g : forall a. (forall a. a -> a) -> Int"
`shouldBeErrBi`
"Unused forall"
rp_forall = specify "Rank2 forall with Hm" $
"f : a -> b -> (forall a. a -> a) -> a"
`shouldBeErrHm`
"Higher rank forall not allowed"
shouldBeErrBi = shouldBeErr Bi
shouldBeErrHm = shouldBeErr Hm
shouldBeErr tc s err = run tc s `shouldBe` Bad err
run tc = reportForall tc <=< pProgram . resolveLayout True . myLexer

View file

@ -8,19 +8,25 @@ module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<)) import Control.Monad ((<=<))
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok) import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout) import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar)) import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck) import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
test = hspec testTypeCheckerBidir test = hspec testTypeCheckerBidir
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do testTypeCheckerBidir = describe "Test Bidirectional type checker" $ do
tc_id tc_id
tc_double tc_double
tc_add_lam tc_add_lam
@ -39,7 +45,7 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_id = tc_id =
specify "Basic identity function polymorphism" $ specify "Basic identity function polymorphism" $
run run
[ "id : forall a. a -> a" [ "id : a -> a"
, "id x = x" , "id x = x"
, "main = id 4" , "main = id 4"
] ]
@ -60,7 +66,7 @@ tc_add_lam =
tc_const = tc_const =
specify "Basic polymorphism with multiple type variables" $ specify "Basic polymorphism with multiple type variables" $
run run
[ "const : forall a. forall b. a -> b -> a" [ "const : a -> b -> a"
, "const x y = x" , "const x y = x"
, "main = const 'a' 65" , "main = const 'a' 65"
] ]
@ -69,9 +75,9 @@ tc_const =
tc_simple_rank2 = tc_simple_rank2 =
specify "Simple rank two polymorphism" $ specify "Simple rank two polymorphism" $
run run
[ "id : forall a. a -> a" [ "id : a -> a"
, "id x = x" , "id x = x"
, "f : forall a. a -> (forall b. b -> b) -> a" , "f : a -> (forall b. b -> b) -> a"
, "f x g = g x" , "f x g = g x"
, "main = f 4 id" , "main = f 4 id"
] ]
@ -80,11 +86,11 @@ tc_simple_rank2 =
tc_rank2 = tc_rank2 =
specify "Rank two polymorphism is ok" $ specify "Rank two polymorphism is ok" $
run run
[ "const : forall a. forall b. a -> b -> a" [ "const : a -> b -> a"
, "const x y = x" , "const x y = x"
, "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int" , "rank2 : a -> (forall c. c -> Int) -> b -> Int"
, "rank2 x f y = f x + f y" , "rank2 x f y = f x + f y"
, "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h'" , "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
] ]
`shouldSatisfy` ok `shouldSatisfy` ok
@ -93,9 +99,9 @@ tc_identity = describe "(∀b. b → b) should only accept the identity function
specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok
where where
fs = fs =
[ "f : forall a. a -> (forall b. b -> b) -> a" [ "f : a -> (forall b. b -> b) -> a"
, "f x g = g x" , "f x g = g x"
, "id : forall a. a -> a" , "id : a -> a"
, "id x = x" , "id x = x"
, "id_int : Int -> Int" , "id_int : Int -> Int"
, "id_int x = x" , "id_int x = x"
@ -114,7 +120,7 @@ tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where where
fs = fs =
[ "data forall a. forall b. Pair (a b) where" [ "data Pair (a b) where"
, " Pair : a -> b -> Pair (a b)" , " Pair : a -> b -> Pair (a b)"
, "main : Pair (Int Char)" , "main : Pair (Int Char)"
] ]
@ -126,7 +132,7 @@ tc_tree = describe "Tree. Recursive data type" $ do
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where where
fs = fs =
[ "data forall a. Tree (a) where" [ "data Tree (a) where"
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)" , " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
, " Leaf : a -> Tree (a)" , " Leaf : a -> Tree (a)"
] ]
@ -195,30 +201,30 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
run (fs ++ correct4) `shouldSatisfy` ok run (fs ++ correct4) `shouldSatisfy` ok
where where
fs = fs =
[ "data forall a. List (a) where" [ "data List (a) where"
, " Nil : List (a)" , " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)" , " Cons : a -> List (a) -> List (a)"
] ]
wrong1 = wrong1 =
[ "length : forall c. List (c) -> Int" [ "length : List (c) -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " Cons 6 xs => 1 + length xs" , " Cons 6 xs => 1 + length xs"
] ]
wrong2 = wrong2 =
[ "length : forall c. List (c) -> Int" [ "length : List (c) -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " Cons => 0" , " Cons => 0"
, " Cons x xs => 1 + length xs" , " Cons x xs => 1 + length xs"
] ]
wrong3 = wrong3 =
[ "length : forall c. List (c) -> Int" [ "length : List (c) -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " 0 => 0" , " 0 => 0"
, " Cons x xs => 1 + length xs" , " Cons x xs => 1 + length xs"
] ]
wrong4 = wrong4 =
[ "elems : forall c. List (List(c)) -> Int" [ "elems : List (List(c)) -> Int"
, "elems = \\list. case list of" , "elems = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " Cons Nil Nil => 0" , " Cons Nil Nil => 0"
@ -226,14 +232,14 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs)" , " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs)"
] ]
correct1 = correct1 =
[ "length : forall c. List (c) -> Int" [ "length : List (c) -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " Cons x xs => 1 + length xs" , " Cons x xs => 1 + length xs"
, " Cons x (Cons y Nil) => 2" , " Cons x (Cons y Nil) => 2"
] ]
correct2 = correct2 =
[ "length : forall c. List (c) -> Int" [ "length : List (c) -> Int"
, "length = \\list. case list of" , "length = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " non_empty => 1" , " non_empty => 1"
@ -246,7 +252,7 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons x (Cons 2 xs) => 2 + length xs" , " Cons x (Cons 2 xs) => 2 + length xs"
] ]
correct4 = correct4 =
[ "elems : forall c. List (List(c)) -> Int" [ "elems : List (List(c)) -> Int"
, "elems = \\list. case list of" , "elems = \\list. case list of"
, " Nil => 0" , " Nil => 0"
, " Cons Nil Nil => 0" , " Cons Nil Nil => 0"
@ -292,9 +298,19 @@ tc_rec2 = specify "Infer recursive definition with pattern matching" $ run
, " _ => test (x+1)" , " _ => test (x+1)"
] `shouldSatisfy` ok ] `shouldSatisfy` ok
run :: [String] -> Err T.Program run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . resolveLayout True . myLexer . unlines run = fmap removeForall
. reportTEVar
<=< typecheck
<=< run'
run' s = do
p <- (pProgram . resolveLayout True . myLexer . unlines) s
reportForall Bi p
(rename <=< annotateForall) p
runPrint = (putStrLn . either show printTree . run')
["double x = x + x"]
ok = \case ok = \case
Ok _ -> True Ok _ -> True

View file

@ -1,23 +1,25 @@
{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE QualifiedDo #-}
module TestTypeCheckerHm where module TestTypeCheckerHm where
import Control.Monad ((<=<)) import Control.Monad (sequence_, (<=<))
import qualified DoStrings as D
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Prelude (Bool (..), Either (..), fmap,
foldl1, fst, not, ($), (.), (>>))
import Test.Hspec import Test.Hspec
-- import Test.QuickCheck import AnnForall (annotateForall)
import qualified DoStrings as D
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.TypeChecker (TypeChecker (Hm))
import TypeChecker.TypeCheckerHm (typecheck) import TypeChecker.TypeCheckerHm (typecheck)
import TypeChecker.TypeCheckerIr (Program)
testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do
foldl1 (>>) goods sequence_ goods
foldl1 (>>) bads sequence_ bads
foldl1 (>>) bes sequence_ bes
goods = goods =
[ testSatisfy [ testSatisfy
@ -118,26 +120,29 @@ bads =
" };" " };"
) )
bad bad
, testSatisfy -- FIXME FAILING TEST
"id with incorrect signature" -- , testSatisfy
( D.do -- "id with incorrect signature"
"id : a -> b;" -- ( D.do
"id x = x;" -- "id : a -> b;"
) -- "id x = x;"
bad -- )
, testSatisfy -- bad
"incorrect signature on const" -- FIXME FAILING TEST
( D.do -- , testSatisfy
"const : a -> b -> b;" -- "incorrect signature on const"
"const x y = x" -- ( D.do
) -- "const : a -> b -> b;"
bad -- "const x y = x"
, testSatisfy -- )
"incorrect type signature on id lambda" -- bad
( D.do -- FIXME FAILING TEST
"id = ((\\x. x) : a -> b);" -- , testSatisfy
) -- "incorrect type signature on id lambda"
bad -- ( D.do
-- "id = ((\\x. x) : a -> b);"
-- )
-- bad
] ]
bes = bes =
@ -211,6 +216,11 @@ testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe
run = fmap (printTree . fst) . typecheck <=< pProgram . myLexer run = fmap (printTree . fst) . typecheck <=< pProgram . myLexer
run' s = do
p <- (pProgram . resolveLayout True . myLexer) s
reportForall Hm p
(rename <=< annotateForall) p
ok (Right _) = True ok (Right _) = True
ok (Left _) = False ok (Left _) = False