Merge remote-tracking branch 'origin/typechecking-merge' into pattern-matching-with-typechecking

This commit is contained in:
Samuel Hammersberg 2023-03-23 16:33:05 +01:00
commit d3d173eb59
21 changed files with 1052 additions and 476 deletions

4
.gitignore vendored
View file

@ -4,5 +4,5 @@ dist-newstyle
*.bak *.bak
src/Grammar src/Grammar
language language
llvm.ll test_program_result
output output/

View file

@ -1,51 +1,94 @@
-------------------------------------------------------------------------------
-- * PROGRAM
-------------------------------------------------------------------------------
Program. Program ::= [Def] ; Program. Program ::= [Def] ;
-------------------------------------------------------------------------------
-- * TOP-LEVEL
-------------------------------------------------------------------------------
DBind. Def ::= Bind ; DBind. Def ::= Bind ;
DSig. Def ::= Sig ;
DData. Def ::= Data ; DData. Def ::= Data ;
separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";" Sig. Sig ::= LIdent ":" Type ;
Ident [Ident] "=" Exp ;
Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ; Bind. Bind ::= LIdent [LIdent] "=" Exp ;
Constructor. Constructor ::= Ident ":" Type ; -------------------------------------------------------------------------------
separator nonempty Constructor "" ; -- * TYPES
-------------------------------------------------------------------------------
TMono. Type1 ::= Ident ; TLit. Type2 ::= UIdent ;
TPol. Type1 ::= "'" Ident ; TVar. Type2 ::= TVar ;
TConstr. Type1 ::= Constr ; TAll. Type1 ::= "forall" TVar "." Type ;
TArr. Type ::= Type1 "->" Type ; TIndexed. Type1 ::= Indexed ;
internal TEVar. Type1 ::= TEVar ;
TFun. Type ::= Type1 "->" Type ;
Constr. Constr ::= Ident "(" [Type] ")" ; MkTVar. TVar ::= LIdent ;
internal MkTEVar. TEVar ::= LIdent ;
-------------------------------------------------------------------------------
-- * DATA TYPES
-------------------------------------------------------------------------------
Constructor. Constructor ::= UIdent ":" Type ;
Indexed. Indexed ::= UIdent "(" [Type] ")" ;
Data. Data ::= "data" Indexed "where" "{" [Constructor] "}" ;
-------------------------------------------------------------------------------
-- * EXPRESSIONS
-------------------------------------------------------------------------------
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ; EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ; EVar. Exp4 ::= LIdent ;
ELit. Exp4 ::= Literal ; ECons. Exp4 ::= UIdent ;
ELit. Exp4 ::= Lit ;
EApp. Exp3 ::= Exp3 Exp4 ; EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ; EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ESub. Exp1 ::= Exp1 "-" Exp2 ; ELet. Exp ::= "let" LIdent "=" Exp "in" Exp ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ; EAbs. Exp ::= "\\" LIdent "." Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}"; ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
LInt. Literal ::= Integer ; -------------------------------------------------------------------------------
-- * LITERALS
-------------------------------------------------------------------------------
LInt. Lit ::= Integer ;
LChar. Lit ::= Char ;
-------------------------------------------------------------------------------
-- * CASE
-------------------------------------------------------------------------------
Inj. Inj ::= Init "=>" Exp ; Inj. Inj ::= Init "=>" Exp ;
separator nonempty Inj ";" ;
InitLit. Init ::= Literal ; InitLit. Init ::= Lit ;
InitConstr. Init ::= Ident [Ident] ; InitConstructor. Init ::= UIdent [LIdent] ;
InitCatch. Init ::= "_" ; InitCatch. Init ::= "_" ;
separator Type " " ; -------------------------------------------------------------------------------
coercions Type 2 ; -- * AUX
-------------------------------------------------------------------------------
separator Def ";" ;
separator nonempty Constructor "" ;
separator Type " " ;
separator nonempty Inj ";" ;
separator Ident " "; separator Ident " ";
separator LIdent " ";
separator TVar " " ;
coercions Exp 5 ; coercions Exp 5 ;
coercions Type 2 ;
token UIdent (upper (letter | digit | '_')*) ;
token LIdent (lower (letter | digit | '_')*) ;
comment "--" ; comment "--" ;
comment "{-" "-}" ; comment "{-" "-}" ;

22
Justfile Normal file
View file

@ -0,0 +1,22 @@
build:
bnfc -o src -d Grammar.cf
# clean the generated directories
clean:
rm -r src/Grammar
rm language
# run all tests
test:
cabal test
ctest:
cabal run language sample-programs/basic-1
cabal run language sample-programs/basic-2
cabal run language sample-programs/basic-3
cabal run language sample-programs/basic-4
cabal run language sample-programs/basic-5
# compile a specific file
run FILE:
cabal run language {{FILE}}

View file

@ -28,10 +28,11 @@ test :
./language ./sample-programs/basic-3 ./language ./sample-programs/basic-3
./language ./sample-programs/basic-4 ./language ./sample-programs/basic-4
./language ./sample-programs/basic-5 ./language ./sample-programs/basic-5
./language ./sample-programs/basic-5
./language ./sample-programs/basic-6 ./language ./sample-programs/basic-6
./language ./sample-programs/basic-7 ./language ./sample-programs/basic-7
./language ./sample-programs/basic-8 ./language ./sample-programs/basic-8
./language ./sample-programs/basic-9
run :
cabal -v0 new-run language -- "test_program"
# EOF # EOF

219
Session.vim Normal file
View file

@ -0,0 +1,219 @@
let SessionLoad = 1
let s:so_save = &g:so | let s:siso_save = &g:siso | setg so=0 siso=0 | setl so=-1 siso=-1
let v:this_session=expand("<sfile>:p")
silent only
silent tabonly
cd ~/Documents/bachelor_thesis/language
if expand('%') == '' && !&modified && line('$') <= 1 && getline(1) == ''
let s:wipebuf = bufnr('%')
endif
let s:shortmess_save = &shortmess
if &shortmess =~ 'A'
set shortmess=aoOA
else
set shortmess=aoO
endif
badd +1 ~/Documents/bachelor_thesis/language
badd +298 src/TypeChecker/TypeChecker.hs
badd +7 test_program
badd +46 src/TypeChecker/TypeCheckerIr.hs
badd +6 Grammar.cf
badd +1 src/Grammar/Abs.hs
argglobal
%argdel
$argadd ~/Documents/bachelor_thesis/language
set stal=2
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabrewind
edit src/TypeChecker/TypeChecker.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
argglobal
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 298 - ((18 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 298
normal! 029|
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/Grammar.cf", ":p")) | buffer ~/Documents/bachelor_thesis/language/Grammar.cf | else | edit ~/Documents/bachelor_thesis/language/Grammar.cf | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/Grammar.cf
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
argglobal
balt ~/Documents/bachelor_thesis/language/test_program
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs", ":p")) | buffer ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | else | edit ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/Grammar.cf
argglobal
balt ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 40 - ((12 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 40
normal! 0
lcd ~/Documents/bachelor_thesis/language
tabnext
edit ~/Documents/bachelor_thesis/language/test_program
argglobal
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 010|
lcd ~/Documents/bachelor_thesis/language
tabnext 1
set stal=1
if exists('s:wipebuf') && len(win_findbuf(s:wipebuf)) == 0 && getbufvar(s:wipebuf, '&buftype') isnot# 'terminal'
silent exe 'bwipe ' . s:wipebuf
endif
unlet! s:wipebuf
set winheight=1 winwidth=20
let &shortmess = s:shortmess_save
let s:sx = expand("<sfile>:p:r")."x.vim"
if filereadable(s:sx)
exe "source " . fnameescape(s:sx)
endif
let &g:so = s:so_save | let &g:siso = s:siso_save
set hlsearch
nohlsearch
doautoall SessionLoadPost
unlet SessionLoad
" vim: set ft=vim :

2
cabal.project.local~ Normal file
View file

@ -0,0 +1,2 @@
ignore-project: False
tests: False

View file

@ -1,14 +1 @@
indentation: 4
function-arrows: trailing
comma-style: leading
import-export-style: diff-friendly
indent-wheres: false indent-wheres: false
record-brace-space: false
newlines-between-decls: 1
haddock-style: multi-line
haddock-style-module:
let-style: auto
in-style: right-align
respectful: true
fixities: []
unicode: never

View file

@ -12,11 +12,9 @@ build-type: Simple
extra-doc-files: CHANGELOG.md extra-doc-files: CHANGELOG.md
extra-source-files: extra-source-files:
Grammar.cf Grammar.cf
common warnings common warnings
ghc-options: -W ghc-options: -W
@ -32,16 +30,14 @@ executable language
Grammar.Print Grammar.Print
Grammar.Skel Grammar.Skel
Grammar.ErrM Grammar.ErrM
LambdaLifter.LambdaLifter
Auxiliary Auxiliary
Renamer.Renamer
TypeChecker.TypeChecker TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
Monomorphizer.Monomorphizer Renamer.Renamer
Monomorphizer.MonomorphizerIr -- LambdaLifter.LambdaLifter
-- Interpreter -- Codegen.Codegen
Codegen.Codegen -- Codegen.LlvmIr
Codegen.LlvmIr
hs-source-dirs: src hs-source-dirs: src
build-depends: build-depends:
@ -49,7 +45,39 @@ executable language
, mtl , mtl
, containers , containers
, either , either
, array
, extra , extra
, directory , array
, hspec
, QuickCheck
default-language: GHC2021
Test-suite language-testsuite
type: exitcode-stdio-1.0
main-is: Tests.hs
other-modules:
Grammar.Abs
Grammar.Lex
Grammar.Par
Grammar.Print
Grammar.Skel
Grammar.ErrM
Auxiliary
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
hs-source-dirs: src, tests
build-depends:
base >=4.16
, mtl
, containers
, either
, extra
, array
, hspec
, QuickCheck
default-language: GHC2021 default-language: GHC2021

View file

@ -1,29 +1,8 @@
posMul: _Int - > _Int - > _Int; posMul: _Int - > _Int - > _Int;
posMul a b = a + b; {-case b of { posMul a b = a + b; {
-
case b of {
0 => 0; 0 => 0;
_ => a + posMul a(b - 1) _ => a + posMul a(b - 1)
};-} }; -
}
main : _Int;
main = posMul 5 10;
--
-- facc : _Int -> _Int;
-- facc a = case a of {
-- 1 => 1;
-- _ => posMul a (facc (a - 1))
-- };
--
-- minimization : (_Int -> _Int) -> _Int -> _Int;
-- minimization p x = case p x of {
-- 1 => x;
-- _ => minimization p (x + 1)
-- };
--
-- checkFac : _Int -> _Int;
-- checkFac x = case facc x of {
-- 0 => 1;
-- _ => 0
-- };
--
-- main : _Int;
-- main = minimization checkFac 1

5
sample-programs/basic-2 Normal file
View file

@ -0,0 +1,5 @@
add : Int -> Int -> Int ;
add x = \y. x+y;
main : Int ;
main = (\z. z+z) ((add 4) 6) ;

2
sample-programs/basic-3 Normal file
View file

@ -0,0 +1,2 @@
main : Int ;
main = (\x. x+x+3) ((\x. x) 2) ;

2
sample-programs/basic-4 Normal file
View file

@ -0,0 +1,2 @@
f : Int -> Int ;
f x = let g = (\y. y+1) in g (g x)

8
sample-programs/basic-5 Normal file
View file

@ -0,0 +1,8 @@
double : Int -> Int ;
double n = n + n;
id : forall a. a -> a ;
id x = x ;
main : Int ;
main = id double 5;

10
sample-programs/basic-6 Normal file
View file

@ -0,0 +1,10 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
main : Bool () -> Int ;
main b = case b of {
False => 0;
True => 0
}

10
sample-programs/basic-7 Normal file
View file

@ -0,0 +1,10 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
ifThenElse : forall a. Bool () -> a -> a -> a;
ifThenElse b if else = case b of {
True => if;
False => else
}

24
sample-programs/basic-8 Normal file
View file

@ -0,0 +1,24 @@
data Maybe (a) where {
Nothing : Maybe (a)
Just : a -> Maybe (a)
};
fromJust : Maybe (a) -> a ;
fromJust a =
case a of {
Just a => a
};
fromMaybe : a -> Maybe (a) -> a ;
fromMaybe a b =
case b of {
Just a => a;
Nothing => a
};
maybe : b -> (a -> b) -> Maybe (a) -> b;
maybe b f ma =
case ma of {
Just a => f a;
Nothing => b
}

View file

@ -1,50 +1,105 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module Renamer.Renamer where module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM) import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Applicative (Applicative (liftA2))
modify) import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError)
import Data.List (foldl') import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (
MonadState,
StateT,
evalStateT,
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import Data.Map qualified as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe) import Data.Tuple.Extra (dupe)
import Grammar.Abs import Grammar.Abs
-- | Rename all variables and local binds -- | Rename all variables and local binds
rename :: Program -> Program rename :: Program -> Either String Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0 rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where where
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc renameDef :: Def -> Rn Def
saveIfBind acc _ = acc renameDef = \case
renameSc :: Names -> Def -> Rn Def DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
renameSc old_names (DBind (Bind name t _ parms rhs)) = do DBind (Bind name vars rhs) -> do
(new_names, parms') <- newNames old_names parms (new_names, vars') <- newNames initNames (coerce vars)
rhs' <- snd <$> renameExp new_names rhs rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs' pure . DBind $ Bind name (coerce vars') rhs'
renameSc _ def = pure def DData (Data (Indexed cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (Indexed cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TIndexed _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
renameConstr new_types (Constructor name typ) =
Constructor name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names. -- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: State Int a} newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Int) deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name -- | Maps old to new name
type Names = Map Ident Ident type Names = Map LIdent LIdent
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
renameExp :: Names -> Exp -> Rn (Names, Exp) renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names) EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1)) ECons n -> pure (old_names, ECons n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
@ -53,25 +108,25 @@ renameExp old_names = \case
(env1, e1') <- renameExp old_names e1 (env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2 (env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2') pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1 -- TODO fix shadowing
(env2, e2') <- renameExp old_names e2 ELet name rhs e -> do
pure (Map.union env1 env2, ESub e1' e2') (new_names, name') <- newName old_names (coerce name)
ELet i e1 e2 -> do (new_names', rhs') <- renameExp new_names rhs
(new_names, e1') <- renameExp old_names e1 (new_names'', e') <- renameExp new_names' e
(new_names', e2') <- renameExp new_names e2 pure (new_names'', ELet (coerce name') rhs' e')
pure (new_names', ELet i e1' e2')
EAbs par e -> do EAbs par e -> do
(new_names, par') <- newName old_names par (new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e (new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e') pure (new_names', EAbs (coerce par') e')
EAnn e t -> do EAnn e t -> do
(new_names, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t) t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do ECase e injs -> do
(_, e') <- renameExp old_names e (new_names, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs (new_names', injs') <- renameInjs new_names injs
pure (new_names, ECase e' injs') pure (new_names', ECase e' injs')
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj]) renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do renameInjs ns xs = do
@ -80,19 +135,64 @@ renameInjs ns xs = do
renameInj :: Names -> Inj -> Rn (Names, Inj) renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e (new_names, init') <- renameInit ns init
return (new_names, Inj init e') (new_names', e') <- renameExp new_names e
return (new_names', Inj init' e')
renameInit :: Names -> Init -> Rn (Names, Init)
renameInit ns i = case i of
InitConstructor cs vars -> do
(ns_new, vars') <- newNames ns (coerce vars)
return (ns_new, InitConstructor cs (coerce vars'))
rest -> return (ns, rest)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment. -- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident) newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do newName env old_name = do
new_name <- makeName old_name new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name) pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment -- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident]) newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@. -- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident makeName :: LIdent -> Rn LIdent
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -7,19 +7,27 @@ module TypeChecker.TypeChecker where
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_) import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity) import Data.Functor.Identity (runIdentity)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr (
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer, Ctx (..),
Poly (..), Subst) Env (..),
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty initCtx = Ctx mempty
@ -37,51 +45,17 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg typecheck = run . checkPrg
{- | Start by freshening the type variable of data types to avoid clash with
other user defined polymorphic types
This might be wrong for type constructors that work over several variables
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
{- | Freshen all polymorphic variables, regardless of name
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
-}
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol _) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) ->
TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
checkData :: Data -> Infer () checkData :: Data -> Infer ()
checkData d = do checkData d = do
d' <- freshenData d case d of
case d' of (Data typ@(Indexed name ts) constrs) -> do
(Data typ@(Constr name ts) constrs) -> do
unless unless
(all isPoly ts) (all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"]) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ traverse_
( \(Constructor name' t') -> ( \(Constructor name' t') ->
if TConstr typ == retType t' if TIndexed typ == retType t'
then insertConstr name' t' then insertConstr (coerce name') (toNew t')
else else
throwError $ throwError $
unwords unwords
@ -96,18 +70,29 @@ checkData d = do
constrs constrs
retType :: Type -> Type retType :: Type -> Type
retType (TArr _ t2) = retType t2 retType (TFun _ t2) = retType t2
retType a = a retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
T.Program <$> checkDef bs -- Type check the program twice to produce all top-level types in the first pass through
bs' <- checkDef bs
trace "\nFIRST ITERATION" return ()
trace (printTree bs' ++ "\nSECOND ITERATION\n") return ()
bs'' <- checkDef bs
return $ T.Program bs''
where where
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs -- TODO: Check for no overlapping signature definitions
DSig (Sig n t) -> insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def]
@ -117,79 +102,75 @@ checkPrg (Program bs) = do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs) (DData d) -> fmap (T.DData d :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do checkBind (Bind name args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args) let lambda = makeLambda e (reverse $ coerce args)
s <- unify t t' e@(_, t') <- inferExp lambda
let t'' = apply s t s <- gets sigs
unless case M.lookup (coerce name) s of
(t `typeEq` t'') Just (Just t) -> do
( throwError $ sub <- unify t t'
unwords let newT = apply sub t
[ "Top level signature" insertSig (coerce name) (Just newT)
, printTree t return $ T.Bind (coerce name, newT) [] e
, "does not match body with inferred type:" _ -> do
, printTree t'' insertSig (coerce name) (Just t')
] return (T.Bind (coerce name, t') [] e) -- (apply s e)
)
return $ T.Bind (n, t) e'
where where
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs) makeLambda = foldl (flip (EAbs . coerce))
{- | Check if two types are considered equal isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
For the purpose of the algorithm two polymorphic types are always considered isMoreSpecificOrEq _ (T.TAll _ _) = True
equal isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
-}
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) = isMoreSpecificOrEq (T.TIndexed (T.Indexed n1 ts1)) (T.TIndexed (T.Indexed n2 ts2)) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2) && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
isPoly (TPol _) = True isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp) inferExp :: Exp -> Infer T.ExpT
inferExp e = do inferExp e = do
(s, t, e') <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
return (subbed, replace subbed e') return $ replace subbed (e', t)
replace :: Type -> T.Exp -> T.Exp replace :: T.Type -> T.ExpT -> T.ExpT
replace t = \case replace t = second (const t)
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
algoW :: Exp -> Infer (Subst, Type, T.Exp) class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> T.TFun (toNew t1) (toNew t2)
TAll b t -> T.TAll (toNew b) (toNew t)
TIndexed i -> T.TIndexed (toNew i)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Indexed T.Indexed where
toNew (Indexed name vars) = T.Indexed (coerce name) (map toNew vars)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this -- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do EAnn e t -> do
(s1, t', e') <- algoW e (s1, (e', t')) <- algoW e
unless unless
(t `isMoreSpecificOrEq` t') (toNew t `isMoreSpecificOrEq` t')
( throwError $ ( throwError $
unwords unwords
[ "Annotated type:" [ "Annotated type:"
@ -199,34 +180,34 @@ algoW = \case
] ]
) )
applySt s1 $ do applySt s1 $ do
s2 <- unify t t' s2 <- unify (toNew t) t'
return (s2 `compose` s1, t, e') let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit (LInt n) -> ELit lit ->
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) let lt = litType lit
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a in return (nullSubst, (T.ELit lit, lt))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
EVar i -> do
EId i -> do
var <- asks vars var <- asks vars
case M.lookup i var of case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x)) Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup i sig of case M.lookup (coerce i) sig of
Just t -> return (nullSubst, t, T.EId (i, t)) Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> do Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh
Nothing -> throwError $ "Unbound variable: " ++ printTree i
ECons i -> do
constr <- gets constructors constr <- gets constructors
case M.lookup i constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, t, T.EId (i, t)) Just t -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined"
throwError $
"Unbound variable: " ++ show i
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S -- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| --------------------------------- -- \| ---------------------------------
@ -234,11 +215,11 @@ algoW = \case
EAbs name e -> do EAbs name e -> do
fr <- fresh fr <- fresh
withBinding name (Forall [] fr) $ do withBinding (coerce name) fr $ do
(s1, t', e') <- algoW e (s1, (e', t')) <- algoW e
let varType = apply s1 fr let varType = apply s1 fr
let newArr = TArr varType t' let newArr = T.TFun varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e') return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -247,29 +228,16 @@ algoW = \case
-- This might be wrong -- This might be wrong
EAdd e0 e1 -> do EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0 (s1, (e0', t0)) <- algoW e0
applySt s1 $ do applySt s1 $ do
(s2, t1, e1') <- algoW e1 (s2, (e1', t1)) <- algoW e1
-- applySt s2 $ do -- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int") s3 <- unify (apply s2 t0) int
s4 <- unify (apply s3 t1) (TMono "Int") s4 <- unify (apply s3 t1) int
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return return
( s4 `compose` s3 `compose` s2 `compose` s1 ( comp
, TMono "Int" , apply comp (T.EAdd (e0', t0) (e1', t1), int)
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
) )
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
@ -279,13 +247,13 @@ algoW = \case
EApp e0 e1 -> do EApp e0 e1 -> do
fr <- fresh fr <- fresh
(s0, t0, e0') <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do applySt s0 $ do
(s1, t1, e1') <- algoW e1 (s1, (e1', t1)) <- algoW e1
-- applySt s1 $ do s2 <- unify (apply s1 t0) (T.TFun t1 fr)
s2 <- unify (apply s1 t0) (TArr t1 fr)
let t = apply s2 fr let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ---------------------------------------------- -- \| ----------------------------------------------
@ -294,39 +262,37 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize" -- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0 (s1, (e0', t1)) <- algoW e0
env <- asks vars env <- asks vars
let t' = generalize (apply s1 env) t1 let t' = generalize (apply s1 env) t1
withBinding name t' $ do withBinding (coerce name) t' $ do
(s2, t2, e1') <- algoW e1 (s2, (e1', t2)) <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') let comp = s2 `compose` s1
return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr (sub, (e', t)) <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs (subst, injs, ret_t) <- checkCase t injs
case ts of let comp = subst `compose` sub
[] -> throwError "Case expression missing any matches" let t' = apply comp ret_t
ts -> do return (comp, (T.ECase (e', t) injs, t'))
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst unify :: T.Type -> T.Type -> Infer Subst
unify t0 t1 = do unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of case (t0, t1) of
(TArr a b, TArr c d) -> do (T.TFun a b, T.TFun c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
(TPol a, b) -> occurs a b (T.TVar (T.MkTVar a), t) -> occurs a t
(a, TPol b) -> occurs b a (t, T.TVar (T.MkTVar b)) -> occurs b t
(TMono a, TMono b) -> (T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
if a == b then return M.empty else throwError "Types do not unify" if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing (T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) ->
(TConstr (Constr name t), TConstr (Constr name' t')) ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
@ -334,56 +300,71 @@ unify t0 t1 = do
else else
throwError $ throwError $
unwords unwords
[ "Type constructor:" [ "T.Type constructor:"
, printTree name , printTree name
, "(" ++ printTree t ++ ")" , "(" ++ printTree t ++ ")"
, "does not match with:" , "does not match with:"
, printTree name' , printTree name'
, "(" ++ printTree t' ++ ")" , "(" ++ printTree t' ++ ")"
] ]
(a, b) -> (a, b) -> do
throwError . unwords $ throwError . unwords $
[ "Type:" [ "'" ++ printTree a ++ "'"
, printTree a , "can't be unified with"
, "can't be unified with:" , "'" ++ printTree b ++ "'"
, printTree b
] ]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal where these are equal
-} -}
occurs :: Ident -> Type -> Infer Subst occurs :: Ident -> T.Type -> Infer Subst
occurs _ (TPol _) = return nullSubst occurs i t@(T.TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then then
throwError $ throwError $
unwords unwords
[ "Occurs check failed, can't unify" [ "Occurs check failed, can't unify"
, printTree (TPol i) , printTree (T.TVar $ T.MkTVar i)
, "with" , "with"
, printTree t , printTree t
] ]
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly generalize :: Map Ident T.Type -> T.Type -> T.Type
generalize env t = Forall (S.toList $ free t S.\\ free env) t generalize env t = go freeVars $ removeForalls t
where
freeVars :: [Ident]
freeVars = S.toList $ free t S.\\ free env
go :: [Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
-} -}
inst :: Poly -> Infer Type inst :: T.Type -> Infer T.Type
inst (Forall xs t) = do inst = \case
xs' <- mapM (const fresh) xs T.TAll (T.MkTVar bound) t -> do
let s = M.fromList $ zip xs xs' fr <- fresh
return $ apply s t let s = M.singleton bound fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets -- | Compose two substitution sets
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1 compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
-- | A class representing free variables functions -- | A class representing free variables functions
class FreeVars t where class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
@ -392,37 +373,59 @@ class FreeVars t where
-- | Apply a substitution to t -- | Apply a substitution to t
apply :: Subst -> t -> t apply :: Subst -> t -> t
instance FreeVars Type where instance FreeVars T.Type where
free :: Type -> Set Ident free :: T.Type -> Set Ident
free (TPol a) = S.singleton a free (T.TVar (T.MkTVar a)) = S.singleton a
free (TMono _) = mempty free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (TArr a b) = free a `S.union` free b free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) = free (T.TIndexed (T.Indexed _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type apply :: Subst -> T.Type -> T.Type
apply sub t = do apply sub t = do
case t of case t of
TMono a -> TMono a T.TLit a -> T.TLit a
TPol a -> case M.lookup a sub of T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> TPol a Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b) T.TAll (T.MkTVar i) t -> case M.lookup i sub of
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a)) Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TIndexed (T.Indexed name a) -> T.TIndexed (T.Indexed name (map (apply sub) a))
instance FreeVars Poly where instance FreeVars (Map Ident T.Type) where
free :: Poly -> Set Ident free :: Map Ident T.Type -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly apply :: Subst -> Map Ident T.Type -> Map Ident T.Type
apply s = M.map (apply s) apply s = M.map (apply s)
instance FreeVars T.ExpT where
free :: T.ExpT -> Set Ident
free = error "free not implemented for T.Exp"
apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t)
instance FreeVars T.Inj where
free :: T.Inj -> Set Ident
free = undefined
apply :: Subst -> T.Inj -> T.Inj
apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e)
instance FreeVars [T.Inj] where
free :: [T.Inj] -> Set Ident
free = foldl' (\acc x -> free x `S.union` acc) mempty
apply s = map (apply s)
-- | Apply substitutions to the environment. -- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)}) applySt s = local (\st -> st{vars = apply s (vars st)})
@ -432,86 +435,85 @@ nullSubst :: Subst
nullSubst = M.empty nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type fresh :: Infer T.Type
fresh = do fresh = do
n <- gets count n <- gets count
modify (\st -> st{count = n + 1}) modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n return . T.TVar . T.MkTVar . Ident $ show n
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer () insertSig :: Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer () insertConstr :: Ident -> T.Type -> Infer ()
insertConstr i t = insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)}) modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type)
checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkCase expT injs = do
checkInj caseType (Inj it expr) = do (injTs, injs, returns) <- unzip3 <$> mapM checkInj injs
(args, t') <- initType caseType it (sub1, _) <-
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr) foldM
return (T.Inj (it, t') e', t) ( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
return (sub2 `compose` sub1, injs, returns_type)
initType :: Type -> Init -> Infer (Map Ident Poly, Type) {- | fst = type of init
initType expected = \case | snd = type of expr
InitLit lit -> -}
let returnType = litType lit checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type)
in if expected == returnType checkInj (Inj it expr) = do
then return (mempty, expected) (initT, vars) <- inferInit it
else (e, exprT) <- withBindings vars (inferExp expr)
throwError $ return (initT, T.Inj (it, initT) (e, exprT), exprT)
unwords
[ "Inferred type" inferInit :: Init -> Infer (T.Type, [T.Id])
, printTree returnType inferInit = \case
, "does not match expected type:" InitLit lit -> return (litType lit, mempty)
, printTree expected InitConstructor fn vars -> do
] gets (M.lookup (coerce fn) . constructors) >>= \case
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
Nothing -> Nothing ->
throwError $ throwError $
unwords "Constructor: " ++ printTree fn ++ " does not exist"
[ "Constructor:" Just a -> do
, printTree c case unsnoc $ flattenType a of
, "does not exist" Nothing -> throwError "Partial pattern match not allowed"
] Just (vs, ret) ->
Just t -> do case length vars `compare` length vs of
let flat = flattenType t EQ -> do
let returnType = last flat return (ret, zip (coerce vars) vs)
case ( length (init flat) == length args _ -> throwError "Partial pattern match not allowed"
, returnType `isMoreSpecificOrEq` expected InitCatch -> (,mempty) <$> fresh
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
flattenType :: Type -> [Type] flattenType :: T.Type -> [T.Type]
flattenType (TArr a b) = flattenType a ++ flattenType b flattenType (T.TFun a b) = flattenType a ++ flattenType b
flattenType a = [a] flattenType a = [a]
litType :: Literal -> Type litType :: Lit -> T.Type
litType (LInt _) = TMono "Int" litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"

View file

@ -7,23 +7,25 @@ import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Functor.Identity (Identity) import Data.Functor.Identity (Identity)
import Data.Map (Map) import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..), import Grammar.Abs (
Literal (..), Type (..)) Data (..),
Ident (..),
Init (..),
Lit (..),
)
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show) import Prelude qualified as C (Eq, Ord, Read, Show)
-- | A data type representing type variables newtype Ctx = Ctx {vars :: Map Ident Type}
data Poly = Forall [Ident] Type
deriving (Show) deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env data Env = Env
{ count :: Int { count :: Int
, sigs :: Map Ident Type , sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type , constructors :: Map Ident Type
} }
deriving (Show)
type Error = String type Error = String
type Subst = Map Ident Type type Subst = Map Ident Type
@ -33,18 +35,33 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read)
data Type
= TLit Ident
| TVar TVar
| TFun Type Type
| TAll TVar Type
| TIndexed Indexed
deriving (Show, Eq, Ord, Read)
data Exp data Exp
= EId Id = EId Ident
| ELit Type Literal | ELit Lit
| ELet Bind Exp | ELet Bind ExpT
| EApp Type Exp Exp | EApp ExpT ExpT
| EAdd Type Exp Exp | EAdd ExpT ExpT
| ESub Type Exp Exp | EAbs Ident ExpT
| EAbs Type Id Exp | ECase ExpT [Inj]
| ECase Type Exp [Inj]
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Inj = Inj (Init, Type) Exp type ExpT = (Exp, Type)
data Indexed = Indexed Ident [Type]
deriving (Show, Read, Ord, Eq)
data Inj = Inj (Init, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show) deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data data Def = DBind Bind | DData Data
@ -52,12 +69,12 @@ data Def = DBind Bind | DData Data
type Id = (Ident, Type) type Id = (Ident, Type)
data Bind = Bind Id Exp data Bind = Bind Id [Id] ExpT
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where instance Print [Def] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Def where instance Print Def where
prt i (DBind bind) = prt i bind prt i (DBind bind) = prt i bind
@ -67,7 +84,7 @@ instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind (t, name) rhs) = prt i (Bind (name, t) _ rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prt 0 name [ prt 0 name
@ -91,9 +108,11 @@ prtId :: Int -> Id -> Doc
prtId i (name, t) = prtId i (name, t) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prt 0 name [ doc $ showString "("
, prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
, doc $ showString ")"
] ]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
@ -109,8 +128,8 @@ prtIdP i (name, t) =
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"] EId n -> prPrec i 3 $ concatD [prt 0 n]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"] ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
ELet bs e -> ELet bs e ->
prPrec i 3 $ prPrec i 3 $
concatD concatD
@ -118,46 +137,30 @@ instance Print Exp where
, prt 0 bs , prt 0 bs
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
EApp _ e1 e2 -> EApp e1 e2 ->
prPrec i 2 $ prPrec i 2 $
concatD concatD
[ prt 2 e1 [ prt 2 e1
, prt 3 e2 , prt 3 e2
] ]
EAdd t e1 e2 -> EAdd e1 e2 ->
prPrec i 1 $ prPrec i 1 $
concatD concatD
[ doc $ showString "@" [ doc $ showString "@"
, prt 0 t
, prt 1 e1 , prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
, doc $ showString "\n"
] ]
ESub t e1 e2 -> EAbs n e ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "-"
, prt 2 e2
, doc $ showString "\n"
]
EAbs t n e ->
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ doc $ showString "@" [ doc $ showString "λ"
, prt 0 t , prt 0 n
, doc $ showString "\\"
, prtId 0 n
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
ECase t exp injs -> ECase exp injs ->
prPrec prPrec
i i
0 0
@ -169,11 +172,12 @@ instance Print Exp where
, prt 0 injs , prt 0 injs
, doc (showString "}") , doc (showString "}")
, doc (showString ":") , doc (showString ":")
, prt 0 t
, doc $ showString "\n"
] ]
) )
instance Print ExpT where
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp]) Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
@ -182,3 +186,17 @@ instance Print [Inj] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print TVar where
prt i (MkTVar id) = prt i id
instance Print Type where
prt i = \case
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar -> prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TIndexed indexed -> prPrec i 1 (concatD [prt 0 indexed])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Indexed where
prt i (Indexed u ts) = concatD [prt i u, prt i ts]

4
test_program Normal file
View file

@ -0,0 +1,4 @@
data Maybe (a) where {
Nothing : Maybe (a)
Just : a -> Maybe (a)
}

110
tests/Tests.hs Normal file
View file

@ -0,0 +1,110 @@
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use <$>" #-}
{-# HLINT ignore "Use camelCase" #-}
module Main where
import Data.Either (isLeft, isRight)
import Data.Map (Map)
import Data.Map qualified as M
import Grammar.Abs
import Test.Hspec
import Test.QuickCheck
import TypeChecker.TypeChecker
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Poly (..),
)
import TypeChecker.TypeCheckerIr qualified as T
main :: IO ()
main = hspec $ do
infer_elit
infer_eann
infer_eid
infer_eabs
test_id_function
infer_elit = describe "algoW used on ELit" $ do
it "infers the type mono Int" $ do
getType (ELit (LInt 0)) `shouldBe` Right (T.TLit "Int")
it "infers the type mono Int" $ do
getType (ELit (LInt 9999)) `shouldBe` Right (T.TLit "Int")
infer_eann = describe "algoW used on EAnn" $ do
it "infers the type and checks if the annotated type matches" $ do
getType (EAnn (ELit $ LInt 0) (TLit "Int")) `shouldBe` Right (T.TLit "Int")
it "fails if the annotated type does not match with the inferred type" $ do
getType (EAnn (ELit $ LInt 0) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
it "should be possible to annotate with a more specific type" $ do
let annotated_lambda = EAnn (EAbs "x" (EVar "x")) (TFun (TLit "Int") (TLit "Int"))
in getType annotated_lambda `shouldBe` Right (T.TFun (T.TLit "Int") (T.TLit "Int"))
it "should fail if the annotated type is more general than the inferred type" $ do
getType (EAnn (ELit (LInt 0)) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
it "should fail if the annotated type is an arrow but the annotated type is not" $ do
getType (EAnn (EAbs "x" (EVar "x")) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
infer_eid = describe "algoW used on EVar" $ do
it "should fail if the variable is not added to the environment" $ do
property $ \x -> getType (EVar (LIdent (x :: String))) `shouldSatisfy` isLeft
it "should succeed if the type exist in the environment" $ do
property $ \x -> do
let env = Env 0 mempty mempty
let t = T.TVar $ T.MkTVar "a"
let ctx = Ctx (M.singleton (Ident (x :: String)) t)
getTypeC env ctx (EVar (LIdent x)) `shouldBe` Right (T.TVar $ T.MkTVar "a")
infer_eabs = describe "algoW used on EAbs" $ do
it "should infer the argument type as int if the variable is used as an int" $ do
let lambda = EAbs "x" (EAdd (EVar "x") (ELit (LInt 0)))
getType lambda `shouldBe` Right (T.TFun (T.TLit "Int") (T.TLit "Int"))
it "should infer the argument type as polymorphic if it is not used in the lambda" $ do
let lambda = EAbs "x" (ELit (LInt 0))
getType lambda `shouldSatisfy` isArrowPolyToMono
it "should infer a variable as function if used as one" $ do
let lambda = EAbs "f" (EAbs "x" (EApp (EVar "f") (EVar "x")))
let isOk (Right (T.TFun (T.TFun (T.TVar _) (T.TVar _)) (T.TFun (T.TVar _) (T.TVar _)))) = True
isOk _ = False
getType lambda `shouldSatisfy` isOk
churf_id :: Bind
churf_id = Bind "id" ["x"] (EVar "x")
churf_add :: Bind
churf_add = Bind "add" ["x", "y"] (EAdd (EVar "x") (EVar "y"))
churf_main :: Bind
churf_main = Bind "main" [] (EApp (EApp (EVar "id") (EVar "add")) (ELit (LInt 0)))
prg = Program [DBind churf_main, DBind churf_add, DBind churf_id]
test_id_function :: SpecWith ()
test_id_function =
describe "typechecking a program with id, add and main, where id is applied to add in main" $ do
it "should succeed to find the correct type" $ do
typecheck prg `shouldSatisfy` isRight
isArrowPolyToMono :: Either Error T.Type -> Bool
isArrowPolyToMono (Right (T.TFun (T.TVar _) (T.TLit _))) = True
isArrowPolyToMono _ = False
-- | Empty environment
getType :: Exp -> Either Error T.Type
getType e = pure snd <*> run (inferExp e)
-- | Custom environment
getTypeC :: Env -> Ctx -> Exp -> Either Error T.Type
getTypeC env ctx e = pure snd <*> runC env ctx (inferExp e)