(*  Title: 	HOL/sum
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   1991  University of Cambridge

The disjoint sum of two types
*)

structure Sum =
struct

val sext = Sext{
      mixfix = [TInfixl("+", "sum", 10)],
      parse_translation = [],
      print_translation = []};

val thy = extend_theory Prod.thy "sum"
 ([],[],
  [(["sum"], (["term","term"],"term"))],
  [(["Inl_Rep"],	"['a,'a,'b,bool] => bool"),
   (["Inr_Rep"],	"['b,'a,'b,bool] => bool"),
   (["Sum"],		"(['a,'b,bool] => bool)class"),
   (["Rep_Sum"],	"'a + 'b => (['a,'b,bool] => bool)"),
   (["Abs_Sum"],	"(['a,'b,bool] => bool) => 'a+'b"),
   (["Inl"],		"'a => 'a+'b"),
   (["Inr"],		"'b => 'a+'b"),
   (["when"],		"['a+'b, 'a=>'c, 'b=>'c] =>'c")],
  Some sext)
 [
  ("Inl_Rep_def",	"Inl_Rep == (%a. %x y p. x=a & p)"),
  ("Inr_Rep_def",	"Inr_Rep == (%b. %x y p. y=b & ~p)"),
  ("Sum_def", "Sum == {f. (? a. f = Inl_Rep(a)) | (? b. f = Inr_Rep(b))}"),
    (*faking a type definition...*)
  ("Rep_Sum", 		"Rep_Sum(s): Sum"),
  ("Rep_Sum_inverse", 	"Abs_Sum(Rep_Sum(s)) = s"),
  ("Abs_Sum_inverse", 	"f: Sum ==> Rep_Sum(Abs_Sum(f)) = f"),
    (*defining the abstract constants*)
  ("Inl_def",  		"Inl == (%a. Abs_Sum(Inl_Rep(a)))"),
  ("Inr_def", 		"Inr == (%b. Abs_Sum(Inr_Rep(b)))"),
  ("when_def",	"when == (%p f g. @z.  (!x. p=Inl(x) --> z=f(x))\
 \                                   & (!y. p=Inr(y) --> z=g(y)))")
 ];
end;

local val ax = get_axiom Sum.thy
in
val Inl_Rep_def = ax"Inl_Rep_def";
val Inr_Rep_def = ax"Inr_Rep_def";
val Sum_def     = ax"Sum_def";
val Rep_Sum     = ax"Rep_Sum";
val Rep_Sum_inverse     = ax"Rep_Sum_inverse";
val Abs_Sum_inverse     = ax"Abs_Sum_inverse";
val Inl_def     = ax"Inl_def";
val Inr_def     = ax"Inr_def";
val when_def    = ax"when_def";
end;

(** Inl_Rep and Inr_Rep: Representations of the constructors **)

(*This counts as a non-emptiness result for admitting 'a+'b as a type*)
goalw Sum.thy [Sum_def] "Inl_Rep(a) : Sum";
by (EVERY1 [rtac CollectI, rtac disjI1, rtac exI, rtac refl]);
val Inl_RepI = result();

goalw Sum.thy [Sum_def] "Inr_Rep(b) : Sum";
by (EVERY1 [rtac CollectI, rtac disjI2, rtac exI, rtac refl]);
val Inr_RepI = result();

goal Sum.thy "One_One_on(Abs_Sum,Sum)";
br One_One_on_inverseI 1;
be Abs_Sum_inverse 1;
val One_One_on_Abs_Sum = result();

(** Distinctness of Inl and Inr **)

goalw Sum.thy [Inl_Rep_def, Inr_Rep_def] "~ (Inl_Rep(a) = Inr_Rep(b))";
by (EVERY1 [rtac notI,
	    etac (ap_thm RS ap_thm RS ap_thm RS iffE), 
	    rtac (notE RS ccontr),  etac (mp RS conjunct2), 
	    REPEAT o (ares_tac [refl,conjI]) ]);
val Inl_Rep_not_Inr_Rep = result();

goalw Sum.thy [Inl_def,Inr_def] "~ (Inl(a) = Inr(b))";
br (One_One_on_Abs_Sum RS One_One_on_contraD) 1;
br Inl_Rep_not_Inr_Rep 1;
br Inl_RepI 1;
br Inr_RepI 1;
val Inl_not_Inr = result();

val Inl_neq_Inr = standard (Inl_not_Inr RS notE);
val Inr_neq_Inl = sym RS Inl_neq_Inr;

(** Injectiveness of Inl and Inr **)

val [major] = goalw Sum.thy [Inl_Rep_def] "Inl_Rep(a) = Inl_Rep(c) ==> a=c";
br (major RS ap_thm RS ap_thm RS ap_thm RS iffE) 1;
by (fast_tac HOL_cs 1);
val Inl_Rep_inject = result();

val [major] = goalw Sum.thy [Inr_Rep_def] "Inr_Rep(b) = Inr_Rep(d) ==> b=d";
br (major RS ap_thm RS ap_thm RS ap_thm RS iffE) 1;
by (fast_tac HOL_cs 1);
val Inr_Rep_inject = result();

goalw Sum.thy [Inl_def] "One_One(Inl)";
br One_OneI 1;
be (One_One_on_Abs_Sum RS One_One_onD RS Inl_Rep_inject) 1;
br Inl_RepI 1;
br Inl_RepI 1;
val One_One_Inl = result();
val Inl_inject = One_One_Inl RS One_OneD;

goalw Sum.thy [Inr_def] "One_One(Inr)";
br One_OneI 1;
be (One_One_on_Abs_Sum RS One_One_onD RS Inr_Rep_inject) 1;
br Inr_RepI 1;
br Inr_RepI 1;
val One_One_Inr = result();
val Inr_inject = One_One_Inr RS One_OneD;

(** when -- the selection operator for sums **)

goalw Sum.thy [when_def] "when(Inl(x), f, g) = f(x)";
by (fast_tac (class_cs addIs [select_equality] 
		       addEs [make_elim Inl_inject, Inl_neq_Inr]) 1);
val when_Inl_conv = result();

goalw Sum.thy [when_def] "when(Inr(x), f, g) = g(x)";
by (fast_tac (class_cs addIs [select_equality] 
		       addEs [make_elim Inr_inject, Inr_neq_Inl]) 1);
val when_Inr_conv = result();

(** Exhaustion rule for sums -- a degenerate form of induction **)

val prems = goalw Sum.thy [Inl_def,Inr_def]
    "[| !!x::'a. P(Inl(x));  !!y::'b. P(Inr(y)) \
\    |] ==> P(s)";
by (res_inst_tac [("t","s")] (Rep_Sum_inverse RS subst) 1);
br (rewrite_rule [Sum_def] Rep_Sum RS CollectE) 1;
by (REPEAT (eresolve_tac [disjE,exE,ssubst] 1 ORELSE resolve_tac prems 1));
val sumE = result();

goal Sum.thy "when(s, %x::'a. f(Inl(x)), %y::'b. f(Inr(y))) = f(s)";
by (EVERY1 [res_inst_tac [("s","s")] sumE, 
	    rtac when_Inl_conv,  rtac when_Inr_conv]);
val surjective_sum = result();

