Merge remote-tracking branch 'origin/typechecking-merge' into pattern-matching-with-typechecking

This commit is contained in:
Samuel Hammersberg 2023-03-23 16:33:05 +01:00
commit d3d173eb59
21 changed files with 1052 additions and 476 deletions

4
.gitignore vendored
View file

@ -4,5 +4,5 @@ dist-newstyle
*.bak
src/Grammar
language
llvm.ll
output
test_program_result
output/

View file

@ -1,51 +1,94 @@
-------------------------------------------------------------------------------
-- * PROGRAM
-------------------------------------------------------------------------------
Program. Program ::= [Def] ;
-------------------------------------------------------------------------------
-- * TOP-LEVEL
-------------------------------------------------------------------------------
DBind. Def ::= Bind ;
DSig. Def ::= Sig ;
DData. Def ::= Data ;
separator Def ";" ;
Bind. Bind ::= Ident ":" Type ";"
Ident [Ident] "=" Exp ;
Sig. Sig ::= LIdent ":" Type ;
Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ;
Bind. Bind ::= LIdent [LIdent] "=" Exp ;
Constructor. Constructor ::= Ident ":" Type ;
separator nonempty Constructor "" ;
-------------------------------------------------------------------------------
-- * TYPES
-------------------------------------------------------------------------------
TMono. Type1 ::= Ident ;
TPol. Type1 ::= "'" Ident ;
TConstr. Type1 ::= Constr ;
TArr. Type ::= Type1 "->" Type ;
TLit. Type2 ::= UIdent ;
TVar. Type2 ::= TVar ;
TAll. Type1 ::= "forall" TVar "." Type ;
TIndexed. Type1 ::= Indexed ;
internal TEVar. Type1 ::= TEVar ;
TFun. Type ::= Type1 "->" Type ;
Constr. Constr ::= Ident "(" [Type] ")" ;
MkTVar. TVar ::= LIdent ;
internal MkTEVar. TEVar ::= LIdent ;
-------------------------------------------------------------------------------
-- * DATA TYPES
-------------------------------------------------------------------------------
Constructor. Constructor ::= UIdent ":" Type ;
Indexed. Indexed ::= UIdent "(" [Type] ")" ;
Data. Data ::= "data" Indexed "where" "{" [Constructor] "}" ;
-------------------------------------------------------------------------------
-- * EXPRESSIONS
-------------------------------------------------------------------------------
-- TODO: Move literal to its own thing since it's reused in Init as well.
EAnn. Exp5 ::= "(" Exp ":" Type ")" ;
EId. Exp4 ::= Ident ;
ELit. Exp4 ::= Literal ;
EVar. Exp4 ::= LIdent ;
ECons. Exp4 ::= UIdent ;
ELit. Exp4 ::= Lit ;
EApp. Exp3 ::= Exp3 Exp4 ;
EAdd. Exp1 ::= Exp1 "+" Exp2 ;
ESub. Exp1 ::= Exp1 "-" Exp2 ;
ELet. Exp ::= "let" Ident "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" Ident "." Exp ;
ELet. Exp ::= "let" LIdent "=" Exp "in" Exp ;
EAbs. Exp ::= "\\" LIdent "." Exp ;
ECase. Exp ::= "case" Exp "of" "{" [Inj] "}";
LInt. Literal ::= Integer ;
-------------------------------------------------------------------------------
-- * LITERALS
-------------------------------------------------------------------------------
LInt. Lit ::= Integer ;
LChar. Lit ::= Char ;
-------------------------------------------------------------------------------
-- * CASE
-------------------------------------------------------------------------------
Inj. Inj ::= Init "=>" Exp ;
separator nonempty Inj ";" ;
InitLit. Init ::= Literal ;
InitConstr. Init ::= Ident [Ident] ;
InitCatch. Init ::= "_" ;
InitLit. Init ::= Lit ;
InitConstructor. Init ::= UIdent [LIdent] ;
InitCatch. Init ::= "_" ;
-------------------------------------------------------------------------------
-- * AUX
-------------------------------------------------------------------------------
separator Def ";" ;
separator nonempty Constructor "" ;
separator Type " " ;
coercions Type 2 ;
separator nonempty Inj ";" ;
separator Ident " ";
separator LIdent " ";
separator TVar " " ;
coercions Exp 5 ;
coercions Type 2 ;
token UIdent (upper (letter | digit | '_')*) ;
token LIdent (lower (letter | digit | '_')*) ;
comment "--" ;
comment "{-" "-}" ;

22
Justfile Normal file
View file

@ -0,0 +1,22 @@
build:
bnfc -o src -d Grammar.cf
# clean the generated directories
clean:
rm -r src/Grammar
rm language
# run all tests
test:
cabal test
ctest:
cabal run language sample-programs/basic-1
cabal run language sample-programs/basic-2
cabal run language sample-programs/basic-3
cabal run language sample-programs/basic-4
cabal run language sample-programs/basic-5
# compile a specific file
run FILE:
cabal run language {{FILE}}

View file

@ -28,10 +28,11 @@ test :
./language ./sample-programs/basic-3
./language ./sample-programs/basic-4
./language ./sample-programs/basic-5
./language ./sample-programs/basic-5
./language ./sample-programs/basic-6
./language ./sample-programs/basic-7
./language ./sample-programs/basic-8
./language ./sample-programs/basic-9
run :
cabal -v0 new-run language -- "test_program"
# EOF

219
Session.vim Normal file
View file

@ -0,0 +1,219 @@
let SessionLoad = 1
let s:so_save = &g:so | let s:siso_save = &g:siso | setg so=0 siso=0 | setl so=-1 siso=-1
let v:this_session=expand("<sfile>:p")
silent only
silent tabonly
cd ~/Documents/bachelor_thesis/language
if expand('%') == '' && !&modified && line('$') <= 1 && getline(1) == ''
let s:wipebuf = bufnr('%')
endif
let s:shortmess_save = &shortmess
if &shortmess =~ 'A'
set shortmess=aoOA
else
set shortmess=aoO
endif
badd +1 ~/Documents/bachelor_thesis/language
badd +298 src/TypeChecker/TypeChecker.hs
badd +7 test_program
badd +46 src/TypeChecker/TypeCheckerIr.hs
badd +6 Grammar.cf
badd +1 src/Grammar/Abs.hs
argglobal
%argdel
$argadd ~/Documents/bachelor_thesis/language
set stal=2
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabnew +setlocal\ bufhidden=wipe
tabrewind
edit src/TypeChecker/TypeChecker.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
argglobal
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 298 - ((18 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 298
normal! 029|
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/Grammar.cf", ":p")) | buffer ~/Documents/bachelor_thesis/language/Grammar.cf | else | edit ~/Documents/bachelor_thesis/language/Grammar.cf | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/Grammar.cf
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 99 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 73 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
let s:save_splitbelow = &splitbelow
let s:save_splitright = &splitright
set splitbelow splitright
wincmd _ | wincmd |
vsplit
1wincmd h
wincmd w
let &splitbelow = s:save_splitbelow
let &splitright = s:save_splitright
wincmd t
let s:save_winminheight = &winminheight
let s:save_winminwidth = &winminwidth
set winminheight=0
set winheight=1
set winminwidth=0
set winwidth=1
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
argglobal
balt ~/Documents/bachelor_thesis/language/test_program
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
argglobal
if bufexists(fnamemodify("~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs", ":p")) | buffer ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | else | edit ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs | endif
if &buftype ==# 'terminal'
silent file ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
endif
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeCheckerIr.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 1 - ((0 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 1
normal! 0
lcd ~/Documents/bachelor_thesis/language
wincmd w
exe 'vert 1resize ' . ((&columns * 86 + 86) / 173)
exe 'vert 2resize ' . ((&columns * 86 + 86) / 173)
tabnext
edit ~/Documents/bachelor_thesis/language/Grammar.cf
argglobal
balt ~/Documents/bachelor_thesis/language/src/Grammar/Abs.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 40 - ((12 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 40
normal! 0
lcd ~/Documents/bachelor_thesis/language
tabnext
edit ~/Documents/bachelor_thesis/language/test_program
argglobal
balt ~/Documents/bachelor_thesis/language/src/TypeChecker/TypeChecker.hs
setlocal fdm=manual
setlocal fde=0
setlocal fmr={{{,}}}
setlocal fdi=#
setlocal fdl=0
setlocal fml=1
setlocal fdn=20
setlocal fen
silent! normal! zE
let &fdl = &fdl
let s:l = 7 - ((6 * winheight(0) + 21) / 42)
if s:l < 1 | let s:l = 1 | endif
keepjumps exe s:l
normal! zt
keepjumps 7
normal! 010|
lcd ~/Documents/bachelor_thesis/language
tabnext 1
set stal=1
if exists('s:wipebuf') && len(win_findbuf(s:wipebuf)) == 0 && getbufvar(s:wipebuf, '&buftype') isnot# 'terminal'
silent exe 'bwipe ' . s:wipebuf
endif
unlet! s:wipebuf
set winheight=1 winwidth=20
let &shortmess = s:shortmess_save
let s:sx = expand("<sfile>:p:r")."x.vim"
if filereadable(s:sx)
exe "source " . fnameescape(s:sx)
endif
let &g:so = s:so_save | let &g:siso = s:siso_save
set hlsearch
nohlsearch
doautoall SessionLoadPost
unlet SessionLoad
" vim: set ft=vim :

2
cabal.project.local~ Normal file
View file

@ -0,0 +1,2 @@
ignore-project: False
tests: False

View file

@ -1,14 +1 @@
indentation: 4
function-arrows: trailing
comma-style: leading
import-export-style: diff-friendly
indent-wheres: false
record-brace-space: false
newlines-between-decls: 1
haddock-style: multi-line
haddock-style-module:
let-style: auto
in-style: right-align
respectful: true
fixities: []
unicode: never

View file

@ -12,11 +12,9 @@ build-type: Simple
extra-doc-files: CHANGELOG.md
extra-source-files:
Grammar.cf
common warnings
ghc-options: -W
@ -32,16 +30,14 @@ executable language
Grammar.Print
Grammar.Skel
Grammar.ErrM
LambdaLifter.LambdaLifter
Auxiliary
Renamer.Renamer
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr
-- Interpreter
Codegen.Codegen
Codegen.LlvmIr
Renamer.Renamer
-- LambdaLifter.LambdaLifter
-- Codegen.Codegen
-- Codegen.LlvmIr
hs-source-dirs: src
build-depends:
@ -49,7 +45,39 @@ executable language
, mtl
, containers
, either
, array
, extra
, directory
, array
, hspec
, QuickCheck
default-language: GHC2021
Test-suite language-testsuite
type: exitcode-stdio-1.0
main-is: Tests.hs
other-modules:
Grammar.Abs
Grammar.Lex
Grammar.Par
Grammar.Print
Grammar.Skel
Grammar.ErrM
Auxiliary
TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr
Renamer.Renamer
hs-source-dirs: src, tests
build-depends:
base >=4.16
, mtl
, containers
, either
, extra
, array
, hspec
, QuickCheck
default-language: GHC2021

View file

@ -1,29 +1,8 @@
posMul : _Int -> _Int -> _Int;
posMul a b = a + b; {-case b of {
0 => 0;
_ => a + posMul a (b - 1)
};-}
main : _Int;
main = posMul 5 10;
--
-- facc : _Int -> _Int;
-- facc a = case a of {
-- 1 => 1;
-- _ => posMul a (facc (a - 1))
-- };
--
-- minimization : (_Int -> _Int) -> _Int -> _Int;
-- minimization p x = case p x of {
-- 1 => x;
-- _ => minimization p (x + 1)
-- };
--
-- checkFac : _Int -> _Int;
-- checkFac x = case facc x of {
-- 0 => 1;
-- _ => 0
-- };
--
-- main : _Int;
-- main = minimization checkFac 1
posMul: _Int - > _Int - > _Int;
posMul a b = a + b; {
-
case b of {
0 => 0;
_ => a + posMul a(b - 1)
}; -
}

5
sample-programs/basic-2 Normal file
View file

@ -0,0 +1,5 @@
add : Int -> Int -> Int ;
add x = \y. x+y;
main : Int ;
main = (\z. z+z) ((add 4) 6) ;

2
sample-programs/basic-3 Normal file
View file

@ -0,0 +1,2 @@
main : Int ;
main = (\x. x+x+3) ((\x. x) 2) ;

2
sample-programs/basic-4 Normal file
View file

@ -0,0 +1,2 @@
f : Int -> Int ;
f x = let g = (\y. y+1) in g (g x)

8
sample-programs/basic-5 Normal file
View file

@ -0,0 +1,8 @@
double : Int -> Int ;
double n = n + n;
id : forall a. a -> a ;
id x = x ;
main : Int ;
main = id double 5;

10
sample-programs/basic-6 Normal file
View file

@ -0,0 +1,10 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
main : Bool () -> Int ;
main b = case b of {
False => 0;
True => 0
}

10
sample-programs/basic-7 Normal file
View file

@ -0,0 +1,10 @@
data Bool () where {
True : Bool ()
False : Bool ()
};
ifThenElse : forall a. Bool () -> a -> a -> a;
ifThenElse b if else = case b of {
True => if;
False => else
}

24
sample-programs/basic-8 Normal file
View file

@ -0,0 +1,24 @@
data Maybe (a) where {
Nothing : Maybe (a)
Just : a -> Maybe (a)
};
fromJust : Maybe (a) -> a ;
fromJust a =
case a of {
Just a => a
};
fromMaybe : a -> Maybe (a) -> a ;
fromMaybe a b =
case b of {
Just a => a;
Nothing => a
};
maybe : b -> (a -> b) -> Maybe (a) -> b;
maybe b f ma =
case ma of {
Just a => f a;
Nothing => b
}

View file

@ -1,50 +1,105 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module Renamer.Renamer where
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (
MonadState,
StateT,
evalStateT,
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Program
rename (Program bs) = Program $ evalState (runRn $ mapM (renameSc initNames) bs) 0
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
-- initNames = Map.fromList $ map (\(Bind name _ _ _ _) -> dupe name) bs
initNames = Map.fromList $ foldl' saveIfBind [] bs
saveIfBind acc (DBind (Bind name _ _ _ _)) = dupe name : acc
saveIfBind acc _ = acc
renameSc :: Names -> Def -> Rn Def
renameSc old_names (DBind (Bind name t _ parms rhs)) = do
(new_names, parms') <- newNames old_names parms
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name t name parms' rhs'
renameSc _ def = pure def
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNames initNames (coerce vars)
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name (coerce vars') rhs'
DData (Data (Indexed cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (Indexed cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TIndexed _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Constructor -> Constructor
renameConstr new_types (Constructor name typ) =
Constructor name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: State Int a}
deriving (Functor, Applicative, Monad, MonadState Int)
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map Ident Ident
renameLocalBind :: Names -> Bind -> Rn (Names, Bind)
renameLocalBind old_names (Bind name t _ parms rhs) = do
(new_names, name') <- newName old_names name
(new_names', parms') <- newNames new_names parms
(new_names'', rhs') <- renameExp new_names' rhs
pure (new_names'', Bind name' t name' parms' rhs')
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EId n -> pure (old_names, EId . fromMaybe n $ Map.lookup n old_names)
ELit (LInt i1) -> pure (old_names, ELit (LInt i1))
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
ECons n -> pure (old_names, ECons n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
@ -53,25 +108,25 @@ renameExp old_names = \case
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
ESub e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, ESub e1' e2')
ELet i e1 e2 -> do
(new_names, e1') <- renameExp old_names e1
(new_names', e2') <- renameExp new_names e2
pure (new_names', ELet i e1' e2')
-- TODO fix shadowing
ELet name rhs e -> do
(new_names, name') <- newName old_names (coerce name)
(new_names', rhs') <- renameExp new_names rhs
(new_names'', e') <- renameExp new_names' e
pure (new_names'', ELet (coerce name') rhs' e')
EAbs par e -> do
(new_names, par') <- newName old_names par
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
pure (new_names, EAnn e' t)
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(_, e') <- renameExp old_names e
(new_names, injs') <- renameInjs old_names injs
pure (new_names, ECase e' injs')
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameInjs new_names injs
pure (new_names', ECase e' injs')
renameInjs :: Names -> [Inj] -> Rn (Names, [Inj])
renameInjs ns xs = do
@ -80,19 +135,64 @@ renameInjs ns xs = do
renameInj :: Names -> Inj -> Rn (Names, Inj)
renameInj ns (Inj init e) = do
(new_names, e') <- renameExp ns e
return (new_names, Inj init e')
(new_names, init') <- renameInit ns init
(new_names', e') <- renameExp new_names e
return (new_names', Inj init' e')
renameInit :: Names -> Init -> Rn (Names, Init)
renameInit ns i = case i of
InitConstructor cs vars -> do
(ns_new, vars') <- newNames ns (coerce vars)
return (ns_new, InitConstructor cs (coerce vars'))
rest -> return (ns, rest)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TIndexed (Indexed name typs) -> TIndexed . Indexed name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> Ident -> Rn (Names, Ident)
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [Ident] -> Rn (Names, [Ident])
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: Ident -> Rn Ident
makeName (Ident prefix) = gets (\i -> Ident $ prefix ++ "_" ++ show i) <* modify succ
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar

View file

@ -1,25 +1,33 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeChecker where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Poly (..), Subst)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Coerce (coerce)
import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl')
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Subst,
)
import TypeChecker.TypeCheckerIr qualified as T
initCtx = Ctx mempty
@ -37,51 +45,17 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program
typecheck = run . checkPrg
{- | Start by freshening the type variable of data types to avoid clash with
other user defined polymorphic types
This might be wrong for type constructors that work over several variables
-}
freshenData :: Data -> Infer Data
freshenData (Data (Constr name ts) constrs) = do
fr <- fresh
let fr' = case fr of
TPol a -> a
-- Meh, this part assumes fresh generates a polymorphic type
_ ->
error
"Bug: implementation of \
\ fresh and freshenData are not compatible"
let new_ts = map (freshenType fr') ts
let new_constrs = map (freshenConstr fr') constrs
return $ Data (Constr name new_ts) new_constrs
{- | Freshen all polymorphic variables, regardless of name
| freshenType "d" (a -> b -> c) becomes (d -> d -> d)
-}
freshenType :: Ident -> Type -> Type
freshenType iden = \case
(TPol _) -> TPol iden
(TArr a b) -> TArr (freshenType iden a) (freshenType iden b)
(TConstr (Constr a ts)) ->
TConstr (Constr a (map (freshenType iden) ts))
rest -> rest
freshenConstr :: Ident -> Constructor -> Constructor
freshenConstr iden (Constructor name t) =
Constructor name (freshenType iden t)
checkData :: Data -> Infer ()
checkData d = do
d' <- freshenData d
case d' of
(Data typ@(Constr name ts) constrs) -> do
case d of
(Data typ@(Indexed name ts) constrs) -> do
unless
(all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"])
traverse_
( \(Constructor name' t') ->
if TConstr typ == retType t'
then insertConstr name' t'
if TIndexed typ == retType t'
then insertConstr (coerce name') (toNew t')
else
throwError $
unwords
@ -96,19 +70,30 @@ checkData d = do
constrs
retType :: Type -> Type
retType (TArr _ t2) = retType t2
retType a = a
retType (TFun _ t2) = retType t2
retType a = a
checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do
preRun bs
T.Program <$> checkDef bs
-- Type check the program twice to produce all top-level types in the first pass through
bs' <- checkDef bs
trace "\nFIRST ITERATION" return ()
trace (printTree bs' ++ "\nSECOND ITERATION\n") return ()
bs'' <- checkDef bs
return $ T.Program bs''
where
preRun :: [Def] -> Infer ()
preRun [] = return ()
preRun (x : xs) = case x of
DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
-- TODO: Check for no overlapping signature definitions
DSig (Sig n t) -> insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def]
checkDef [] = return []
@ -117,79 +102,75 @@ checkPrg (Program bs) = do
b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData d :) (checkDef xs)
(DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind
checkBind (Bind n t _ args e) = do
(t', e') <- inferExp $ makeLambda e (reverse args)
s <- unify t t'
let t'' = apply s t
unless
(t `typeEq` t'')
( throwError $
unwords
[ "Top level signature"
, printTree t
, "does not match body with inferred type:"
, printTree t''
]
)
return $ T.Bind (n, t) e'
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse $ coerce args)
e@(_, t') <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t) -> do
sub <- unify t t'
let newT = apply sub t
insertSig (coerce name) (Just newT)
return $ T.Bind (coerce name, newT) [] e
_ -> do
insertSig (coerce name) (Just t')
return (T.Bind (coerce name, t') [] e) -- (apply s e)
where
makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip EAbs)
makeLambda = foldl (flip (EAbs . coerce))
{- | Check if two types are considered equal
For the purpose of the algorithm two polymorphic types are always considered
equal
-}
typeEq :: Type -> Type -> Bool
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
typeEq (TMono a) (TMono b) = a == b
typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) =
length a == length b
&& name == name'
&& and (zipWith typeEq a b)
typeEq (TPol _) (TPol _) = True
typeEq _ _ = False
isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq _ (TPol _) = True
isMoreSpecificOrEq (TArr a b) (TArr c d) =
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq _ (T.TAll _ _) = True
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) =
isMoreSpecificOrEq (T.TIndexed (T.Indexed n1 ts1)) (T.TIndexed (T.Indexed n2 ts2)) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool
isPoly (TPol _) = True
isPoly _ = False
isPoly (TAll _ _) = True
isPoly (TVar _) = True
isPoly _ = False
inferExp :: Exp -> Infer (Type, T.Exp)
inferExp :: Exp -> Infer T.ExpT
inferExp e = do
(s, t, e') <- algoW e
(s, (e', t)) <- algoW e
let subbed = apply s t
return (subbed, replace subbed e')
return $ replace subbed (e', t)
replace :: Type -> T.Exp -> T.Exp
replace t = \case
T.ELit _ e -> T.ELit t e
T.EId (n, _) -> T.EId (n, t)
T.EAbs _ name e -> T.EAbs t name e
T.EApp _ e1 e2 -> T.EApp t e1 e2
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
T.ESub _ e1 e2 -> T.ESub t e1 e2
T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2
T.ECase _ expr injs -> T.ECase t expr injs
replace :: T.Type -> T.ExpT -> T.ExpT
replace t = second (const t)
algoW :: Exp -> Infer (Subst, Type, T.Exp)
class NewType a b where
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> T.TFun (toNew t1) (toNew t2)
TAll b t -> T.TAll (toNew b) (toNew t)
TIndexed i -> T.TIndexed (toNew i)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Indexed T.Indexed where
toNew (Indexed name vars) = T.Indexed (coerce name) (map toNew vars)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
EAnn e t -> do
(s1, t', e') <- algoW e
(s1, (e', t')) <- algoW e
unless
(t `isMoreSpecificOrEq` t')
(toNew t `isMoreSpecificOrEq` t')
( throwError $
unwords
[ "Annotated type:"
@ -199,34 +180,34 @@ algoW = \case
]
)
applySt s1 $ do
s2 <- unify t t'
return (s2 `compose` s1, t, e')
s2 <- unify (toNew t) t'
let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
ELit (LInt n) ->
return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
ELit lit ->
let lt = litType lit
in return (nullSubst, (T.ELit lit, lt))
-- \| x : σ ∈ Γ τ = inst(σ)
-- \| ----------------------
-- \| Γ ⊢ x : τ, ∅
EId i -> do
EVar i -> do
var <- asks vars
case M.lookup i var of
Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x))
case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Nothing -> do
sig <- gets sigs
case M.lookup i sig of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing -> do
constr <- gets constructors
case M.lookup i constr of
Just t -> return (nullSubst, t, T.EId (i, t))
Nothing ->
throwError $
"Unbound variable: " ++ show i
case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh
Nothing -> throwError $ "Unbound variable: " ++ printTree i
ECons i -> do
constr <- gets constructors
case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S
-- \| ---------------------------------
@ -234,11 +215,11 @@ algoW = \case
EAbs name e -> do
fr <- fresh
withBinding name (Forall [] fr) $ do
(s1, t', e') <- algoW e
withBinding (coerce name) fr $ do
(s1, (e', t')) <- algoW e
let varType = apply s1 fr
let newArr = TArr varType t'
return (s1, newArr, T.EAbs newArr (name, varType) e')
let newArr = T.TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -247,29 +228,16 @@ algoW = \case
-- This might be wrong
EAdd e0 e1 -> do
(s1, t0, e0') <- algoW e0
(s1, (e0', t0)) <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
(s2, (e1', t1)) <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
s3 <- unify (apply s2 t0) int
s4 <- unify (apply s3 t1) int
let comp = s4 `compose` s3 `compose` s2 `compose` s1
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.EAdd (TMono "Int") e0' e1'
)
ESub e0 e1 -> do
(s1, t0, e0') <- algoW e0
applySt s1 $ do
(s2, t1, e1') <- algoW e1
-- applySt s2 $ do
s3 <- unify (apply s2 t0) (TMono "Int")
s4 <- unify (apply s3 t1) (TMono "Int")
return
( s4 `compose` s3 `compose` s2 `compose` s1
, TMono "Int"
, T.ESub (TMono "Int") e0' e1'
( comp
, apply comp (T.EAdd (e0', t0) (e1', t1), int)
)
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
@ -279,13 +247,13 @@ algoW = \case
EApp e0 e1 -> do
fr <- fresh
(s0, t0, e0') <- algoW e0
(s0, (e0', t0)) <- algoW e0
applySt s0 $ do
(s1, t1, e1') <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (TArr t1 fr)
(s1, (e1', t1)) <- algoW e1
s2 <- unify (apply s1 t0) (T.TFun t1 fr)
let t = apply s2 fr
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
-- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
-- \| ----------------------------------------------
@ -294,39 +262,37 @@ algoW = \case
-- The bar over S₀ and Γ means "generalize"
ELet name e0 e1 -> do
(s1, t1, e0') <- algoW e0
(s1, (e0', t1)) <- algoW e0
env <- asks vars
let t' = generalize (apply s1 env) t1
withBinding name t' $ do
(s2, t2, e1') <- algoW e1
return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1')
withBinding (coerce name) t' $ do
(s2, (e1', t2)) <- algoW e1
let comp = s2 `compose` s1
return (comp, apply comp (T.ELet (T.Bind (coerce name, t2) [] (e0', t1)) (e1', t2), t2))
-- \| TODO: Add judgement
ECase caseExpr injs -> do
(_, t0, e0') <- algoW caseExpr
(injs', ts) <- mapAndUnzipM (checkInj t0) injs
case ts of
[] -> throwError "Case expression missing any matches"
ts -> do
unified <- zipWithM unify ts (tail ts)
let unified' = foldl' compose mempty unified
let typ = apply unified' (head ts)
return (unified', typ, T.ECase typ e0' injs')
(sub, (e', t)) <- algoW caseExpr
(subst, injs, ret_t) <- checkCase t injs
let comp = subst `compose` sub
let t' = apply comp ret_t
return (comp, (T.ECase (e', t) injs, t'))
-- | Unify two types producing a new substitution
unify :: Type -> Type -> Infer Subst
unify :: T.Type -> T.Type -> Infer Subst
unify t0 t1 = do
trace ("t0: " ++ show t0) return ()
trace ("t1: " ++ show t1) return ()
case (t0, t1) of
(TArr a b, TArr c d) -> do
(T.TFun a b, T.TFun c d) -> do
s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2
(TPol a, b) -> occurs a b
(a, TPol b) -> occurs b a
(TMono a, TMono b) ->
(T.TVar (T.MkTVar a), t) -> occurs a t
(t, T.TVar (T.MkTVar b)) -> occurs b t
(T.TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t
(T.TLit a, T.TLit b) ->
if a == b then return M.empty else throwError "Types do not unify"
-- \| TODO: Figure out a cleaner way to express the same thing
(TConstr (Constr name t), TConstr (Constr name' t')) ->
(T.TIndexed (T.Indexed name t), T.TIndexed (T.Indexed name' t')) ->
if name == name' && length t == length t'
then do
xs <- zipWithM unify t t'
@ -334,56 +300,71 @@ unify t0 t1 = do
else
throwError $
unwords
[ "Type constructor:"
[ "T.Type constructor:"
, printTree name
, "(" ++ printTree t ++ ")"
, "does not match with:"
, printTree name'
, "(" ++ printTree t' ++ ")"
]
(a, b) ->
(a, b) -> do
throwError . unwords $
[ "Type:"
, printTree a
, "can't be unified with:"
, printTree b
[ "'" ++ printTree a ++ "'"
, "can't be unified with"
, "'" ++ printTree b ++ "'"
]
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
such that these are equal
where these are equal
-}
occurs :: Ident -> Type -> Infer Subst
occurs _ (TPol _) = return nullSubst
occurs :: Ident -> T.Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t)
occurs i t =
if S.member i (free t)
then
throwError $
unwords
[ "Occurs check failed, can't unify"
, printTree (TPol i)
, printTree (T.TVar $ T.MkTVar i)
, "with"
, printTree t
]
else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set
generalize :: Map Ident Poly -> Type -> Poly
generalize env t = Forall (S.toList $ free t S.\\ free env) t
generalize :: Map Ident T.Type -> T.Type -> T.Type
generalize env t = go freeVars $ removeForalls t
where
freeVars :: [Ident]
freeVars = S.toList $ free t S.\\ free env
go :: [Ident] -> T.Type -> T.Type
go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)
removeForalls :: T.Type -> T.Type
removeForalls (T.TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
-}
inst :: Poly -> Infer Type
inst (Forall xs t) = do
xs' <- mapM (const fresh) xs
let s = M.fromList $ zip xs xs'
return $ apply s t
inst :: T.Type -> Infer T.Type
inst = \case
T.TAll (T.MkTVar bound) t -> do
fr <- fresh
let s = M.singleton bound fr
apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2
rest -> return rest
-- | Compose two substitution sets
compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1
-- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions
-- | A class representing free variables functions
class FreeVars t where
-- | Get all free variables from t
@ -392,37 +373,59 @@ class FreeVars t where
-- | Apply a substitution to t
apply :: Subst -> t -> t
instance FreeVars Type where
free :: Type -> Set Ident
free (TPol a) = S.singleton a
free (TMono _) = mempty
free (TArr a b) = free a `S.union` free b
instance FreeVars T.Type where
free :: T.Type -> Set Ident
free (T.TVar (T.MkTVar a)) = S.singleton a
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t
free (T.TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct
free (TConstr (Constr _ a)) =
free (T.TIndexed (T.Indexed _ a)) =
foldl' (\acc x -> free x `S.union` acc) S.empty a
apply :: Subst -> Type -> Type
apply :: Subst -> T.Type -> T.Type
apply sub t = do
case t of
TMono a -> TMono a
TPol a -> case M.lookup a sub of
Nothing -> TPol a
Just t -> t
TArr a b -> TArr (apply sub a) (apply sub b)
TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a))
T.TLit a -> T.TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of
Nothing -> T.TVar (T.MkTVar $ coerce a)
Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b)
T.TIndexed (T.Indexed name a) -> T.TIndexed (T.Indexed name (map (apply sub) a))
instance FreeVars Poly where
free :: Poly -> Set Ident
free (Forall xs t) = free t S.\\ S.fromList xs
apply :: Subst -> Poly -> Poly
apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t)
instance FreeVars (Map Ident Poly) where
free :: Map Ident Poly -> Set Ident
instance FreeVars (Map Ident T.Type) where
free :: Map Ident T.Type -> Set Ident
free m = foldl' S.union S.empty (map free $ M.elems m)
apply :: Subst -> Map Ident Poly -> Map Ident Poly
apply :: Subst -> Map Ident T.Type -> Map Ident T.Type
apply s = M.map (apply s)
instance FreeVars T.ExpT where
free :: T.ExpT -> Set Ident
free = error "free not implemented for T.Exp"
apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e injs, t) -> (T.ECase (apply s e) (apply s injs), apply s t)
instance FreeVars T.Inj where
free :: T.Inj -> Set Ident
free = undefined
apply :: Subst -> T.Inj -> T.Inj
apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e)
instance FreeVars [T.Inj] where
free :: [T.Inj] -> Set Ident
free = foldl' (\acc x -> free x `S.union` acc) mempty
apply s = map (apply s)
-- | Apply substitutions to the environment.
applySt :: Subst -> Infer a -> Infer a
applySt s = local (\st -> st{vars = apply s (vars st)})
@ -432,86 +435,85 @@ nullSubst :: Subst
nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter
fresh :: Infer Type
fresh :: Infer T.Type
fresh = do
n <- gets count
modify (\st -> st{count = n + 1})
return . TPol . Ident $ show n
return . T.TVar . T.MkTVar . Ident $ show n
-- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a
withBinding :: (Monad m, MonadReader Ctx m) => Ident -> T.Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(Ident, T.Type)] -> m a -> m a
withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment
insertSig :: Ident -> Type -> Infer ()
insertSig :: Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type
insertConstr :: Ident -> Type -> Infer ()
insertConstr :: Ident -> T.Type -> Infer ()
insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING ---------
-- "case expr of", the type of 'expr' is caseType
checkInj :: Type -> Inj -> Infer (T.Inj, Type)
checkInj caseType (Inj it expr) = do
(args, t') <- initType caseType it
(_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr)
return (T.Inj (it, t') e', t)
checkCase :: T.Type -> [Inj] -> Infer (Subst, [T.Inj], T.Type)
checkCase expT injs = do
(injTs, injs, returns) <- unzip3 <$> mapM checkInj injs
(sub1, _) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, expT)
injTs
(sub2, returns_type) <-
foldM
( \(sub, acc) x ->
(\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc
)
(nullSubst, head returns)
(tail returns)
return (sub2 `compose` sub1, injs, returns_type)
initType :: Type -> Init -> Infer (Map Ident Poly, Type)
initType expected = \case
InitLit lit ->
let returnType = litType lit
in if expected == returnType
then return (mempty, expected)
else
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitConstr c args -> do
st <- gets constructors
case M.lookup c st of
{- | fst = type of init
| snd = type of expr
-}
checkInj :: Inj -> Infer (T.Type, T.Inj, T.Type)
checkInj (Inj it expr) = do
(initT, vars) <- inferInit it
(e, exprT) <- withBindings vars (inferExp expr)
return (initT, T.Inj (it, initT) (e, exprT), exprT)
inferInit :: Init -> Infer (T.Type, [T.Id])
inferInit = \case
InitLit lit -> return (litType lit, mempty)
InitConstructor fn vars -> do
gets (M.lookup (coerce fn) . constructors) >>= \case
Nothing ->
throwError $
unwords
[ "Constructor:"
, printTree c
, "does not exist"
]
Just t -> do
let flat = flattenType t
let returnType = last flat
case ( length (init flat) == length args
, returnType `isMoreSpecificOrEq` expected
) of
(True, True) ->
return
( M.fromList $ zip args (map (Forall []) flat)
, expected
)
(False, _) ->
throwError $
"Can't partially match on the constructor: "
++ printTree c
(_, False) ->
throwError $
unwords
[ "Inferred type"
, printTree returnType
, "does not match expected type:"
, printTree expected
]
InitCatch -> return (mempty, expected)
"Constructor: " ++ printTree fn ++ " does not exist"
Just a -> do
case unsnoc $ flattenType a of
Nothing -> throwError "Partial pattern match not allowed"
Just (vs, ret) ->
case length vars `compare` length vs of
EQ -> do
return (ret, zip (coerce vars) vs)
_ -> throwError "Partial pattern match not allowed"
InitCatch -> (,mempty) <$> fresh
flattenType :: Type -> [Type]
flattenType (TArr a b) = flattenType a ++ flattenType b
flattenType a = [a]
flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a ++ flattenType b
flattenType a = [a]
litType :: Literal -> Type
litType (LInt _) = TMono "Int"
litType :: Lit -> T.Type
litType (LInt _) = int
litType (LChar _) = char
int = T.TLit "Int"
char = T.TLit "Char"

View file

@ -2,28 +2,30 @@
module TypeChecker.TypeCheckerIr where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (Data (..), Ident (..), Init (..),
Literal (..), Type (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor.Identity (Identity)
import Data.Map (Map)
import Grammar.Abs (
Data (..),
Ident (..),
Init (..),
Lit (..),
)
import Grammar.Print
import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
newtype Ctx = Ctx {vars :: Map Ident Type}
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Poly}
data Env = Env
{ count :: Int
, sigs :: Map Ident Type
{ count :: Int
, sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type
}
deriving (Show)
type Error = String
type Subst = Map Ident Type
@ -33,18 +35,33 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read)
data Type
= TLit Ident
| TVar TVar
| TFun Type Type
| TAll TVar Type
| TIndexed Indexed
deriving (Show, Eq, Ord, Read)
data Exp
= EId Id
| ELit Type Literal
| ELet Bind Exp
| EApp Type Exp Exp
| EAdd Type Exp Exp
| ESub Type Exp Exp
| EAbs Type Id Exp
| ECase Type Exp [Inj]
= EId Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| EAbs Ident ExpT
| ECase ExpT [Inj]
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Inj = Inj (Init, Type) Exp
type ExpT = (Exp, Type)
data Indexed = Indexed Ident [Type]
deriving (Show, Read, Ord, Eq)
data Inj = Inj (Init, Type) ExpT
deriving (C.Eq, C.Ord, C.Read, C.Show)
data Def = DBind Bind | DData Data
@ -52,22 +69,22 @@ data Def = DBind Bind | DData Data
type Id = (Ident, Type)
data Bind = Bind Id Exp
data Bind = Bind Id [Id] ExpT
deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print [Def] where
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Def where
prt i (DBind bind) = prt i bind
prt i (DData d) = prt i d
prt i (DData d) = prt i d
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where
prt i (Bind (t, name) rhs) =
prt i (Bind (name, t) _ rhs) =
prPrec i 0 $
concatD
[ prt 0 name
@ -91,9 +108,11 @@ prtId :: Int -> Id -> Doc
prtId i (name, t) =
prPrec i 0 $
concatD
[ prt 0 name
[ doc $ showString "("
, prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
]
prtIdP :: Int -> Id -> Doc
@ -109,8 +128,8 @@ prtIdP i (name, t) =
instance Print Exp where
prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
ELit _ (LInt i1) -> prPrec i 3 $ concatD [prt 0 i1, doc $ showString "\n"]
EId n -> prPrec i 3 $ concatD [prt 0 n]
ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
ELet bs e ->
prPrec i 3 $
concatD
@ -118,46 +137,30 @@ instance Print Exp where
, prt 0 bs
, doc $ showString "in"
, prt 0 e
, doc $ showString "\n"
]
EApp _ e1 e2 ->
EApp e1 e2 ->
prPrec i 2 $
concatD
[ prt 2 e1
, prt 3 e2
]
EAdd t e1 e2 ->
EAdd e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "+"
, prt 2 e2
, doc $ showString "\n"
]
ESub t e1 e2 ->
prPrec i 1 $
concatD
[ doc $ showString "@"
, prt 0 t
, prt 1 e1
, doc $ showString "-"
, prt 2 e2
, doc $ showString "\n"
]
EAbs t n e ->
EAbs n e ->
prPrec i 0 $
concatD
[ doc $ showString "@"
, prt 0 t
, doc $ showString "\\"
, prtId 0 n
[ doc $ showString "λ"
, prt 0 n
, doc $ showString "."
, prt 0 e
, doc $ showString "\n"
]
ECase t exp injs ->
ECase exp injs ->
prPrec
i
0
@ -169,16 +172,31 @@ instance Print Exp where
, prt 0 injs
, doc (showString "}")
, doc (showString ":")
, prt 0 t
, doc $ showString "\n"
]
)
instance Print ExpT where
prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Inj where
prt i = \case
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
instance Print [Inj] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print TVar where
prt i (MkTVar id) = prt i id
instance Print Type where
prt i = \case
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
TVar tvar -> prPrec i 2 (concatD [prt 0 tvar])
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
TIndexed indexed -> prPrec i 1 (concatD [prt 0 indexed])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Indexed where
prt i (Indexed u ts) = concatD [prt i u, prt i ts]

4
test_program Normal file
View file

@ -0,0 +1,4 @@
data Maybe (a) where {
Nothing : Maybe (a)
Just : a -> Maybe (a)
}

110
tests/Tests.hs Normal file
View file

@ -0,0 +1,110 @@
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use <$>" #-}
{-# HLINT ignore "Use camelCase" #-}
module Main where
import Data.Either (isLeft, isRight)
import Data.Map (Map)
import Data.Map qualified as M
import Grammar.Abs
import Test.Hspec
import Test.QuickCheck
import TypeChecker.TypeChecker
import TypeChecker.TypeCheckerIr (
Ctx (..),
Env (..),
Error,
Infer,
Poly (..),
)
import TypeChecker.TypeCheckerIr qualified as T
main :: IO ()
main = hspec $ do
infer_elit
infer_eann
infer_eid
infer_eabs
test_id_function
infer_elit = describe "algoW used on ELit" $ do
it "infers the type mono Int" $ do
getType (ELit (LInt 0)) `shouldBe` Right (T.TLit "Int")
it "infers the type mono Int" $ do
getType (ELit (LInt 9999)) `shouldBe` Right (T.TLit "Int")
infer_eann = describe "algoW used on EAnn" $ do
it "infers the type and checks if the annotated type matches" $ do
getType (EAnn (ELit $ LInt 0) (TLit "Int")) `shouldBe` Right (T.TLit "Int")
it "fails if the annotated type does not match with the inferred type" $ do
getType (EAnn (ELit $ LInt 0) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
it "should be possible to annotate with a more specific type" $ do
let annotated_lambda = EAnn (EAbs "x" (EVar "x")) (TFun (TLit "Int") (TLit "Int"))
in getType annotated_lambda `shouldBe` Right (T.TFun (T.TLit "Int") (T.TLit "Int"))
it "should fail if the annotated type is more general than the inferred type" $ do
getType (EAnn (ELit (LInt 0)) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
it "should fail if the annotated type is an arrow but the annotated type is not" $ do
getType (EAnn (EAbs "x" (EVar "x")) (TVar $ MkTVar "a")) `shouldSatisfy` isLeft
infer_eid = describe "algoW used on EVar" $ do
it "should fail if the variable is not added to the environment" $ do
property $ \x -> getType (EVar (LIdent (x :: String))) `shouldSatisfy` isLeft
it "should succeed if the type exist in the environment" $ do
property $ \x -> do
let env = Env 0 mempty mempty
let t = T.TVar $ T.MkTVar "a"
let ctx = Ctx (M.singleton (Ident (x :: String)) t)
getTypeC env ctx (EVar (LIdent x)) `shouldBe` Right (T.TVar $ T.MkTVar "a")
infer_eabs = describe "algoW used on EAbs" $ do
it "should infer the argument type as int if the variable is used as an int" $ do
let lambda = EAbs "x" (EAdd (EVar "x") (ELit (LInt 0)))
getType lambda `shouldBe` Right (T.TFun (T.TLit "Int") (T.TLit "Int"))
it "should infer the argument type as polymorphic if it is not used in the lambda" $ do
let lambda = EAbs "x" (ELit (LInt 0))
getType lambda `shouldSatisfy` isArrowPolyToMono
it "should infer a variable as function if used as one" $ do
let lambda = EAbs "f" (EAbs "x" (EApp (EVar "f") (EVar "x")))
let isOk (Right (T.TFun (T.TFun (T.TVar _) (T.TVar _)) (T.TFun (T.TVar _) (T.TVar _)))) = True
isOk _ = False
getType lambda `shouldSatisfy` isOk
churf_id :: Bind
churf_id = Bind "id" ["x"] (EVar "x")
churf_add :: Bind
churf_add = Bind "add" ["x", "y"] (EAdd (EVar "x") (EVar "y"))
churf_main :: Bind
churf_main = Bind "main" [] (EApp (EApp (EVar "id") (EVar "add")) (ELit (LInt 0)))
prg = Program [DBind churf_main, DBind churf_add, DBind churf_id]
test_id_function :: SpecWith ()
test_id_function =
describe "typechecking a program with id, add and main, where id is applied to add in main" $ do
it "should succeed to find the correct type" $ do
typecheck prg `shouldSatisfy` isRight
isArrowPolyToMono :: Either Error T.Type -> Bool
isArrowPolyToMono (Right (T.TFun (T.TVar _) (T.TLit _))) = True
isArrowPolyToMono _ = False
-- | Empty environment
getType :: Exp -> Either Error T.Type
getType e = pure snd <*> run (inferExp e)
-- | Custom environment
getTypeC :: Env -> Ctx -> Exp -> Either Error T.Type
getTypeC env ctx e = pure snd <*> runC env ctx (inferExp e)