Theory Pattern_Match

(*  Title:      HOL/HOLCF/ex/Pattern_Match.thy
    Author:     Brian Huffman
*)

section ‹An experimental pattern-matching notation›

theory Pattern_Match
imports HOLCF
begin

default_sort pcpo

text ‹FIXME: Find a proper way to un-hide constants.›

abbreviation fail :: "'a match"
where "fail  Fixrec.fail"

abbreviation succeed :: "'a  'a match"
where "succeed  Fixrec.succeed"

abbreviation run :: "'a match  'a"
where "run  Fixrec.run"

subsection ‹Fatbar combinator›

definition
  fatbar :: "('a  'b match)  ('a  'b match)  ('a  'b match)" where
  "fatbar = (Λ a b x. ax +++ bx)"

abbreviation
  fatbar_syn :: "['a  'b match, 'a  'b match]  'a  'b match" (infixr "" 60)  where
  "m1  m2 == fatbarm1m2"

lemma fatbar1: "mx =   (m  ms)x = "
by (simp add: fatbar_def)

lemma fatbar2: "mx = fail  (m  ms)x = msx"
by (simp add: fatbar_def)

lemma fatbar3: "mx = succeedy  (m  ms)x = succeedy"
by (simp add: fatbar_def)

lemmas fatbar_simps = fatbar1 fatbar2 fatbar3

lemma run_fatbar1: "mx =   run((m  ms)x) = "
by (simp add: fatbar_def)

lemma run_fatbar2: "mx = fail  run((m  ms)x) = run(msx)"
by (simp add: fatbar_def)

lemma run_fatbar3: "mx = succeedy  run((m  ms)x) = y"
by (simp add: fatbar_def)

lemmas run_fatbar_simps [simp] = run_fatbar1 run_fatbar2 run_fatbar3

subsection ‹Bind operator for match monad›

definition match_bind :: "'a match  ('a  'b match)  'b match" where
  "match_bind = (Λ m k. sscase(Λ _. fail)(fupk)(Rep_match m))"

lemma match_bind_simps [simp]:
  "match_bindk = "
  "match_bindfailk = fail"
  "match_bind(succeedx)k = kx"
unfolding match_bind_def fail_def succeed_def
by (simp_all add: cont_Rep_match cont_Abs_match
  Rep_match_strict Abs_match_inverse)

subsection ‹Case branch combinator›

definition
  branch :: "('a  'b match)  ('b  'c)  ('a  'c match)" where
  "branch p  Λ r x. match_bind(px)(Λ y. succeed(ry))"

lemma branch_simps:
  "px =   branch prx = "
  "px = fail  branch prx = fail"
  "px = succeedy  branch prx = succeed(ry)"
by (simp_all add: branch_def)

lemma branch_succeed [simp]: "branch succeedrx = succeed(rx)"
by (simp add: branch_def)

subsection ‹Cases operator›

definition
  cases :: "'a match  'a::pcpo" where
  "cases = Fixrec.run"

text ‹rewrite rules for cases›

lemma cases_strict [simp]: "cases = "
by (simp add: cases_def)

lemma cases_fail [simp]: "casesfail = "
by (simp add: cases_def)

lemma cases_succeed [simp]: "cases(succeedx) = x"
by (simp add: cases_def)

subsection ‹Case syntax›

nonterminal Case_pat and Case_syn and Cases_syn

syntax
  "_Case_syntax":: "['a, Cases_syn] => 'b"               ("(Case _ of/ _)" 10)
  "_Case1"      :: "[Case_pat, 'b] => Case_syn"          ("(2_ / _)" 10)
  ""            :: "Case_syn => Cases_syn"               ("_")
  "_Case2"      :: "[Case_syn, Cases_syn] => Cases_syn"  ("_/ | _")
  "_strip_positions" :: "'a => Case_pat"                 ("_")

syntax (ASCII)
  "_Case1"      :: "[Case_pat, 'b] => Case_syn"          ("(2_ =>/ _)" 10)

translations
  "_Case_syntax x ms" == "CONST cases(msx)"
  "_Case2 m ms" == "m  ms"

text ‹Parsing Case expressions›

syntax
  "_pat" :: "'a"
  "_variable" :: "'a"
  "_noargs" :: "'a"

translations
  "_Case1 p r" => "CONST branch (_pat p)(_variable p r)"
  "_variable (_args x y) r" => "CONST csplit(_variable x (_variable y r))"
  "_variable _noargs r" => "CONST unit_whenr"

parse_translation (* rewrite (_pat x) => (succeed) *)
(* rewrite (_variable x t) => (Abs_cfun (%x. t)) *)
 [(syntax_const‹_pat›, fn _ => fn _ => Syntax.const const_syntaxFixrec.succeed),
  Syntax_Trans.mk_binder_tr (syntax_const‹_variable›, const_syntaxAbs_cfun)]

text ‹Printing Case expressions›

syntax
  "_match" :: "'a"

print_translation let
    fun dest_LAM (Const (const_syntaxRep_cfun,_) $ Const (const_syntaxunit_when,_) $ t) =
          (Syntax.const syntax_const‹_noargs›, t)
    |   dest_LAM (Const (const_syntaxRep_cfun,_) $ Const (const_syntaxcsplit,_) $ t) =
          let
            val (v1, t1) = dest_LAM t;
            val (v2, t2) = dest_LAM t1;
          in (Syntax.const syntax_const‹_args› $ v1 $ v2, t2) end
    |   dest_LAM (Const (const_syntaxAbs_cfun,_) $ t) =
          let
            val abs =
              case t of Abs abs => abs
                | _ => ("x", dummyT, incr_boundvars 1 t $ Bound 0);
            val (x, t') = Syntax_Trans.atomic_abs_tr' abs;
          in (Syntax.const syntax_const‹_variable› $ x, t') end
    |   dest_LAM _ = raise Match; (* too few vars: abort translation *)

    fun Case1_tr' [Const(const_syntaxbranch,_) $ p, r] =
          let val (v, t) = dest_LAM r in
            Syntax.const syntax_const‹_Case1› $
              (Syntax.const syntax_const‹_match› $ p $ v) $ t
          end;

  in [(const_syntaxRep_cfun, K Case1_tr')] end

translations
  "x" <= "_match (CONST succeed) (_variable x)"


subsection ‹Pattern combinators for data constructors›

type_synonym ('a, 'b) pat = "'a  'b match"

definition
  cpair_pat :: "('a, 'c) pat  ('b, 'd) pat  ('a × 'b, 'c × 'd) pat" where
  "cpair_pat p1 p2 = (Λ(x, y).
    match_bind(p1x)(Λ a. match_bind(p2y)(Λ b. succeed(a, b))))"

definition
  spair_pat ::
  "('a, 'c) pat  ('b, 'd) pat  ('a::pcpo  'b::pcpo, 'c × 'd) pat" where
  "spair_pat p1 p2 = (Λ(:x, y:). cpair_pat p1 p2(x, y))"

definition
  sinl_pat :: "('a, 'c) pat  ('a::pcpo  'b::pcpo, 'c) pat" where
  "sinl_pat p = sscasep(Λ x. fail)"

definition
  sinr_pat :: "('b, 'c) pat  ('a::pcpo  'b::pcpo, 'c) pat" where
  "sinr_pat p = sscase(Λ x. fail)p"

definition
  up_pat :: "('a, 'b) pat  ('a u, 'b) pat" where
  "up_pat p = fupp"

definition
  TT_pat :: "(tr, unit) pat" where
  "TT_pat = (Λ b. If b then succeed() else fail)"

definition
  FF_pat :: "(tr, unit) pat" where
  "FF_pat = (Λ b. If b then fail else succeed())"

definition
  ONE_pat :: "(one, unit) pat" where
  "ONE_pat = (Λ ONE. succeed())"

text ‹Parse translations (patterns)›
translations
  "_pat (XCONST Pair x y)" => "CONST cpair_pat (_pat x) (_pat y)"
  "_pat (XCONST spairxy)" => "CONST spair_pat (_pat x) (_pat y)"
  "_pat (XCONST sinlx)" => "CONST sinl_pat (_pat x)"
  "_pat (XCONST sinrx)" => "CONST sinr_pat (_pat x)"
  "_pat (XCONST upx)" => "CONST up_pat (_pat x)"
  "_pat (XCONST TT)" => "CONST TT_pat"
  "_pat (XCONST FF)" => "CONST FF_pat"
  "_pat (XCONST ONE)" => "CONST ONE_pat"

text ‹CONST version is also needed for constructors with special syntax›
translations
  "_pat (CONST Pair x y)" => "CONST cpair_pat (_pat x) (_pat y)"
  "_pat (CONST spairxy)" => "CONST spair_pat (_pat x) (_pat y)"

text ‹Parse translations (variables)›
translations
  "_variable (XCONST Pair x y) r" => "_variable (_args x y) r"
  "_variable (XCONST spairxy) r" => "_variable (_args x y) r"
  "_variable (XCONST sinlx) r" => "_variable x r"
  "_variable (XCONST sinrx) r" => "_variable x r"
  "_variable (XCONST upx) r" => "_variable x r"
  "_variable (XCONST TT) r" => "_variable _noargs r"
  "_variable (XCONST FF) r" => "_variable _noargs r"
  "_variable (XCONST ONE) r" => "_variable _noargs r"

translations
  "_variable (CONST Pair x y) r" => "_variable (_args x y) r"
  "_variable (CONST spairxy) r" => "_variable (_args x y) r"

text ‹Print translations›
translations
  "CONST Pair (_match p1 v1) (_match p2 v2)"
      <= "_match (CONST cpair_pat p1 p2) (_args v1 v2)"
  "CONST spair(_match p1 v1)(_match p2 v2)"
      <= "_match (CONST spair_pat p1 p2) (_args v1 v2)"
  "CONST sinl(_match p1 v1)" <= "_match (CONST sinl_pat p1) v1"
  "CONST sinr(_match p1 v1)" <= "_match (CONST sinr_pat p1) v1"
  "CONST up(_match p1 v1)" <= "_match (CONST up_pat p1) v1"
  "CONST TT" <= "_match (CONST TT_pat) _noargs"
  "CONST FF" <= "_match (CONST FF_pat) _noargs"
  "CONST ONE" <= "_match (CONST ONE_pat) _noargs"

lemma cpair_pat1:
  "branch prx =   branch (cpair_pat p q)(csplitr)(x, y) = "
apply (simp add: branch_def cpair_pat_def)
apply (cases "px", simp_all)
done

lemma cpair_pat2:
  "branch prx = fail  branch (cpair_pat p q)(csplitr)(x, y) = fail"
apply (simp add: branch_def cpair_pat_def)
apply (cases "px", simp_all)
done

lemma cpair_pat3:
  "branch prx = succeeds 
   branch (cpair_pat p q)(csplitr)(x, y) = branch qsy"
apply (simp add: branch_def cpair_pat_def)
apply (cases "px", simp_all)
apply (cases "qy", simp_all)
done

lemmas cpair_pat [simp] =
  cpair_pat1 cpair_pat2 cpair_pat3

lemma spair_pat [simp]:
  "branch (spair_pat p1 p2)r = "
  "x  ; y  
      branch (spair_pat p1 p2)r(:x, y:) =
         branch (cpair_pat p1 p2)r(x, y)"
by (simp_all add: branch_def spair_pat_def)

lemma sinl_pat [simp]:
  "branch (sinl_pat p)r = "
  "x    branch (sinl_pat p)r(sinlx) = branch prx"
  "y    branch (sinl_pat p)r(sinry) = fail"
by (simp_all add: branch_def sinl_pat_def)

lemma sinr_pat [simp]:
  "branch (sinr_pat p)r = "
  "x    branch (sinr_pat p)r(sinlx) = fail"
  "y    branch (sinr_pat p)r(sinry) = branch pry"
by (simp_all add: branch_def sinr_pat_def)

lemma up_pat [simp]:
  "branch (up_pat p)r = "
  "branch (up_pat p)r(upx) = branch prx"
by (simp_all add: branch_def up_pat_def)

lemma TT_pat [simp]:
  "branch TT_pat(unit_whenr) = "
  "branch TT_pat(unit_whenr)TT = succeedr"
  "branch TT_pat(unit_whenr)FF = fail"
by (simp_all add: branch_def TT_pat_def)

lemma FF_pat [simp]:
  "branch FF_pat(unit_whenr) = "
  "branch FF_pat(unit_whenr)TT = fail"
  "branch FF_pat(unit_whenr)FF = succeedr"
by (simp_all add: branch_def FF_pat_def)

lemma ONE_pat [simp]:
  "branch ONE_pat(unit_whenr) = "
  "branch ONE_pat(unit_whenr)ONE = succeedr"
by (simp_all add: branch_def ONE_pat_def)


subsection ‹Wildcards, as-patterns, and lazy patterns›

definition
  wild_pat :: "'a  unit match" where
  "wild_pat = (Λ x. succeed())"

definition
  as_pat :: "('a  'b match)  'a  ('a × 'b) match" where
  "as_pat p = (Λ x. match_bind(px)(Λ a. succeed(x, a)))"

definition
  lazy_pat :: "('a  'b::pcpo match)  ('a  'b match)" where
  "lazy_pat p = (Λ x. succeed(cases(px)))"

text ‹Parse translations (patterns)›
translations
  "_pat _" => "CONST wild_pat"

text ‹Parse translations (variables)›
translations
  "_variable _ r" => "_variable _noargs r"

text ‹Print translations›
translations
  "_" <= "_match (CONST wild_pat) _noargs"

lemma wild_pat [simp]: "branch wild_pat(unit_whenr)x = succeedr"
by (simp add: branch_def wild_pat_def)

lemma as_pat [simp]:
  "branch (as_pat p)(csplitr)x = branch p(rx)x"
apply (simp add: branch_def as_pat_def)
apply (cases "px", simp_all)
done

lemma lazy_pat [simp]:
  "branch prx =   branch (lazy_pat p)rx = succeed(r)"
  "branch prx = fail  branch (lazy_pat p)rx = succeed(r)"
  "branch prx = succeeds  branch (lazy_pat p)rx = succeeds"
apply (simp_all add: branch_def lazy_pat_def)
apply (cases "px", simp_all)+
done

subsection ‹Examples›

term "Case t of (:up(sinlx), sinry:)  (x, y)"

term "Λ t. Case t of up(sinla)  a | up(sinrb)  b"

term "Λ t. Case t of (:up(sinl_), sinrx:)  x"

subsection ‹ML code for generating definitions›

ML local open HOLCF_Library in

infixr 6 ->>;
infix 9 ` ;

val beta_rules =
  @{thms beta_cfun cont_id cont_const cont2cont_APP cont2cont_LAM'} @
  @{thms cont2cont_fst cont2cont_snd cont2cont_Pair};

val beta_ss =
  simpset_of (put_simpset HOL_basic_ss context addsimps (@{thms simp_thms} @ beta_rules));

fun define_consts
    (specs : (binding * term * mixfix) list)
    (thy : theory)
    : (term list * thm list) * theory =
  let
    fun mk_decl (b, t, mx) = (b, fastype_of t, mx);
    val decls = map mk_decl specs;
    val thy = Cont_Consts.add_consts decls thy;
    fun mk_const (b, T, mx) = Const (Sign.full_name thy b, T);
    val consts = map mk_const decls;
    fun mk_def c (b, t, mx) =
      (Thm.def_binding b, Logic.mk_equals (c, t));
    val defs = map2 mk_def consts specs;
    val (def_thms, thy) =
      Global_Theory.add_defs false (map Thm.no_attributes defs) thy;
  in
    ((consts, def_thms), thy)
  end;

fun prove
    (thy : theory)
    (defs : thm list)
    (goal : term)
    (tacs : {prems: thm list, context: Proof.context} -> tactic list)
    : thm =
  let
    fun tac {prems, context} =
      rewrite_goals_tac context defs THEN
      EVERY (tacs {prems = map (rewrite_rule context defs) prems, context = context})
  in
    Goal.prove_global thy [] [] goal tac
  end;

fun get_vars_avoiding
    (taken : string list)
    (args : (bool * typ) list)
    : (term list * term list) =
  let
    val Ts = map snd args;
    val ns = Name.variant_list taken (Old_Datatype_Prop.make_tnames Ts);
    val vs = map Free (ns ~~ Ts);
    val nonlazy = map snd (filter_out (fst o fst) (args ~~ vs));
  in
    (vs, nonlazy)
  end;

(******************************************************************************)
(************** definitions and theorems for pattern combinators **************)
(******************************************************************************)

fun add_pattern_combinators
    (bindings : binding list)
    (spec : (term * (bool * typ) list) list)
    (lhsT : typ)
    (exhaust : thm)
    (case_const : typ -> term)
    (case_rews : thm list)
    (thy : theory) =
  let

    (* utility functions *)
    fun mk_pair_pat (p1, p2) =
      let
        val T1 = fastype_of p1;
        val T2 = fastype_of p2;
        val (U1, V1) = apsnd dest_matchT (dest_cfunT T1);
        val (U2, V2) = apsnd dest_matchT (dest_cfunT T2);
        val pat_typ = [T1, T2] --->
            (mk_prodT (U1, U2) ->> mk_matchT (mk_prodT (V1, V2)));
        val pat_const = Const (const_namecpair_pat, pat_typ);
      in
        pat_const $ p1 $ p2
      end;
    fun mk_tuple_pat [] = succeed_const Typeunit
      | mk_tuple_pat ps = foldr1 mk_pair_pat ps;

    (* define pattern combinators *)
    local
      val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));

      fun pat_eqn (i, (bind, (con, args))) : binding * term * mixfix =
        let
          val pat_bind = Binding.suffix_name "_pat" bind;
          val Ts = map snd args;
          val Vs =
              (map (K "'t") args)
              |> Old_Datatype_Prop.indexify_names
              |> Name.variant_list tns
              |> map (fn t => TFree (t, sortpcpo));
          val patNs = Old_Datatype_Prop.indexify_names (map (K "pat") args);
          val patTs = map2 (fn T => fn V => T ->> mk_matchT V) Ts Vs;
          val pats = map Free (patNs ~~ patTs);
          val fail = mk_fail (mk_tupleT Vs);
          val (vs, nonlazy) = get_vars_avoiding patNs args;
          val rhs = big_lambdas vs (mk_tuple_pat pats ` mk_tuple vs);
          fun one_fun (j, (_, args')) =
            let
              val (vs', nonlazy) = get_vars_avoiding patNs args';
            in if i = j then rhs else big_lambdas vs' fail end;
          val funs = map_index one_fun spec;
          val body = list_ccomb (case_const (mk_matchT (mk_tupleT Vs)), funs);
        in
          (pat_bind, lambdas pats body, NoSyn)
        end;
    in
      val ((pat_consts, pat_defs), thy) =
          define_consts (map_index pat_eqn (bindings ~~ spec)) thy
    end;

    (* syntax translations for pattern combinators *)
    local
      fun syntax c = Lexicon.mark_const (fst (dest_Const c));
      fun app s (l, r) = Ast.mk_appl (Ast.Constant s) [l, r];
      val capp = app const_syntaxRep_cfun;
      val capps = Library.foldl capp

      fun app_var x = Ast.mk_appl (Ast.Constant "_variable") [x, Ast.Variable "rhs"];
      fun app_pat x = Ast.mk_appl (Ast.Constant "_pat") [x];
      fun args_list [] = Ast.Constant "_noargs"
        | args_list xs = foldr1 (app "_args") xs;
      fun one_case_trans (pat, (con, args)) =
        let
          val cname = Ast.Constant (syntax con);
          val pname = Ast.Constant (syntax pat);
          val ns = 1 upto length args;
          val xs = map (fn n => Ast.Variable ("x"^(string_of_int n))) ns;
          val ps = map (fn n => Ast.Variable ("p"^(string_of_int n))) ns;
          val vs = map (fn n => Ast.Variable ("v"^(string_of_int n))) ns;
        in
          [Syntax.Parse_Rule (app_pat (capps (cname, xs)),
            Ast.mk_appl pname (map app_pat xs)),
           Syntax.Parse_Rule (app_var (capps (cname, xs)),
            app_var (args_list xs)),
           Syntax.Print_Rule (capps (cname, ListPair.map (app "_match") (ps,vs)),
            app "_match" (Ast.mk_appl pname ps, args_list vs))]
        end;
      val trans_rules : Ast.ast Syntax.trrule list =
          maps one_case_trans (pat_consts ~~ spec);
    in
      val thy = Sign.add_trrules trans_rules thy;
    end;

    (* prove strictness and reduction rules of pattern combinators *)
    local
      val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
      val rn = singleton (Name.variant_list tns) "'r";
      val R = TFree (rn, sortpcpo);
      fun pat_lhs (pat, args) =
        let
          val Ts = map snd args;
          val Vs =
              (map (K "'t") args)
              |> Old_Datatype_Prop.indexify_names
              |> Name.variant_list (rn::tns)
              |> map (fn t => TFree (t, sortpcpo));
          val patNs = Old_Datatype_Prop.indexify_names (map (K "pat") args);
          val patTs = map2 (fn T => fn V => T ->> mk_matchT V) Ts Vs;
          val pats = map Free (patNs ~~ patTs);
          val k = Free ("rhs", mk_tupleT Vs ->> R);
          val branch1 = Constbranch lhsT mk_tupleT Vs R;
          val fun1 = (branch1 $ list_comb (pat, pats)) ` k;
          val branch2 = Constbranch mk_tupleT Ts mk_tupleT Vs R;
          val fun2 = (branch2 $ mk_tuple_pat pats) ` k;
          val taken = "rhs" :: patNs;
        in (fun1, fun2, taken) end;
      fun pat_strict (pat, (con, args)) =
        let
          val (fun1, fun2, taken) = pat_lhs (pat, args);
          val defs = @{thm branch_def} :: pat_defs;
          val goal = mk_trp (mk_strict fun1);
          val rules = @{thms match_bind_simps} @ case_rews;
          fun tacs ctxt = [simp_tac (put_simpset beta_ss ctxt addsimps rules) 1];
        in prove thy defs goal (tacs o #context) end;
      fun pat_apps (i, (pat, (con, args))) =
        let
          val (fun1, fun2, taken) = pat_lhs (pat, args);
          fun pat_app (j, (con', args')) =
            let
              val (vs, nonlazy) = get_vars_avoiding taken args';
              val con_app = list_ccomb (con', vs);
              val assms = map (mk_trp o mk_defined) nonlazy;
              val rhs = if i = j then fun2 ` mk_tuple vs else mk_fail R;
              val concl = mk_trp (mk_eq (fun1 ` con_app, rhs));
              val goal = Logic.list_implies (assms, concl);
              val defs = @{thm branch_def} :: pat_defs;
              val rules = @{thms match_bind_simps} @ case_rews;
              fun tacs ctxt = [asm_simp_tac (put_simpset beta_ss ctxt addsimps rules) 1];
            in prove thy defs goal (tacs o #context) end;
        in map_index pat_app spec end;
    in
      val pat_stricts = map pat_strict (pat_consts ~~ spec);
      val pat_apps = flat (map_index pat_apps (pat_consts ~~ spec));
    end;

  in
    (pat_stricts @ pat_apps, thy)
  end

end

(*
Cut from HOLCF/Tools/domain_constructors.ML
in function add_domain_constructors:

    ( * define and prove theorems for pattern combinators * )
    val (pat_thms : thm list, thy : theory) =
      let
        val bindings = map #1 spec;
        fun prep_arg (lazy, sel, T) = (lazy, T);
        fun prep_con c (b, args, mx) = (c, map prep_arg args);
        val pat_spec = map2 prep_con con_consts spec;
      in
        add_pattern_combinators bindings pat_spec lhsT
          exhaust case_const cases thy
      end

*)

end