Require Import Metarouting.Signatures.DecSetoid.
Require Import Metarouting.Language.Syntax.
Require Import Metarouting.Language.Semantics.
Require Import Metarouting.Logic.Logic.
Require Import Coq.Lists.List.
Require Import Coq.Bool.Bool.

(******************************************************************)
(*             Reflecting carriers for Ocaml code                 *)
(******************************************************************)

Section OcamlTypes.
   Inductive ocamlTypes :=
      | Ocm_unit   : ocamlTypes
      | Ocm_int    : ocamlTypes
      | Ocm_bool   : ocamlTypes
      | Ocm_sum    : ocamlTypes -> ocamlTypes -> ocamlTypes
      | Ocm_prod   : ocamlTypes -> ocamlTypes -> ocamlTypes
      | Ocm_list   : ocamlTypes -> ocamlTypes.

   (* 
    *    ocamlTypes   ------->   Type
    *)
   Fixpoint ocmType (x : ocamlTypes) : Type :=
      match x with
         | Ocm_unit     => unit
         | Ocm_int      => nat
         | Ocm_bool     => bool
         | Ocm_sum x y  => ((ocmType x) + (ocmType y))%type
         | Ocm_prod x y => ((ocmType x) * (ocmType y))%type
         | Ocm_list x   => list (ocmType x)
      end.

   Infix "+" := Ocm_sum (at level 50, left associativity) : ocamlTypes_scope.
   Infix "*" := Ocm_prod (at level 40, left associativity) : ocamlTypes_scope.
   Notation "'unit'" := (Ocm_unit) : ocamlTypes_scope.
   Notation "'int'" := (Ocm_int) : ocamlTypes_scope.
   Notation "'bool'" := (Ocm_bool) : ocamlTypes_scope.
   Notation "'list' A" := (Ocm_list A) (at level 30) : ocamlTypes_scope.
   
   Open Scope ocamlTypes_scope.

   (*
    *    Lang    ---------->    ocamlTypes
    *)

   Fixpoint otDS (d : DS) : ocamlTypes :=
      match d with
         | dAddConstant x => (otDS x) + unit
         | dBool          => bool
         | dNat           => int
         | dProduct x y   => (otDS x) * (otDS y)
         | dRange n       => int
         | dUnion x y     => (otDS x) + (otDS y)
         | dUnit          => unit
         | dFSets x       => list (otDS x)
         | dFMinSets x    => list (otPO x)
         | dSeq x         => list (otDS x)
         | dSimpleSeq x   => list (otDS x)
         | dMultiSets x   => list (otDS x)
      end
   with otSG (s : SG) : ocamlTypes :=
      match s with
         | sBoolAnd           => bool
         | sBoolOr            => bool
         | sNatMax            => int
         | sNatMin            => int
         | sNatPlus           => int
         | sLex x y           => (otSG x) * (otSG y)
         | sProduct x y       => (otSG x) * (otSG y)
         | sRangeMax n        => int
         | sRangeMin n        => int
         | sRangePlus n       => int
         | sTopUnion x y      => ((otSG x) + (otSG y)) + unit
         | sUnion x y         => (otSG x) + (otSG y)
         | sUnionSwap x y     => (otSG y) + (otSG x)
         | sUnit              => unit
         | sFSetsIntersect x  => list (otDS x)
         | sFSetsUnion x      => list (otDS x)
         | sFSetsOp x         => list (otSG x)
         | sFMinSetsUnion x   => list (otPO x)
         | sFMinSetsOp x      => list (otOS x)
         | sLeft x            => otDS x
	 | sRight x            => otDS x
         | sSelLex x y        => (otSG x) * (otSG y)
         | sSeq x             => list (otDS x)
         | sPrefix x          => list (otDS x)
         | sPostfix x         => list (otDS x)
         | sSimpleSeq x       => list (otDS x)
	 | sRevOp x           => (otSG x)
         | sMultiSetsUnion x  => list (otDS x)
         | sMultiSetsIntersect x  => list (otDS x)
      end
   with otPO (p : PO) : ocamlTypes :=
      match p with
         | pDual x              => otPO x
         | pLeftNaturalOrder x  => otSG x
         | pRightNaturalOrder x => otSG x
         | pLex x y             => (otPO x) * (otPO y)
         | pNatLe               => int
         | pAnnTop x            => otSG x
      end
   with otOS (o : OS) : ocamlTypes :=
      match o with
         | oDual x               => otOS x
         | oLeftNaturalOrder x   => otSG x
         | oRightNaturalOrder x  => otSG x
         | oLex x y              => (otOS x) * (otOS y)
         | oBsLeftNaturalOrder x => otBS x
         | oSimpleSeq x          => list (otDS x)
      end
   with otBS (b : BS) : ocamlTypes :=
      match b with
         | bUnit           => unit
         | bBoolOrAnd      => bool
         | bNatMaxPlus     => int
         | bNatMinPlus     => int
         | bNatMaxMin      => int
         | bNatIMaxPlus    => unit + int
         | bNatIMinPlus    => unit + int
         | bNatIMaxMin     => unit + int
         | bRangeMaxPlus n => int
         | bRangeMinPlus n => int
         | bRangeMaxMin n  => int
         | bSwap x         => otBS x
         | bFMinSets x     => list (otOS x)
         | bFMinSetsOpUnion x => list (otOS x)
         | bFSets x        => list (otDS x)
         | bFSetsOp x      => list (otSG x)
         | bLex x y        => (otBS x) * (otBS y)
         | bProduct x y    => (otBS x) * (otBS y)
         | bLeft x         => otSG x
         | bAddZero x      => (otBS x) + unit
         | bAddOne x       => unit + (otBS x)
         | bSelLex x y     => (otBS x) * (otBS y)
         | bRevTimes x     => otBS x
         | bPrefixSeq x    => list (otDS x)
         | bPostfixSeq x   => list (otDS x)
         | bMultiSets x    => list (otDS x)
      end.
   
   Fixpoint otTF (x : TF) : ocamlTypes :=
      match x with
         | tId x        => otDS x
         | tReplace x   => otDS x
         | tProduct x y => (otTF x) * (otTF y)
         | tUnion x y   => otTF x
         | tCayley x    => otSG x
      end.
   
   Fixpoint otST (x : ST) : ocamlTypes :=
      match x with
         | stLeft x     => otSG x
         | stRight x    => otSG x
         | stLex x y    => (otST x) * (otST y)
         | stSelLex x y => (otST x) * (otST y)
         | stUnion x y  => otST x
         | stCayley x   => otBS x
      end.

   Definition otLang (x : Lang) : ocamlTypes :=
      match x with
         | dsInc x => otDS x
         | sgInc x => otSG x
         | poInc x => otPO x
         | osInc x => otOS x
         | bsInc x => otBS x
         | tfInc x => otTF x
         | stInc x => otST x
      end.

   (*  We want to check that the following commutes if semantics exists
    * 
    *   Lang   ---------->  ocamlTypes
    *
    *     |                     |
    *     |                     |
    *     |                     |
    *     | SemRel              |
    *     |                     |
    *     v                     v
    *
    *   Sem    ---------->    Type
    *)

   Definition otDS_spec : Prop := forall x a, DsSemRel x a -> carrier a = ocmType (otDS x).
   Definition otSG_spec : Prop := forall x a, SgSemRel x a -> carrier a = ocmType (otSG x).
   Definition otPO_spec : Prop := forall x a, PoSemRel x a -> carrier a = ocmType (otPO x).
   Definition otOS_spec : Prop := forall x a, OsSemRel x a -> carrier a = ocmType (otOS x).
   Definition otBS_spec : Prop := forall x a, BsSemRel x a -> carrier a = ocmType (otBS x).
   Definition otTF_spec : Prop := forall x a, TfSemRel x a -> carrier a = ocmType (otTF x).
   Definition otST_spec : Prop := forall x a, StSemRel x a -> carrier a = ocmType (otST x).

   Lemma otLang_correct :
      otDS_spec * otSG_spec * otPO_spec * otOS_spec * otBS_spec * otTF_spec * otST_spec.
   Proof. apply SemRel_ind; auto; simpl;
      try (intros x a r e; rewrite <- e; auto; fail);
      try (intros x a y b r e r' e'; rewrite <- e, <- e'; auto; fail).
      
      (* unionTransform *)
      intros x a y b r e r' e' iso;
      destruct a; destruct setoid; destruct b; destruct setoid; simpl in *; auto.
      
      (* unionSemigroupTransform *)
      intros x a y b r e r' e' iso;
      destruct a; destruct setoid; destruct b; destruct setoid; simpl in *; auto.      
   Qed.

End OcamlTypes.

(*********************************************************************)
(*             Build isomorphic setoids from split mono              *)
(*********************************************************************)

Section SplitMono.
   Set Implicit Arguments.

   Record splitMono (T T' : Type) := {
      mono      : T -> T';
      mono_inv  : T' -> T;
      mono_spec : forall x, mono_inv (mono x) = x
   }.

End SplitMono.

Section OcamlAst.
   Set Implicit Arguments.

   Inductive ast :=
      | Ast_unit : ast
      | Ast_bool : bool -> ast
      | Ast_int  : nat -> ast
      | Ast_inl  : ast -> ast
      | Ast_inr  : ast -> ast
      | Ast_prod : ast -> ast -> ast
      | Ast_list : list ast -> ast.

   Inductive wfAst : ocamlTypes -> Set :=
      | WfAst_unit : wfAst Ocm_unit
      | WfAst_bool : bool -> wfAst Ocm_bool
      | WfAst_int  : nat -> wfAst Ocm_int
      | WfAst_inl  : forall t1 t2, wfAst t1 -> wfAst (Ocm_sum t1 t2)
      | WfAst_inr  : forall t1 t2, wfAst t2 -> wfAst (Ocm_sum t1 t2)
      | WfAst_prod : forall t1 t2, wfAst t1 -> wfAst t2 -> wfAst (Ocm_prod t1 t2)
      | WfAst_list : forall t, list (wfAst t) -> wfAst (Ocm_list t).

   Fixpoint ast_forget {t} (x : wfAst t) : ast :=
      match x with
         | WfAst_unit => Ast_unit
         | WfAst_bool b => Ast_bool b
         | WfAst_int n  => Ast_int n
         | WfAst_inl _ _ x => Ast_inl (ast_forget x)
         | WfAst_inr _ _ x => Ast_inr (ast_forget x)
         | WfAst_prod _ _ x y => Ast_prod (ast_forget x) (ast_forget y)
         | WfAst_list _ l => Ast_list (map ast_forget l)
      end.
      
   Fixpoint ast_typecheck (t : ocamlTypes) (x : ast) : bool :=
      match t, x with
         | Ocm_unit, Ast_unit => true
         | Ocm_bool, Ast_bool _ => true
         | Ocm_int,  Ast_int _ => true
         | Ocm_sum t1 t2, Ast_inl x => ast_typecheck t1 x
         | Ocm_sum t1 t2, Ast_inr x => ast_typecheck t2 x
         | Ocm_prod t1 t2, Ast_prod x y => ast_typecheck t1 x && ast_typecheck t2 x
         | Ocm_list t1, Ast_list l => forallb (ast_typecheck t) l
         | _, _ => false
      end.

   (* Not well typed*)
   Inductive notWT :=
      | NotWT : ast -> ocamlTypes -> notWT.

   Fixpoint ast_wfAst (t : ocamlTypes) (x : ast) : wfAst t + notWT :=
      match t as t0, x return wfAst t0 + notWT with
         | Ocm_unit, Ast_unit       => inl notWT (WfAst_unit)
         | Ocm_bool, Ast_bool b     => inl notWT (WfAst_bool b)
         | Ocm_int,  Ast_int i      => inl notWT (WfAst_int i)
         | Ocm_sum t1 t2, Ast_inl x => match ast_wfAst t1 x with
                                         | inl x => inl notWT (WfAst_inl t2 x)
                                         | inr e => inr (wfAst (Ocm_sum t1 t2)) e
                                       end
         | Ocm_sum t1 t2, Ast_inr x => match ast_wfAst t2 x with
                                         | inl x => inl notWT (WfAst_inr t1 x)
                                         | inr e => inr (wfAst (Ocm_sum t1 t2)) e
                                       end
         | Ocm_prod t1 t2, Ast_prod x y => match ast_wfAst t1 x, ast_wfAst t2 y with
                                             | inl x, inl y => inl notWT (WfAst_prod x y)
                                             | inl _, inr t => inr (wfAst (Ocm_prod t1 t2)) t
                                             | inr t, _ => inr (wfAst (Ocm_prod t1 t2)) t
                                           end
         | Ocm_list t1, Ast_list l => let l := fold_right (fun x l => 
                                                    match l with
                                                      | inl l => 
                                                         match ast_wfAst t1 x with
                                                           | inl x => inl notWT (x :: l)
                                                           | inr e => inr (list (wfAst t1)) e
                                                         end
                                                      | inr e => inr (list (wfAst t1)) e
                                                    end
                                                 ) (inl notWT nil) l
                                      in
                                      match l with
                                         | inl l => inl notWT (WfAst_list l)
                                         | inr e => inr (wfAst (Ocm_list t1)) e
                                      end
         | t, x => inr (wfAst t) (NotWT x t)
      end.

   Definition toAst {t} : ocmType t -> wfAst t.
      refine (fix toAst {t} : ocmType t -> wfAst t := _).
      destruct t; simpl.
      intros []; apply WfAst_unit.
      intros n; apply (WfAst_int n).
      intros b; apply (WfAst_bool b).
      intros [x|x]; [ apply WfAst_inl | apply WfAst_inr ]; auto.
      intros [x y]; apply WfAst_prod; auto.
      intros l; apply WfAst_list. apply (map (toAst t) l).
   Defined.

   Definition fromAst {t} : wfAst t -> ocmType t.
      refine (fix fromAst {t} (x : wfAst t) {struct x} : ocmType t := _).
      destruct x; simpl.
      apply tt.
      apply b.
      apply n.
      apply (inl _ (fromAst _ x)).
      apply (inr _ (fromAst _ x)).
      apply (fromAst _ x1, fromAst _ x2).
      apply (map (fromAst t) l).
   Defined.

   Definition ast_to_ocmType (t : ocamlTypes) (x : ast) : ocmType t + notWT :=
      match ast_wfAst t x with
         | inl x => inl _ (fromAst x)
         | inr e => inr _ e
      end.

   Definition ocmType_to_ast (t : ocamlTypes) (x : ocmType t) : ast :=
      ast_forget (toAst x).
   
   (**
     * toAst is a split mono with inverse fromAst
     *)
   Lemma wfWfAst_inv : forall {t} (x : ocmType t), fromAst (toAst x) = x.
   Proof. induction t; simpl.
      intros []; auto.
      intros n; auto.
      intros b; auto.
      intros [x | x]; simpl; rewrite ?IHt1, ?IHt2; auto.
      intros [x y]; simpl; rewrite IHt1, IHt2; auto.
      intros l; simpl.
      induction l; auto; simpl; rewrite IHt, IHl; auto.
   Qed.

   Definition astSplitMono {t} :=
      Build_splitMono
         (@toAst t)
         (@fromAst t)
         (@wfWfAst_inv t).

End OcamlAst.





