# Theory SPMF

```(* Author: Andreas Lochbihler, ETH Zurich *)

section ‹Discrete subprobability distribution›

theory SPMF imports
Probability_Mass_Function
"HOL-Library.Complete_Partial_Order2"
"HOL-Library.Rewrite"
begin

subsection ‹Auxiliary material›

lemma cSUP_singleton [simp]: "(SUP x∈{x}. f x :: _ :: conditionally_complete_lattice) = f x"
by (metis cSup_singleton image_empty image_insert)

lemma [simp]:
shows ennreal_max_0: "ennreal (max 0 x) = ennreal x"
and ennreal_max_0': "ennreal (max x 0) = ennreal x"

lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0"

lemma enn2real_bot [simp]: "enn2real ⊥ = 0"

lemma continuous_at_ennreal[continuous_intros]: "continuous F f ⟹ continuous F (λx. ennreal (f x))"
unfolding continuous_def by auto

lemma ennreal_Sup:
assumes *: "(SUP a∈A. ennreal a) ≠ ⊤"
and "A ≠ {}"
shows "ennreal (Sup A) = (SUP a∈A. ennreal a)"
proof (rule continuous_at_Sup_mono)
obtain r where r: "ennreal r = (SUP a∈A. ennreal a)" "r ≥ 0"
using * by(cases "(SUP a∈A. ennreal a)") simp_all
then show "bdd_above A"
by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric])
qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms)

lemma ennreal_SUP:
"⟦ (SUP a∈A. ennreal (f a)) ≠ ⊤; A ≠ {} ⟧ ⟹ ennreal (SUP a∈A. f a) = (SUP a∈A. ennreal (f a))"
using ennreal_Sup[of "f ` A"] by (auto simp: image_comp)

lemma ennreal_lt_0: "x < 0 ⟹ ennreal x = 0"

lemma None_in_map_option_image [simp]: "None ∈ map_option f ` A ⟷ None ∈ A"
by auto

lemma Some_in_map_option_image [simp]: "Some x ∈ map_option f ` A ⟷ (∃y. x = f y ∧ Some y ∈ A)"
by (smt (verit, best) imageE imageI map_option_eq_Some)

lemma case_option_collapse: "case_option x (λ_. x) = (λ_. x)"

lemma case_option_id: "case_option None Some = id"
by(rule ext)(simp split: option.split)

inductive ord_option :: "('a ⇒ 'b ⇒ bool) ⇒ 'a option ⇒ 'b option ⇒ bool"
for ord :: "'a ⇒ 'b ⇒ bool"
where
None: "ord_option ord None x"
| Some: "ord x y ⟹ ord_option ord (Some x) (Some y)"

inductive_simps ord_option_simps [simp]:
"ord_option ord None x"
"ord_option ord x None"
"ord_option ord (Some x) (Some y)"
"ord_option ord (Some x) None"

inductive_simps ord_option_eq_simps [simp]:
"ord_option (=) None y"
"ord_option (=) (Some x) y"

lemma ord_option_reflI: "(⋀y. y ∈ set_option x ⟹ ord y y) ⟹ ord_option ord x x"
by(cases x) simp_all

lemma reflp_ord_option: "reflp ord ⟹ reflp (ord_option ord)"

lemma ord_option_trans:
"⟦ ord_option ord x y; ord_option ord y z;
⋀a b c. ⟦ a ∈ set_option x; b ∈ set_option y; c ∈ set_option z; ord a b; ord b c ⟧ ⟹ ord a c ⟧
⟹ ord_option ord x z"
by(auto elim!: ord_option.cases)

lemma transp_ord_option: "transp ord ⟹ transp (ord_option ord)"
unfolding transp_def by(blast intro: ord_option_trans)

lemma antisymp_ord_option: "antisymp ord ⟹ antisymp (ord_option ord)"
by(auto intro!: antisympI elim!: ord_option.cases dest: antisympD)

lemma ord_option_chainD:
"Complete_Partial_Order.chain (ord_option ord) Y
⟹ Complete_Partial_Order.chain ord {x. Some x ∈ Y}"
by(rule chainI)(auto dest: chainD)

definition lub_option :: "('a set ⇒ 'b) ⇒ 'a option set ⇒ 'b option"
where "lub_option lub Y = (if Y ⊆ {None} then None else Some (lub {x. Some x ∈ Y}))"

lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f ∘ lub) Y"

lemma lub_option_upper:
assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x ∈ Y"
and lub_upper: "⋀Y x. ⟦ Complete_Partial_Order.chain ord Y; x ∈ Y ⟧ ⟹ ord x (lub Y)"
shows "ord_option ord x (lub_option lub Y)"
using assms(1-2)
by(cases x)(auto simp: lub_option_def intro: lub_upper[OF ord_option_chainD])

lemma lub_option_least:
assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
and upper: "⋀x. x ∈ Y ⟹ ord_option ord x y"
assumes lub_least: "⋀Y y. ⟦ Complete_Partial_Order.chain ord Y; ⋀x. x ∈ Y ⟹ ord x y ⟧ ⟹ ord (lub Y) y"
shows "ord_option ord (lub_option lub Y) y"
using Y
by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)

lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub ∘ (`) f) Y"
proof -
have "⋀u y. ⟦Some u ∈ Y; y ∈ Y⟧ ⟹ {f y |y. Some y ∈ Y} = f ` {x. Some x ∈ Y}"
by blast
then show ?thesis
by (auto simp: lub_option_def)
qed

lemma ord_option_mono: "⟦ ord_option A x y; ⋀x y. A x y ⟹ B x y ⟧ ⟹ ord_option B x y"
by(auto elim: ord_option.cases)

lemma ord_option_mono' [mono]:
"(⋀x y. A x y ⟶ B x y) ⟹ ord_option A x y ⟶ ord_option B x y"
by(blast intro: ord_option_mono)

lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
by(auto simp: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)

lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
proof(rule antisym)
show "?lhs ≤ ?rhs" by(auto elim!: ord_option.cases)
qed(auto elim: ord_option_mono)

lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (λx y. ord x (f y)) x y"
by(auto elim: ord_option.cases)

lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (λx y. ord (f x) y) x y"
by(auto elim: ord_option.cases)

lemma option_ord_Some1_iff: "option_ord (Some x) y ⟷ y = Some x"
by(auto simp: flat_ord_def)

subsubsection ‹A relator for sets that treats sets like predicates›

context includes lifting_syntax
begin

definition rel_pred :: "('a ⇒ 'b ⇒ bool) ⇒ 'a set ⇒ 'b set ⇒ bool"
where "rel_pred R A B = (R ===> (=)) (λx. x ∈ A) (λy. y ∈ B)"

lemma rel_predI: "(R ===> (=)) (λx. x ∈ A) (λy. y ∈ B) ⟹ rel_pred R A B"

lemma rel_predD: "⟦ rel_pred R A B; R x y ⟧ ⟹ x ∈ A ⟷ y ∈ B"

lemma Collect_parametric: "((A ===> (=)) ===> rel_pred A) Collect Collect"
― ‹Declare this rule as @{attribute transfer_rule} only locally
because it blows up the search space for @{method transfer}
(in combination with @{thm [source] Collect_transfer})›

end

subsubsection ‹Monotonicity rules›

lemma monotone_gfp_eadd1: "monotone (≥) (≥) (λx. x + y :: enat)"
by(auto intro!: monotoneI)

lemma monotone_gfp_eadd2: "monotone (≥) (≥) (λy. x + y :: enat)"
by(auto intro!: monotoneI)

shows monotone_eadd: "monotone (rel_prod (≥) (≥)) (≥) (λ(x, y). x + y :: enat)"

"⟦ monotone (fun_ord (≥)) (≥) f; monotone (fun_ord (≥)) (≥) g ⟧
⟹ monotone (fun_ord (≥)) (≥) (λx. f x + g x :: enat)"

lemma mono2mono_ereal[THEN lfp.mono2mono]:
shows monotone_ereal: "monotone (≤) (≤) ereal"
by(rule monotoneI) simp

lemma mono2mono_ennreal[THEN lfp.mono2mono]:
shows monotone_ennreal: "monotone (≤) (≤) ennreal"

subsubsection ‹Bijections›

lemma bi_unique_rel_set_bij_betw:
assumes unique: "bi_unique R"
and rel: "rel_set R A B"
shows "∃f. bij_betw f A B ∧ (∀x∈A. R x (f x))"
proof -
from assms obtain f where f: "⋀x. x ∈ A ⟹ R x (f x)" and B: "⋀x. x ∈ A ⟹ f x ∈ B"
by (metis bi_unique_rel_set_lemma image_eqI)
have "inj_on f A"
by (metis (no_types, lifting) bi_unique_def f inj_on_def unique)
moreover have "f ` A = B" using rel
by (smt (verit) bi_unique_def bi_unique_rel_set_lemma f image_cong unique)
ultimately have "bij_betw f A B" unfolding bij_betw_def ..
thus ?thesis using f by blast
qed

lemma bij_betw_rel_setD: "bij_betw f A B ⟹ rel_set (λx y. y = f x) A B"
by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric])

subsection ‹Subprobability mass function›

type_synonym 'a spmf = "'a option pmf"
translations (type) "'a spmf" ↽ (type) "'a option pmf"

definition measure_spmf :: "'a spmf ⇒ 'a measure"
where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the"

abbreviation spmf :: "'a spmf ⇒ 'a ⇒ real"
where "spmf p x ≡ pmf p (Some x)"

lemma space_measure_spmf: "space (measure_spmf p) = UNIV"

lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)"

lemma measure_spmf_not_bot [simp]: "measure_spmf p ≠ ⊥"
by (metis empty_not_UNIV space_bot space_measure_spmf)

lemma measurable_the_measure_pmf_Some [measurable, simp]:
"the ∈ measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)"
by(auto simp: measurable_def sets_restrict_space space_restrict_space integral_restrict_space)

lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV → space N"
by(auto simp: measurable_def space_measure_spmf)

lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)"
by(intro measurable_cong_sets) simp_all

lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)"
proof
show "emeasure (measure_spmf p) (space (measure_spmf p)) ≤ 1"
by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1)

interpretation measure_spmf: subprob_space "measure_spmf p" for p
by(rule subprob_space_measure_spmf)

lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)"
by unfold_locales

lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}"
by(auto simp: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure])

lemma emeasure_measure_spmf_conv_measure_pmf:
"emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)"
by(auto simp: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_measure_spmf_conv_measure_pmf:
"measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)"
using emeasure_measure_spmf_conv_measure_pmf[of p A]

lemma emeasure_spmf_map_pmf_Some [simp]:
"emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A"
by(auto simp: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_spmf_map_pmf_Some [simp]:
"measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A"
using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)

lemma nn_integral_measure_spmf: "(∫⇧+ x. f x ∂measure_spmf p) = ∫⇧+ x. ennreal (spmf p x) * f x ∂count_space UNIV"
(is "?lhs = ?rhs")
proof -
have "?lhs = ∫⇧+ x. pmf p x * f (the x) ∂count_space (range Some)"
by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps
flip: times_ereal.simps [symmetric])
also have "… = ∫⇧+ x. ennreal (spmf p (the x)) * f (the x) ∂count_space (range Some)"
by(rule nn_integral_cong) auto
also have "… = ∫⇧+ x. spmf p (the (Some x)) * f (the (Some x)) ∂count_space UNIV"
also have "… = ?rhs" by simp
finally show ?thesis .
qed

lemma integral_measure_spmf:
assumes "integrable (measure_spmf p) f"
shows "(∫ x. f x ∂measure_spmf p) = ∫ x. spmf p x * f x ∂count_space UNIV"
proof -
have "integrable (count_space UNIV) (λx. spmf p x * f x)"
using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'')
then show ?thesis using assms
qed

lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x"

lemma measurable_measure_spmf[measurable]:
"(λx. measure_spmf (M x)) ∈ measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
by (auto simp: space_subprob_algebra)

lemma nn_integral_measure_spmf_conv_measure_pmf:
assumes [measurable]: "f ∈ borel_measurable (count_space UNIV)"
shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f ∘ the)"

lemma measure_spmf_in_space_subprob_algebra [simp]:
"measure_spmf p ∈ space (subprob_algebra (count_space UNIV))"

lemma nn_integral_spmf_neq_top: "(∫⇧+ x. spmf p x ∂count_space UNIV) ≠ ⊤"
using nn_integral_measure_spmf[where f="λ_. 1", of p, symmetric]
by simp

lemma SUP_spmf_neq_top': "(SUP p∈Y. ennreal (spmf p x)) ≠ ⊤"
by (metis SUP_least ennreal_le_1 ennreal_one_neq_top neq_top_trans pmf_le_1)

lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) ≠ ⊤"
by (meson SUP_eq_top_iff ennreal_le_1 ennreal_one_less_top linorder_not_le pmf_le_1)

lemma SUP_emeasure_spmf_neq_top: "(SUP p∈Y. emeasure (measure_spmf p) A) ≠ ⊤"
by (metis ennreal_one_less_top less_SUP_iff linorder_not_le measure_spmf.subprob_emeasure_le_1)

subsection ‹Support›

definition set_spmf :: "'a spmf ⇒ 'a set"
where "set_spmf p = set_pmf p ⤜ set_option"

lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} ≠ 0}"
proof -
have "⋀x :: 'a. the -` {x} ∩ range Some = {Some x}" by auto
then show ?thesis
unfolding set_spmf_def measure_spmf_def
by(auto simp: set_pmf.rep_eq  measure_distr measure_restrict_space space_restrict_space)
qed

lemma in_set_spmf: "x ∈ set_spmf p ⟷ Some x ∈ set_pmf p"

lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) ⟷ (∀x∈set_spmf p. P x)"
unfolding set_spmf_def measure_spmf_def
by(force simp: AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff cong del: AE_cong)

lemma spmf_eq_0_set_spmf: "spmf p x = 0 ⟷ x ∉ set_spmf p"
by(auto simp: pmf_eq_0_set_pmf set_spmf_def)

lemma in_set_spmf_iff_spmf: "x ∈ set_spmf p ⟷ spmf p x ≠ 0"
by(auto simp: set_spmf_def set_pmf_iff)

lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}"
by(auto simp: set_spmf_def)

lemma countable_set_spmf [simp]: "countable (set_spmf p)"

lemma spmf_eqI:
assumes "⋀i. spmf p i = spmf q i"
shows "p = q"
proof(rule pmf_eqI)
fix i
show "pmf p i = pmf q i"
proof(cases i)
case (Some i')
next
case None
have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def)
also have "{i} = space (measure_pmf p) - range Some"
by(auto simp: None intro: ccontr)
also have "measure (measure_pmf p) … = ennreal 1 - measure (measure_pmf p) (range Some)"
by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
also have "range Some = (⋃x∈set_spmf p. {Some x}) ∪ Some ` (- set_spmf p)"
by auto
also have "measure (measure_pmf p) … = measure (measure_pmf p) (⋃x∈set_spmf p. {Some x})"
by(rule measure_pmf.measure_zero_union)(auto simp: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
also have "ennreal … = ∫⇧+ x. measure (measure_pmf p) {Some x} ∂count_space (set_spmf p)"
unfolding measure_pmf.emeasure_eq_measure[symmetric]
also have "… = ∫⇧+ x. spmf p x ∂count_space (set_spmf p)" by(simp add: pmf_def)
also have "… = ∫⇧+ x. spmf q x ∂count_space (set_spmf p)" by(simp add: assms)
also have "set_spmf p = set_spmf q" by(auto simp: in_set_spmf_iff_spmf assms)
also have "(∫⇧+ x. spmf q x ∂count_space (set_spmf q)) = ∫⇧+ x. measure (measure_pmf q) {Some x} ∂count_space (set_spmf q)"
also have "… = measure (measure_pmf q) (⋃x∈set_spmf q. {Some x})"
unfolding measure_pmf.emeasure_eq_measure[symmetric]
also have "… = measure (measure_pmf q) ((⋃x∈set_spmf q. {Some x}) ∪ Some ` (- set_spmf q))"
by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
also have "((⋃x∈set_spmf q. {Some x}) ∪ Some ` (- set_spmf q)) = range Some" by auto
also have "ennreal 1 - measure (measure_pmf q) … = measure (measure_pmf q) (space (measure_pmf q) - range Some)"
by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
also have "space (measure_pmf q) - range Some = {i}"
by(auto simp: None intro: ccontr)
also have "measure (measure_pmf q) … = pmf q i" by(simp add: pmf_def)
finally show ?thesis by simp
qed
qed

lemma integral_measure_spmf_restrict:
fixes f ::  "'a ⇒ 'b :: {banach, second_countable_topology}"
shows "(∫ x. f x ∂measure_spmf M) = (∫ x. f x ∂restrict_space (measure_spmf M) (set_spmf M))"
by(auto intro!: integral_cong_AE simp add: integral_restrict_space)

lemma nn_integral_measure_spmf':
"(∫⇧+ x. f x ∂measure_spmf p) = ∫⇧+ x. ennreal (spmf p x) * f x ∂count_space (set_spmf p)"
by(auto simp: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)

subsection ‹Functorial structure›

abbreviation map_spmf :: "('a ⇒ 'b) ⇒ 'a spmf ⇒ 'b spmf"
where "map_spmf f ≡ map_pmf (map_option f)"

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf")›

lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f ∘ g) p"

lemma map_id0: "map_spmf id = id"

lemma map_id [simp]: "map_spmf id p = p"

lemma map_ident [simp]: "map_spmf (λx. x) p = p"

end

lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p"
by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map)

lemma map_spmf_cong:
"⟦ p = q; ⋀x. x ∈ set_spmf q ⟹ f x = g x ⟧ ⟹ map_spmf f p = map_spmf g q"
by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf)

lemma map_spmf_cong_simp:
"⟦ p = q; ⋀x. x ∈ set_spmf q =simp=> f x = g x ⟧
⟹ map_spmf f p = map_spmf g q"
unfolding simp_implies_def by(rule map_spmf_cong)

lemma map_spmf_idI: "(⋀x. x ∈ set_spmf p ⟹ f x = x) ⟹ map_spmf f p = p"

lemma emeasure_map_spmf:
"emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)"
by(auto simp: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)"
using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure)

lemma measure_map_spmf_conv_distr:
"measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f"

lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i"

lemma spmf_map_inj: "⟦ inj_on f (set_spmf M); x ∈ set_spmf M ⟧ ⟹ spmf (map_spmf f M) (f x) = spmf M x"
by (smt (verit) elem_set in_set_spmf inj_on_def option.inj_map_strong option.map(2) pmf_map_inj)

lemma spmf_map_inj': "inj f ⟹ spmf (map_spmf f M) (f x) = spmf M x"
by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map])

lemma spmf_map_outside: "x ∉ f ` set_spmf M ⟹ spmf (map_spmf f M) x = 0"
unfolding spmf_eq_0_set_spmf by simp

lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})"
by (metis emeasure_map_spmf emeasure_spmf_single)

lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})"
using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure)

lemma ennreal_spmf_map_conv_nn_integral:
"ennreal (spmf (map_spmf f p) x) = integral⇧N (measure_spmf p) (indicator (f -` {x}))"

subsubsection ‹Return›

abbreviation return_spmf :: "'a ⇒ 'a spmf"
where "return_spmf x ≡ return_pmf (Some x)"

lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)"
by(fact pmf_return)

lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x"
by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def)

lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)"

lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}"
by(auto simp: set_spmf_def)

subsubsection ‹Bind›

definition bind_spmf :: "'a spmf ⇒ ('a ⇒ 'b spmf) ⇒ 'b spmf"
where "bind_spmf x f = bind_pmf x (λa. case a of None ⇒ return_pmf None | Some a' ⇒ f a')"

lemma return_None_bind_spmf [simp]: "return_pmf None ⤜ (f :: 'a ⇒ _) = return_pmf None"

lemma return_bind_spmf [simp]: "return_spmf x ⤜ f = f x"

lemma bind_return_spmf [simp]: "x ⤜ return_spmf = x"
proof -
have "⋀a :: 'a option. (case a of None ⇒ return_pmf None | Some a' ⇒ return_spmf a') = return_pmf a"
by(simp split: option.split)
then show ?thesis
qed

lemma bind_spmf_assoc [simp]:
fixes x :: "'a spmf" and f :: "'a ⇒ 'b spmf" and g :: "'b ⇒ 'c spmf"
shows "(x ⤜ f) ⤜ g = x ⤜ (λy. f y ⤜ g)"
unfolding bind_spmf_def
by (smt (verit, best) bind_assoc_pmf bind_pmf_cong bind_return_pmf option.case_eq_if)

lemma pmf_bind_spmf_None: "pmf (p ⤜ f) None = pmf p None + ∫ x. pmf (f x) None ∂measure_spmf p"
(is "?lhs = ?rhs")
proof -
let ?f = "λx. pmf (case x of None ⇒ return_pmf None | Some x ⇒ f x) None"
have "?lhs = ∫ x. ?f x ∂measure_pmf p"
also have "… = ∫ x. ?f None * indicator {None} x + ?f x * indicator (range Some) x ∂measure_pmf p"
by(rule Bochner_Integration.integral_cong)(auto simp: indicator_def)
also have "… = (∫ x. ?f None * indicator {None} x ∂measure_pmf p) + (∫ x. ?f x * indicator (range Some) x ∂measure_pmf p)"
by(rule Bochner_Integration.integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1)
also have "… = pmf p None + ∫ x. indicator (range Some) x * pmf (f (the x)) None ∂measure_pmf p"
by(auto simp: measure_measure_pmf_finite indicator_eq_0_iff intro!: Bochner_Integration.integral_cong)
also have "… = ?rhs"
unfolding measure_spmf_def
by(subst integral_distr)(auto simp: integral_restrict_space)
finally show ?thesis .
qed

lemma spmf_bind: "spmf (p ⤜ f) y = ∫ x. spmf (f x) y ∂measure_spmf p"
proof -
have "⋀x. spmf (case x of None ⇒ return_pmf None | Some x ⇒ f x) y =
indicat_real (range Some) x * spmf (f (the x)) y"
then show ?thesis
by (simp add: measure_spmf_def integral_distr bind_spmf_def pmf_bind integral_restrict_space)
qed

lemma ennreal_spmf_bind: "ennreal (spmf (p ⤜ f) x) = ∫⇧+ y. spmf (f y) x ∂measure_spmf p"
proof -
have "⋀y. ennreal (spmf (case y of None ⇒ return_pmf None | Some x ⇒ f x) x) =
ennreal (spmf (f (the y)) x) * indicator (range Some) y"
then show ?thesis
by (simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space)
qed

lemma measure_spmf_bind_pmf: "measure_spmf (p ⤜ f) = measure_pmf p ⤜ measure_spmf ∘ f"
(is "?lhs = ?rhs")
proof(rule measure_eqI)
show "sets ?lhs = sets ?rhs"
next
fix A :: "'a set"
have "emeasure ?lhs A = ∫⇧+ x. emeasure (measure_spmf (f x)) A ∂measure_pmf p"
by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
also have "… = emeasure ?rhs A"
by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
finally show "emeasure ?lhs A = emeasure ?rhs A" .
qed

lemma measure_spmf_bind: "measure_spmf (p ⤜ f) = measure_spmf p ⤜ measure_spmf ∘ f"
(is "?lhs = ?rhs")
proof(rule measure_eqI)
show "sets ?lhs = sets ?rhs"
by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
next
fix A :: "'a set"
let ?A = "the -` A ∩ range Some"
have "emeasure ?lhs A = ∫⇧+ x. emeasure (measure_pmf (case x of None ⇒ return_pmf None | Some x ⇒ f x)) ?A ∂measure_pmf p"
by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
also have "… =  ∫⇧+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x ∂measure_pmf p"
by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def)
also have "… = ∫⇧+ x. emeasure (measure_spmf (f x)) A ∂measure_spmf p"
by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space)
also have "… = emeasure ?rhs A"
by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
finally show "emeasure ?lhs A = emeasure ?rhs A" .
qed

lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f ∘ g)"
by(auto simp: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf])

lemma bind_map_spmf: "map_spmf f p ⤜ g = p ⤜ g ∘ f"
by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak)

lemma spmf_bind_leI:
assumes "⋀y. y ∈ set_spmf p ⟹ spmf (f y) x ≤ r"
and "0 ≤ r"
shows "spmf (bind_spmf p f) x ≤ r"
proof -
have "ennreal (spmf (bind_spmf p f) x) = ∫⇧+ y. spmf (f y) x ∂measure_spmf p"
by(rule ennreal_spmf_bind)
also have "… ≤ ∫⇧+ y. r ∂measure_spmf p"
also have "… ≤ r"
using assms measure_spmf.emeasure_space_le_1
by(auto simp: measure_spmf.emeasure_eq_measure intro!: mult_left_le)
finally show ?thesis using assms(2) by(simp)
qed

lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p ⤜ (λx. return_spmf (f x)))"
by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split)

lemma bind_spmf_cong:
"⟦ p = q; ⋀x. x ∈ set_spmf q ⟹ f x = g x ⟧ ⟹ bind_spmf p f = bind_spmf q g"
by(auto simp: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong)

lemma bind_spmf_cong_simp:
"⟦ p = q; ⋀x. x ∈ set_spmf q =simp=> f x = g x ⟧
⟹ bind_spmf p f = bind_spmf q g"

lemma set_bind_spmf: "set_spmf (M ⤜ f) = set_spmf M ⤜ (set_spmf ∘ f)"
by(auto simp: set_spmf_def bind_spmf_def bind_UNION split: option.splits)

lemma bind_spmf_const_return_None [simp]: "bind_spmf p (λ_. return_pmf None) = return_pmf None"

lemma bind_commute_spmf:
"bind_spmf p (λx. bind_spmf q (f x)) = bind_spmf q (λy. bind_spmf p (λx. f x y))"
(is "?lhs = ?rhs")
proof -
let ?f = "λx y. case x of None ⇒ return_pmf None | Some a ⇒ (case y of None ⇒ return_pmf None | Some b ⇒ f a b)"
have "?lhs = p ⤜ (λx. q ⤜ ?f x)"
unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split)
also have "… = q ⤜ (λy. p ⤜ (λx. ?f x y))" by(rule bind_commute_pmf)
also have "… = ?rhs" unfolding bind_spmf_def
by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def)
finally show ?thesis .
qed

subsection ‹Relator›

abbreviation rel_spmf :: "('a ⇒ 'b ⇒ bool) ⇒ 'a spmf ⇒ 'b spmf ⇒ bool"
where "rel_spmf R ≡ rel_pmf (rel_option R)"

lemma rel_spmf_mono:
"⟦rel_spmf A f g; ⋀x y. A x y ⟹ B x y ⟧ ⟹ rel_spmf B f g"
by (metis option.rel_sel pmf.rel_mono_strong)

lemma rel_spmf_mono_strong:
"⟦ rel_spmf A f g; ⋀x y. ⟦ A x y; x ∈ set_spmf f; y ∈ set_spmf g ⟧ ⟹ B x y ⟧ ⟹ rel_spmf B f g"
by (metis elem_set in_set_spmf option.rel_mono_strong pmf.rel_mono_strong)

lemma rel_spmf_reflI: "(⋀x. x ∈ set_spmf p ⟹ P x x) ⟹ rel_spmf P p p"
by (metis (mono_tags, lifting) option.rel_eq pmf.rel_eq rel_spmf_mono_strong)

lemma rel_spmfI [intro?]:
"⟦ ⋀x y. (x, y) ∈ set_spmf pq ⟹ P x y; map_spmf fst pq = p; map_spmf snd pq = q ⟧
⟹ rel_spmf P p q"
by(rule rel_pmf.intros[where pq="map_pmf (λx. case x of None ⇒ (None, None) | Some (a, b) ⇒ (Some a, Some b)) pq"])
(auto simp: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)

lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]:
assumes "rel_spmf P p q"
obtains pq where
"⋀x y. (x, y) ∈ set_spmf pq ⟹ P x y"
"p = map_spmf fst pq"
"q = map_spmf snd pq"
using assms
proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf])
case (rel_pmf pq)
let ?pq = "map_pmf (λ(a, b). case (a, b) of (Some x, Some y) ⇒ Some (x, y) | _ ⇒ None) pq"
have "⋀x y. (x, y) ∈ set_spmf ?pq ⟹ P x y"
by(auto simp: in_set_spmf split: option.split_asm dest: rel_pmf(1))
moreover
have "⋀x. (x, None) ∈ set_pmf pq ⟹ x = None" by(auto dest!: rel_pmf(1))
then have "p = map_spmf fst ?pq" using rel_pmf(2)
by(auto simp: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
moreover
have "⋀y. (None, y) ∈ set_pmf pq ⟹ y = None" by(auto dest!: rel_pmf(1))
then have "q = map_spmf snd ?pq" using rel_pmf(3)
by(auto simp: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
ultimately show thesis ..
qed

lemma rel_spmf_simps:
"rel_spmf R p q ⟷ (∃pq. (∀(x, y)∈set_spmf pq. R x y) ∧ map_spmf fst pq = p ∧ map_spmf snd pq = q)"
by(auto intro: rel_spmfI elim!: rel_spmfE)

lemma spmf_rel_map:
shows spmf_rel_map1: "⋀R f x. rel_spmf R (map_spmf f x) = rel_spmf (λx. R (f x)) x"
and spmf_rel_map2: "⋀R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (λx y. R x (g y)) x y"

lemma spmf_rel_conversep: "rel_spmf R¯¯ = (rel_spmf R)¯¯"

lemma spmf_rel_eq: "rel_spmf (=) = (=)"

context includes lifting_syntax
begin

lemma bind_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf"
unfolding bind_spmf_def[abs_def] by transfer_prover

lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf"
by transfer_prover

lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf"
by transfer_prover

lemma rel_spmf_parametric:
"((A ===> B ===> (=)) ===> rel_spmf A ===> rel_spmf B ===> (=)) rel_spmf rel_spmf"
by transfer_prover

lemma set_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> rel_set A) set_spmf set_spmf"
unfolding set_spmf_def[abs_def] by transfer_prover

lemma return_spmf_None_parametric:
"(rel_spmf A) (return_pmf None) (return_pmf None)"
by simp

end

lemma rel_spmf_bindI:
"⟦ rel_spmf R p q; ⋀x y. R x y ⟹ rel_spmf P (f x) (g y) ⟧
⟹ rel_spmf P (p ⤜ f) (q ⤜ g)"
by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])

lemma rel_spmf_bind_reflI:
"(⋀x. x ∈ set_spmf p ⟹ rel_spmf P (f x) (g x)) ⟹ rel_spmf P (p ⤜ f) (p ⤜ g)"
by(rule rel_spmf_bindI[where R="λx y. x = y ∧ x ∈ set_spmf p"])(auto intro: rel_spmf_reflI)

lemma rel_pmf_return_pmfI: "P x y ⟹ rel_pmf P (return_pmf x) (return_pmf y)"
by simp

context includes lifting_syntax
begin

text ‹We do not yet have a relator for \<^typ>‹'a measure›, so we combine \<^const>‹measure› and \<^const>‹measure_pmf››
lemma measure_pmf_parametric:
"(rel_pmf A ===> rel_pred A ===> (=)) (λp. measure (measure_pmf p)) (λq. measure (measure_pmf q))"
proof(rule rel_funI)+
fix p q X Y
assume "rel_pmf A p q" and "rel_pred A X Y"
from this(1) obtain pq where A: "⋀x y. (x, y) ∈ set_pmf pq ⟹ A x y"
and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
show "measure p X = measure q Y" unfolding p q measure_map_pmf
by(rule measure_pmf.finite_measure_eq_AE)(auto simp: AE_measure_pmf_iff dest!: A rel_predD[OF ‹rel_pred _ _ _›])
qed

lemma measure_spmf_parametric:
"(rel_spmf A ===> rel_pred A ===> (=)) (λp. measure (measure_spmf p)) (λq. measure (measure_spmf q))"
proof -
have "⋀x y xa ya. rel_pred A xa ya ⟹ rel_pred (rel_option A) (Some ` xa) (Some ` ya)"
by(auto simp: rel_pred_def rel_fun_def elim: option.rel_cases)
then show ?thesis
unfolding measure_measure_spmf_conv_measure_pmf[abs_def]
by (intro rel_funI) (force elim!: measure_pmf_parametric[THEN rel_funD, THEN rel_funD])
qed

end

subsection ‹From \<^typ>‹'a pmf› to \<^typ>‹'a spmf››

definition spmf_of_pmf :: "'a pmf ⇒ 'a spmf"
where "spmf_of_pmf = map_pmf Some"

lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p"
by(auto simp: spmf_of_pmf_def set_spmf_def bind_image o_def)

lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x"

lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0"
using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def)

lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A"

lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p"
by(rule measure_eqI) simp_all

lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)"

lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q"

lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x"

lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f"

lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf ∘ f)"
unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp

lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (λx. spmf_of_pmf (f x))"

lemma bind_pmf_return_spmf: "p ⤜ (λx. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)"

subsection ‹Weight of a subprobability›

abbreviation weight_spmf :: "'a spmf ⇒ real"
where "weight_spmf p ≡ measure (measure_spmf p) (space (measure_spmf p))"

lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV"

lemma weight_spmf_le_1: "weight_spmf p ≤ 1"
by(rule measure_spmf.subprob_measure_le_1)

lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1"

lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0"
by(simp)

lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p"

lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1"
by simp

lemma weight_spmf_nonneg: "weight_spmf p ≥ 0"
by(fact measure_nonneg)

lemma (in finite_measure) integrable_weight_spmf [simp]:
"(λx. weight_spmf (f x)) ∈ borel_measurable M ⟹ integrable M (λx. weight_spmf (f x))"
by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1)

lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = ∫⇧+ x. spmf p x ∂count_space UNIV"
by (metis NO_MATCH_def measure_spmf.emeasure_eq_measure nn_integral_count_space_indicator nn_integral_indicator nn_integral_measure_spmf sets_UNIV sets_measure_spmf space_measure_spmf)

lemma weight_spmf_eq_nn_integral_support:
"weight_spmf p = ∫⇧+ x. spmf p x ∂count_space (set_spmf p)"
unfolding weight_spmf_eq_nn_integral_spmf
by(auto simp: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)

lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p"
proof -
have "weight_spmf p = ∫⇧+ x. spmf p x ∂count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf)
also have "… = ∫⇧+ x. ennreal (pmf p x) * indicator (range Some) x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
also have "… + pmf p None = ∫⇧+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x ∂count_space UNIV"
also have "… = ∫⇧+ x. pmf p x ∂count_space UNIV"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = 1" by (simp add: nn_integral_pmf)
finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus)
qed

lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None"

lemma weight_spmf_lt_0: "¬ weight_spmf p < 0"

lemma spmf_le_weight: "spmf p x ≤ weight_spmf p"

lemma weight_spmf_eq_0: "weight_spmf p = 0 ⟷ p = return_pmf None"
by (metis measure_le_0_iff measure_spmf.bounded_measure spmf_conv_measure_spmf spmf_eqI weight_return_pmf_None)

lemma weight_bind_spmf: "weight_spmf (x ⤜ f) = lebesgue_integral (measure_spmf x) (weight_spmf ∘ f)"
unfolding weight_spmf_def
by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"])

lemma rel_spmf_weightD: "rel_spmf A p q ⟹ weight_spmf p = weight_spmf q"
by(erule rel_spmfE) simp

lemma rel_spmf_bij_betw:
assumes f: "bij_betw f (set_spmf p) (set_spmf q)"
and eq: "⋀x. x ∈ set_spmf p ⟹ spmf p x = spmf q (f x)"
shows "rel_spmf (λx y. f x = y) p q"
proof -
let ?f = "map_option f"

have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)"
unfolding weight_spmf_eq_nn_integral_support
by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space)
then have "None ∈ set_pmf p ⟷ None ∈ set_pmf q"
with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)"
apply(auto simp: bij_betw_def in_set_spmf inj_on_def intro: option.expand split: option.split)
apply(rename_tac [!] x)
apply(case_tac [!] x)
apply(auto iff: in_set_spmf)
done
then have "rel_pmf (λx y. ?f x = y) p q"
proof (rule rel_pmf_bij_betw)
show "pmf p x = pmf q (map_option f x)" if "x ∈ set_pmf p" for x
proof (cases x)
case None
then show ?thesis
by (metis ennreal_inj measure_nonneg option.map_disc_iff pmf_None_eq_weight_spmf weq)
qed (use eq in_set_spmf that in force)
qed
thus ?thesis
by (smt (verit, ccfv_SIG) None_eq_map_option_iff option.map_sel option.rel_sel pmf.rel_mono_strong)
qed

subsection ‹From density to spmfs›

context fixes f :: "'a ⇒ real" begin

definition embed_spmf :: "'a spmf"
where "embed_spmf = embed_pmf (λx. case x of None ⇒ 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV) | Some x' ⇒ max 0 (f x'))"

context
assumes prob: "(∫⇧+ x. ennreal (f x) ∂count_space UNIV) ≤ 1"
begin

lemma nn_integral_embed_spmf_eq_1:
"(∫⇧+ x. ennreal (case x of None ⇒ 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV) | Some x' ⇒ max 0 (f x')) ∂count_space UNIV) = 1"
(is "?lhs = _" is "(∫⇧+ x. ?f x ∂?M) = _")
proof -
have "?lhs = ∫⇧+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x ∂?M"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV)) + ∫⇧+ x. ?f x * indicator (range Some) x ∂?M"
(is "_ = ?None + ?Some")
also have "?Some = ∫⇧+ x. ?f x ∂count_space (range Some)"
also have "count_space (range Some) = embed_measure (count_space UNIV) Some"
also have "(∫⇧+ x. ?f x ∂…) = ∫⇧+ x. ennreal (f x) ∂count_space UNIV"
also have "?None + … = 1" using prob
finally show ?thesis .
qed

lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV)"
unfolding embed_spmf_def
by (smt (verit, del_insts) enn2real_leI ennreal_1 nn_integral_cong nn_integral_embed_spmf_eq_1
option.case_eq_if pmf_embed_pmf prob)

lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)"
unfolding embed_spmf_def
by (smt (verit, best) enn2real_leI ennreal_1 nn_integral_cong nn_integral_embed_spmf_eq_1 option.case_eq_if option.simps(5) pmf_embed_pmf prob)

end

end

lemma embed_spmf_K_0[simp]: "embed_spmf (λ_. 0) = return_pmf None"

subsection ‹Ordering on spmfs›

text ‹
\<^const>‹rel_pmf› does not preserve a ccpo structure. Counterexample by Saheb-Djahromi:
Take prefix order over ‹bool llist› and
the set ‹range (λn :: nat. uniform (llist_n n))› where ‹llist_n› is the set
of all ‹llist›s of length ‹n› and ‹uniform› returns a uniform distribution over
the given set. The set forms a chain in ‹ord_pmf lprefix›, but it has not an upper bound.
Any upper bound may contain only infinite lists in its support because otherwise it is not greater
than the ‹n+1›-st element in the chain where ‹n› is the length of the finite list.
Moreover its support must contain all infinite lists, because otherwise there is a finite list
all of whose finite extensions are not in the support - a contradiction to the upper bound property.
Hence, the support is uncountable, but pmf's only have countable support.

However, if all chains in the ccpo are finite, then it should preserve the ccpo structure.
›

abbreviation ord_spmf :: "('a ⇒ 'a ⇒ bool) ⇒ 'a spmf ⇒ 'a spmf ⇒ bool"
where "ord_spmf ord ≡ rel_pmf (ord_option ord)"

locale ord_spmf_syntax begin
notation ord_spmf (infix "⊑ı" 60)
end

lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (λx. R (f x)) p"

lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (λx y. R x (f y)) p q"

lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (λx y. R (f x) (f y)) p q"

lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12

context fixes ord :: "'a ⇒ 'a ⇒ bool" (structure) begin
interpretation ord_spmf_syntax .

lemma ord_spmfI:
"⟦ ⋀x y. (x, y) ∈ set_spmf pq ⟹ ord x y; map_spmf fst pq = p; map_spmf snd pq = q ⟧
⟹ p ⊑ q"
by(rule rel_pmf.intros[where pq="map_pmf (λx. case x of None ⇒ (None, None) | Some (a, b) ⇒ (Some a, Some b)) pq"])
(auto simp: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)

lemma ord_spmf_None [simp]: "return_pmf None ⊑ x"
by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp: pmf.map_comp o_def)

lemma ord_spmf_reflI: "(⋀x. x ∈ set_spmf p ⟹ ord x x) ⟹ p ⊑ p"
by (metis elem_set in_set_spmf ord_option_reflI pmf.rel_refl_strong)

lemma rel_spmf_inf:
assumes "p ⊑ q"
and "q ⊑ p"
and refl: "reflp ord"
and trans: "transp ord"
shows "rel_spmf (inf ord ord¯¯) p q"
proof -
from ‹p ⊑ q› ‹q ⊑ p›
have "rel_pmf (inf (ord_option ord) (ord_option ord)¯¯) p q"
using local.refl local.trans reflp_ord_option rel_pmf_inf transp_ord_option by blast
also have "inf (ord_option ord) (ord_option ord)¯¯ = rel_option (inf ord ord¯¯)"
by(auto simp: fun_eq_iff elim: ord_option.cases option.rel_cases)
finally show ?thesis .
qed

end

lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) ⟷ (∀x∈set_spmf p. R x y)"
by(auto simp: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr)

lemma ord_spmf_mono: "⟦ ord_spmf A p q; ⋀x y. A x y ⟹ B x y ⟧ ⟹ ord_spmf B p q"
by(erule pmf.rel_mono_strong)(erule ord_option_mono)

lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B"

lemma ord_spmf_bindI:
assumes pq: "ord_spmf R p q"
and fg: "⋀x y. R x y ⟹ ord_spmf P (f x) (g y)"
shows "ord_spmf P (p ⤜ f) (q ⤜ g)"
unfolding bind_spmf_def using pq
by(rule rel_pmf_bindI)(auto split: option.split intro: fg)

lemma ord_spmf_bind_reflI:
"(⋀x. x ∈ set_spmf p ⟹ ord_spmf R (f x) (g x)) ⟹ ord_spmf R (p ⤜ f) (p ⤜ g)"
by(rule ord_spmf_bindI[where R="λx y. x = y ∧ x ∈ set_spmf p"])(auto intro: ord_spmf_reflI)

lemma ord_pmf_increaseI:
assumes le: "⋀x. spmf p x ≤ spmf q x"
and refl: "⋀x. x ∈ set_spmf p ⟹ R x x"
shows "ord_spmf R p q"
proof(rule rel_pmf.intros)
define pq where "pq = embed_pmf
(λ(x, y). case x of Some x' ⇒ (case y of Some y' ⇒ if x' = y' then spmf p x' else 0 | None ⇒ 0)
| None ⇒ (case y of None ⇒ pmf q None | Some y' ⇒ spmf q y' - spmf p y'))"
(is "_ = embed_pmf ?f")
have nonneg: "⋀xy. ?f xy ≥ 0"
by(clarsimp simp add: le field_simps split: option.split)
have integral: "(∫⇧+ xy. ?f xy ∂count_space UNIV) = 1" (is "nn_integral ?M _ = _")
proof -
have "(∫⇧+ xy. ?f xy ∂count_space UNIV) =
∫⇧+ xy. ennreal (?f xy) * indicator {(None, None)} xy +
ennreal (?f xy) * indicator (range (λx. (None, Some x))) xy +
ennreal (?f xy) * indicator (range (λx. (Some x, Some x))) xy ∂?M"
by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm)
also have "… = (∫⇧+ xy. ?f xy * indicator {(None, None)} xy ∂?M) +
(∫⇧+ xy. ennreal (?f xy) * indicator (range (λx. (None, Some x))) xy ∂?M) +
(∫⇧+ xy. ennreal (?f xy) * indicator (range (λx. (Some x, Some x))) xy ∂?M)"
(is "_ = ?None + ?Some2 + ?Some")
also have "?None = pmf q None" by simp
also have "?Some2 = ∫⇧+ x. ennreal (spmf q x) - spmf p x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
also have "… = (∫⇧+ x. spmf q x ∂count_space UNIV) - (∫⇧+ x. spmf p x ∂count_space UNIV)"
(is "_ = ?Some2' - ?Some2''")
also have "?Some = ∫⇧+ x. spmf p x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
also have "pmf q None + (?Some2' - ?Some2'') + … = pmf q None + ?Some2'"
by(auto simp: diff_add_self_ennreal le intro!: nn_integral_mono)
also have "… = ∫⇧+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x ∂count_space UNIV"
also have "… = ∫⇧+ x. pmf q x ∂count_space UNIV"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = 1"
finally show ?thesis .
qed
note f = nonneg integral

{ fix x y
assume "(x, y) ∈ set_pmf pq"
hence "?f (x, y) ≠ 0" unfolding pq_def by(simp add: set_embed_pmf[OF f])
then show "ord_option R x y"
by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) }

have weight_le: "weight_spmf p ≤ weight_spmf q"
by(subst ennreal_le_iff[symmetric])(auto simp: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le)

show "map_pmf fst pq = p"
proof(rule pmf_eqI)
fix i :: "'a option"
have bi: "bij_betw (Pair i) UNIV (fst -` {i})"
by(auto simp: bij_betw_def inj_on_def)
have "ennreal (pmf (map_pmf fst pq) i) = (∫⇧+ y. pmf pq (i, y) ∂count_space UNIV)"
unfolding pq_def ennreal_pmf_map
apply (simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density flip: nn_integral_count_space_indicator)
by (smt (verit, best) nn_integral_bij_count_space [OF bi] integral nn_integral_cong nonneg pmf_embed_pmf)
also have "… = pmf p i"
proof(cases i)
case (Some x)
have "(∫⇧+ y. pmf pq (Some x, y) ∂count_space UNIV) = ∫⇧+ y. pmf p (Some x) * indicator {Some x} y ∂count_space UNIV"
by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
then show ?thesis using Some by simp
next
case None
have "(∫⇧+ y. pmf pq (None, y) ∂count_space UNIV) =
(∫⇧+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y +
ennreal (pmf pq (None, None)) * indicator {None} y ∂count_space UNIV)"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (∫⇧+ y. ennreal (pmf pq (None, Some (the y))) ∂count_space (range Some)) + pmf pq (None, None)"
also have "… = (∫⇧+ y. ennreal (spmf q y) - ennreal (spmf p y) ∂count_space UNIV) + pmf q None"
by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
also have "(∫⇧+ y. ennreal (spmf q y) - ennreal (spmf p y) ∂count_space UNIV) =
(∫⇧+ y. spmf q y ∂count_space UNIV) - (∫⇧+ y. spmf p y ∂count_space UNIV)"
by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator)
also have "… = pmf p None - pmf q None"
also have "… = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus)
finally show ?thesis using None weight_le
by(auto simp: diff_add_self_ennreal pmf_None_eq_weight_spmf intro: ennreal_leI)
qed
finally show "pmf (map_pmf fst pq) i = pmf p i" by simp
qed

show "map_pmf snd pq = q"
proof(rule pmf_eqI)
fix i :: "'a option"
have bi: "bij_betw (λx. (x, i)) UNIV (snd -` {i})"
by (auto simp: bij_betw_def inj_on_def)
have "ennreal (pmf (map_pmf snd pq) i) = (∫⇧+ x. pmf pq (x, i) ∂count_space UNIV)"
unfolding pq_def ennreal_pmf_map
apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
by (smt (verit, best) nn_integral_bij_count_space [OF bi] integral nn_integral_cong nonneg pmf_embed_pmf)
also have "… = ennreal (pmf q i)"
proof(cases i)
case None
have "(∫⇧+ x. pmf pq (x, None) ∂count_space UNIV) = ∫⇧+ x. pmf q None * indicator {None :: 'a option} x ∂count_space UNIV"
by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
then show ?thesis using None by simp
next
case (Some y)
have "(∫⇧+ x. pmf pq (x, Some y) ∂count_space UNIV) =
(∫⇧+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x +
ennreal (pmf pq (None, Some y)) * indicator {None} x ∂count_space UNIV)"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (∫⇧+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x ∂count_space UNIV) + pmf pq (None, Some y)"
also have "… = (∫⇧+ x. ennreal (spmf p y) * indicator {Some y} x ∂count_space UNIV) + (spmf q y - spmf p y)"
by(auto simp: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="(+)"] nn_integral_cong split: option.split)
also have "… = spmf q y" by(simp add: ennreal_minus[symmetric] le)
finally show ?thesis using Some by simp
qed
finally show "pmf (map_pmf snd pq) i = pmf q i" by simp
qed
qed

lemma ord_spmf_eq_leD:
assumes "ord_spmf (=) p q"
shows "spmf p x ≤ spmf q x"
proof(cases "x ∈ set_spmf p")
case False
next
case True
from assms obtain pq
where pq: "⋀x y. (x, y) ∈ set_pmf pq ⟹ ord_option (=) x y"
and p: "p = map_pmf fst pq"
and q: "q = map_pmf snd pq" by cases auto
have "ennreal (spmf p x) = integral⇧N pq (indicator (fst -` {Some x}))"
also have "… = integral⇧N pq (indicator {(Some x, Some x)})"
by(rule nn_integral_cong_AE)(auto simp: AE_measure_pmf_iff split: split_indicator dest: pq)
also have "… ≤ integral⇧N pq (indicator (snd -` {Some x}))"
by(rule nn_integral_mono) simp
also have "… = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map)
finally show ?thesis by simp
qed

lemma ord_spmf_eqD_set_spmf: "ord_spmf (=) p q ⟹ set_spmf p ⊆ set_spmf q"
by (metis ord_spmf_eq_leD pmf_le_0_iff spmf_eq_0_set_spmf subsetI)

lemma ord_spmf_eqD_emeasure:
"ord_spmf (=) p q ⟹ emeasure (measure_spmf p) A ≤ emeasure (measure_spmf q) A"
by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric])

lemma ord_spmf_eqD_measure_spmf: "ord_spmf (=) p q ⟹ measure_spmf p ≤ measure_spmf q"
by (subst le_measure) (auto simp: ord_spmf_eqD_emeasure)

subsection ‹CCPO structure for the flat ccpo \<^term>‹ord_option (=)››

context fixes Y :: "'a spmf set" begin

definition lub_spmf :: "'a spmf"
where "lub_spmf = embed_spmf (λx. enn2real (SUP p ∈ Y. ennreal (spmf p x)))"
― ‹We go through \<^typ>‹ennreal› to have a sensible definition even if \<^term>‹Y› is empty.›

lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None"

context assumes chain: "Complete_Partial_Order.chain (ord_spmf (=)) Y"
begin

lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (≤) ((λp x. ennreal (spmf p x)) ` Y)"
(is "Complete_Partial_Order.chain _ (?f ` _)")
proof(rule chainI)
fix f g
assume "f ∈ ?f ` Y" "g ∈ ?f ` Y"
then obtain p q where f: "f = ?f p" "p ∈ Y" and g: "g = ?f q" "q ∈ Y" by blast
from chain ‹p ∈ Y› ‹q ∈ Y› have "ord_spmf (=) p q ∨ ord_spmf (=) q p" by(rule chainD)
thus "f ≤ g ∨ g ≤ f"
by (metis ennreal_leI f(1) g(1) le_funI ord_spmf_eq_leD)
qed

lemma ord_spmf_eq_pmf_None_eq:
assumes le: "ord_spmf (=) p q"
and None: "pmf p None = pmf q None"
shows "p = q"
proof(rule spmf_eqI)
fix i
from le have le': "⋀x. spmf p x ≤ spmf q x" by(rule ord_spmf_eq_leD)
have "(∫⇧+ x. ennreal (spmf q x) - spmf p x ∂count_space UNIV) =
(∫⇧+ x. spmf q x ∂count_space UNIV) - (∫⇧+ x. spmf p x ∂count_space UNIV)"
by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top)
also have "… = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf
also have "… = 0" using None by simp
finally have "⋀x. spmf q x ≤ spmf p x"
by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff)
with le' show "spmf p i = spmf q i" by(rule antisym)
qed

lemma ord_spmf_eqD_pmf_None:
assumes "ord_spmf (=) x y"
shows "pmf x None ≥ pmf y None"
using assms
apply cases
apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map)
apply(fastforce simp: AE_measure_pmf_iff intro!: nn_integral_mono_AE)
done

text ‹
Chains on \<^typ>‹'a spmf› maintain countable support.
Thanks to Johannes Hölzl for the proof idea.
›
lemma spmf_chain_countable: "countable (⋃p∈Y. set_spmf p)"
proof(cases "Y = {}")
case Y: False
show ?thesis
proof(cases "∃x∈Y. ∀y∈Y. ord_spmf (=) y x")
case True
then obtain x where x: "x ∈ Y" and upper: "⋀y. y ∈ Y ⟹ ord_spmf (=) y x" by blast
hence "(⋃x∈Y. set_spmf x) ⊆ set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf)
thus ?thesis by(rule countable_subset) simp
next
case False
define N :: "'a option pmf ⇒ real" where "N p = pmf p None" for p

have N_less_imp_le_spmf: "⟦ x ∈ Y; y ∈ Y; N y < N x ⟧ ⟹ ord_spmf (=) x y" for x y
using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x]
by (auto simp: N_def)
have N_eq_imp_eq: "⟦ x ∈ Y; y ∈ Y; N y = N x ⟧ ⟹ x = y" for x y
using chainD[OF chain, of x y] by(auto simp: N_def dest: ord_spmf_eq_pmf_None_eq)

have NC: "N ` Y ≠ {}" "bdd_below (N ` Y)"
using ‹Y ≠ {}› by(auto intro!: bdd_belowI[of _ 0] simp: N_def)
have NC_less: "Inf (N ` Y) < N x" if "x ∈ Y" for x unfolding cInf_less_iff[OF NC]
proof(rule ccontr)
assume **: "¬ (∃y∈N ` Y. y < N x)"
{ fix y
assume "y ∈ Y"
with ** consider "N x < N y" | "N x = N y" by(auto simp: not_less le_less)
hence "ord_spmf (=) y x" using ‹y ∈ Y› ‹x ∈ Y›
by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) }
with False ‹x ∈ Y› show False by blast
qed

from NC have "Inf (N ` Y) ∈ closure (N ` Y)" by (intro closure_contains_Inf)
then obtain X' where "⋀n. X' n ∈ N ` Y" and X': "X' ⇢ Inf (N ` Y)"
unfolding closure_sequential by auto
then obtain X where X: "⋀n. X n ∈ Y" and "X' = (λn. N (X n))" unfolding image_iff Bex_def by metis

with X' have seq: "(λn. N (X n)) ⇢ Inf (N ` Y)" by simp
have "(⋃x ∈ Y. set_spmf x) ⊆ (⋃n. set_spmf (X n))"
proof(rule UN_least)
fix x
assume "x ∈ Y"
from order_tendstoD(2)[OF seq NC_less[OF ‹x ∈ Y›]]
obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially)
thus "set_spmf x ⊆ (⋃n. set_spmf (X n))" using X ‹x ∈ Y›
by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf)
qed
thus ?thesis by(rule countable_subset) simp
qed
qed simp

lemma lub_spmf_subprob: "(∫⇧+ x. (SUP p ∈ Y. ennreal (spmf p x)) ∂count_space UNIV) ≤ 1"
proof(cases "Y = {}")
case True