AN EXACT POLY-TIME MEMBERSHIP-QUERIES AL-GORITHM FOR EXTRACTING A THREE-LAYER RELU NETWORK

Abstract

We consider the natural problem of learning a ReLU network from queries, which was recently remotivated by model extraction attacks. In this work, we present a polynomial-time algorithm that can learn a depth-two ReLU network from queries under mild general position assumptions. We also present a polynomial-time algorithm that, under mild general position assumptions, can learn a rich class of depth-three ReLU networks from queries. For instance, it can learn most networks where the number of first layer neurons is smaller than the dimension and the number of second layer neurons. These two results substantially improve state-of-the-art: Until our work, polynomial-time algorithms were only shown to learn from queries depth-two networks under the assumption that either the underlying distribution is Gaussian (Chen et al. ( 2021)) or that the weights matrix rows are linearly independent (Milli et al. (2019) ). For depth three or more, there were no known poly-time results.

1. INTRODUCTION

With the growth of neural-network-based applications, many commercial companies offer machine learning services, allowing public use of trained networks as a black-box. Those networks allow the user to query the model and, in some cases, return the exact output of the network to allow the users to reason about the model's output. Yet, the parameters of the model and its architecture are considered the companies' intellectual property, and they do not often wish to reveal it. Moreover, sometimes the training phase uses sensitive data, and as demonstrated in Zhang et al. (2020) , inversion attacks can expose those sensitive data to one who has the trained model. Nevertheless, the model is still vulnerable to membership query attacks even as a black box. A recent line of works (Tramer et al. (2016) , Shi et al. (2017) , Milli et al. (2019) , Rolnick & Körding (2020) , Carlini et al. (2020) , Fornasier et al. (2021) ) showed either empirically or theoretically that using a specific set of queries, one can reconstruct some hidden models. Theoretical work includes Chen et al. (2021) that proposed a novel algorithm that, under the Gaussian distribution, can approximate a two-layer model with ReLU activation in a guaranteed polynomial time and query complexity without any further assumptions on the parameters. Likewise, Milli et al. (2019) has shown how to exactly extract the parameters of depth-two networks, assuming that the weight matrix has independent rows (in particular, the number of neurons is at most the input dimension). Our work extends their work by showing: 1. A polynomial time and query complexity algorithm for exact reconstruction of a two-layer neural network with any number of hidden neurons, under mild general position assumptions; and 2. A polynomial time and a query complexity algorithm for exact reconstruction of a threelayer neural network under mild general position assumptions, with the additional assumptions that the number of first layer neurons is smaller than the input dimension and the assumption that the second layer has non-zero partial derivatives. The last assumption is valid for most networks with more second layer neurons than first layer neurons. The mild general position assumptions are further explained in section 3.3. However, we note that the proposed algorithm will work on any two-layer neural network except for a set with a zero Lebesgue measure. Furthermore, it will work in polynomial time provided that the input weights are slightly perturbed (for instance, each weight is perturbed by adding a uniform number in [-2 -d , 2 -d ]) At a very high level, the basis of our approach is to find points in which the linearity of the network breaks and extract neurons by recovering the affine transformations computed by the network near these points. This approach was taken by the previous theoretical papers Milli et al. (2019) ; Chen et al. (2021) and also in the empirical works of Carlini et al. (2020) ; Jagielski et al. (2019) . In order to derive our results, we add several ideas to the existing techniques, including the ability to distinguish first from second layer neurons, which allows us to deal with three-layer networks, as well as the ability to reconstruct the neurons correctly in general depth-two networks with any finite width in a polynomial time, without assuming that the rows are independent.

2. RESULTS

We next describe our results. Our results will assume a general position assumption quantified by a parameter δ ∈ (0, 1), and a network that satisfies our assumption with parameter δ will be called δ-regular. This assumption is defined in section 3.3. We note, however, that a slight perturbation of the network weights, say, adding to each weight a uniform number in [-2 -d , 2 -d ], guarantees that w.p. 1 -2 -d the network will be δ-regular with δ that is large enough to guarantee polynomial time complexity. Thus, δ-regularity is argued to be a mild general position assumption. Throughout the paper, we denote by Q the time it takes to make a single query.

2.1. DEPTH TWO NETWORKS

Consider a 2-layer network model given by M(x) = d1 j=1 u j ϕ (⟨w j , x⟩ + b j ) where ϕ(x) = x + = max(x, 0) is the ReLU function, and for any j ∈ [d 1 ], w j ∈ R d , b j ∈ R, and u j ∈ R. We assume that the w j 's, the b j 's and the u j 's, along with the width d 1 , are unknown to the user, which has only black box access to M(x), for any x ∈ R d . We do not make any further assumptions on the network weights, rather than δ-regularity. Theorem 1. There is an algorithm that given an oracle access to a δ-regular network as in equation 1, reconstructs it using O (d 1 log(1/δ) + d 1 d) Q + d 2 d 1 time and O (d 1 log(1/δ) + d 1 d) queries. We note that by reconstruction we mean that the algorithm will find d ′ 1 and weights w ′ 0 , . . . , w ′ d ′ 1 ∈ R d , b ′ 0 , . . . , b ′ d ′ 1 ∈ R, and u ′ 1 , . . . , u ′ d ′ 1 ∈ R such that ∀x ∈ R d , M(x) = ⟨w ′ 0 , x⟩ + b ′ 0 + d ′ 1 j=1 u ′ j ϕ w ′ j , x + b ′ j . (2) We will also prove a similar result for the case that the algorithm is allowed to query the network just on points in R d + , but on the other hand, equation equation 2 needs to be satisfied just for x ∈ R d + . This case is essential for reconstructing depth-three networks, and we will call it the R d + -restricted case. Theorem 2. In the R d + -restricted case there is an algorithm that given an oracle access to a δregular network as in equation 1, reconstructs it using O (dd 1 log(1/δ) + d 1 d) Q + d 2 d 2 1 time and O (dd 1 log(1/δ) + d 1 d) queries.

2.2. DEPTH THREE NETWORKS

Consider a 3-layer network given by M(x) = ⟨u, ϕ(V ϕ(W x + b) + c)⟩ (3) where W ∈ R d1×d , b ∈ R d1 , V ∈ R d2×d1 , c ∈ R d2 , u ∈ R d2 and ϕ is the ReLU function defined element-wise. We assume W , V , u, b, c, along with d 1 and d 2 , are unknown to the user, which have only black box access to M(x) for any x ∈ R d . Besides δ-regularity we will assume that (i) d 1 ≤ d and that (ii) the top layer has non-zero partial derivatives: For the second layer function F : R d1 → R given by F (x) = ⟨u, ϕ(V x + c)⟩ we assume that for any x ∈ R d1 + and j ∈ [d 1 ], the derivative of F in the direction of e (j) and -e (j) is not zero. We note that if d 2 is large compared to d 1 (d 2 ≥ 3.5d 1 would be enough) this assumption is valid for most choices of u, V and c (see theorem 5). Theorem 3. There is an algorithm that given an oracle access to a δ-regular network as in equation 2, with d 1 ≤ d and top layer with non-zero partial derivatives, reconstruct it using poly(d, d 1 , d 2 , log(1/δ)) time and queries. By reconstruction we mean that the algorithm will find d ′ 1 , d ′ 2 ∈ N, weights v ′ 0 , . . . , v ′ d ′ 2 ∈ R d ′ 1 , c ′ 0 , . . . , c ′ d ′ 2 ∈ R, u ′ 1 , . . . , u ′ d ′ 2 ∈ R, as well as a matrix W ′ ∈ R d ′ 1 ×d and a vector b ′ ∈ R d ′ 1 such that ∀x ∈ R d , M(x) = ⟨v ′ 0 , ϕ (W ′ x + b ′ )⟩ + c ′ 0 + d ′ 2 j=1 u ′ j ϕ v ′ j , ϕ (W ′ x + b ′ ) + c ′ j .

2.3. NOVELTY OF THE RECONSTRUCTIONS

Having an exact reconstruction is an essential task for extracting a model. While approximate reconstructions, such as in Chen et al. (2021) , may mimic the output of the extracted network, they cannot reveal information on the architecture, like the network's width. Moreover, an approximated reconstruction can be viewed as a regression task. For example, the work of Shi et al. (2017) used Naive Bayes and SVM models to predict the network's output. An exact reconstruction requires building new tools, as we provide in this work. Exploring the non-linearity parts of a network can offer information on the relations between the weights of a neuron up to a multiplicative factor. Specifically, the sign of a neuron is missing. Indeed: for the j'th neuron both (w j , b j ) and (-w j , -b j ) have the property of breaking the linearity of M(x) at the same values of x. To achieve the global signs of all the neurons, one requires either to restrict the width of the network (as in Milli et al. (2019) ) or to use brute-force over all possible combinations (as in Carlini et al. (2020) and Rolnick & Körding (2020) ). We bypass this challenge by allowing reconstruction up to an affine transformation and using the fact that for all x ∈ R d , ⟨w, x⟩ + b = ϕ(⟨w, x⟩ + b) -ϕ(-⟨w, x⟩ -b). This bypass allows the reconstruction of a network with any finite width in a polynomial time. Another technical novelty of the paper is an algorithm that can identify whether a neuron belongs to the first or the second layer. This allows us to handle a second hidden layer after peeling the first layer.

3.1. NOTATIONS AND TERMINOLOGY

We denote by e (1) , . . . , e (d) the standard basis of R The state of a neuron on a point x ∈ R d is the sign of the input of that neuron (either positive, negative, or zero). The state of a network on a point x ∈ R d is a description of the states of all neurons at x. Similarly, the state of the first layer at x ∈ R d is a description of the state of all first layer neurons at x. The angle between a hyperplane P with a normal vector n and a line {tx + y : t ∈ R} (or just a vector x ̸ = 0) is defined as x ∥x∥ , n . Likewise, the distance between two hyperplanes P 1 , P 2 with normal vectors n 1 , n 2 respectively, is given by D(P 1 , P 2 ) := 1 -⟨n 1 , n 2 ⟩ 2 . We say that a hyperplane is δ-general if its angle with all the d axes is at least δ. A hyperplane is general if it is δ-general for some δ > 0 (equivalently, it is not parallel to any axis).

3.2. PIECEWISE LINEAR FUNCTIONS

Let f : R d → R be piecewise linear, with finitely many pieces. A general point is a point x ∈ R d such that exists a neighborhood around x for which f is affine in that neighborhood. Furthermore, we say that the point x ∈ R d is a δ-general point if f is affine in B(x, δ). Complementarily, a critical point is a point x ∈ R d such that for every δ > 0, f is not affine in B(x, δ). A critical hyperplane is an affine hyperplane P, whose intersection with the set of critical points is of dimension d -1. For a critical hyperplane P, we say that a point x ∈ R d is P-critical if it is critical and x ∈ P. Figure 3 .2 illustrates the above definitions for the one-dimensional input case. Note that there are finitely many critical hyperplanes for any piecewise linear function, that any critical point belongs to at least one critical hyperplane, and mostfoot_0 critical points belong to exactly one critical hyperplane. We will call such points non-degenerate. Furthermore, we will say that a critical point x is δ-non-degenerate if exactly one critical hyperplane intersects with B(x, δ). For the function M computed by a network such as equation 1 or equation 3, we note that for any j ∈ [d 1 ], the hyperplane P j = {x : ⟨w j , x⟩ + b j = 0} is a critical hyperplane. In this case, we say that P corresponds to the jth neuron, and vice-verse. Also, if x is a critical point, then at least one of the neurons is in a critical state (i.e., its input is 0). In this case, we will say that x is a critical point of that neuron. We next describe a few simple algorithms related to piecewise linear functions that we will use frequently. Their correctness is given in section D of the appendix; here we briefly sketch the idea behind it.

3.2.1. RECONSTRUCTION OF AN AFFINE FUNCTION

We note that if x is an ϵ-general point of a function f , then one can reconstruct the affine function f computes over B(x, ϵ) with d + 1 queries in B(x, ϵ) and O(dQ) time. Algorithm 2 reconstructs the desired affine function.

3.2.2. RECONSTRUCTION OF CRITICAL POINTS IN ONE DIMENSION

We say that a piecewise linear one dimensional function f : R → R is δ-nice if: (1) All its critical points are in -1 δ , 1 δ \ (-δ, δ), (2) each piece is of length at least δ, (3) there are no two pieces that share the same affine function, and (4) all the points in the gridfoot_1 -⌈log 2 (2/δ 2 )⌉ δ Z are δ 2 -general. Given a δ-nice function, algorithm 1 recovers the left-most critical point in the range (a, 1/δ), if such a point exist, using O (log(1/δ)Q) time. In short, the algorithm works similar to a binary search, where each iteration splits the current range into two halves and keeps the left half if and only if it is not affine. Algorithm 1 FIND CP(δ, f, a): Single critical point reconstruction Input: Parameter δ < 1, black box access to a δ-nice f : R → R, and left limit a ∈ -1 δ , 1 δ Output: The left most critical point of f in (a, 1/δ). 1: Set x L = -1 δ , x R = 1 δ 2: for j = 1, . . . , ⌈log 2 (2/δ 2 )⌉ + 1 do 3: If x L +x R 2 ≤ a or AFFINE δ 2 (f, x L ) = AFFINE δ 2 f, x L +x R 2 , set x L = x L +x R 2 . Else, set x R = x L +x R 2 . 4: end for 5: Let Λ L = AFFINE δ 2 (f, x L ) and Λ R = AFFINE δ 2 (f, x R ). If Λ L = Λ R then return "no critical points in (a, 1/δ)". Else, return the point x for which Λ L (x) = Λ R (x) With algorithm 1 we can reconstruct all the critical points of f in a given range (a, b) ⊂ (-1/δ, 1/δ) in O (k log(1/δ)Q) time, where k is the number of critical points in (a, b). Indeed, we can invoke algorithm 1 to find the left-most critical point x 1 in (a, b), then the one on its right and so on, until there are no more critical points in (a, b).

3.2.3. RECONSTRUCTION OF A CRITICAL HYPERPLANE

Let f : R d → R be a piecewise linear function. Assume that x is a δ-non-degenerate P-critical point. If x 1 , x 2 ∈ B(x, δ) are two points on opposite sides of P, then P is the null space of Λ 1 -Λ 2 , where Λ 1 , Λ 2 are the affine functions computed by f near x 1 and x 2 . Algorithm 3 therefore reconstructs P in O(dQ) time.

3.2.4. CHECKING CONVEXITY/CONCAVITY IN A δ-NON-DEGENERATE CRITICAL POINT

Let f : R d → R be a piecewise linear function. Assume that x is a δ-non-degenerate P-critical point. As P is the intersection of exactly two affine functions, then f is necessarily convex or concave in B(x, δ). Furthermore, for any unit vector e that is not parallel 2 to P, we have that f is convex in B(x, δ) if and only if it is convex in [x -δe, x + δe], in which case the slope of t → f (x + te) in [-δ, 0] is strictly smaller then its slope in [0, δ]. Algorithm 4 therefore determine if f is convex or concave in B(x, δ) in O(Q) time. 3.2.5 DISTINGUISH ϵ-GENERAL POINT FROM ϵ-NON-DEGENERATE CRITICAL POINT Let f : R d → R be a piecewise linear function. Assume that x is either a ϵ-non-degenerate P-critical point or an ϵ-general point. Then by the definitions, for any unit vector e that is not parallel to P, x is critical if and only if the slope of t → f (x + te) is different in the segments [-ϵ, 0] and [0, ϵ]. Algorithm 5 therefore determine if x is critical in O(Q) time.

3.3. GENERAL POSITION ASSUMPTION

We say that a two-layers network as in equation 1 is δ-regular if the conditions for the inputs of algorithms 2-5 are met for the network and for any critical point that lies on the standard axes. For a three-layer network, as in equation 3, we also require that the above apply to the sub-network defined by the top two layers. A two-and three-layers network is called regular if it is δ-regular Algorithm 2 AFFINE ϵ (f, x) -Affine map reconstruction from ϵ-general point Input: Black box access to a piecewise linear f : R d → R, parameter ϵ > 0, and an ϵ-general point x ∈ R d Output: Vector w ∈ R d and b ∈ R such that ∀y ∈ B(x, ϵ), Λ w,b (y) = f (y) 1: Return w i = f (x+ϵe (i) )-f (x) ϵ and b = f (x) - d i=1 f (x+ϵe (i) )-f (x) ϵ x i Algorithm 3 FIND HP(f, δ, x) - Reconstruction of a critical hyperplane Input: Black box access to a piecewise linear f : R d → R, a parameter δ > 0, a δ-non-degenerate P-critical point x ∈ R d for δ-general P Output: w ∈ R d and b ∈ R such that P = {x : Λ w,b (x) = 0} 1: Set ϵ = (δ/2) 2 2: Using algorithm 2 obtain (w 1 , b 1 ) = AFFINE ϵ f, x + δ 2 e (1) and (w 2 , b 2 ) = AFFINE ϵ f, x -δ 2 e (1) 3: Return w = w 1 -w 2 and b = b 1 -b 2 Algorithm 4 IS CONVEX(f, δ, x) - Checking convexity/concavity Input: Black box access to a piecewise linear f : R d → R, a parameter δ > 0, a δ-non-degenerate P-critical point x ∈ R d for general P Output: Is x convex in f at B(x, δ) 1: if f (x+δe (1) )-f (x) > f (x)-f (x-δe (1) ) then 2: Return "convex" 3: else 4: Return "concave" 5: end if Algorithm 5 IS GENERAL(f, ϵ, x) -Distinguish general point from critical point Input: Black box access to a piecewise linear f : R d → R, a parameter ϵ > 0, a point x that is either ϵ-general or ϵ-non-degenerate P-critical point for general P Output: Is x general? 1: if f (x+ϵe (1) )-f (x) = f (x)-f (x-ϵe (1) ) then 2: Return "general", else return "critical" 3: end if for some δ > 0. A network is in general position if it is regular, and for three-layer networks, as in equation 3, we also require W to be surjective and that the top-layer will not have zero partial derivatives. A formal definition for a δ-regular network is given in section A of the appendix. Here we want to state sufficient conditions that ensure the regularity and general position of a network. The proofs are given in section A of the appendix. Lemma 1. The set of non-regular neural networks as in equation 1 and equation 3 have a zero Lebesgue measure. Lemma 2. Let M be a neural network as in equation 1 or equation 3. Let q be the number of neurons in the network, and let M > 0 be an upper bound on the absolute value of the weights. For each weight in the network, add a uniform element in [-2 -d , 2 -d ]. Then, the noisy network M ′ is δ-regular for δ > 0 such that log(1/δ) = poly(d log(qM )) with probability of 1 -2 -d . Lemma 3. For a general three-layers network as in equation 3, if d 1 ≤ d then W is surjective with probability 1. Lemma 4. For a general three-layers network as in equation 3, if 3.5d 1 ≤ d 2 then the top layer has non-zero partial derivatives with probability 1 -o(1). We note that the assumptions in section A may seem lengthy. The keen reader may notice overlaps between some of them and might suggest approaches to avoid others, for example, by adding randomization to the queries. Yet, we keep them as is for the fluency of reading, to emphasize the main concepts of the extraction. As training a network in practice begins from a random initialization, it is very likely for the network to be found in a regular position after the learning phase. Therefore, we took the freedom to ignore unlikely positions instead of combining them under a very restrictive rule.

3.4. RECONSTRUCTION OF DEPTH TWO NETWORK -SKETCH PROOF OF THEOREMS 1 AND 2

Recall that our goal is to recover a depth-two network in the form of equation equation 1. We will assume without loss of generality that the u i 's are in {±1}, as any neuron x → uϕ(⟨w, x⟩ + b) calculates the same function as x → u |u| ϕ(⟨|u|w, x⟩ + |u|b), as ReLU is a positive homogeneous function. Our algorithm will first find a critical point for each neuron. For a regular network, each critical hyperplane intersects the axis Re (1) exactly once, so we can reconstruct such a set of critical points by invoking algorithm 1 on the function t → M(te (1) ). We next reconstruct a single neuron corresponding to a given critical point x. For simplicity, assume that x is a δ-critical point of the j'th neuron. Using algorithm 3 we find an affine function Λ such that Λ = Λ wj ,bj or Λ = -Λ wj ,bj . Then, to recover u j , note that if u j = 1 then M(x) is strictly convex in B(x, δ) as the function u j ϕ(⟨w j , x⟩ + b j ) is convex. Similarly, if u j = -1 then M(x) is strictly concave in B(x, δ). Thus, we recover u j using using algorithm 4. Finally, note that ϕ(Λ(x)) is either ϕ(⟨w j , x⟩ + b j ) or ϕ(⟨w j , x⟩ + b j ) -⟨w j , x⟩ -b j . Hence, u j ϕ(Λ(x)) equals to u j ϕ(Λ wj ,bj (x)) up to an affine map. The approach is detailed in Algorithm 6. Algorithm 6 Recover depth-two network Input: Parameter δ and a black box access to a δ-regular network M as in equation 1Output: Weights such that for all x, M(x ) = Λ w ′ 0 ,b ′ 0 (x) + m i=1 u ′ i ϕ Λ w ′ i ,b ′ i (x) 1: Use repeatedly FIND CP(δ, t → M(te (1) ), •) to find all the critical points on the axis {te (1) : t ∈ R} (see section 3.2). Denote these points by x 1 , . . . , x m . 2: for i = 1, . . . , m do 3: Compute (w ′ i , b ′ i ) = FIND HP(M, δ, x i ). 4: If IS CONVEX(M, δ, x i ) = "convex" then set u ′ i = 1. Else, set u ′ i = -1. 5: end for 6: Calc (w ′ 0 , b ′ 0 ) = AFFINE δ x → M(x) - m i=1 u ′ i ϕ Λ w ′ i ,b ′ i (x) , x ′ for a random x ′ ∈ R d . 7: Return the function x → Λ w ′ 0 ,b ′ 0 (x) + m i=1 u ′ i ϕ Λ w ′ i ,b ′ i (x) . The following theorem proves the correctness of algorithm 6, and implies theorem 1. The proof is given in section B of the appendix. Theorem 4. Algorithm 6 reconstruct a δ-regular network in time O (log(1/δ) + d)d 1 Q + d 2 d 1 .

3.4.1. SKETCH PROOF OF THEOREM 2

Our algorithm for reconstruction of depth-two networks can be easily modified to work in the R d +restricted setting, with the difference that in order to reconstruct a δ-critical point for each neuron (step 1 in algorithm 6), we will need to search in the range 0, 1 δ e (i) for all i ∈ [d], as a critical hyperplane of a given neuron might not intersect with R + e (1) . Because of this change, each neuron might be discovered several times (up to d times), and we will need an additional step that combines neurons with the same affine map (up to a sign). For the particular case where the neuron has no critical points on the positive orthant, one can ignore it without affecting equation equation 2 for all x ∈ R d + . These changes will result in a total runtime of O(dd  and O(d 2 1 d 2 ) for combining similar neurons. The total runtime will therefore be O((log 1 log(1/δ)Q) instead of O(d 1 log(1/δ)Q) for step 1, O(d 2 d 1 Q) instead of O(dd 1 Q) for the loop, (1/δ) + d)dd 1 Q + d 2 d 2 1 ) . A formal proof is given in section B of the appendix.

3.5. RECONSTRUCTION OF DEPTH THREE NETWORK -SKETCH PROOF OF THEOREM 3

Recall that our goal is to recover a δ-regular network of the form M(x) = ⟨u, ϕ(V ϕ(W x + b) + c)⟩. We denote by w j the jth row of W and assume without loss of generality that it is of unit norm, as any neuron of the form x → ϕ(⟨w, x⟩ + b) can be replaced by x → ∥w∥ϕ w ∥w∥ , x + b ∥w∥ . Likewise, and similar to our algorithm for reconstruction of depth-two networks, we will assume that u ∈ {±1} d2 . The algorithm will be decomposed into four steps described in the following four subsections. In the first step, we will extract a set of critical hyperplanes that contains all the critical hyperplanes that correspond to a first layer neuron. In the second step, we will prune this list and will be left with a list that contains precisely the critical hyperplanes that correspond to a first layer neuron. In the third step, we will use this list to recover the first layer. Once the first layer is recovered, as the fourth step, we recover the second layer via a reduction to the problem of recovering a depth-two network.

3.5.1. EXTRACTING A SET CONTAINING THE CRITICAL HYPERPLANES OF THE FIRST LAYER

For the first step, we find a list L = (x 1 , P1 ), . . . , (x m , Pm ) of pairs such that: • For each k, Pk is a critical hyperplane of M and x k is a δ-non-degenerate critical point whose critical hyperplane is P k • The list contains all the critical hyperplanes of first-layer neurons We find those points using Algorithm 3. Note that m = O(d 1 d 2 ) (e.g. Telgarsky (2016) ). Let P1 , . . . , Pm be the critical hyperplanes corresponding to these points, found using Algorithm 3. Finally, lemma 8 below, together with δ-regularity implies that every hyperplane P that corresponds to a first layer neuron intersects Re (1) exactly once, and this intersection point is a δ-non-degenerate P-critical point.

3.5.2. IDENTIFYING FIRST LAYER CRITICAL HYPERPLANES

The next step is to take the list L = (x 1 , P1 ), . . . , (x m , Pm ) from the previous step, verify all the planes corresponding to first-layer neurons and remove all the other hyperplanes. The idea behind this verification is simple: If P corresponds to a neuron in the first layer then any point in P is a critical point of M (see lemma 8). On the other hand, if P corresponds to a neuron in the second layer, then not all its points are critical for M. Moreover, intersections with hyperplanes from the first layer change the input for the second layer neurons, hence creating a new piece that replaces P. Thus, in order to verify if P corresponds to a first layer neuron, we will go over all the hyperplanes Pk ∈ L, and for each of them, will find a point x ′ ∈ P that is the opposite side of Pk (relative to x) and will check if it is critical. If it is not critical for one of the hyperplanes, we know that P does not correspond to a first layer neuron. If all the points that we have examined are critical, even for Pk corresponded to a first layer neuron, then x ′ is critical, which means that P must correspond to a first layer neuron. Algorithm 7 implements this idea. There is one caveat that we need to handle: The examined point has to be generic enough in order to test whether it is critical or not using algorithm 5. To make sure that the point is general enough, we slightly perturb it. The correctness of the algorithm follows from lemmas 10 and 11 below. Due to the perturbations, the algorithm has a success probability of at least 1 -2 -d m over the choice of x ′ for each hyperplane and at least 1 -2 -d for all hyperplanes. Each step in the for-loop takes O(dQ) operations. As the list size is O(d 1 d 2 ), the total running time over all hyperplanes is O(d 2 1 d 2 2 dQ).

3.5.3. IDENTIFYING DIRECTIONS

Since the rows in W are assumed to have a unit norm, the list of the critical hyperplanes of the first-layer neurons, obtained in the previous step, determines the weights up to sign. In order to recover the correct sign of ( ŵ1 , b1 ), we can simply do the following test: Choose a point x such that ŵ1 x + b1 = 0, and query the network in the points x + ϵz, x -ϵz, for small ϵ, where z ∈ R d is a unit vector that has the property that is orthogonal to ŵ2 , . . . , ŵd1 , but ŵ1 z > 0. If we assume that W is right invertible, then such a z exists, as w 1 , . . . , w d1 are linearly independent. Let out(x) be the output of the first layer given some point x, then: out(x + ϵz) =     ⟨w 1 , x + ϵz⟩ + b 1 ⟨w 2 , x + ϵz⟩ + b 2 . . . ⟨w d1 , x + ϵz⟩ + b d1     =     ⟨w 1 , x⟩ + ⟨w 1 , ϵz⟩ +b 1 ⟨w 2 , x⟩ +b 2 . . . ⟨w d1 , x⟩ +b d1     = out(x) + ϵ ⟨w 1 , z⟩ e (1) Algorithm 7 Identifying whether a critical hyperplane corresponds to the first layer Input: A Black box access to a δ-regular network M as in equation 3, a list L = (x 1 , P1 ), . . . , (x m , Pm ) of pairs as described in section 3.5.1 and a pair (x, P) ∈ L Output: Does P correspond to a first layer neuron? 1: Choose δ ′ small enough such that 2 2(d1+d2) δ ′ √ 2 δ √ π ≤ 2 -d-1 m 2 2: Choose R > 0 large enough such that e -(R-δ ′ ) 2 2 ≤ 2 -d-1 m 2 3: for any k ∈ [m], such that Pk ̸ = P do 4: Choose a point z ∈ P such that z and x are separated by Pk , and d(z, P) > R

5:

Choose a standard Gaussian Z in P whose mean is z 6: If IS GENERAL(M, δ ′ , Z), return "P is not a first-layer critical hyperplane" 7: end for 8: Return "P is a first-layer critical hyperplane" Therefore, when moving from x to either x + ϵz or x -ϵz, only the first neuron changes, and after the ReLU activation function, only the positive direction will return a different value. Hence, in order to have the correct sign we can do the following: If M(x) ̸ = M(x + ϵz) then keep ( ŵ1 , b1 ). Else, replace it with (-ŵ1 , -b1 ). We repeat this method for all the neurons j ∈ [d 1 ]. The above method fails in the special case where both V ϕ(out(x + ϵz)) + c ≤ 0 and V ϕ(out(xϵz)) + c ≤ 0, which occur if the partial derivatives of the top-layer are zero at ϕ(x + ϵz) and ϕ(x -ϵz). As we showed on section 3.3, this is not expected if the second layer is wide enough. The runtime of this step is O(d 3 1 d + d 1 Q), as to find z we need to do Gram-Schmidt, which takes O(d 2 1 d), and additional two queries to find the sign.

3.5.4. RECONSTRUCTION OF THE TOP TWO LAYERS

After having the weights of the first layer at hand, and since W is assumed to be right invertible, we can directly access the sub-network defined by the top two layers. Namely, given x ∈ R d1 + , we can find z ∈ R d such that x = (ϕ (w 1 z + b 1 ) , . . . , ϕ (w d1 z + b d1 )) e.g., by taking z = W -1 (x -b) where W -1 is a right inverse of W . Now, M(z) is precisely the value of the top layer on the input x, and the problem boils down to the problem of reconstructing a depth two network in the R d + -restricted case, which we already solved. Now, the cost of a query to the second layer is Q plus the cost of computing z, which is O(dd 1 ). There is also an asymptotically negligible cost of O(d 2 1 d) for computing W -1 . The runtime of this step is therefore O((log(1/δ) + d 1 )d 1 d 2 (Q + dd 1 ) + d 2 1 d 2 2 ).

4. DISCUSSION AND SOCIAL IMPACT

This work continues a set of empirical and theoretical results, showing that extracting a ReLU network given membership queries is possible. Here we prove that two-and three-layer model extraction can be done in polynomial time. Our nonrestrictive assumptions make it feasible to construct a fully connected network, a convolutional network, and many other architectures. For practical use, our approach suffers several limitations. First, two-and three-layer networks are too shallow in practice. Second, exact access to the black-box network may not be feasible in practice. As the number of output bits is bounded, numerical inaccuracies may affect the reconstruction, especially when δ is very small. In that regard, our work is mostly theoretical in nature, showing that reconstruction is provably achievable. Yet, this work raises practical social concerns regarding the potential risks of membership-queries attacks. Extracting the exact parameters and architecture will allow attackers to reveal proprietary information and even construct adversarial examples. Therefore, uncovering those risks and creating a conversation on ways to protect against them is essential. As empirical evidence shows, we believe it is possible to prove similar results with even fewer assumptions for deeper models and more complex architectures. Furthermore, it might be interesting to investigate the methods of this paper when we restrict the queries and the outputs up to machine precision. We leave those challenges for future works.

A REGULAR NETWORKS

Definition 1. A neural network is called δ-regular if it satisfies the following requirements: 1. For each i ∈ [d], the piecewise linear function t → M(te (i) ) is δ-nice as defined in section 3.2.2. 2. Any critical point in the axes {te (i) : t ∈ R} is δ-non-degenerate. 3. Each critical hyperplane corresponds to a single neuron. 4. The distance between each pair of critical hyperplanes is at least δ 5. The angle between any critical hyperplane and an axis is at least δ. I.e., all critical hyperplanes are δ-general. 6. Each critical hyperplane P correspondingfoot_2 to a second layer neuron also corresponds to a single first layer state. That is, the state of the first layer is the same for any P-critical point. 7. In the case of depth-three networks, we assume that the above conditions also apply to the sub-network defined by the top two layers. While the definition above is lengthy, most of the requirements overlap, and we detailed them separately for ease of analysis. The following lemma shows that a regular network is expected from a random network. As an untrained network begins from a random initialization, it is very likely to be found in some random position after the learning phase. However, we note that some postprocessing methods, like weights-pruning, may affect the general position assumption; such cases should be given specific care and are not in the scope of this paper. Lemma 5. Let S be the set of networks as in equation 1 and equation 3 that violate at least one of the above: 1. For each i ∈ [d], the piecewise linear function t → M(te (i) ) is nice. 2. Any critical point in the axes {te (i) : t ∈ R} is non-degenerate. 3. Each critical hyperplane corresponds to a single neuron. 4. The distance between each pair of critical hyperplanes is non-zero. 5. All critical hyperplanes are general. 6. Each non-zero critical hyperplane P corresponding to a second layer neuron also corresponds to a single first layer state. 7. In the case of depth-three networks, we assume that the above conditions also apply to the sub-network defined by the top two layers. Then S has a zero Lebesgue measure. Proof. It is enough to show that each of the above has a zero measure, as a finite sum of sets with zero measure has a zero measure. The demand of a specific point x ∈ R d to be critical requires some critical hyperplane P such that x ∈ P. This imposes a linear constraint on the set of all such possible hyperplanes and reduces their degree of freedom. Any subspace with dimension < d has a zero Lebesgue measure in R d , which Published as a conference paper at ICLR 2023 is also the case of all the possible hyperplanes containing x. As a corollary, the set of hyperplanes that contains the points of 2 -⌈log 2 (2/δ 2 )⌉ δ Z has also a zero measure, as Z is sparse in R. As another corollary, once fixing a plane P, the set of planes collides with P exactly on the i'th axis, i ∈ [d], is of zero measure as well, which is the case where a critical point on one of the axes to be degenerate. Even a more degenerate case is where two neurons have the same hyperplane, which means both neurons have exactly the same parameters up to a factor. Obviously, this case has a zero measure in R d , which implies that with probability 1, a finite set of hyperplanes have a non-zero distance between each other. If we consider the points on the i'th axis, x ∈ R d : x, e (i) = 0 , as a hyperplane itself, then it is easy to see that a non-general hyperplane also has a zero measure. For a one-dimensional function to be nice, one must require that no two pieces share the same affine function. For functions of t → M(te (i) ), there are no two neurons whose i'th parameters are the same. Indeed, the opposite case, where two neurons share the exact same parameters, has a zero measure in R. As for depth-three networks, all the above is valid for the sub-network defined by the top layer. Moreover, we can consider the first-layer state as an affine transformation for the second layer's neurons. Therefore, in order for a second-layer critical hyperplane P to span two first-layer states, there must be two second-layer neurons that have the same parameters up to an affine transformation whose uniquely defined by the parameters of the first layer. As the set of all those affine transformations is finite, this imposes a finite set of possible constraints, and each has a zero measure. The following two lemmas state the effect of a small perturbation over δ. Lemma 6. Let M be a two-layers neural network as in equation 1. Let q be the number of neurons in the network, and let M be an upper bound on the absolute value of the weights. For each weight in the network, add a uniform element in [-2 -d , 2 -d ], and denote the the noisy network by M ′ . Then: 1. For each i ∈ [d], all critical points of the piecewise linear function t → M(te (i) ) are in -1 δ , 1 δ \ (-δ, δ) with probability 1 -p 1 = 1 -dqδ(M + 1)2 d+1 .

2.. For each i ∈ [d]

, each piece in the piecewise linear function t → M(te (i) ) is of length at least δ with probability 1 -p 2 = 1 -d2 d q 2 δ(M + 1).

3.. For each i ∈ [d],

all the points in the grid 2 -⌈log 2 (2/δ 2 )⌉ δ Z of the piecewise linear function t → M(te (i) ) are δ 2 -general with probability 1 -p 3 = 1 -3dqδ.

4.. Any critical point in the axes {te

(i) : t ∈ R} is δ-non-degenerate with probability 1 -p 4 = 1 -d 3/2 q 2 δ(M + 1)2 d-1 . 5. The distance between each pair of critical hyperplanes is at least δ with probability 1-p 5 = 1 -δq 2 √ dd(M + 1) 2 2 d+1 . 6. The angle between any critical hyperplane and any axis is at least δ with probability 1 - p 6 = 1 -dqδ2 d . Proof. Let the ith axis to be te (i) : t ∈ R . Denote by w j,i as the ith element of w j and by w ′ j,i = w j,i + r j,i to be the noisy value of w j,i , where r j,i ∼ U ([-2 -d , 2 -d ]). Similarly, let b ′ j = b j + s j to be the noisy value of b j , where s j ∼ U ([-2 -d , 2 -d ]). Then the jth neuron has a critical point on the ith axis when te (i) = t j,i e (i) where t j,i = - b ′ j w ′ j,i . Note that from lemma 5, we have almost surely that -∞ < t j,i < ∞.

1.. For each

i ∈ [d] and j ∈ [q], note that w ′ j,i ≤ M + 2 -d ≤ M + 1. Given α ∈ (0, 2 -d ), we have that: p( w ′ j,i > α) ≥ p(|r j,i | > α) = 1 - α 2 -d . Now, with probability 1 -α2 d+1 we have that both w ′ j,i > α and b ′ j > α, and therefore, by setting α = δ(M + 1): δ = α M + 1 < |t j,i | = b ′ j w ′ j,i < M + 1 α = 1 δ . To make the above valid for every i ∈ [d] and j ∈ [q], we can use the union bound to get an overall probability 1 -p 1 where: p 1 ≤ dqα2 d+1 = dqδ(M + 1)2 d+1 . 2. Assume the weights were perturbed in the following order: First, W ′ is perturbed. Second, the bias of the first neuron, b 1 , is defined, which sets its critical points with the axes, t 1,1 , . . . , t 1,d . As for the second neuron, we can ask what is the probability for b 2 to have a critical point that is δ-close to a critical point of the first neuron. That is, for some i ∈ [d], p (t 2,i ∈ B(t 1,i , δ)) = p - b 2 + r 2 w ′ 2,i ∈ (t 1,i -δ, t 1,i + δ) = p r 2 ∈ -w ′ 2,i (t 1,i + δ) -b 2 , -w ′ 2,i (t 1,i -δ) -b 2 ≤ 2 d δw ′ 2,i ≤ 2 d δ(M + 1) and using the union bound, p (∃i ∈ [d], t 2,i ∈ B(t 1,i , δ)) ≤ d2 d δ(M + 1). Now, let us continue with the perturbation, and for the jth neuron, note that the probability to intersect with any of the balls with radius δ around t 1,i , . . . , t j-1,i , i ∈ [d], is at most (j -1)d2 d δ(M + 1). Finally, the probability that all the pieces for all i ∈ [d] are of length at least δ is 1 -p 2 with: p 2 ≤ q j=2 (j -1)d2 d δ(M + 1) ≤ d2 d q 2 δ(M + 1). 3. Fix W ′ and some i ∈ [d]. Note that for the jth neuron, t j,i is uniform in L = -bj -2 -d wj,i , -bj +2 -d wj,i . As L is bounded, it intersects with the grid at most k = |L|δ 2 -⌈log 2 (2/δ 2 )⌉ times. Therefore, for all the points in the grid to be δ 2 -general, it means that a segment of length 2δ 2 k should not contain a critical point. As t j,i is uniform, the probability of avoiding that segment is therefore: 1 - 2δ 2 k |L| = 1 -δ 3 2 ⌈log 2 (2/δ 2 )⌉ ≥ 1 -3δ where the last inequality follows for δ ≤ 1. Overall, we get that the points in the grid are δ 2 -general with probability 1 -p 3 , where p 3 = 3dqδ. 4. Let P j the critical hyperplane defined by the jth neuron. The distance between P j and a critical point t k,i e (i) , k ̸ = j is D(P j , t k,i ) = w ′ j , t k,i e (i) + b ′ j ∥w j ∥ ≥ t k,i w ′ j,i + b j + s j √ d(M + 1) . As s j is a symmetric distribution around 0, we have with probability ≥ 1 2 that t k,i w ′ j,i + b j + s j ≥ |s j | and with probability 1 2 -α2 d-1 we have that Published as a conference paper at ICLR 2023 t k,i w ′ j,i + b j + s j ≥ |s j | ≥ α. If we set α = δ √ d(M + 1 ) then using the union bound we get: p (∃i ∈ [d], j ̸ = k, s.t. D(P j , t k,i ) ≤ δ) ≤ dq 2 δ √ d(M + 1)2 d-1 - 1 2 ≤ d 3/2 q 2 δ(M + 1)2 d-1 = p 4 . Note that if t k,i e (i) is far from every other critical hyperplane with at least δ, then it is δ-non-degenerate. Therefore, all the critical points on the axes are δ-non-degenerate with probability 1 -p 4 . 5. For any unit vector e we have that at least one of the coordinates is of absolute value at least 1/ √ d. Thus, p ⟨w ′ , e⟩ ∈ δ ′ -2 -d √ d , 2 -d √ d ≤ δ ′ and ⟨w, e⟩ 2 ≤ ∥w∥ 2 -⟨w, e ′ ⟩ 2 ≤ ∥w∥ 2 -δ ′ 2 -d √ d w.p. at least 1 -δ ′ . It follows that ⟨w,e⟩ 2 ∥w∥ 2 ≤ 1 -δ ′ 2 -d √ d∥w∥ 2 ≤ 1 -δ ′ 2 -d √ dd(M +1) 2 w.p. at least 1 -δ ′ . Taking roots we get that w ∥w∥ , e ≤ 1 - δ ′ 2 -d 2 √ dd(M +1) 2 . Hence, w.p. at least 1 -δ ′ , the distance is at least δ ′ 2 -d 2 √ dd(M +1) 2 ≥ δ ′ 2 -d 2 √ dd(M +1) 2 . Hence, we get for each pair of critical hyperplanes a distance of at least δ w.p. 1 -p 5 = 1 -δq 2 √ dd(M + 1) 2 2 d+1 . 6. The angle between the jth neuron and the axis te (i) : t ∈ R equals to w ′ j , e (i) = w ′ j,i . The probability for this to be at least δ is p( w ′ j,i > δ) ≥ p(|r j,i | > δ) = 1 - δ 2 -d . Using the union bound, we have that the probability for each neuron and each axis to have an angle of at least δ is 1 -p 6 , where p 6 = dqδ2 d . Lemma 7. Let M be a three-layers neural network as in equation 3. Let q be the number of neurons in the network, and let M be an upper bound on the absolute value of the weights. For each weight in the network, add a uniform element in [-2 -d , 2 -d ], and denote the the noisy network by M ′ . Then: 1. For each i ∈ [d], all critical points of the piecewise linear function t → M(te (i) ) are in -1 δ , 1 δ \ (-δ, δ) with probability 1 -p 1 = 1 -dq 2 δ(dM 2 + 2)2 d+1 . 2. For each i ∈ [d], each piece in the piecewise linear function t → M(te (i) ) is of length at least δ with probability 1 -p 2 = 1 -d2 d q 4 δ(dM 2 + 2).

3.. For each i ∈ [d],

all the points in the grid 2 -⌈log 2 (2/δ 2 )⌉ δ Z of the piecewise linear function t → M(te (i) ) are δ 2 -general with probability 1 -p 3 = 1 -3dq 2 δ.

4.. Any critical point in the axes {te

(i) : t ∈ R} is δ-non-degenerate with probability 1 -p 4 = 1 -d 3/2 q 4 δ(dM 2 + 2)2 d-1 . 5. The distance between each pair of critical hyperplanes is at least δ with probability 1-p 5 = 1 -δq 4 √ dd(dM 2 + 2) 2 2 d+1 . 6. The angle between any critical hyperplane and any axis is at least δ with probability 1p 6 = 1 -dq 2 δ2 d . 7. In the case of depth-three networks, we assume that the above conditions also apply to the sub-network defined by the top two layers with probability 1 -p 7 where p 7 is the sum of the probabilities of lemma 6. Proof. Lemma 5 tells us that each non-zero critical hyperplane P corresponding to a second layer neuron also corresponds to a single first layer state almost surely. Therefore, given i ∈ [d], we can consider the q ′ critical points that intersect with the ith axis as q ′ first-layer neurons, where each second neuron is multiplied by an affine transformation that is the current state of the first neurons. As each first layer neuron intersects with the axis at most once, and each second layer neuron intersects with the axis at most q 1 times, where q 1 is the number of first layer neurons, we can bound q ′ by q ′ ≤ q 2 . Furthermore, given a critical hyperplane P corresponding to a second layer neuron j, denote by (W ′ P , b ′ P ) the state of that first layer (which is the same as (W ′ , b ′ ) as defined in the proof of Lemma 6, except to some zero rows due to ReLU). That is, P = {x : ⟨v j , W ′ P x⟩ + ⟨v j , b ′ P ⟩ + c j = 0} which can be viewed locally as a pseudo-neuron with parameters ((W ′ P ) T v j , +⟨v j , b ′ P ⟩ + c j ) that are each bounded in magnitude by M ′ ≤ dM 2 + 1. 1. Applying the above to lemma 6, we get: p 1 ≤ dq ′ δ(M ′ + 1)2 d+1 ≤ dq 2 δ(dM 2 + 2)2 d+1 . 2. Applying the above to lemma 6, we get: p 2 ≤ d2 d q ′2 δ(M ′ + 1) ≤ d2 d q 4 δ(dM 2 + 2). 3. Applying the above to lemma 6, we get: p 3 ≤ 3dq ′ δ ≤ 3dq 2 δ. 4. Applying the above to lemma 6, we get: p 4 ≤ d 3/2 q ′2 δ(M ′ + 1)2 d-1 ≤ d 3/2 q 4 δ(dM 2 + 2)2 d-1 . 5. Note that the maximal possible number of critical hyperplanes is at most h = q 2 , as interactions between each first-layer neuron and a second-layer neuron may cause a single hyperplane. Therefore, we get: p 5 = δh 2 √ dd(M ′ + 1) 2 2 d+1 ≤ δq 4 √ dd(dM 2 + 2) 2 2 d+1 . 6. Applying the above to lemma 6, we get: p 6 ≤ dq ′ δ2 d ≤ dq 2 δ2 d . 7. Let p 1 , p 2 , p 3 , p 4 , p 5 , p 6 as defined on lemma 6. As the number of neurons in the second layer is at most q, using the union bound we get: p 7 = 6 i=1 p i . In the rest of the section, we prove lemmas stated in section 3.3. Proof. (of lemma 1) The proof follow from lemma 5 and the fact that the number of neurons -hence, the number of critical hyperplanes -is finite. Proof. (of lemma 2) From lemma 1, we have that the perturbed network M ′ is regular almost surely. This implies that it is δ-regular for some δ > 0, As the number of neurons is finite. Fix a δ > 0. For two-layer networks, lemma 6 bounds the probability to dispose one of the restrictions of δ-regular network. Let p 1 , . . . , p 6 as in lemma 6, then, using the union bound, we get that the network is δ-regular with probability of at least 1 -p where p = p 1 + p 2 + p 3 + p 4 + p 5 + p 6 = dqδ(M + 1)2 d+1 + d2 d q 2 δ(M + 1) + 3dqδ + d 3/2 q 2 δ(M + 1)2 d-1 + δq 2 √ dd(M + 1) 2 2 d+1 + dqδ2 d < 10(M + 1) 2 q 2 d 3/2 δ2 d . Therefore, if we choose δ = (10(M + 1) 2 q 2 d 3/2 2 2d ) -1 we will get the requested bound. For three-layer networks, let p 1 , . . . , p 7 as in lemma 7, then, using the union bound, we get that the network is δ-regular with probability of at least 1 -p ′ where p ′ = p 1 + p 2 + p 3 + p 4 + p 5 + p 6 + p 7 ≤ dq 2 δ(dM 2 + 2)2 d+1 + d2 d q 4 δ(dM 2 + 2) + 3dq 2 δ + d 3/2 q 4 δ(dM 2 + 2)2 d-1 + δq 4 √ dd(dM 2 + 2) 2 2 d+1 + dq 2 δ2 d + 10(M + 1) 2 q 2 d 3/2 δ2 d < 20(dM 2 + 2) 2 q 4 d 3/2 δ2 d . Therefore, if we set δ = (20(dM 2 + 2) 2 q 4 d 3/2 2 2d ) -1 we will get the requested bound. Proof. (of lemma 3) Let W ∈ R d1×d a random matrix, where each element is drawn independent of the other, and define by w j its j'th row, j ∈ [d]. Also, let r = min{d, d 1 }. Note that W has a full rank with probability 1, where by full rank we mean that rank(W) = min{d, d 1 } = r. Indeed, consider drawing at random the jth row, for j ≤ r after fixing the first j -1 rows. In order of that row to be dependent in w 1 , . . . , w j-1 , then w j must fall in a subspace whose dimension is at most j -1 < r, which has a zero Lebesgue measure in an r-dimensional space. Therefore, if d 1 ≤ d then r = d 1 and W has a rank d 1 with probabilty 1. The Rank-nullity theorem then implies that the image of W is a d 1 -dimensional space, and thus W is surjective. Proof. (of lemma 4) The proof follows from Theorem 5. If we set 3.5d 1 ≤ d 2 we get a probability of: 1 - ed2 d1 d1+1 2 d2 ≤ 1 - (3.5e) d1+1 2 3.5d1 = 1 -3.5e 3.5e 2 3.5 d1 d1→∞ ----→ 1.

B PROOF OF THE MAIN THEOREMS

Proof. (of theorem 1) The correctness of the theorem follows from the correctness of theorem 4 below. Proof. (of theorem 4) We will assume without loss of generality that the u i 's are in {±1}, as any neuron x → uϕ(⟨w, x⟩ + b) calculates the same function as x → u |u| ϕ(⟨|u|w, x⟩ + |u|b), as ReLU is a positive homogeneous function. Let S = {x 1 , . . . , x m } be the list of points found using FIND CP in algorithm 6. Our general assumption is that all the critical points on the line Re (1) are on the range -1 δ , 1 δ e (1) . Hence, from the correctness of lemma 13, we are guarantees that all the critical points on the line Re (1) are in S. We claim that for each x ∈ S there is exactly one critical hyperplane P with x ∈ P, and |S ∩ P| = 1. Assume by contradiction that one of the above is false. If x / ∈ P for all the critical hyperplanes, then x is not a critical point, which contradicts lemma 13. If |S ∩ P| = 0 this means that P does not intersect with e (1) , i.e., parallel to this axis, which contradict our general position assumption. Finally, if |S ∩ P| > 1 this means that P intersects with e (1) , which means P is not affine. Therefore, each neuron is represented by a unique critical point x ∈ S. Let x ∈ S be a critical point of the j'th neuron, and (w ′ , b ′ ) = FIND HP(M, δ, x). From lemma 14 we get that either (w j , b j ) = (w ′ , b ′ ) or (w ′ , b ′ ) = (-w ′ j , -b ′ j ). To recover u j , note that if u j = 1 then M(x) is strictly convex in B(x, δ) as the sum of the affine function M ′ (x) and the convex function u j ϕ(⟨w j , x⟩ + b j ). Similarly, if u j = -1 then M(x) is strictly concave in B(x, δ). Thus, using algorithm 4, we will be able to determine u j correctly. • For each k, P k is a critical hyperplane of M and x k is a δ-non-degenerate critical point whose critical hyperplane is P k • The list contains all the critical hyperplanes of first-layer neurons For that we will use repeatedly FIND CP(δ, t → M(te (1) ), •) to find all the critical points on the axis {te (1) : t ∈ R} (see section 3.2), similar to our algorithm for reconstructing depth two networks. Denote those set of points by S = {x 1 , . . . , x m }. Lemma 13 along with the general position assumption, guarantee that for each critical hyperplane P that corresponds to a first-layer neuron, |P ∩ S| = 1, and that all the points in S are δ-non-degenerate. Then, using algorithm 3 we will find the critical hyperplane P k for each point x k ∈ S. For the runtime, note that m = O(d 1 d 2 ) (e.g.  O (d 1 d 2 log(1/δ)Q + d 1 d 2 dQ). The second step is to take the list L = (x 1 , P1 ), . . . , (x m , Pm ) and remove all the points that don't correspond to first-layer neurons. After that, the list will contain precisely the critical hyperplanes of the neurons in the first layer. In order to do so, it is enough to efficiently decide, given the list L, whether a given hyperplane P is a critical hyperplane of a neuron in the first layer. The idea behind this verification is simple: If P corresponds to a neuron in the first layer then any point in P is a critical point of M (see lemma 8). Indeed, suppose that P is critical at P for a first layer neuron h(x) = ϕ(wx + b). We have that P is the null space of the affine input to h in the proximity of P. But the input to h is the same affine function in the proximity of every point x ∈ R d . Thus, for every x ∈ R d is a critical point for h with P as its critical hyperplane. On the other hand, if P corresponds to a neuron in the second layer, then not all its points are critical for M: Indeed, suppose that we start from x ∈ P, which is critical for M and start to move inside P until one of the neurons in the first layer changes its state. Then we will reach a point in x ′ ∈ P, which is not critical for M, as, by our general position assumption, P corresponds to a single first layer state. Thus, in order to verify if P corresponds to a first layer neuron, we will go over all the hyperplanes Pk ∈ L, and for each of them, will find a point x ′ ∈ P that is the opposite side of Pk (relative to x) and will check if it is critical. If it is not critical for one of the hyperplanes, we know that P does not correspond to a first layer neuron. If all the points that we have examined are critical, even for Pk corresponding to a first layer neuron, then x ′ is critical, which means that P must correspond to a first layer neuron. Algorithm 7 implements this idea. There is one caveat that we need to handle: The examined point has to be generic enough in order to test whether it is critical or not using algorithm 5. To make sure that the point is general enough, we slightly perturb it. The correctness of the algorithm follows from lemmas 10 and 11. Indeed, if P corresponds to a first layer neuron, then lemma 10 implies that each test in the for loop will fail w.p. at least 1 -2 -d m 2 . Thus, w.p. at least 1 -2 -d m all the tests will fail, and the algorithm will reach step 8 and will correctly output that "P is a first-layer critical hyperplane." In the case that P corresponds to a second layer neuron, lemma 11 implies that once we will reach an iteration in which Pk corresponds to a first layer neuron, the test in step 6 will succeed w.p at least 1 -2 -d m 2 , in which case the algorithm will correctly output "P is not a first-layer critical hyperplane." All in all, it follows that the algorithm will output the correct output w.p. at least 1 -2 -d m for every hyperplane P. Thus, w.p. at least 1 -2 -d it will output the correct answer for all hyperplanes. As for runtime, note that each step in the for-loop takes O(dQ). As the list size is O(d 1 d 2 ), the total running time over all hyperplanes is O(d 2 1 d 2 2 dQ). Since the rows in W are assumed to have a unit norm, the list of the critical hyperplanes of the first-layer neurons, obtained in the previous step, determines the weights up to sign. Namely, we can reconstruct a list L = ( ŵ1 , b1 ), . . . , ( ŵd1 bd1 ) that define precisely the neurons on the first layer, up so sign. For the third step, it, therefore, remains to recover the correct signs (note that this process is only required for inner layers and avoidable for the top layer, as explained above). In order to recover the correct sign of ( ŵ1 , b1 ), we can simply do the following test: Choose a point x such that ŵ1 x + b1 = 0, and query the network in the points x + ϵz, x -ϵz, for small ϵ, where z ∈ R d is a unit vector that has the property that is orthogonal to ŵ2 , . . . , ŵd1 , but ŵ1 z > 0. If we assume that W is right invertible, then such a z exists, as w 1 , . . . , w d1 are linearly independent. Now, when moving from x to either x + ϵz or x -ϵz, the value of all the neurons in the first layer, possibly except the one that corresponds to ( ŵ1 , b1 ), does not change. As for the neuron that corresponds to ( ŵ1 , b1 ), if its real weights are indeed ( ŵ1 , b1 ), then its value changes when we move from x to x + ϵz but not when we move from x to x -ϵz. On the other hand, if its real weights are (-ŵ1 , -b1 ), then the value changes when we move from x to x -ϵz but not when we move from x to x + ϵz. Hence, in order to have the correct sign we can do the following: If M(x) ̸ = M(x + ϵz) then keep ( ŵ1 , b1 ). Else, replace it with (-ŵ1 , -b1 ). This test works because of the above discussion, together with the assumption that the second layer has non-zero partial derivatives; therefore, we can guarantee that either x + ϵz or x -ϵz will show a change in the values of M. More on the non-zero partial derivatives assumption, see section E. The runtime of this step is O(d 3 1 d + d 1 Q). Indeed, to find z, we need to do Gram-Schmidt, which takes O(d 2 1 d). After that, all that is needed is two queries. We need to do this for each first layer neuron, so the total runtime is O(d 3 1 d + d 1 Q). For the fourth step, we shall recover the values of the top layer up to an affine transformation. After having the weights of the first layer at hand, and since W is assumed to be right invertible, we can directly access the sub-network defined by the top two layers. Namely, given x ∈ R d1 + , we can find z ∈ R d such that x = (ϕ (w 1 z + b 1 ) , . . . , ϕ (w d1 z + b d1 )) e.g., by taking z = W -1 (x -b) where W -1 is a right inverse of W . Now, M(z) is precisely the value of the top layer on the input x. Hence, the problem of reconstructing the top two layers boils down to the problem of reconstructing a depth two network in the R d + -restricted case, which its correctness is given in theorem 2. The cost of a query to the second layer is Q plus the cost of computing z, which is O(dd 1 ). There is also an asymptotically negligible cost of O(d 2 1 d) for computing W -1 . The runtime of this step is therefore O ((log(1/δ) + d 1 )d 1 d 2 (Q + dd 1 ) + d 2 1 d 2 2 ).

C PROOFS OF LEMMAS

Lemma 8. Let P be a critical hyperplane corresponding to a first layer neuron. Then, any point in P is critical for M. Proof. W.l.o.g. P corresponds to the neuron ϕ(w 1 x + b 1 ). Let x 0 ∈ P and let e be a unit vector that is orthogonal to w 2 , . . . , w d1 and such that ⟨w 1 , e⟩ > 0. Such e exists as we assume that w 1 , . . . , w d1 are independent. Consider the function f (t) = M(x 0 + te). We claim that it is not linear in any neighborhood of 0, which implies that x 0 is critical. Indeed, for all i > 1, t → ϕ(w i (x 0 + te) + b i ) is constant, as e is orthogonal to w i . As for i = 1, t → ϕ(w 1 (x 0 + te) + b 1 ) is the zero function for t ≤ 0, as in this case w 1 (x 0 + te) + b 1 < w 1 x 0 + b 1 = 0. Hence, the left derivative of f at 0 is 0. On the other hand, for t > 0, ϕ(w 1 (x 0 + te) + b 1 ) = w 1 (x 0 + te) + b 1 = tw 1 e. Hence, the right derivative of g(t) = ϕ(W (x 0 + te) + b) is ⟨w 1 , e⟩e (1) . Now, it is assumed that the derivative of F (z) = uϕ(V z + c) in the direction of e (1) is not zero. Hence, the right derivative of f (t) = F (g(t)) is not zero. All in all we have shown that the right derivative of f at 0 is different from the left derivative, which implies that f is not linear in any neighborhood of 0. Lemma 9. Let P 1 , P 2 be hyperplanes such that D(P 1 , P 2 ) ≥ δ. Let x ∈ P 1 and let x be a standard Gaussian in P 1 with mean x. Then p (d(x, P 2 ) ≤ a) ≤ √ 2a δ √ π . Proof. W.l.o.g. we can assume that P 1 and P 2 contain the origin. Let n 2 be the normal of P 2 . We have that d(x, P 2 ) = |⟨x, n 2 ⟩|. Now ⟨x, n 2 ⟩ = x, n 2 -proj P1 n 2 + x, proj P1 n 2 x∈P1 = x, proj P1 n 2 = x -x, proj P1 n 2 + x, proj P1 n 2 Hence ⟨x, n 2 ⟩ is a Gaussian with mean µ := x, proj P1 n 2 and variance ϕ 2 ≥ δ 2 . Hence, p(x ∈ [-a, a]) = 1 ϕ √ 2π a -a e -1 2 ( t-µ 2 ) 2 dt ≤ 2a δ √ 2π = √ 2a δ √ π . Lemma 10. Let P be a hyperplane that corresponds to a first layer neuron. Let x ∈ P and let x be a standard Gaussian in P with mean x. Then x is δ ′ -non-degenerate critical point of P w.p. at least 1 -2 2(d1+d2) δ ′ √ 2 δ √ π . Proof. By lemma 8 x is critical w.p. 1. It is therefore enough to show that w.p. at least 1 -2 2(d1+d2) δ ′ √ 2 δ √ π , the distance of x from every critical hyperplane other than P is at least δ ′ . Indeed, by lemma 9 and the fact that there are at most (d 1 + d 2 )2 d1+d2 ≤ 2 2(d1+d2) critical hyperplanes, the probability that the distance from x to one of the critical hyperplane is less than δ ′ is at most 2 2(d1+d2) δ ′ √ 2 δ √ π . Lemma 11. Let P be a hyperplane that corresponds to a second layer neuron. Let x 1 ∈ P be a critical point with P as its critical hyperplane. Let P 1 be a hyperplane that corresponds to a first layer neuron. Let x 2 ∈ P be another point and assume that x 1 and x 2 are of opposite sides of P 1 . Let x be a standard Gaussian in P with mean x 2 . Then x is δ ′ -general w.p. at least 1 -2 2(d1+d2) δ ′ √ 2 δ √ π -e -(d(x2,P1)-δ ′ ) 2 2 . Proof. As in the proof of lemma 10 the probability that the distance from x to one of the critical hyperplanes other than P is less than δ ′ is at most 2 2(d1+d2) δ ′ √ 2 δ √ π . It is therefore remains to show that the probability that x is δ ′ -close to on of P's critical points is at most e -(d(x2,P1)-δ ′ ) 2 2 . Denote by n 1 the normal of P 1 . We first note that there are no P-critical points in x 2 's side of P 1 . Indeed, the state of the first layer is different than the state at x 1 , as the neuron corresponding to P 1 changes its state. As it is assumed that each second layer critical hyperplane corresponds to a single neuron and single first layer state, it follows that there are no P-critical points in x 2 's side of P 1 . It is therefore enough to bound the probability that x is δ ′ -close to x 1 's side of P 1 , which is same as the probability that ⟨x -x 2 , n 1 ⟩ ≥ d(x 2 , P 1 ) -δ ′ . Finally, ⟨x -x 2 , n 1 ⟩ is a centered Gaussian Hence, for every z ∈ B(x, ϵ) we have f (z) = f (x + (z -x)) = f (x) + d i=1 f (x + ϵe (i) ) -f (x) ϵ (z i -x i ) = f (x) - d i=1 f (x + ϵe (i) ) -f (x) ϵ x i + d i=1 f (x + ϵe (i) ) -f (x) ϵ z i . As f is affine at x, we therefore get: w i = f (x + ϵe (i) ) -f (x) ϵ and b = f (x) - d i=1 f (x + ϵe (i) ) -f (x) ϵ x i . Lemma 13. Algorithm 1 returns the left most critical point of a δ-nice one-dimensional function f in the range (a, 1/δ). Proof. Let x * is the left-most critical point in (a, 1/δ). Throughout the algorithm's execution, we have that x L < x * < x R , as in each iteration, we choose the left half of the segment unless this half is affine (and therefore cannot have a critical point). As we start with a segment of size 2/δ and split it two halves at each iteration, after ⌈log 2 (2/δ 2 )⌉ iterations we left with |x L -x R | < δ. Hence, in the final step, we have that x L is in the left-most piece, while x R is in the piece that is adjacent to the left-most piece. Therefore, x * is the point at the intersection of those two affine functions. If no critical point is in (a, 1/δ), then the segment is affine and we get that Λ L = Λ R . Finally, note that all the points x L , x R and x R +x l 2 during the execution of the algorithm are in the grid 2 -⌈log 2 (2/δ 2 )⌉ δ Z and therefore δ 2 -general. Lemma 14. Algorithm 3 returns critical hyperplane of δ-non-degenerate critical point x ∈ R d , assuming the hyperplane is δ general. Proof. Let us assume that x is a δ-critical point of the j'th neuron. We will reconstruct the j'th neuron in two steps. 1. The first step is to find an affine function Λ such that Λ = Λ wj ,bj or Λ = -Λ wj ,bj . Let M ′ (x) := M(x) -u j ϕ(⟨w j , x⟩ + b j ). Note that M ′ is affine in B(x, δ), as no neuron other than the j'th one changes its state in B(x, δ). We have that in B(x, δ) on one side of x's critical hyperplane the network computes M ′ (x) and on the other hand it computes M ′ (x) + u j (⟨w j , x⟩ + b j ). Thus, to extract Λ wj ,bj up to sign, we can simply compute the affine functions computed by the network on both sides of the x's critical hyperplane, and subtract them. 2. The second step is to recover u j . To this end, we note that if u j = 1 then M(x) is strictly convex in B(x, δ) as the sum of the affine function M ′ (x) and the convex function u j ϕ(⟨w j , x⟩ + b j ). Similarly, if u j = -1 then M(x) is strictly concave in B(x, δ). Thus, to recover u j we will simply check the convexity of M in B(x, δ) using algorithm 4. Finally, note that ϕ(Λ(x)) is either ϕ(⟨w j , x⟩ + b j ) or ϕ(⟨w j , x⟩ + b j ) -⟨w j , x⟩ -b j . Hence, u j ϕ(Λ(x)) is either u j ϕ(⟨w j , x⟩ + b j ) or u j ϕ(⟨w j , x⟩ + b j ) -u j ⟨w j , x⟩ -u j b j . In particular, u j ϕ(Λ(x)) equals to u j ϕ(Λ wj ,bj (x)) up to an affine map.



By most, we mean all except a set whose dimension is d -2. By parallel we mean that the vector is orthogonal to the hyperplane's normal. Remember that we assign a (non-degenerate) critical point to a neuron if the value at that neuron, before the ReLU function, is 0. A point x corresponds to the ith neuron in the first layer if ⟨wi, x⟩ + bi = 0, and corresponds to the jth neuron in the second layer if ⟨vj, ϕ(W x + b)⟩ + cj = 0. A critical hyperplane corresponds to some neuron if there is a non-empty open set S ⊆ P where each x ∈ S corresponds to that neuron. That is, for all j ∈ [d1], the distributions of (wj, bj) and (-wj, -bj) are the same.



d and by B(x, δ) the open ball around x ∈ R d with radius δ > 0. For w ∈ R d and b ∈ R we denote by Λ w,b the affine function Λ w,b (x) = ⟨w, x⟩ + b. For a point x ∈ R d and a set A ⊂ R d we denote by d(x, A) = inf y∈A ∥x -y∥ the distance between x andA. Given a subspace P, A Gaussian in P is a Gaussian vector x in R d whose density function is supported in P. We say that it is standard if the projection of x on any line in P that passes through E[x] has a variance of 1.

Figure 1: An illustration of one-dimensional piecewise linear function M : R → R

with variance ≤ 1. Hence, p(⟨xx 2 , n 1 ⟩ ≥ d(x 2 , P 1 ) -δ ′ ) ≤ e -(d(x2,P1)-δ ′ ) 2 2D CORRECTNESS OF THE ALGORITHMSLemma 12. Algorithm 2 reconstructs the correct affine transformation at an ϵ-general point x ∈ R d .Proof. Note that for any y ∈ R d with ∥y∥ ≤ ϵ we have,f (x + y) = f (x) + d i=1 f (x + ϵe (i) ) -f (x) ϵ y i .

Telgarsky (2016)), the critical points can be found in time O (d 1 d 2 log(1/δ)Q) as explained in section 3.2.2, and each hyperplane Pi can be efficiently found via O(d) queries near x i as explained in section 3.2.3. The total running time of this step is therefore

ACKNOWLEDGMENTS

This research is supported by ISF grant 2258/19, and ERC grant 101041711

annex

Let C ⊂ [d 1 ] be the set of neurons assigned to an incorrect sign. Then, for all x ∈ R d :which is an affine transformation and can be recovered successfully at the last stage of the algorithm.As for the time and query complexity, step 1 takes O (d 1 log(1/δ)Q) (see section 3.2). Since each neuron correspond to a single critical point, we have that m = d 1 . Thus the loop in step 2 makes d 1 iterations. The cost of each iteration is O(dQ). Hence, the total cost of the loop is O(d 1 dQ). Finally, to perform step 6 we need to make d queries to M which cost O(dQ), and also d evaluations ofProof. (of theorem 2) Denote the output of the j'th neuron before the activation by+ be two points such that exactly one neuron j ∈ [d 1 ] changed its state (i.e. changed from active to inactive, or vice versa) in the segmentMoreover, assume that no neuron changes its state in neighborhoods of x 1 and x 2 , so that the change in the state happens in the interior of [x 1 , x 2 ]. We note that finding such a pair of points x 1 , x 2 can be done by considering a ray ℓ(ρ) := ρe (i) , and seeking a critical point ρ of the (one dimensional) function N • ℓ. Under our general position assumptions, for some j ∈ [d], there is such a ρ in R, and it can be found efficiently. Given such a ρ, and again under our general position assumptions, we can take x 1 = ℓ(ρ -ϵ) and x 2 = ℓ(ρ + ϵ), for small enough ϵ.We will explain next how given such two points, we can reconstruct the j'th neuron, up to an affine function. First, the reconstruction of u j is simple. Indeed, in the segment [x 1 , x 2 ], M ′ (x) := M(x) -u j ϕ(w j x + b j ) is affine, as no neuron, except the j'th neuron, changes its mode. Hence, M(x) = u j ϕ(w j x + b j ) + M ′ (x) is a sum of an affine function and the j'th neuron. In particular, it is convex iff the j'th neuron is convex iff u j = 1. Hence, to reconstruct u j we only need to check if the restriction of N to [x 1 , x 2 ] is convex or concave.We next explain how to reconstruct an affine map Λ such that ϕ(Λ(x)) -ϕ(M j (x)) is affine.Let Λ 1 , Λ 2 : R d → R be the affine maps computed by the networks in the neighborhoods of x 1 and x 2 respectively. Note that it is straight forward to reconstruct Λ i from the set M(x i ), M(x i + ϵe 1 ), . . . , M(x i + ϵe d ), for small enough ϵ. We have that Λ := Λ 1 -Λ 2 is either N j or -N j . Hence, we have that ϕ(Λ(x)) is either ϕ(M j (x)) or ϕ(-M j (x)) = ϕ(M j (x)) -M j (x).After removing all the neurons, we are left with an affine map that can be reconstructed easily using O(d) queries as explained above, and the full reconstruction of the network is complete.Proof. (of theorem 3) Recall that our goal is to recover a δ-regular network of the formWe denote by w j the jth row of W and assume without loss of generality that it is of unit norm, as any neuron of the form x → ϕ(⟨w, x⟩ + b) can be replaced by x → ∥w∥ϕ w ∥w∥ , x + b ∥w∥ . Likewise, and similar to our algorithm for reconstruction of depth-two networks, we will assume that u ∈ {±1} d2 . The first step of the algorithm would be to find a listand assume that for any j, u j ̸ = 0 (otherwise the corresponding neuron can be dropped). We have thatNow, if the weights are random, say that the w j 's are independent random variables such that w j (i) has a continuous distribution, then w.p. 1, we have that for every non-zero vector z ∈ {0, 1} d1 , it holds that d1 j=1 z j w j (i) ̸ = 0 and hence ∂M ∂xi (x) ̸ = 0, unless the vector Λ(x) := (w-. It follows the non-zero partial derivatives assumption holds, provided if and only if the affine map Λ maps the positive orthant R d + to the complement of the negative orthant R d1 -. The following lemma shows that if d 1 ≫ d, then this is often the case. Lemma 15. Assume that the pairs (w j , b j ) are independent and symmetric 4 , thenProof. We first note that the number of orthants that has a non-negative intersection with Λ(R d ) is exactly the number of functions in the class) is an affine space of dimension at most d, H has VC dimension at most d + 1 (e.g. Anthony & Bartlet (1999) ). Hence, by the Sauer-Shelah lemma (again, Anthony & Bartlet (1999) )Finally, since the (w j , b j )'s and symmetric, the probability that Λ(R d ) intersects R d1 -is the same as the probability that it intersects any other orthant. Since there are 2 d1 orthants, and Λ(R .All in all we get the following corollary: Theorem 5. Assume that the pairs (w j , b j ) are independent, symmetric, and has continuous marginals, w.p. 1 -we have that ∂M ∂xi (x) ̸ = 0 for all x ∈ R d and i ∈ [d].

