(*  Title: 	ZF/ex/arith
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   1991  University of Cambridge

Arithmetic operators and their definitions

Proofs about elementary arithmetic: addition, multiplication, etc.
Tests definitions and simplifier.
*)

writeln"File ZF/ex/arith";

structure Arith =
struct
val mixfix = 
    [ Infixl("#+",	"[i,i]=>i", 65),
      Infixl("#-",	"[i,i]=>i", 65),
      Infixl("#*",	"[i,i]=>i", 70),
      Infixl("#'/'/",	"[i,i]=>i", 70),
      Infixl("#'/",	"[i,i]=>i", 70) ];

val arith_sext = Sext{mixfix=mixfix,
		      parse_translation=[],
		      print_translation=[]};

val const_decs = 
    [ (["if"],	"[o,i,i]=>i") ];

val thy = extend_theory Nat.thy  "arith" 
    ([], [], [], const_decs, Some arith_sext)
  [ 
    ("if_def",   "if(P,a,b) == THE z. P & z=a | ~P & z=b"),  
    ("add_def",   "m#+n == nat_rec(m, n, %u v.succ(v))"),  
    ("diff_def",  "m#-n == nat_rec(n, m, %u v. nat_rec(v, 0, %x y.x))"),  
    ("mult_def",  "m#*n == nat_rec(m, 0, %u v. n #+ v)"),  
    ("mod_def",   
     "m#//n == transrec(nat, m, %j f. if(j:n, j, f`(j#-n)))"),  
    ("quo_def",   
     "m#/n == transrec(nat, m, %j f. if(j:n, 0, succ(f`(j#-n))))") ];
end;

(*"Difference" is subtraction of natural numbers.
  There are no negative numbers; we have
     m #- n = 0  iff  m<=n   and     m #- n = succ(k) iff m>n.
  Also, nat_rec(m, 0, %z w.z) is pred(m).   *)

local val ax = get_axiom Arith.thy in  
val if_def = ax"if_def";
val add_def = ax"add_def";
val diff_def = ax"diff_def";
val mult_def = ax"mult_def";
val mod_def = ax"mod_def";
val quo_def = ax"quo_def";
end;


(** if -- belongs in main rulefile?? **)

goalw Arith.thy [if_def] "if(True,a,b) = a";
by (fast_tac (ZF_cs addIs [the_equality]) 1);
val if_true_conv = result();

goalw Arith.thy [if_def] "if(False,a,b) = b";
by (fast_tac (ZF_cs addIs [the_equality]) 1);
val if_false_conv = result();

val prems = goalw Arith.thy [if_def]
    "[| P<->Q;  a=c;  b=d |] ==> if(P,a,b) = if(Q,c,d)";
by (REPEAT (resolve_tac (prems@[refl,the_cong]@FOL_congs) 1));
val if_cong = result();

val if_ss = ZF_ss addcongs [if_cong]
	       	  addrews  [if_true_conv,if_false_conv];

val prems = goal Arith.thy
    "[| P ==> a: A;  ~P ==> b: A |] ==> if(P,a,b): A";
by (res_inst_tac [("Q","P")] (excluded_middle RS disjE) 1);
by (all_simp_tac if_ss prems);
val if_type = result();


(****are the "conditional" versions needed??
    val prems = goal Arith.thy "P ==> if(P,a,b) = a";
    val prems = goal Arith.thy "~P ==> if(P,a,b) = b";
****)

val nat_typechecks = [nat_rec_type,nat_0_I,nat_succ_I,Ord_nat,if_type];

val nat_ss = if_ss addcongs [nat_case_cong,nat_rec_cong]
	       	   addrews ([nat_rec_0_conv,nat_rec_succ_conv] @ 
			    nat_typechecks);


(** Addition **)

val add_type = prove_goalw Arith.thy [add_def]
    "[| m:nat;  n:nat |] ==> m #+ n : nat"
 (fn prems=> [ (typechk_tac (prems@nat_typechecks@ZF_typechecks)) ]);

val add_0_conv = prove_goalw Arith.thy [add_def]
    "0 #+ n = n"
 (fn _ => [ (rtac nat_rec_0_conv 1) ]);

val add_succ_conv = prove_goalw Arith.thy [add_def]
    "[| m:nat |] ==> succ(m) #+ n = succ(m #+ n)"
 (fn prems=> [ (rtac nat_rec_succ_conv 1), (resolve_tac prems 1) ]); 

(** Multiplication **)

val mult_type = prove_goalw Arith.thy [mult_def]
    "[| m:nat;  n:nat |] ==> m #* n : nat"
 (fn prems=>
  [ (typechk_tac (prems@[add_type]@nat_typechecks@ZF_typechecks)) ]);

val mult_0_conv = prove_goalw Arith.thy [mult_def]
    "0 #* n = 0"
 (fn _ => [ (rtac nat_rec_0_conv 1) ]);

val mult_succ_conv = prove_goalw Arith.thy [mult_def]
    "[| m:nat |] ==> succ(m) #* n = n #+ (m #* n)"
 (fn prems=> [ (rtac nat_rec_succ_conv 1), (resolve_tac prems 1) ]); 

(** Difference **)

val diff_type = prove_goalw Arith.thy [diff_def]
    "[| m:nat;  n:nat |] ==> m #- n : nat"
 (fn prems=> [ (typechk_tac (prems@nat_typechecks@ZF_typechecks)) ]);

val diff_0_conv = prove_goalw Arith.thy [diff_def]
    "m #- 0 = m"
 (fn _ => [ (rtac nat_rec_0_conv 1) ]);

val diff_0_eq_0 = prove_goalw Arith.thy [diff_def]
    "n:nat ==> 0 #- n = 0"
 (fn [prem]=>
  [ (rtac (prem RS nat_induct) 1),
    (all_simp_tac nat_ss []) ]);

(*Must simplify BEFORE the induction!!  (Else we get a critical pair)
  succ(m) #- succ(n)   rewrites to   pred(succ(m) #- n)  *)
val diff_succ_succ = prove_goalw Arith.thy [diff_def]
    "[| m:nat;  n:nat |] ==> succ(m) #- succ(n) = m #- n"
 (fn prems=>
  [ (ASM_SIMP_TAC (nat_ss addrews prems) 1),
    (nat_ind_tac "n" prems 1),
    (all_simp_tac nat_ss prems) ]);

(*Mathematical induction for "diff"*)
val prems = goal Nat.thy
    "[| m: nat;  n: nat;  \
\       !!x. [| x: nat |] ==> P(x,0);  \
\       !!y. [| y: nat |] ==> P(0,succ(y));  \
\       !!x y. [| x: nat;  y: nat;  P(x,y) |] ==> P(succ(x),succ(y))  \
\    |] ==> P(m,n)";
by (res_inst_tac [("x","m")] bspec 1);
by (resolve_tac prems 2);
by (nat_ind_tac "n" prems 1);
by (rtac ballI 2);
by (nat_ind_tac "x" [] 2);
by (REPEAT (ares_tac (prems@[ballI]) 1 ORELSE etac bspec 1));
val diff_induct = result();

val prems = goal Arith.thy 
    "[| m:nat;  n:nat |] ==> m #- n : succ(m)";
by (res_inst_tac [("m","m"),("n","n")] diff_induct 1);
by (resolve_tac prems 1);
by (resolve_tac prems 1);
by (etac succE 3);
by (all_simp_tac nat_ss (prems@[diff_0_conv,diff_0_eq_0,diff_succ_succ]));
val diff_leq = result();

(*** Simplification over add, mult, diff ***)

val arith_typechecks = [add_type, mult_type, diff_type];
val arith_rews = [add_0_conv, add_succ_conv,
		  mult_0_conv, mult_succ_conv,
		  diff_0_conv, diff_0_eq_0, diff_succ_succ];

val arith_congs = mk_congs Arith.thy ["op #+", "op #-", "op #*"];

val arith_ss = nat_ss addcongs arith_congs
	              addrews  (arith_rews@arith_typechecks);


(*** Addition ***)

(*Associative law for addition*)
val add_assoc = prove_goal Arith.thy 
    "[| m:nat;  n:nat |] ==> (m #+ n) #+ k = m #+ (n #+ k)"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss prems) ]);

(*Commutative law for addition.  Must simplify after first induction!
  Orientation of rewrites is delicate*)  
val add_commute = prove_goal Arith.thy 
    "[| m:nat;  n:nat |] ==> m #+ n = n #+ m"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss prems),
    (nat_ind_tac "n" prems 2),
    (rtac sym 1),
    (nat_ind_tac "n" prems 1),
    (all_simp_tac arith_ss prems) ]);

(*** Multiplication ***)

(*right annihilation in product*)
val mult_right0 = prove_goal Arith.thy 
    "m:nat ==> m #* 0 = 0"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss prems)  ]);

(*right successor law for multiplication*)
val mult_right_succ = prove_goal Arith.thy 
    "[| m:nat;  n:nat |] ==> m #* succ(n) = m #+ (m #* n)"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss ([add_assoc RS sym]@prems)),
       (*The final goal requires the commutative law for addition*)
    (REPEAT (ares_tac (prems@[refl,add_commute]@ZF_congs@arith_congs) 1))  ]);

(*Commutative law for multiplication*)
val mult_commute = prove_goal Arith.thy 
    "[| m:nat;  n:nat |] ==> m #* n = n #* m"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss (prems@[mult_right0, mult_right_succ])) ]);

(*addition distributes over multiplication*)
val add_mult_dist = prove_goal Arith.thy 
    "[| m:nat;  n:nat;  k:nat |] ==> (m #+ n) #* k = (m #* k) #+ (n #* k)"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss ([add_assoc RS sym]@prems)) ]);

(*Associative law for multiplication*)
val mult_assoc = prove_goal Arith.thy 
    "[| m:nat;  n:nat;  k:nat |] ==> (m #* n) #* k = m #* (n #* k)"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss (prems @ [add_mult_dist])) ]);


(*** Difference ***)

val diff_self_eq_0 = prove_goal Arith.thy 
    "m:nat ==> m #- m = 0"
 (fn prems=>
  [ (nat_ind_tac "m" prems 1),
    (all_simp_tac arith_ss prems) ]);

(*Addition is the inverse of subtraction: if n<=m then n+(m-n) = m. *)
val notless::prems = goal Arith.thy
    "[| ~m:n;  m:nat;  n:nat |] ==> n #+ (m#-n) = m";
by (rtac (notless RS rev_mp) 1);
by (res_inst_tac [("m","m"),("n","n")] diff_induct 1);
by (resolve_tac prems 1);
by (resolve_tac prems 1);
by (all_simp_tac arith_ss (prems@[succ_mem_succ_iff, Ord_0_mem_succ, 
				  naturals_are_ordinals]));
val add_inverse_diff = result();


(*** Remainder ***)

(*In ordinary notation: if 0<n and n<=m then m-n < m *)
val prems = goal Arith.thy
    "[| 0:n; ~ m:n;  m:nat;  n:nat |] ==> m #- n : m";
by (cut_facts_tac prems 1);
by (etac rev_mp 1);
by (etac rev_mp 1);
by (res_inst_tac [("m","m"),("n","n")] diff_induct 1);
by (resolve_tac prems 1);
by (resolve_tac prems 1);
by (all_simp_tac nat_ss (prems@[diff_leq,diff_succ_succ]));
val quorem_termination = result();

val prems = goalw Arith.thy [mod_def]
    "[| 0:n;  m:nat;  n:nat |] ==> m#//n : nat";
by (REPEAT (ares_tac ([transrec_type,apply_type,quorem_termination]@
		      prems@nat_typechecks) 1));
val mod_type = result();

val prems = goalw Arith.thy [mod_def]
    "[| 0:n;  m:n;  m:nat;  n:nat |] ==> m#//n = m";
by (ASM_SIMP_TAC (if_ss addrews (prems@[transrec_conv,Ord_nat])) 1);
val mod_less_conv = result();

val prems = goalw Arith.thy [mod_def]
    "[| 0:n;  ~m:n;  m:nat;  n:nat |] ==> m#//n = (m#-n)#//n";
by (ASM_SIMP_TAC (if_ss addrews (prems@[transrec_conv,Ord_nat,
					quorem_termination])) 1);
val mod_geq_conv = result();


(*** Quotient ***)

val prems = goalw Arith.thy [quo_def]
    "[| 0:n;  m:nat;  n:nat |] ==> m#/n : nat";
by (REPEAT (ares_tac ([transrec_type,apply_type,quorem_termination]@
		      prems@nat_typechecks) 1));
val quo_type = result();

val prems = goalw Arith.thy [quo_def]
    "[| 0:n;  m:n;  m:nat;  n:nat |] ==> m#/n = 0";
by (ASM_SIMP_TAC (if_ss addrews (prems@[transrec_conv,Ord_nat])) 1);
val quo_less_conv = result();

val prems = goalw Arith.thy [quo_def]
    "[| 0:n;  ~m:n;  m:nat;  n:nat |] ==> m#/n = succ((m#-n)#/n)";
by (ASM_SIMP_TAC (if_ss addrews (prems@[transrec_conv,Ord_nat,
					quorem_termination])) 1);
val quo_geq_conv = result();

(*Main Result.*)
val prems = goal Arith.thy
    "[| 0:n;  m:nat;  n:nat |] ==> (m#/n)#*n #+ m#//n = m";
by (Ord_ind_tac "m" "nat" (Ord_nat::prems) 1);  (*complete induction*)
by (res_inst_tac [("Q","m1:n")] (excluded_middle RS disjE) 1);
by (all_simp_tac arith_ss ([mod_type,quo_type] @ prems @
        [mod_less_conv,mod_geq_conv, quo_less_conv, quo_geq_conv,
	 add_assoc, add_inverse_diff, quorem_termination]));
val mod_quo_equality = result();


writeln"Reached end of file.";

