Theory Infinite_Product_Measure

(*  Title:      HOL/Probability/Infinite_Product_Measure.thy
    Author:     Johannes Hölzl, TU München
*)

section ‹Infinite Product Measure›

theory Infinite_Product_Measure
  imports Probability_Measure Projective_Family
begin

lemma (in product_prob_space) distr_PiM_restrict_finite:
  assumes "finite J" "J  I"
  shows "distr (PiM I M) (PiM J M) (λx. restrict x J) = PiM J M"
proof (rule PiM_eqI)
  fix X assume X: "i. i  J  X i  sets (M i)"
  { fix J X assume J: "J  {}  I = {}" "finite J" "J  I" and X: "i. i  J  X i  sets (M i)"
    have "emeasure (PiM I M) (emb I J (PiE J X)) = (iJ. M i (X i))"
    proof (subst emeasure_extend_measure_Pair[OF PiM_def, where μ'=lim], goal_cases)
      case 1 then show ?case
        by (simp add: M.emeasure_space_1 emeasure_PiM Pi_iff sets_PiM_I_finite emeasure_lim_emb)
    next
      case (2 J X)
      then have "emb I J (PiE J X)  sets (PiM I M)"
        by (intro measurable_prod_emb sets_PiM_I_finite) auto
      from this[THEN sets.sets_into_space] show ?case
        by (simp add: space_PiM)
    qed (insert assms J X, simp_all del: sets_lim
      add: M.emeasure_space_1 sets_lim[symmetric] emeasure_countably_additive emeasure_positive) }
  note * = this

  have "emeasure (PiM I M) (emb I J (PiE J X)) = (iJ. M i (X i))"
  proof (cases "J  {}  I = {}")
    case False
    then obtain i where i: "J = {}" "i  I" by auto
    then have "emb I {} {λx. undefined} = emb I {i} (ΠE i{i}. space (M i))"
      by (auto simp: space_PiM prod_emb_def)
    with i show ?thesis
      by (simp add: * M.emeasure_space_1)
  next
    case True
    then show ?thesis
      by (simp add: *[OF _ assms X])
  qed
  with assms show "emeasure (distr (PiM I M) (PiM J M) (λx. restrict x J)) (PiE J X) = (iJ. emeasure (M i) (X i))"
    by (subst emeasure_distr_restrict[OF _ refl]) (auto intro!: sets_PiM_I_finite X)
qed (insert assms, auto)

lemma (in product_prob_space) emeasure_PiM_emb':
  "J  I  finite J  X  sets (PiM J M)  emeasure (PiM I M) (emb I J X) = PiM J M X"
  by (subst distr_PiM_restrict_finite[symmetric, of J])
     (auto intro!: emeasure_distr_restrict[symmetric])

lemma (in product_prob_space) emeasure_PiM_emb:
  "J  I  finite J  (i. i  J  X i  sets (M i)) 
    emeasure (PiM I M) (emb I J (PiE J X)) = ( iJ. emeasure (M i) (X i))"
  by (subst emeasure_PiM_emb') (auto intro!: emeasure_PiM)

sublocale product_prob_space  P?: prob_space "PiM I M"
proof
  have *: "emb I {} {λx. undefined} = space (PiM I M)"
    by (auto simp: prod_emb_def space_PiM)
  show "emeasure (PiM I M) (space (PiM I M)) = 1"
    using emeasure_PiM_emb[of "{}" "λ_. {}"] by (simp add: *)
qed

lemma prob_space_PiM:
  assumes M: "i. i  I  prob_space (M i)" shows "prob_space (PiM I M)"
proof -
  let ?M = "λi. if i  I then M i else count_space {undefined}"
  interpret M': prob_space "?M i" for i
    using M by (cases "i  I") (auto intro!: prob_spaceI)
  interpret product_prob_space ?M I
    by unfold_locales
  have "prob_space (ΠM iI. ?M i)"
    by unfold_locales
  also have "(ΠM iI. ?M i) = (ΠM iI. M i)"
    by (intro PiM_cong) auto
  finally show ?thesis .
qed

lemma (in product_prob_space) emeasure_PiM_Collect:
  assumes X: "J  I" "finite J" "i. i  J  X i  sets (M i)"
  shows "emeasure (PiM I M) {xspace (PiM I M). iJ. x i  X i} = ( iJ. emeasure (M i) (X i))"
proof -
  have "{xspace (PiM I M). iJ. x i  X i} = emb I J (PiE J X)"
    unfolding prod_emb_def using assms by (auto simp: space_PiM Pi_iff)
  with emeasure_PiM_emb[OF assms] show ?thesis by simp
qed

lemma (in product_prob_space) emeasure_PiM_Collect_single:
  assumes X: "i  I" "A  sets (M i)"
  shows "emeasure (PiM I M) {xspace (PiM I M). x i  A} = emeasure (M i) A"
  using emeasure_PiM_Collect[of "{i}" "λi. A"] assms
  by simp

lemma (in product_prob_space) measure_PiM_emb:
  assumes "J  I" "finite J" "i. i  J  X i  sets (M i)"
  shows "measure (PiM I M) (emb I J (PiE J X)) = ( iJ. measure (M i) (X i))"
  using emeasure_PiM_emb[OF assms]
  unfolding emeasure_eq_measure M.emeasure_eq_measure
  by (simp add: prod_ennreal measure_nonneg prod_nonneg)

lemma sets_Collect_single':
  "i  I  {xspace (M i). P x}  sets (M i)  {xspace (PiM I M). P (x i)}  sets (PiM I M)"
  using sets_Collect_single[of i I "{xspace (M i). P x}" M]
  by (simp add: space_PiM PiE_iff cong: conj_cong)

lemma (in finite_product_prob_space) finite_measure_PiM_emb:
  "(i. i  I  A i  sets (M i))  measure (PiM I M) (PiE I A) = (iI. measure (M i) (A i))"
  using measure_PiM_emb[of I A] finite_index prod_emb_PiE_same_index[OF sets.sets_into_space, of I A M]
  by auto

lemma (in product_prob_space) PiM_component:
  assumes "i  I"
  shows "distr (PiM I M) (M i) (λω. ω i) = M i"
proof (rule measure_eqI[symmetric])
  fix A assume "A  sets (M i)"
  moreover have "((λω. ω i) -` A  space (PiM I M)) = {xspace (PiM I M). x i  A}"
    by auto
  ultimately show "emeasure (M i) A = emeasure (distr (PiM I M) (M i) (λω. ω i)) A"
    by (auto simp: iI emeasure_distr measurable_component_singleton emeasure_PiM_Collect_single)
qed simp

lemma (in product_prob_space) PiM_eq:
  assumes M': "sets M' = sets (PiM I M)"
  assumes eq: "J F. finite J  J  I  (j. j  J  F j  sets (M j)) 
    emeasure M' (prod_emb I M J (ΠE jJ. F j)) = (jJ. emeasure (M j) (F j))"
  shows "M' = (PiM I M)"
proof (rule measure_eqI_PiM_infinite[symmetric, OF refl M'])
  show "finite_measure (PiM I M)"
    by standard (simp add: P.emeasure_space_1)
qed (simp add: eq emeasure_PiM_emb)

lemma (in product_prob_space) AE_component: "i  I  AE x in M i. P x  AE x in PiM I M. P (x i)"
  apply (rule AE_distrD[of "λω. ω i" "PiM I M" "M i" P])
  apply simp
  apply (subst PiM_component)
  apply simp_all
  done

lemma emeasure_PiM_emb:
  assumes M: "i. i  I  prob_space (M i)"
  assumes J: "J  I" "finite J" and A: "i. i  J  A i  sets (M i)"
  shows "emeasure (PiM I M) (prod_emb I M J (PiE J A)) = (iJ. emeasure (M i) (A i))"
proof -
  let ?M = "λi. if i  I then M i else count_space {undefined}"
  interpret M': prob_space "?M i" for i
    using M by (cases "i  I") (auto intro!: prob_spaceI)
  interpret P: product_prob_space ?M I
    by unfold_locales
  have "emeasure (PiM I M) (prod_emb I M J (PiE J A)) = emeasure (PiM I ?M) (P.emb I J (PiE J A))"
    by (auto simp: prod_emb_def PiE_iff intro!: arg_cong2[where f=emeasure] PiM_cong)
  also have " = (iJ. emeasure (M i) (A i))"
    using J A by (subst P.emeasure_PiM_emb[OF J]) (auto intro!: prod.cong)
  finally show ?thesis .
qed

lemma distr_pair_PiM_eq_PiM:
  fixes i' :: "'i" and I :: "'i set" and M :: "'i  'a measure"
  assumes M: "i. i  I  prob_space (M i)" "prob_space (M i')"
  shows "distr (M i' M (ΠM iI. M i)) (ΠM iinsert i' I. M i) (λ(x, X). X(i' := x)) =
    (ΠM iinsert i' I. M i)" (is "?L = _")
proof (rule measure_eqI_PiM_infinite[symmetric, OF refl])
  interpret M': prob_space "M i'" by fact
  interpret I: prob_space "(ΠM iI. M i)"
    using M by (intro prob_space_PiM) auto
  interpret I': prob_space "(ΠM iinsert i' I. M i)"
    using M by (intro prob_space_PiM) auto
  show "finite_measure (ΠM iinsert i' I. M i)"
    by unfold_locales
  fix J A assume J: "finite J" "J  insert i' I" and A: "i. i  J  A i  sets (M i)"
  let ?X = "prod_emb (insert i' I) M J (PiE J A)"
  have "PiM (insert i' I) M ?X = (iJ. M i (A i))"
    using M J A by (intro emeasure_PiM_emb) auto
  also have " = M i' (if i'  J then (A i') else space (M i')) * (iJ-{i'}. M i (A i))"
    using prod.insert_remove[of J "λi. M i (A i)" i'] J M'.emeasure_space_1
    by (cases "i'  J") (auto simp: insert_absorb)
  also have "(iJ-{i'}. M i (A i)) = PiM I M (prod_emb I M (J - {i'}) (PiE (J - {i'}) A))"
    using M J A by (intro emeasure_PiM_emb[symmetric]) auto
  also have "M i' (if i'  J then (A i') else space (M i')) *  =
    (M i' M PiM I M) ((if i'  J then (A i') else space (M i')) × prod_emb I M (J - {i'}) (PiE (J - {i'}) A))"
    using J A by (intro I.emeasure_pair_measure_Times[symmetric] sets_PiM_I) auto
  also have "((if i'  J then (A i') else space (M i')) × prod_emb I M (J - {i'}) (PiE (J - {i'}) A)) =
    (λ(x, X). X(i' := x)) -` ?X  space (M i' M PiM I M)"
    using A[of i', THEN sets.sets_into_space] unfolding set_eq_iff
    by (simp add: prod_emb_def space_pair_measure space_PiM PiE_fun_upd ac_simps cong: conj_cong)
       (auto simp add: Pi_iff Ball_def all_conj_distrib)
  finally show "PiM (insert i' I) M ?X = ?L ?X"
    using J A by (simp add: emeasure_distr)
qed simp

lemma distr_PiM_reindex:
  assumes M: "i. i  K  prob_space (M i)"
  assumes f: "inj_on f I" "f  I  K"
  shows "distr (PiM K M) (ΠM iI. M (f i)) (λω. λnI. ω (f n)) = (ΠM iI. M (f i))"
    (is "distr ?K ?I ?t = ?I")
proof (rule measure_eqI_PiM_infinite[symmetric, OF refl])
  interpret prob_space ?I
    using f M by (intro prob_space_PiM) auto
  show "finite_measure ?I"
    by unfold_locales
  fix A J assume J: "finite J" "J  I" and A: "i. i  J  A i  sets (M (f i))"
  have [simp]: "i  J  the_inv_into I f (f i) = i" for i
    using J f by (intro the_inv_into_f_f) auto
  have "?I (prod_emb I (λi. M (f i)) J (PiE J A)) = (jJ. M (f j) (A j))"
    using f J A by (intro emeasure_PiM_emb M) auto
  also have " = (jf`J. M j (A (the_inv_into I f j)))"
    using f J by (subst prod.reindex) (auto intro!: prod.cong intro: inj_on_subset simp: the_inv_into_f_f)
  also have " = ?K (prod_emb K M (f`J) (ΠE jf`J. A (the_inv_into I f j)))"
    using f J A by (intro emeasure_PiM_emb[symmetric] M) (auto simp: the_inv_into_f_f)
  also have "prod_emb K M (f`J) (ΠE jf`J. A (the_inv_into I f j)) = ?t -` prod_emb I (λi. M (f i)) J (PiE J A)  space ?K"
    using f J A by (auto simp: prod_emb_def space_PiM Pi_iff PiE_iff Int_absorb1)
  also have "?K  = distr ?K ?I ?t (prod_emb I (λi. M (f i)) J (PiE J A))"
    using f J A by (intro emeasure_distr[symmetric] sets_PiM_I) (auto simp: Pi_iff)
  finally show "?I (prod_emb I (λi. M (f i)) J (PiE J A)) = distr ?K ?I ?t (prod_emb I (λi. M (f i)) J (PiE J A))" .
qed simp

lemma distr_PiM_component:
  assumes M: "i. i  I  prob_space (M i)"
  assumes "i  I"
  shows "distr (PiM I M) (M i) (λω. ω i) = M i"
proof -
  have *: "(λω. ω i) -` A  space (PiM I M) = prod_emb I M {i} (ΠE i'{i}. A)" for A
    by (auto simp: prod_emb_def space_PiM)
  show ?thesis
    apply (intro measure_eqI)
    apply (auto simp add: emeasure_distr iI * emeasure_PiM_emb M)
    apply (subst emeasure_PiM_emb)
    apply (simp_all add: M iI)
    done
qed

lemma AE_PiM_component:
  "(i. i  I  prob_space (M i))  i  I  AE x in M i. P x  AE x in PiM I M. P (x i)"
  using AE_distrD[of "λx. x i" "PiM I M" "M i"]
  by (subst (asm) distr_PiM_component[of I _ i]) (auto intro: AE_distrD[of "λx. x i" _ _ P])

lemma decseq_emb_PiE:
  "incseq J  decseq (λi. prod_emb I M (J i) (ΠE jJ i. X j))"
  by (fastforce simp: decseq_def prod_emb_def incseq_def Pi_iff)

subsection ‹Sequence space›

definition comb_seq :: "nat  (nat  'a)  (nat  'a)  (nat  'a)" where
  "comb_seq i ω ω' j = (if j < i then ω j else ω' (j - i))"

lemma split_comb_seq: "P (comb_seq i ω ω' j)  (j < i  P (ω j))  (k. j = i + k  P (ω' k))"
  by (auto simp: comb_seq_def not_less)

lemma split_comb_seq_asm: "P (comb_seq i ω ω' j)  ¬ ((j < i  ¬ P (ω j))  (k. j = i + k  ¬ P (ω' k)))"
  by (auto simp: comb_seq_def)

lemma measurable_comb_seq:
  "(λ(ω, ω'). comb_seq i ω ω')  measurable ((ΠM iUNIV. M) M (ΠM iUNIV. M)) (ΠM iUNIV. M)"
proof (rule measurable_PiM_single)
  show "(λ(ω, ω'). comb_seq i ω ω')  space ((ΠM iUNIV. M) M (ΠM iUNIV. M))  (UNIV E space M)"
    by (auto simp: space_pair_measure space_PiM PiE_iff split: split_comb_seq)
  fix j :: nat and A assume A: "A  sets M"
  then have *: "{ω  space ((ΠM iUNIV. M) M (ΠM iUNIV. M)). case_prod (comb_seq i) ω j  A} =
    (if j < i then {ω  space (ΠM iUNIV. M). ω j  A} × space (ΠM iUNIV. M)
              else space (ΠM iUNIV. M) × {ω  space (ΠM iUNIV. M). ω (j - i)  A})"
    by (auto simp: space_PiM space_pair_measure comb_seq_def dest: sets.sets_into_space)
  show "{ω  space ((ΠM iUNIV. M) M (ΠM iUNIV. M)). case_prod (comb_seq i) ω j  A}  sets ((ΠM iUNIV. M) M (ΠM iUNIV. M))"
    unfolding * by (auto simp: A intro!: sets_Collect_single)
qed

lemma measurable_comb_seq'[measurable (raw)]:
  assumes f: "f  measurable N (ΠM iUNIV. M)" and g: "g  measurable N (ΠM iUNIV. M)"
  shows "(λx. comb_seq i (f x) (g x))  measurable N (ΠM iUNIV. M)"
  using measurable_compose[OF measurable_Pair[OF f g] measurable_comb_seq] by simp

lemma comb_seq_0: "comb_seq 0 ω ω' = ω'"
  by (auto simp add: comb_seq_def)

lemma comb_seq_Suc: "comb_seq (Suc n) ω ω' = comb_seq n ω (case_nat (ω n) ω')"
  by (auto simp add: comb_seq_def not_less less_Suc_eq le_imp_diff_is_add intro!: ext split: nat.split)

lemma comb_seq_Suc_0[simp]: "comb_seq (Suc 0) ω = case_nat (ω 0)"
  by (intro ext) (simp add: comb_seq_Suc comb_seq_0)

lemma comb_seq_less: "i < n  comb_seq n ω ω' i = ω i"
  by (auto split: split_comb_seq)

lemma comb_seq_add: "comb_seq n ω ω' (i + n) = ω' i"
  by (auto split: nat.split split_comb_seq)

lemma case_nat_comb_seq: "case_nat s' (comb_seq n ω ω') (i + n) = case_nat (case_nat s' ω n) ω' i"
  by (auto split: nat.split split_comb_seq)

lemma case_nat_comb_seq':
  "case_nat s (comb_seq i ω ω') = comb_seq (Suc i) (case_nat s ω) ω'"
  by (auto split: split_comb_seq nat.split)

locale sequence_space = product_prob_space "λi. M" "UNIV :: nat set" for M
begin

abbreviation "S  ΠM iUNIV::nat set. M"

lemma infprod_in_sets[intro]:
  fixes E :: "nat  'a set" assumes E: "i. E i  sets M"
  shows "Pi UNIV E  sets S"
proof -
  have "Pi UNIV E = (i. emb UNIV {..i} (ΠE j{..i}. E j))"
    using E E[THEN sets.sets_into_space]
    by (auto simp: prod_emb_def Pi_iff extensional_def)
  with E show ?thesis by auto
qed

lemma measure_PiM_countable:
  fixes E :: "nat  'a set" assumes E: "i. E i  sets M"
  shows "(λn. in. measure M (E i))  measure S (Pi UNIV E)"
proof -
  let ?E = "λn. emb UNIV {..n} (PiE {.. n} E)"
  have "n. (in. measure M (E i)) = measure S (?E n)"
    using E by (simp add: measure_PiM_emb)
  moreover have "Pi UNIV E = (n. ?E n)"
    using E E[THEN sets.sets_into_space]
    by (auto simp: prod_emb_def extensional_def Pi_iff)
  moreover have "range ?E  sets S"
    using E by auto
  moreover have "decseq ?E"
    by (auto simp: prod_emb_def Pi_iff decseq_def)
  ultimately show ?thesis
    by (simp add: finite_Lim_measure_decseq)
qed

lemma nat_eq_diff_eq:
  fixes a b c :: nat
  shows "c  b  a = b - c  a + c = b"
  by auto

lemma PiM_comb_seq:
  "distr (S M S) S (λ(ω, ω'). comb_seq i ω ω') = S" (is "?D = _")
proof (rule PiM_eq)
  let ?I = "UNIV::nat set" and ?M = "λn. M"
  let "distr _ _ ?f" = "?D"

  fix J E assume J: "finite J" "J  ?I" "j. j  J  E j  sets M"
  let ?X = "prod_emb ?I ?M J (ΠE jJ. E j)"
  have "j x. j  J  x  E j  x  space M"
    using J(3)[THEN sets.sets_into_space] by (auto simp: space_PiM Pi_iff subset_eq)
  with J have "?f -` ?X  space (S M S) =
    (prod_emb ?I ?M (J  {..<i}) (ΠE jJ  {..<i}. E j)) ×
    (prod_emb ?I ?M (((+) i) -` J) (ΠE j((+) i) -` J. E (i + j)))" (is "_ = ?E × ?F")
   by (auto simp: space_pair_measure space_PiM prod_emb_def all_conj_distrib PiE_iff
               split: split_comb_seq split_comb_seq_asm)
  then have "emeasure ?D ?X = emeasure (S M S) (?E × ?F)"
    by (subst emeasure_distr[OF measurable_comb_seq])
       (auto intro!: sets_PiM_I simp: split_beta' J)
  also have " = emeasure S ?E * emeasure S ?F"
    using J by (intro P.emeasure_pair_measure_Times)  (auto intro!: sets_PiM_I finite_vimageI simp: inj_on_def)
  also have "emeasure S ?F = (j((+) i) -` J. emeasure M (E (i + j)))"
    using J by (intro emeasure_PiM_emb) (simp_all add: finite_vimageI inj_on_def)
  also have " = (jJ - (J  {..<i}). emeasure M (E j))"
    by (rule prod.reindex_cong [of "λx. x - i"])
      (auto simp: image_iff ac_simps nat_eq_diff_eq cong: conj_cong intro!: inj_onI)
  also have "emeasure S ?E = (jJ  {..<i}. emeasure M (E j))"
    using J by (intro emeasure_PiM_emb) simp_all
  also have "(jJ  {..<i}. emeasure M (E j)) * (jJ - (J  {..<i}). emeasure M (E j)) = (jJ. emeasure M (E j))"
    by (subst mult.commute) (auto simp: J prod.subset_diff[symmetric])
  finally show "emeasure ?D ?X = (jJ. emeasure M (E j))" .
qed simp_all

lemma PiM_iter:
  "distr (M M S) S (λ(s, ω). case_nat s ω) = S" (is "?D = _")
proof (rule PiM_eq)
  let ?I = "UNIV::nat set" and ?M = "λn. M"
  let "distr _ _ ?f" = "?D"

  fix J E assume J: "finite J" "J  ?I" "j. j  J  E j  sets M"
  let ?X = "prod_emb ?I ?M J (ΠE jJ. E j)"
  have "j x. j  J  x  E j  x  space M"
    using J(3)[THEN sets.sets_into_space] by (auto simp: space_PiM Pi_iff subset_eq)
  with J have "?f -` ?X  space (M M S) = (if 0  J then E 0 else space M) ×
    (prod_emb ?I ?M (Suc -` J) (ΠE jSuc -` J. E (Suc j)))" (is "_ = ?E × ?F")
   by (auto simp: space_pair_measure space_PiM PiE_iff prod_emb_def all_conj_distrib
      split: nat.split nat.split_asm)
  then have "emeasure ?D ?X = emeasure (M M S) (?E × ?F)"
    by (subst emeasure_distr)
       (auto intro!: sets_PiM_I simp: split_beta' J)
  also have " = emeasure M ?E * emeasure S ?F"
    using J by (intro P.emeasure_pair_measure_Times) (auto intro!: sets_PiM_I finite_vimageI)
  also have "emeasure S ?F = (jSuc -` J. emeasure M (E (Suc j)))"
    using J by (intro emeasure_PiM_emb) (simp_all add: finite_vimageI)
  also have " = (jJ - {0}. emeasure M (E j))"
    by (rule prod.reindex_cong [of "λx. x - 1"])
      (auto simp: image_iff nat_eq_diff_eq ac_simps cong: conj_cong intro!: inj_onI)
  also have "emeasure M ?E * (jJ - {0}. emeasure M (E j)) = (jJ. emeasure M (E j))"
    by (auto simp: M.emeasure_space_1 prod.remove J)
  finally show "emeasure ?D ?X = (jJ. emeasure M (E j))" .
qed simp_all

end

lemma PiM_return:
  assumes "finite I"
  assumes [measurable]: "i. i  I  {a i}  sets (M i)"
  shows "PiM I (λi. return (M i) (a i)) = return (PiM I M) (restrict a I)"
proof -
  have [simp]: "a i  space (M i)" if "i  I" for i
    using assms(2)[OF that] by (meson insert_subset sets.sets_into_space)
  interpret prob_space "PiM I (λi. return (M i) (a i))"
    by (intro prob_space_PiM prob_space_return) auto
  have "AE x in PiM I (λi. return (M i) (a i)). iI. x i = restrict a I i"
    by (intro eventually_ball_finite ballI AE_PiM_component prob_space_return assms)
       (auto simp: AE_return)
  moreover have "AE x in PiM I (λi. return (M i) (a i)). x  space (PiM I (λi. return (M i) (a i)))"
    by simp
  ultimately have "AE x in PiM I (λi. return (M i) (a i)). x = restrict a I"
    by eventually_elim (auto simp: fun_eq_iff space_PiM)
  hence "PiM I (λi. return (M i) (a i)) = return (PiM I (λi. return (M i) (a i))) (restrict a I)"
    by (rule AE_eq_constD)
  also have " = return (PiM I M) (restrict a I)"
    by (intro return_cong sets_PiM_cong) auto
  finally show ?thesis .
qed

lemma distr_PiM_finite_prob_space':
  assumes fin: "finite I"
  assumes "i. i  I  prob_space (M i)"
  assumes "i. i  I  prob_space (M' i)"
  assumes [measurable]: "i. i  I  f  measurable (M i) (M' i)"
  shows   "distr (PiM I M) (PiM I M') (compose I f) = PiM I (λi. distr (M i) (M' i) f)"
proof -
  define N where "N = (λi. if i  I then M i else return (count_space UNIV) undefined)"
  define N' where "N' = (λi. if i  I then M' i else return (count_space UNIV) undefined)"
  have [simp]: "PiM I N = PiM I M" "PiM I N' = PiM I M'"
    by (intro PiM_cong; simp add: N_def N'_def)+

  have "distr (PiM I N) (PiM I N') (compose I f) = PiM I (λi. distr (N i) (N' i) f)"
  proof (rule distr_PiM_finite_prob_space)
    show "product_prob_space N"
      by (rule product_prob_spaceI) (auto simp: N_def intro!: prob_space_return assms)
    show "product_prob_space N'"
      by (rule product_prob_spaceI) (auto simp: N'_def intro!: prob_space_return assms)
  qed (auto simp: N_def N'_def fin)
  also have "PiM I (λi. distr (N i) (N' i) f) = PiM I (λi. distr (M i) (M' i) f)"
    by (intro PiM_cong) (simp_all add: N_def N'_def)
  finally show ?thesis by simp
qed

end