Theory Binomial_Heap

(* Author: Peter Lammich
           Tobias Nipkow (tuning)
*)

section ‹Binomial Heap›

theory Binomial_Heap
imports
  "HOL-Library.Pattern_Aliases"
  Complex_Main
  Priority_Queue_Specs
begin

text ‹
  We formalize the binomial heap presentation from Okasaki's book.
  We show the functional correctness and complexity of all operations.

  The presentation is engineered for simplicity, and most
  proofs are straightforward and automatic.
›

subsection ‹Binomial Tree and Heap Datatype›

datatype 'a tree = Node (rank: nat) (root: 'a) (children: "'a tree list")

type_synonym 'a trees = "'a tree list"

subsubsection ‹Multiset of elements›

fun mset_tree :: "'a::linorder tree  'a multiset" where
  "mset_tree (Node _ a ts) = {#a#} + (t∈#mset ts. mset_tree t)"

definition mset_trees :: "'a::linorder trees  'a multiset" where
  "mset_trees ts = (t∈#mset ts. mset_tree t)"

lemma mset_tree_simp_alt[simp]:
  "mset_tree (Node r a ts) = {#a#} + mset_trees ts"
  unfolding mset_trees_def by auto
declare mset_tree.simps[simp del]

lemma mset_tree_nonempty[simp]: "mset_tree t  {#}"
by (cases t) auto

lemma mset_trees_Nil[simp]:
  "mset_trees [] = {#}"
by (auto simp: mset_trees_def)

lemma mset_trees_Cons[simp]: "mset_trees (t#ts) = mset_tree t + mset_trees ts"
by (auto simp: mset_trees_def)

lemma mset_trees_empty_iff[simp]: "mset_trees ts = {#}  ts=[]"
by (auto simp: mset_trees_def)

lemma root_in_mset[simp]: "root t ∈# mset_tree t"
by (cases t) auto

lemma mset_trees_rev_eq[simp]: "mset_trees (rev ts) = mset_trees ts"
by (auto simp: mset_trees_def)

subsubsection ‹Invariants›

text ‹Binomial tree›
fun btree :: "'a::linorder tree  bool" where
"btree (Node r x ts) 
   (tset ts. btree t)  map rank ts = rev [0..<r]"

text ‹Heap invariant›
fun heap :: "'a::linorder tree  bool" where
"heap (Node _ x ts)  (tset ts. heap t  x  root t)"

definition "bheap t  btree t  heap t"

text ‹Binomial Heap invariant›
definition "invar ts  (tset ts. bheap t)  (sorted_wrt (<) (map rank ts))"


text ‹The children of a node are a valid heap›
lemma invar_children:
  "bheap (Node r v ts)  invar (rev ts)"
  by (auto simp: bheap_def invar_def rev_map[symmetric])


subsection ‹Operations and Their Functional Correctness›

subsubsection link›

context
includes pattern_aliases
begin

fun link :: "('a::linorder) tree  'a tree  'a tree" where
  "link (Node r x1 ts1 =: t1) (Node r' x2 ts2 =: t2) =
    (if x1x2 then Node (r+1) x1 (t2#ts1) else Node (r+1) x2 (t1#ts2))"

end

lemma invar_link:
  assumes "bheap t1"
  assumes "bheap t2"
  assumes "rank t1 = rank t2"
  shows "bheap (link t1 t2)"
using assms unfolding bheap_def
by (cases "(t1, t2)" rule: link.cases) auto

lemma rank_link[simp]: "rank (link t1 t2) = rank t1 + 1"
by (cases "(t1, t2)" rule: link.cases) simp

lemma mset_link[simp]: "mset_tree (link t1 t2) = mset_tree t1 + mset_tree t2"
by (cases "(t1, t2)" rule: link.cases) simp

subsubsection ins_tree›

fun ins_tree :: "'a::linorder tree  'a trees  'a trees" where
  "ins_tree t [] = [t]"
| "ins_tree t1 (t2#ts) =
  (if rank t1 < rank t2 then t1#t2#ts else ins_tree (link t1 t2) ts)"

lemma bheap0[simp]: "bheap (Node 0 x [])"
unfolding bheap_def by auto

lemma invar_Cons[simp]:
  "invar (t#ts)
   bheap t  invar ts  (t'set ts. rank t < rank t')"
by (auto simp: invar_def)

lemma invar_ins_tree:
  assumes "bheap t"
  assumes "invar ts"
  assumes "t'set ts. rank t  rank t'"
  shows "invar (ins_tree t ts)"
using assms
by (induction t ts rule: ins_tree.induct) (auto simp: invar_link less_eq_Suc_le[symmetric])

lemma mset_trees_ins_tree[simp]:
  "mset_trees (ins_tree t ts) = mset_tree t + mset_trees ts"
by (induction t ts rule: ins_tree.induct) auto

lemma ins_tree_rank_bound:
  assumes "t'  set (ins_tree t ts)"
  assumes "t'set ts. rank t0 < rank t'"
  assumes "rank t0 < rank t"
  shows "rank t0 < rank t'"
using assms
by (induction t ts rule: ins_tree.induct) (auto split: if_splits)

subsubsection insert›

hide_const (open) insert

definition insert :: "'a::linorder  'a trees  'a trees" where
"insert x ts = ins_tree (Node 0 x []) ts"

lemma invar_insert[simp]: "invar t  invar (insert x t)"
by (auto intro!: invar_ins_tree simp: insert_def)

lemma mset_trees_insert[simp]: "mset_trees (insert x t) = {#x#} + mset_trees t"
by(auto simp: insert_def)

subsubsection merge›

context
includes pattern_aliases
begin

fun merge :: "'a::linorder trees  'a trees  'a trees" where
  "merge ts1 [] = ts1"
| "merge [] ts2 = ts2"
| "merge (t1#ts1 =: h1) (t2#ts2 =: h2) = (
    if rank t1 < rank t2 then t1 # merge ts1 h2 else
    if rank t2 < rank t1 then t2 # merge h1 ts2
    else ins_tree (link t1 t2) (merge ts1 ts2)
  )"

end

lemma merge_simp2[simp]: "merge [] ts2 = ts2"
by (cases ts2) auto

lemma merge_rank_bound:
  assumes "t'  set (merge ts1 ts2)"
  assumes "t1set ts1. rank t < rank t1"
  assumes "t2set ts2. rank t < rank t2"
  shows "rank t < rank t'"
using assms
by (induction ts1 ts2 arbitrary: t' rule: merge.induct)
   (auto split: if_splits simp: ins_tree_rank_bound)

lemma invar_merge[simp]:
  assumes "invar ts1"
  assumes "invar ts2"
  shows "invar (merge ts1 ts2)"
using assms
by (induction ts1 ts2 rule: merge.induct)
   (auto 0 3 simp: Suc_le_eq intro!: invar_ins_tree invar_link elim!: merge_rank_bound)


text ‹Longer, more explicit proof of @{thm [source] invar_merge}, 
      to illustrate the application of the @{thm [source] merge_rank_bound} lemma.›
lemma 
  assumes "invar ts1"
  assumes "invar ts2"
  shows "invar (merge ts1 ts2)"
  using assms
proof (induction ts1 ts2 rule: merge.induct)
  case (3 t1 ts1 t2 ts2)
  ― ‹Invariants of the parts can be shown automatically›
  from "3.prems" have [simp]: 
    "bheap t1" "bheap t2"
    (*"invar (merge (t1#ts1) ts2)" 
    "invar (merge ts1 (t2#ts2))"
    "invar (merge ts1 ts2)"*)
    by auto

  ― ‹These are the three cases of the @{const merge} function›
  consider (LT) "rank t1 < rank t2"
         | (GT) "rank t1 > rank t2"
         | (EQ) "rank t1 = rank t2"
    using antisym_conv3 by blast
  then show ?case proof cases
    case LT 
    ― ‹@{const merge} takes the first tree from the left heap›
    then have "merge (t1 # ts1) (t2 # ts2) = t1 # merge ts1 (t2 # ts2)" by simp
    also have "invar " proof (simp, intro conjI)
      ― ‹Invariant follows from induction hypothesis›
      show "invar (merge ts1 (t2 # ts2))"
        using LT "3.IH" "3.prems" by simp

      ― ‹It remains to show that t1 has smallest rank.›
      show "t'set (merge ts1 (t2 # ts2)). rank t1 < rank t'"
        ― ‹Which is done by auxiliary lemma @{thm [source] merge_rank_bound}
        using LT "3.prems" by (force elim!: merge_rank_bound)
    qed
    finally show ?thesis .
  next
    ― ‹@{const merge} takes the first tree from the right heap›
    case GT 
    ― ‹The proof is anaologous to the LT› case›
    then show ?thesis using "3.prems" "3.IH" by (force elim!: merge_rank_bound)
  next
    case [simp]: EQ
    ― ‹@{const merge} links both first trees, and inserts them into the merged remaining heaps›
    have "merge (t1 # ts1) (t2 # ts2) = ins_tree (link t1 t2) (merge ts1 ts2)" by simp
    also have "invar " proof (intro invar_ins_tree invar_link) 
      ― ‹Invariant of merged remaining heaps follows by IH›
      show "invar (merge ts1 ts2)"
        using EQ "3.prems" "3.IH" by auto

      ― ‹For insertion, we have to show that the rank of the linked tree is ≤› the 
          ranks in the merged remaining heaps›
      show "t'set (merge ts1 ts2). rank (link t1 t2)  rank t'"
      proof -
        ― ‹Which is, again, done with the help of @{thm [source] merge_rank_bound}
        have "rank (link t1 t2) = Suc (rank t2)" by simp
        thus ?thesis using "3.prems" by (auto simp: Suc_le_eq elim!: merge_rank_bound)
      qed
    qed simp_all
    finally show ?thesis .
  qed
qed auto


lemma mset_trees_merge[simp]:
  "mset_trees (merge ts1 ts2) = mset_trees ts1 + mset_trees ts2"
by (induction ts1 ts2 rule: merge.induct) auto

subsubsection get_min›

fun get_min :: "'a::linorder trees  'a" where
  "get_min [t] = root t"
| "get_min (t#ts) = min (root t) (get_min ts)"

lemma bheap_root_min:
  assumes "bheap t"
  assumes "x ∈# mset_tree t"
  shows "root t  x"
using assms unfolding bheap_def
by (induction t arbitrary: x rule: mset_tree.induct) (fastforce simp: mset_trees_def)

lemma get_min_mset:
  assumes "ts[]"
  assumes "invar ts"
  assumes "x ∈# mset_trees ts"
  shows "get_min ts  x"
  using assms
apply (induction ts arbitrary: x rule: get_min.induct)
apply (auto
      simp: bheap_root_min min_def intro: order_trans;
      meson linear order_trans bheap_root_min
      )+
done

lemma get_min_member:
  "ts[]  get_min ts ∈# mset_trees ts"
by (induction ts rule: get_min.induct) (auto simp: min_def)

lemma get_min:
  assumes "mset_trees ts  {#}"
  assumes "invar ts"
  shows "get_min ts = Min_mset (mset_trees ts)"
using assms get_min_member get_min_mset
by (auto simp: eq_Min_iff)

subsubsection get_min_rest›

fun get_min_rest :: "'a::linorder trees  'a tree × 'a trees" where
  "get_min_rest [t] = (t,[])"
| "get_min_rest (t#ts) = (let (t',ts') = get_min_rest ts
                     in if root t  root t' then (t,ts) else (t',t#ts'))"

lemma get_min_rest_get_min_same_root:
  assumes "ts[]"
  assumes "get_min_rest ts = (t',ts')"
  shows "root t' = get_min ts"
using assms
by (induction ts arbitrary: t' ts' rule: get_min.induct) (auto simp: min_def split: prod.splits)

lemma mset_get_min_rest:
  assumes "get_min_rest ts = (t',ts')"
  assumes "ts[]"
  shows "mset ts = {#t'#} + mset ts'"
using assms
by (induction ts arbitrary: t' ts' rule: get_min.induct) (auto split: prod.splits if_splits)

lemma set_get_min_rest:
  assumes "get_min_rest ts = (t', ts')"
  assumes "ts[]"
  shows "set ts = Set.insert t' (set ts')"
using mset_get_min_rest[OF assms, THEN arg_cong[where f=set_mset]]
by auto

lemma invar_get_min_rest:
  assumes "get_min_rest ts = (t',ts')"
  assumes "ts[]"
  assumes "invar ts"
  shows "bheap t'" and "invar ts'"
proof -
  have "bheap t'  invar ts'"
    using assms
    proof (induction ts arbitrary: t' ts' rule: get_min.induct)
      case (2 t v va)
      then show ?case
        apply (clarsimp split: prod.splits if_splits)
        apply (drule set_get_min_rest; fastforce)
        done
    qed auto
  thus "bheap t'" and "invar ts'" by auto
qed

subsubsection del_min›

definition del_min :: "'a::linorder trees  'a::linorder trees" where
"del_min ts = (case get_min_rest ts of
   (Node r x ts1, ts2)  merge (rev ts1) ts2)"

lemma invar_del_min[simp]:
  assumes "ts  []"
  assumes "invar ts"
  shows "invar (del_min ts)"
using assms
unfolding del_min_def
by (auto
      split: prod.split tree.split
      intro!: invar_merge invar_children 
      dest: invar_get_min_rest
    )

lemma mset_trees_del_min:
  assumes "ts  []"
  shows "mset_trees ts = mset_trees (del_min ts) + {# get_min ts #}"
using assms
unfolding del_min_def
apply (clarsimp split: tree.split prod.split)
apply (frule (1) get_min_rest_get_min_same_root)
apply (frule (1) mset_get_min_rest)
apply (auto simp: mset_trees_def)
done


subsubsection ‹Instantiating the Priority Queue Locale›

text ‹Last step of functional correctness proof: combine all the above lemmas
to show that binomial heaps satisfy the specification of priority queues with merge.›

interpretation bheaps: Priority_Queue_Merge
  where empty = "[]" and is_empty = "(=) []" and insert = insert
  and get_min = get_min and del_min = del_min and merge = merge
  and invar = invar and mset = mset_trees
proof (unfold_locales, goal_cases)
  case 1 thus ?case by simp
next
  case 2 thus ?case by auto
next
  case 3 thus ?case by auto
next
  case (4 q)
  thus ?case using mset_trees_del_min[of q] get_min[OF _ invar q]
    by (auto simp: union_single_eq_diff)
next
  case (5 q) thus ?case using get_min[of q] by auto
next
  case 6 thus ?case by (auto simp add: invar_def)
next
  case 7 thus ?case by simp
next
  case 8 thus ?case by simp
next
  case 9 thus ?case by simp
next
  case 10 thus ?case by simp
qed


subsection ‹Complexity›

text ‹The size of a binomial tree is determined by its rank›
lemma size_mset_btree:
  assumes "btree t"
  shows "size (mset_tree t) = 2^rank t"
  using assms
proof (induction t)
  case (Node r v ts)
  hence IH: "size (mset_tree t) = 2^rank t" if "t  set ts" for t
    using that by auto

  from Node have COMPL: "map rank ts = rev [0..<r]" by auto

  have "size (mset_trees ts) = (tts. size (mset_tree t))"
    by (induction ts) auto
  also have " = (tts. 2^rank t)" using IH
    by (auto cong: map_cong)
  also have " = (rmap rank ts. 2^r)"
    by (induction ts) auto
  also have " = (i{0..<r}. 2^i)"
    unfolding COMPL
    by (auto simp: rev_map[symmetric] interv_sum_list_conv_sum_set_nat)
  also have " = 2^r - 1"
    by (induction r) auto
  finally show ?case
    by (simp)
qed

lemma size_mset_tree:
  assumes "bheap t"
  shows "size (mset_tree t) = 2^rank t"
using assms unfolding bheap_def
by (simp add: size_mset_btree)

text ‹The length of a binomial heap is bounded by the number of its elements›
lemma size_mset_trees:
  assumes "invar ts"
  shows "length ts  log 2 (size (mset_trees ts) + 1)"
proof -
  from invar ts have
    ASC: "sorted_wrt (<) (map rank ts)" and
    TINV: "tset ts. bheap t"
    unfolding invar_def by auto

  have "(2::nat)^length ts = (i{0..<length ts}. 2^i) + 1"
    by (simp add: sum_power2)
  also have " = (i[0..<length ts]. 2^i) + 1" (is "_ = ?S + 1")
    by (simp add: interv_sum_list_conv_sum_set_nat)
  also have "?S  (tts. 2^rank t)" (is "_  ?T")
    using sorted_wrt_less_idx[OF ASC] by(simp add: sum_list_mono2)
  also have "?T + 1  (tts. size (mset_tree t)) + 1" using TINV
    by (auto cong: map_cong simp: size_mset_tree)
  also have " = size (mset_trees ts) + 1"
    unfolding mset_trees_def by (induction ts) auto
  finally have "2^length ts  size (mset_trees ts) + 1" by simp
  then show ?thesis using le_log2_of_power by blast
qed

subsubsection ‹Timing Functions›

text ‹
  We define timing functions for each operation, and provide
  estimations of their complexity.
›
definition T_link :: "'a::linorder tree  'a tree  nat" where
[simp]: "T_link _ _ = 1"

text ‹This function is non-canonical: we omitted a +1› in the else›-part,
  to keep the following analysis simpler and more to the point.
›
fun T_ins_tree :: "'a::linorder tree  'a trees  nat" where
  "T_ins_tree t [] = 1"
| "T_ins_tree t1 (t2 # ts) = (
    (if rank t1 < rank t2 then 1
     else T_link t1 t2 + T_ins_tree (link t1 t2) ts)
  )"

definition T_insert :: "'a::linorder  'a trees  nat" where
"T_insert x ts = T_ins_tree (Node 0 x []) ts + 1"

lemma T_ins_tree_simple_bound: "T_ins_tree t ts  length ts + 1"
by (induction t ts rule: T_ins_tree.induct) auto

subsubsection T_insert›

lemma T_insert_bound:
  assumes "invar ts"
  shows "T_insert x ts  log 2 (size (mset_trees ts) + 1) + 2"
proof -
  have "real (T_insert x ts)  real (length ts) + 2"
    unfolding T_insert_def using T_ins_tree_simple_bound 
    using of_nat_mono by fastforce
  also note size_mset_trees[OF invar ts]
  finally show ?thesis by simp
qed

subsubsection T_merge›

context
includes pattern_aliases
begin

fun T_merge :: "'a::linorder trees  'a trees  nat" where
  "T_merge ts1 [] = 1"
| "T_merge [] ts2 = 1"
| "T_merge (t1#ts1 =: h1) (t2#ts2 =: h2) = 1 + (
    if rank t1 < rank t2 then T_merge ts1 h2
    else if rank t2 < rank t1 then T_merge h1 ts2
    else T_ins_tree (link t1 t2) (merge ts1 ts2) + T_merge ts1 ts2
  )"

end

text ‹A crucial idea is to estimate the time in correlation with the
  result length, as each carry reduces the length of the result.›

lemma T_ins_tree_length:
  "T_ins_tree t ts + length (ins_tree t ts) = 2 + length ts"
by (induction t ts rule: ins_tree.induct) auto

lemma T_merge_length:
  "T_merge ts1 ts2 + length (merge ts1 ts2)  2 * (length ts1 + length ts2) + 1"
by (induction ts1 ts2 rule: T_merge.induct)
   (auto simp: T_ins_tree_length algebra_simps)

text ‹Finally, we get the desired logarithmic bound›
lemma T_merge_bound:
  fixes ts1 ts2
  defines "n1  size (mset_trees ts1)"
  defines "n2  size (mset_trees ts2)"
  assumes "invar ts1" "invar ts2"
  shows "T_merge ts1 ts2  4*log 2 (n1 + n2 + 1) + 1"
proof -
  note n_defs = assms(1,2)

  have "T_merge ts1 ts2  2 * real (length ts1) + 2 * real (length ts2) + 1"
    using T_merge_length[of ts1 ts2] by simp
  also note size_mset_trees[OF invar ts1]
  also note size_mset_trees[OF invar ts2]
  finally have "T_merge ts1 ts2  2 * log 2 (n1 + 1) + 2 * log 2 (n2 + 1) + 1"
    unfolding n_defs by (simp add: algebra_simps)
  also have "log 2 (n1 + 1)  log 2 (n1 + n2 + 1)" 
    unfolding n_defs by (simp add: algebra_simps)
  also have "log 2 (n2 + 1)  log 2 (n1 + n2 + 1)" 
    unfolding n_defs by (simp add: algebra_simps)
  finally show ?thesis by (simp add: algebra_simps)
qed

subsubsection T_get_min›

fun T_get_min :: "'a::linorder trees  nat" where
  "T_get_min [t] = 1"
| "T_get_min (t#ts) = 1 + T_get_min ts"

lemma T_get_min_estimate: "ts[]  T_get_min ts = length ts"
by (induction ts rule: T_get_min.induct) auto

lemma T_get_min_bound:
  assumes "invar ts"
  assumes "ts[]"
  shows "T_get_min ts  log 2 (size (mset_trees ts) + 1)"
proof -
  have 1: "T_get_min ts = length ts" using assms T_get_min_estimate by auto
  also note size_mset_trees[OF invar ts]
  finally show ?thesis .
qed

subsubsection T_del_min›

fun T_get_min_rest :: "'a::linorder trees  nat" where
  "T_get_min_rest [t] = 1"
| "T_get_min_rest (t#ts) = 1 + T_get_min_rest ts"

lemma T_get_min_rest_estimate: "ts[]  T_get_min_rest ts = length ts"
  by (induction ts rule: T_get_min_rest.induct) auto

lemma T_get_min_rest_bound:
  assumes "invar ts"
  assumes "ts[]"
  shows "T_get_min_rest ts  log 2 (size (mset_trees ts) + 1)"
proof -
  have 1: "T_get_min_rest ts = length ts" using assms T_get_min_rest_estimate by auto
  also note size_mset_trees[OF invar ts]
  finally show ?thesis .
qed

text‹Note that although the definition of function constrev has quadratic complexity,
it can and is implemented (via suitable code lemmas) as a linear time function.
Thus the following definition is justified:›

definition "T_rev xs = length xs + 1"

definition T_del_min :: "'a::linorder trees  nat" where
  "T_del_min ts = T_get_min_rest ts + (case get_min_rest ts of (Node _ x ts1, ts2)
                     T_rev ts1 + T_merge (rev ts1) ts2
  ) + 1"

lemma T_del_min_bound:
  fixes ts
  defines "n  size (mset_trees ts)"
  assumes "invar ts" and "ts[]"
  shows "T_del_min ts  6 * log 2 (n+1) + 3"
proof -
  obtain r x ts1 ts2 where GM: "get_min_rest ts = (Node r x ts1, ts2)"
    by (metis surj_pair tree.exhaust_sel)

  have I1: "invar (rev ts1)" and I2: "invar ts2"
    using invar_get_min_rest[OF GM ts[] invar ts] invar_children
    by auto

  define n1 where "n1 = size (mset_trees ts1)"
  define n2 where "n2 = size (mset_trees ts2)"

  have "n1  n" "n1 + n2  n" unfolding n_def n1_def n2_def
    using mset_get_min_rest[OF GM ts[]]
    by (auto simp: mset_trees_def)

  have "T_del_min ts = real (T_get_min_rest ts) + real (T_rev ts1) + real (T_merge (rev ts1) ts2) + 1"
    unfolding T_del_min_def GM
    by simp
  also have "T_get_min_rest ts  log 2 (n+1)" 
    using T_get_min_rest_bound[OF invar ts ts[]] unfolding n_def by simp
  also have "T_rev ts1  1 + log 2 (n1 + 1)"
    unfolding T_rev_def n1_def using size_mset_trees[OF I1] by simp
  also have "T_merge (rev ts1) ts2  4*log 2 (n1 + n2 + 1) + 1"
    unfolding n1_def n2_def using T_merge_bound[OF I1 I2] by (simp add: algebra_simps)
  finally have "T_del_min ts  log 2 (n+1) + log 2 (n1 + 1) + 4*log 2 (real (n1 + n2) + 1) + 3"
    by (simp add: algebra_simps)
  also note n1 + n2  n
  also note n1  n
  finally show ?thesis by (simp add: algebra_simps)
qed

end