functor
  (M : MatrixSig) (A : sig
                         type arr = M.arr
                         type elt = M.elt
                         val empty : int array -> arr
                         val zeros : int array -> arr
                         val uniform : ?scale:elt -> int array -> arr
                         val gaussian : ?sigma:elt -> int array -> arr
                         val bernoulli :
                           ?p:float -> ?seed:int -> int array -> arr
                         val shape : arr -> int array
                         val numel : arr -> int
                         val reset : arr -> unit
                         val reshape : arr -> int array -> arr
                         val sum_slices : ?axis:int -> arr -> arr
                         val print : arr -> unit
                         val abs : arr -> arr
                         val neg : arr -> arr
                         val floor : arr -> arr
                         val ceil : arr -> arr
                         val round : arr -> arr
                         val sqr : arr -> arr
                         val sqrt : arr -> arr
                         val log : arr -> arr
                         val log2 : arr -> arr
                         val log10 : arr -> arr
                         val exp : arr -> arr
                         val sin : arr -> arr
                         val cos : arr -> arr
                         val tan : arr -> arr
                         val sinh : arr -> arr
                         val cosh : arr -> arr
                         val tanh : arr -> arr
                         val asin : arr -> arr
                         val acos : arr -> arr
                         val atan : arr -> arr
                         val asinh : arr -> arr
                         val acosh : arr -> arr
                         val atanh : arr -> arr
                         val sum : arr -> elt
                         val signum : arr -> arr
                         val l1norm : arr -> elt
                         val l2norm : arr -> elt
                         val l2norm_sqr : arr -> elt
                         val sigmoid : arr -> arr
                         val relu : arr -> arr
                         val clip_by_l2norm : elt -> arr -> arr
                         val pow : arr -> arr -> arr
                         val scalar_pow : elt -> arr -> arr
                         val pow_scalar : arr -> elt -> arr
                         val atan2 : arr -> arr -> arr
                         val scalar_atan2 : elt -> arr -> arr
                         val atan2_scalar : arr -> elt -> arr
                         val add : arr -> arr -> arr
                         val sub : arr -> arr -> arr
                         val mul : arr -> arr -> arr
                         val div : arr -> arr -> arr
                         val add_scalar : arr -> elt -> arr
                         val sub_scalar : arr -> elt -> arr
                         val mul_scalar : arr -> elt -> arr
                         val div_scalar : arr -> elt -> arr
                         val scalar_add : elt -> arr -> arr
                         val scalar_sub : elt -> arr -> arr
                         val scalar_mul : elt -> arr -> arr
                         val scalar_div : elt -> arr -> arr
                         type padding
                         val conv2d :
                           ?padding:padding -> arr -> arr -> int array -> arr
                         val conv2d_backward_input :
                           arr -> arr -> int array -> arr -> arr
                         val conv2d_backward_kernel :
                           arr -> arr -> int array -> arr -> arr
                         val conv3d :
                           ?padding:padding -> arr -> arr -> int array -> arr
                         val conv3d_backward_input :
                           arr -> arr -> int array -> arr -> arr
                         val conv3d_backward_kernel :
                           arr -> arr -> int array -> arr -> arr
                         val max_pool2d :
                           ?padding:padding ->
                           arr -> int array -> int array -> arr
                         val max_pool3d :
                           ?padding:padding ->
                           arr -> int array -> int array -> arr
                         val avg_pool2d :
                           ?padding:padding ->
                           arr -> int array -> int array -> arr
                         val avg_pool3d :
                           ?padding:padding ->
                           arr -> int array -> int array -> arr
                         val max_pool2d_backward :
                           padding ->
                           arr -> int array -> int array -> arr -> arr
                         val avg_pool2d_backward :
                           padding ->
                           arr -> int array -> int array -> arr -> arr
                       end->
  sig
    type arr = A.arr
    type mat = M.mat
    type elt = M.elt
    type padding = A.padding
    type trace_op
    type t =
        F of float
      | Arr of Owl_algodiff_generic.Make.arr
      | Mat of Owl_algodiff_generic.Make.mat
      | DF of Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t * int
      | DR of Owl_algodiff_generic.Make.t *
          Owl_algodiff_generic.Make.t Pervasives.ref *
          Owl_algodiff_generic.Make.trace_op * int Pervasives.ref * int
    module Maths :
      sig
        val add :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sub :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val mul :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val div :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val dot :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val pow :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val atan2 :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val min2 :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val max2 :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val cross_entropy :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val neg : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val abs : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val signum :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val floor :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ceil : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val round :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sqr : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sqrt : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val log : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val log2 : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val log10 :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val exp : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sin : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val cos : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val tan : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sinh : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val cosh : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val tanh : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val asin : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val acos : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val atan : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val asinh :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val acosh :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val atanh :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sum : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val average :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val transpose :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val l1norm :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val l2norm :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val l2norm_sqr :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sigmoid :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val relu : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val softplus :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val softsign :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val softmax :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( + ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( - ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( * ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( / ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( *@ ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val ( ** ) :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val conv2d :
          ?padding:Owl_algodiff_generic.Make.padding ->
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t ->
          int array -> Owl_algodiff_generic.Make.t
        val conv3d :
          ?padding:Owl_algodiff_generic.Make.padding ->
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t ->
          int array -> Owl_algodiff_generic.Make.t
        val max_pool2d :
          Owl_algodiff_generic.Make.padding ->
          Owl_algodiff_generic.Make.t ->
          int array -> int array -> Owl_algodiff_generic.Make.t
        val avg_pool2d :
          Owl_algodiff_generic.Make.padding ->
          Owl_algodiff_generic.Make.t ->
          int array -> int array -> Owl_algodiff_generic.Make.t
        val reshape :
          Owl_algodiff_generic.Make.t ->
          int array -> Owl_algodiff_generic.Make.t
        val flatten :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val mat_to_arr :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val arr_to_mat :
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val dropout :
          ?rate:float ->
          ?seed:int ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
      end
    module Mat :
      sig
        val empty : int -> int -> Owl_algodiff_generic.Make.t
        val zeros : int -> int -> Owl_algodiff_generic.Make.t
        val uniform :
          ?scale:float -> int -> int -> Owl_algodiff_generic.Make.t
        val gaussian :
          ?sigma:float -> int -> int -> Owl_algodiff_generic.Make.t
        val shape : Owl_algodiff_generic.Make.t -> int * int
        val numel : Owl_algodiff_generic.Make.t -> int
        val row_num : Owl_algodiff_generic.Make.t -> int
        val col_num : Owl_algodiff_generic.Make.t -> int
        val reset : Owl_algodiff_generic.Make.t -> unit
        val reshape :
          int ->
          int -> Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val get :
          Owl_algodiff_generic.Make.t ->
          int -> int -> Owl_algodiff_generic.Make.t
        val set :
          Owl_algodiff_generic.Make.t ->
          int ->
          int -> Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val add :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val sub :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val mul :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val div :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val dot :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val clip_by_l2norm :
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val mapi :
          (int ->
           int ->
           Owl_algodiff_generic.Make.elt -> Owl_algodiff_generic.Make.elt) ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val iter2_rows :
          (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t -> unit) ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t -> unit
        val map_by_row :
          (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
          Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
        val draw_rows2 :
          ?replacement:bool ->
          Owl_algodiff_generic.Make.t ->
          Owl_algodiff_generic.Make.t ->
          int ->
          Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t *
          int array
      end
    module Arr :
      sig
        val empty : int array -> Owl_algodiff_generic.Make.t
        val zeros : int array -> Owl_algodiff_generic.Make.t
        val uniform :
          ?scale:float -> int array -> Owl_algodiff_generic.Make.t
        val gaussian :
          ?sigma:float -> int array -> Owl_algodiff_generic.Make.t
        val shape : Owl_algodiff_generic.Make.t -> int array
        val numel : Owl_algodiff_generic.Make.t -> int
        val reset : Owl_algodiff_generic.Make.t -> unit
        val reshape :
          Owl_algodiff_generic.Make.t ->
          int array -> Owl_algodiff_generic.Make.t
      end
    val diff :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val diff' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val grad :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val grad' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val jacobian :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val jacobian' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val jacobianv :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val jacobianv' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val jacobianTv :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val jacobianTv' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val hessian :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val hessian' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val hessianv :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val hessianv' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val laplacian :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val laplacian' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val gradhessian :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val gradhessian' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t *
      Owl_algodiff_generic.Make.t
    val gradhessianv :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t
    val gradhessianv' :
      (Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t) ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t * Owl_algodiff_generic.Make.t *
      Owl_algodiff_generic.Make.t
    val pack_flt :
      Owl_algodiff_generic.Make.elt -> Owl_algodiff_generic.Make.t
    val unpack_flt :
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.elt
    val pack_arr :
      Owl_algodiff_generic.Make.arr -> Owl_algodiff_generic.Make.t
    val unpack_arr :
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.arr
    val pack_mat :
      Owl_algodiff_generic.Make.mat -> Owl_algodiff_generic.Make.t
    val unpack_mat :
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.mat
    val tag : unit -> int
    val primal : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val primal' : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val adjval : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val adjref :
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t Pervasives.ref
    val tangent : Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
    val make_forward :
      Owl_algodiff_generic.Make.t ->
      Owl_algodiff_generic.Make.t -> int -> Owl_algodiff_generic.Make.t
    val make_reverse :
      Owl_algodiff_generic.Make.t -> int -> Owl_algodiff_generic.Make.t
    val reverse_prop :
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t -> unit
    val type_info : Owl_algodiff_generic.Make.t -> string
    val shape : Owl_algodiff_generic.Make.t -> int array
    val clip_by_l2norm :
      Owl_algodiff_generic.Make.elt ->
      Owl_algodiff_generic.Make.t -> Owl_algodiff_generic.Make.t
  end