Theory JVMDefensive

(*  Title:      HOL/MicroJava/JVM/JVMDefensive.thy
    Author:     Gerwin Klein
*)

section ‹A Defensive JVM›

theory JVMDefensive
imports JVMExec
begin

text ‹
  Extend the state space by one element indicating a type error (or
  other abnormal termination)›
datatype 'a type_error = TypeError | Normal 'a


abbreviation
  fifth :: "'a × 'b × 'c × 'd × 'e × 'f  'e"
  where "fifth x == fst(snd(snd(snd(snd x))))"

fun isAddr :: "val  bool" where
  "isAddr (Addr loc) = True"
| "isAddr v          = False"

fun isIntg :: "val  bool" where
  "isIntg (Intg i) = True"
| "isIntg v        = False"

definition isRef :: "val  bool" where
  "isRef v  v = Null  isAddr v"

primrec check_instr :: "[instr, jvm_prog, aheap, opstack, locvars, 
                  cname, sig, p_count, nat, frame list]  bool" where
  "check_instr (Load idx) G hp stk vars C sig pc mxs frs = 
  (idx < length vars  size stk < mxs)"

| "check_instr (Store idx) G hp stk vars Cl sig pc mxs frs = 
  (0 < length stk  idx < length vars)"

| "check_instr (LitPush v) G hp stk vars Cl sig pc mxs frs = 
  (¬isAddr v  size stk < mxs)"

| "check_instr (New C) G hp stk vars Cl sig pc mxs frs = 
  (is_class G C  size stk < mxs)"

| "check_instr (Getfield F C) G hp stk vars Cl sig pc mxs frs = 
  (0 < length stk  is_class G C  field (G,C) F  None  
  (let (C', T) = the (field (G,C) F); ref = hd stk in 
    C' = C  isRef ref  (ref  Null  
      hp (the_Addr ref)  None  
      (let (D,vs) = the (hp (the_Addr ref)) in 
        G  D ≼C C  vs (F,C)  None  G,hp  the (vs (F,C)) ::≼ T))))" 

| "check_instr (Putfield F C) G hp stk vars Cl sig pc mxs frs = 
  (1 < length stk  is_class G C  field (G,C) F  None  
  (let (C', T) = the (field (G,C) F); v = hd stk; ref = hd (tl stk) in 
    C' = C  isRef ref  (ref  Null  
      hp (the_Addr ref)  None  
      (let (D,vs) = the (hp (the_Addr ref)) in 
        G  D ≼C C  G,hp  v ::≼ T))))" 

| "check_instr (Checkcast C) G hp stk vars Cl sig pc mxs frs =
  (0 < length stk  is_class G C  isRef (hd stk))"

| "check_instr (Invoke C mn ps) G hp stk vars Cl sig pc mxs frs =
  (length ps < length stk  
  (let n = length ps; v = stk!n in
  isRef v  (v  Null  
    hp (the_Addr v)  None 
    method (G,cname_of hp v) (mn,ps)  None 
    list_all2 (λv T. G,hp  v ::≼ T) (rev (take n stk)) ps)))"
  
| "check_instr Return G hp stk0 vars Cl sig0 pc mxs frs =
  (0 < length stk0  (0 < length frs  
    method (G,Cl) sig0  None     
    (let v = hd stk0;  (C, rT, body) = the (method (G,Cl) sig0) in
    Cl = C  G,hp  v ::≼ rT)))"
 
| "check_instr Pop G hp stk vars Cl sig pc mxs frs = 
  (0 < length stk)"

| "check_instr Dup G hp stk vars Cl sig pc mxs frs = 
  (0 < length stk  size stk < mxs)"

| "check_instr Dup_x1 G hp stk vars Cl sig pc mxs frs = 
  (1 < length stk  size stk < mxs)"

| "check_instr Dup_x2 G hp stk vars Cl sig pc mxs frs = 
  (2 < length stk  size stk < mxs)"

| "check_instr Swap G hp stk vars Cl sig pc mxs frs =
  (1 < length stk)"

| "check_instr IAdd G hp stk vars Cl sig pc mxs frs =
  (1 < length stk  isIntg (hd stk)  isIntg (hd (tl stk)))"

| "check_instr (Ifcmpeq b) G hp stk vars Cl sig pc mxs frs =
  (1 < length stk  0  int pc+b)"

| "check_instr (Goto b) G hp stk vars Cl sig pc mxs frs =
  (0  int pc+b)"

| "check_instr Throw G hp stk vars Cl sig pc mxs frs =
  (0 < length stk  isRef (hd stk))"

definition check :: "jvm_prog  jvm_state  bool" where
  "check G s  let (xcpt, hp, frs) = s in
               (case frs of []  True | (stk,loc,C,sig,pc)#frs'  
                (let  (C',rt,mxs,mxl,ins,et) = the (method (G,C) sig); i = ins!pc in
                 pc < size ins  
                 check_instr i G hp stk loc C sig pc mxs frs'))"


definition exec_d :: "jvm_prog  jvm_state type_error  jvm_state option type_error" where
  "exec_d G s  case s of 
      TypeError  TypeError 
    | Normal s'  if check G s' then Normal (exec (G, s')) else TypeError"


definition
  exec_all_d :: "jvm_prog  jvm_state type_error  jvm_state type_error  bool" 
                   ("_  _ ─jvmd→ _" [61,61,61]60) where
  "G  s ─jvmd→ t 
         (s,t)  ({(s,t). exec_d G s = TypeError  t = TypeError} 
                  {(s,t). t'. exec_d G s = Normal (Some t')  t = Normal t'})*"


declare split_paired_All [simp del]
declare split_paired_Ex [simp del]

lemma [dest!]:
  "(if P then A else B)  B  P"
  by (cases P, auto)

lemma exec_d_no_errorI [intro]:
  "check G s  exec_d G (Normal s)  TypeError"
  by (unfold exec_d_def) simp

theorem no_type_error_commutes:
  "exec_d G (Normal s)  TypeError  
  exec_d G (Normal s) = Normal (exec (G, s))"
  by (unfold exec_d_def, auto)


lemma defensive_imp_aggressive:
  "G  (Normal s) ─jvmd→ (Normal t)  G  s ─jvm→ t"
proof -
  have "x y. G  x ─jvmd→ y  s t. x = Normal s  y = Normal t   G  s ─jvm→ t"
    apply (unfold exec_all_d_def)
    apply (erule rtrancl_induct)
     apply (simp add: exec_all_def)
    apply (fold exec_all_d_def)
    apply simp
    apply (intro allI impI)
    apply (erule disjE, simp)
    apply (elim exE conjE)
    apply (erule allE, erule impE, assumption)
    apply (simp add: exec_all_def exec_d_def split: type_error.splits if_split_asm)
    apply (rule rtrancl_trans, assumption)
    apply blast
    done
  moreover
  assume "G  (Normal s) ─jvmd→ (Normal t)" 
  ultimately
  show "G  s ─jvm→ t" by blast
qed

end