(**************************************************************************)
(*     ARM/Power Multiprocessor Machine Code Semantics: HOL sources       *)
(*                                                                        *)
(*                                                                        *)
(*  Jade Alglave (2), Anthony Fox (1), Samin Isthiaq (3),                 *)
(*  Magnus Myreen (1), Susmit Sarkar (1), Peter Sewell (1),               *)
(*  Francesco Zappa Nardelli (2)                                          *)
(*                                                                        *)
(*   (1) Computer Laboratory, University of Cambridge                     *)
(*   (2) Moscova project, INRIA Paris-Rocquencourt                        *)
(*   (3) Microsoft Research Cambridge                                     *)
(*                                                                        *)
(*     Copyright 2007-2008                                                *)
(*                                                                        *)
(*  Redistribution and use in source and binary forms, with or without    *)
(*  modification, are permitted provided that the following conditions    *)
(*  are met:                                                              *)
(*                                                                        *)
(*  1. Redistributions of source code must retain the above copyright     *)
(*     notice, this list of conditions and the following disclaimer.      *)
(*  2. Redistributions in binary form must reproduce the above copyright  *)
(*     notice, this list of conditions and the following disclaimer in    *)
(*     the documentation and/or other materials provided with the         *)
(*     distribution.                                                      *)
(*  3. The names of the authors may not be used to endorse or promote     *)
(*     products derived from this software without specific prior         *)
(*     written permission.                                                *)
(*                                                                        *)
(*  THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS    *)
(*  OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED     *)
(*  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE    *)
(*  ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY       *)
(*  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL    *)
(*  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE     *)
(*  GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS         *)
(*  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,          *)
(*  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING             *)
(*  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS    *)
(*  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.          *)
(*                                                                        *)
(**************************************************************************)

open HolKernel boolLib bossLib Parse;
open wordsLib;

open arm_astTheory;

open HolDoc;
val _ = new_theory "arm_seq_monad";


val _ = type_abbrev("arm_state", (*  state = tuple consisting of: *)
  ``: (ARMreg -> word32) #       (*  - general-purpose registers *)
      (ARMpsr -> ARMstatus) #    (*  - program-status registers *)
      ARMcp_registers #          (*  - co-processor registers *)
      (word32 -> word8 option) # (*  - unsegmented memory *)
      ARMinfo #                  (*  - info on ISA version and extensions *)
      ExclusiveMonitors          (*  - for synchronization & semaphores *) ``);

val arm_state = ``(r,p,c,m,v,e):arm_state``;

(* functions for reading/writing state *)

val AREAD_REG_def  = Define `AREAD_REG x ^arm_state = r x`;
val AREAD_PSR_def  = Define `AREAD_PSR x ^arm_state = p x`;
val AREAD_CP_def   = Define `AREAD_CP    ^arm_state = c`;
val AREAD_MEM_def  = Define `AREAD_MEM x ^arm_state = m x`;
val AREAD_INFO_def = Define `AREAD_INFO  ^arm_state = v`;
val AREAD_EXCL_def = Define `AREAD_EXCL  ^arm_state = e`;

val AWRITE_REG_def = Define`
  AWRITE_REG i x ^arm_state = ((i =+ x) r,p,c,m,v,e):arm_state`;

val AWRITE_PSR_def = Define`
  AWRITE_PSR i x ^arm_state = (r,(i =+ x) p,c,m,v,e):arm_state`;

val AWRITE_MEM_def = Define`
  AWRITE_MEM i x ^arm_state = (r,p,c,(i =+ x) m,v,e):arm_state`;

val AWRITE_EXCL_def = Define`
  AWRITE_EXCL x ^arm_state = (r,p,c,m,v,x):arm_state`;

fun tupleCases M =
 let val vlist = pairSyntax.strip_pair M
     val X = variant vlist (mk_var("x",type_of M))
     val tm = list_mk_exists(vlist, mk_eq(X,M))
 in
   GEN_ALL (METIS_PROVE [pairTheory.ABS_PAIR_THM] tm)
 end;

fun Cases_on_arm_state tm =
  FULL_STRUCT_CASES_TAC (Q.SPEC tm (tupleCases arm_state));

(* ------------------------------------------------------------------------ *>

  We define a state and monads for constructing a sequential version of the
  semantics.

<* ------------------------------------------------------------------------ *)

(* val _ = type_abbrev("Astate",``:arm_state -> ('a # arm_state) option``); *)

val _ = type_abbrev("M",``:arm_state -> ('a # arm_state) option``);


(* sequential monads for an option state *)

val constT_seq_def = Define `
  (constT_seq: 'a -> 'a M) x = \y. SOME (x,y)`;

val failureT_seq_def = Define `
  (failureT_seq: 'a M) = \y. NONE`;

val bindT_seq_def = Define `
  (bindT_seq: 'a M -> ('a -> 'b M) -> 'b M) s f =
    \y. case s y of NONE -> NONE || SOME (z,t) -> f z t`;

val seqT_seq_def = Define `
  (seqT_seq: 'a M -> 'b M -> 'b M) s f =
    bindT_seq s (\x. f)`;

val parT_seq_def = Define `
  (parT_seq: 'a M -> 'b M -> ('a # 'b) M) s t =
    bindT_seq s (\x. bindT_seq t (\y. constT_seq (x,y)))`;

val lockT_seq_def = Define `
  (lockT_seq: 'a M -> 'a M) s = s`;

val condT_seq_def = Define`
  (condT_seq: bool -> unit M -> unit M) b s =
    if b then s else constT_seq ()`;

val discardT_seq_def = Define `
  (discardT_seq: 'a M -> unit M) s =
    seqT_seq s (constT_seq ())`;

val addT_seq_def = Define `
  (addT_seq: 'a -> 'b M -> ('a # 'b) M) x s =
    bindT_seq s (\z. constT_seq (x,z))`;

(* register reads/writes always succeed. *)

val write_reg_seq_def = Define `(write_reg_seq ii r x):unit M =
  \s. SOME ((),AWRITE_REG r x s)`;

val read_reg_seq_def = Define `(read_reg_seq ii r):Aimm M =
  \s. SOME (let x = AREAD_REG r s in if r = r15 then x + 8w else x,s)`;

val write_psr_seq_def = Define`
  (write_psr_seq ii r x):unit M =
    \s. case word5_to_mode x.M of
           NONE -> NONE
        || SOME m -> SOME ((),AWRITE_PSR r x s)`;

val read_psr_seq_def = Define `(read_psr_seq ii r):ARMstatus M =
  \s. SOME (AREAD_PSR r s,s)`;

val write_flags_seq_def = Define`
  (write_flags_seq ii (n,z,c,v)):unit M =
  \s. SOME ((),
        let cpsr = AREAD_PSR CPSR s with <| N := n; Z := z; C := c; V := v |> in
          AWRITE_PSR CPSR cpsr s)`;

val read_flags_seq_def = Define `(read_flags_seq ii):(bool#bool#bool#bool) M =
  \s. SOME (let cpsr = AREAD_PSR CPSR s in (cpsr.N,cpsr.Z,cpsr.C,cpsr.V), s)`;

val read_endian_seq_def = Define `(read_endian_seq ii):bool M =
  \s. SOME ((AREAD_PSR CPSR s).E,s)`;

val read_mode_seq_def = Define `(read_mode_seq ii):ARMmode M =
  \s. case word5_to_mode (AREAD_PSR CPSR s).M of
         NONE -> NONE
      || SOME m -> SOME (m,s)`;

val read_instr_set_seq_def = Define `(read_instr_set_seq ii):InstrSet M =
  \s. SOME (let cpsr = AREAD_PSR CPSR s in
        (case (cpsr.J, cpsr.T) of
            (F,F) -> InstrSet_ARM
         || (F,T) -> InstrSet_Thumb
         || (T,F) -> InstrSet_Jazelle
         || (T,T) -> InstrSet_ThumbEE), s)`;

val read_sctlr_seq_def = Define `(read_sctlr_seq ii):ARMsctlr M =
  \s. SOME ((AREAD_CP s).SCTLR,s)`;

(* read ISA info. *)

val read_info_seq_def = Define `(read_info_seq ii):ARMinfo M =
  \s. SOME (AREAD_INFO s,s)`;

val read_version_seq_def = Define `(read_version_seq ii):num M =
  \s. SOME (version_number (AREAD_INFO s).version,s)`;

(* exclusive monitor operations. *)

val set_exclusive_monitorsT_seq_def = Define`
  (set_exclusive_monitorsT_seq ii (a:word32,size:num)) : unit M =
  (\s. case word5_to_mode (AREAD_PSR CPSR s).M of
         NONE -> NONE
      || SOME m -> SOME ((),
           let monitor = AREAD_EXCL s in
           let memaddrdesc = monitor.TranslateAddress(a, ~(m = usr), F)
           in
             AWRITE_EXCL
               (monitor with
                  <| IsExclusiveGlobal :=
                      (if memaddrdesc.memattrs.shareable then
                         monitor.MarkExclusiveGlobal
                           (memaddrdesc.paddress,ii,size)
                           monitor.IsExclusiveGlobal
                       else monitor.IsExclusiveGlobal);
                     IsExclusiveLocal :=
                       monitor.MarkExclusiveLocal (memaddrdesc.paddress,ii,size)
                         monitor.IsExclusiveLocal |>) s))`;

val exclusive_monitors_passT_seq_def = Define`
  (exclusive_monitors_passT_seq ii (a:word32,size:num)) : bool M =
  \s. case word5_to_mode (AREAD_PSR CPSR s).M of
         NONE -> NONE
      || SOME m ->
           let monitor = AREAD_EXCL s in
           let memaddrdesc = monitor.TranslateAddress(a, ~(m = usr), F) in
           let local_pass = monitor.IsExclusiveLocal
                              (memaddrdesc.paddress,ii,size) in
           let passed =
                 if memaddrdesc.memattrs.shareable then
                   monitor.IsExclusiveLocal (memaddrdesc.paddress,ii,size) /\
                   local_pass
                 else local_pass
           in
             SOME (passed,
               if passed then
                 AWRITE_EXCL
                   (let (local,global) =
                          monitor.ClearExclusiveLocal ii
                            (monitor.IsExclusiveLocal,monitor.IsExclusiveGlobal)
                    in
                      monitor with
                      <| IsExclusiveLocal  := local;
                         IsExclusiveGlobal := global |>) s
               else s)`;

(* memory barrier *)

val dmbT_seq_def = Define`
  (dmbT_seq : iiid -> MBReqDomain # MBReqTypes -> unit M) ii x =
    \s. SOME ((), s)`;

(* memory writes are only allowed to modelled memory, i.e. locations
   containing SOME ... *)

val write_mem8_seq_def = Define `(write_mem8_seq ii a x):unit M =
  (\s. case AREAD_MEM a s of
          NONE   -> NONE
       || SOME y -> SOME ((),AWRITE_MEM a (SOME x) s))`;

val write_mem16_seq_def = Define `(write_mem16_seq ii a (x:word16)):unit M =
  if a ' 0 then
    failureT_seq
  else
    bindT_seq (read_endian_seq ii)
      (\e. let l = word2bytes 4 x in
        discardT_seq
          (if e then
            parT_seq (write_mem8_seq ii a      (EL 1 l))
                     (write_mem8_seq ii (a+1w) (EL 0 l))
          else
            parT_seq (write_mem8_seq ii a      (EL 0 l))
                     (write_mem8_seq ii (a+1w) (EL 1 l))))`;

val write_mem32_seq_def = Define`(write_mem32_seq ii a (x:word32)):unit M =
  bindT_seq (read_endian_seq ii)
    (\e. let aa = align(a,4) and l = word2bytes 4 x in
      discardT_seq
        (if e then
           parT_seq (write_mem8_seq ii aa      (EL 3 l))
          (parT_seq (write_mem8_seq ii (aa+1w) (EL 2 l))
          (parT_seq (write_mem8_seq ii (aa+2w) (EL 1 l))
                    (write_mem8_seq ii (aa+3w) (EL 0 l))))
         else
           parT_seq (write_mem8_seq ii aa      (EL 0 l))
          (parT_seq (write_mem8_seq ii (aa+1w) (EL 1 l))
          (parT_seq (write_mem8_seq ii (aa+2w) (EL 2 l))
                    (write_mem8_seq ii (aa+3w) (EL 3 l))))))`;

(* a memory read to an unmodelled memory location causes a failure *)

val read_mem8_seq_def = Define `(read_mem8_seq ii a):word8 M =
  (\s. case AREAD_MEM a s of NONE -> NONE || SOME x -> SOME (x,s))`;

val read_mem16_seq_def = Define `(read_mem16_seq ii a):word16 M =
  if a ' 0 then
    failureT_seq
  else
    bindT_seq (parT_seq (read_mem8_seq ii a) (read_mem8_seq ii (a+1w)))
      (\(b0,b1). constT_seq (b1 @@ b0))`;

val _ = wordsLib.guess_lengths();

val read_mem32_seq_def = Define `(read_mem32_seq ii a):word32 M =
  bindT_seq
    (let aa = align(a,4) in
        parT_seq (read_mem8_seq ii aa)
       (parT_seq (read_mem8_seq ii (aa+1w))
       (parT_seq (read_mem8_seq ii (aa+2w))
                 (read_mem8_seq ii (aa+3w)))))
    (\(b0,b1,b2,b3). constT_seq ((b3 @@ b2 @@ b1 @@ b0)))`;

(* export *)

val _ = (
 Define `(constT: 'a -> 'a M)                    = constT_seq`;
 Define `(addT: 'a -> 'b M -> ('a # 'b) M)       = addT_seq`;
 Define `(lockT: unit M -> unit M)               = lockT_seq`;
 Define `(failureT: unit M)                      = failureT_seq`;
 Define `(discardT: 'a M -> unit M)              = discardT_seq`;
 Define `(condT: bool -> unit M -> unit M)       = condT_seq`;
 Define `(bindT: 'a M -> (('a -> 'b M) -> 'b M)) = bindT_seq`;
 Define `(seqT: 'a M -> 'b M -> 'b M)            = seqT_seq`;
 Define `(parT: 'a M -> 'b M -> ('a # 'b) M)     = parT_seq`;

 Define `(read_info: iiid -> ARMinfo M)                  = read_info_seq`;
 Define `(read_version: iiid -> num M)                   = read_version_seq`;
 Define `(write_reg: iiid -> ARMreg -> Aimm -> unit M)   = write_reg_seq`;
 Define `(read_reg: iiid -> ARMreg -> Aimm M)            = read_reg_seq`;
 Define `(write_psr:
                  iiid -> ARMpsr -> ARMstatus -> unit M) = write_psr_seq`;
 Define `(read_psr: iiid -> ARMpsr -> ARMstatus M)       = read_psr_seq`;
 Define `(write_flags:
            iiid -> bool # bool # bool # bool -> unit M) = write_flags_seq`;
 Define `(read_flags:
                  iiid -> (bool # bool # bool # bool) M) = read_flags_seq`;
 Define `(read_mode: iiid -> ARMmode M)                  = read_mode_seq`;
 Define `(read_instr_set: iiid -> InstrSet M)            = read_instr_set_seq`;
 Define `(read_sctlr: iiid -> ARMsctlr M)                = read_sctlr_seq`;

 Define `(write_mem8: iiid -> word32 -> word8 -> unit M) = write_mem8_seq`;
 Define `(read_mem8: iiid -> word32 -> word8 M)          = read_mem8_seq`;
 Define `(write_mem16:
            iiid -> word32 -> word16 -> unit M)          = write_mem16_seq`;
 Define `(read_mem16: iiid -> word32 -> word16 M)        = read_mem16_seq`;
 Define `(write_mem32:
            iiid -> word32 -> word32 -> unit M)          = write_mem32_seq`;
 Define `(read_mem32: iiid -> word32 -> Aimm M)          = read_mem32_seq`;

 Define `(dmbT :
            iiid -> MBReqDomain # MBReqTypes -> unit M) = dmbT_seq`;
 Define `(set_exclusive_monitorsT :
            iiid -> (word32 # num) -> unit M) = set_exclusive_monitorsT_seq`;
 Define `(exclusive_monitors_passT :
            iiid -> (word32 # num) -> bool M) = exclusive_monitors_passT_seq`);


(* some rewriter-friendly theorems *)

val option_apply_def = Define `
  option_apply x f = if x = NONE then NONE else f (THE x)`;

val option_apply_SOME = prove(
  ``!x f. option_apply (SOME x) f = f x``,SRW_TAC [] [option_apply_def]);

val mem_seq_lemma = prove(
  ``(read_mem8_seq ii a s = option_apply (AREAD_MEM a s) (\x. SOME (x,s))) /\
    (write_mem8_seq ii a y s =
       option_apply (AREAD_MEM a s) (\x. SOME ((),AWRITE_MEM a (SOME y) s)))``,
  SRW_TAC [] [option_apply_def,read_mem8_seq_def,write_mem8_seq_def]
  THEN Cases_on `AREAD_MEM a s` THEN FULL_SIMP_TAC std_ss []);

val monad_simp_lemma = prove(
  ``(constT_seq x = \y. SOME (x,y)) /\ (failureT_seq = \y. NONE) /\  (lockT_seq d = d) /\
    (addT_seq q s = \y. option_apply (s y) (\t. SOME ((q,FST t),SND t))) /\
    (bindT_seq s f = \y. option_apply (s y) (\t. f (FST t) (SND t))) /\
    (parT_seq s t = \y. option_apply (s y) (\z.
                    option_apply (t (SND z)) (\x. SOME ((FST z,FST x),SND x))))``,
  SRW_TAC [] [parT_seq_def,bindT_seq_def,failureT_seq_def,lockT_seq_def,
                   addT_seq_def,constT_seq_def,FUN_EQ_THM]
  THEN Cases_on `s y` THEN POP_ASSUM MP_TAC THEN SRW_TAC [] [option_apply_def]
  THEN Cases_on `x` THEN POP_ASSUM MP_TAC THEN SRW_TAC [] [option_apply_def]
  THEN Cases_on `t r` THEN SRW_TAC [] [option_apply_def] THEN FULL_SIMP_TAC std_ss []
  THEN Cases_on `x` THEN SRW_TAC [] [option_apply_def]);

val seq_monad_thm = save_thm("seq_monad_thm",let
  val xs = option_apply_SOME :: mem_seq_lemma :: (CONJUNCTS monad_simp_lemma)
  in LIST_CONJ (map GEN_ALL xs) end);

val AREAD_CLAUSES = store_thm("AREAD_CLAUSES",
  ``!s.
    (AREAD_REG r (AWRITE_MEM a x s) = AREAD_REG r s) /\
    (AREAD_REG r (AWRITE_PSR p w s) = AREAD_REG r s) /\
    (AREAD_MEM a (AWRITE_REG r y s) = AREAD_MEM a s) /\
    (AREAD_MEM a (AWRITE_PSR p w s) = AREAD_MEM a s) /\
    (AREAD_PSR p (AWRITE_REG r y s) = AREAD_PSR p s) /\
    (AREAD_PSR p (AWRITE_MEM a x s) = AREAD_PSR p s) /\
    (AREAD_REG r (AWRITE_REG r2 y s) = if r = r2 then y else AREAD_REG r s) /\
    (AREAD_MEM a (AWRITE_MEM a2 x s) = if a = a2 then x else AREAD_MEM a s) /\
    (AREAD_PSR p (AWRITE_PSR p2 w s) = if p = p2 then w else AREAD_PSR p s)``,
  STRIP_TAC THEN Cases_on_arm_state `s`
  THEN SRW_TAC [] [AREAD_REG_def,AREAD_MEM_def,AREAD_PSR_def,AWRITE_MEM_def,
    AWRITE_REG_def,AWRITE_PSR_def, combinTheory.APPLY_UPDATE_THM]);

val _ = export_theory ();
