What do deep networks learn and when do they learn it

Scribe notes by Manos Theodosis

Previous post: A blitz through statistical learning theory Next post: Unsupervised learning and generative models. See also all seminar posts and course webpage.

Lecture videoSlides (pdf)Slides (powerpoint with ink and animation)

In this lecture, we talk about what neural networks end up learning (in terms of their weights) and when, during training, they learn it.

In particular, we’re going to discuss

  • Simplicity bias: how networks favor “simple” features first.
  • Learning dynamics: what is learned early in training.
  • Different layers: do the different layers learn the same features?

The type of results we will discuss are:

  • Gradient-based deep learning algorithms have a bias toward learning simple classifiers. In particular this often holds when the optimization problem they are trying to solve is “underconstrained/overparameterized”, in the sense that there are exponentially many different models that fit the data.
  • Simplicity also affects the timing of learning. Deep learning algorithms tend to learn simple (but still predictive!) features first.
  • Such “simple predictive features” tend to be in lower (closer to input) levels of the network. Hence deep learning also tends to learn lower levels earlier.
  • On the other side, the above means that distributions that do not have “simple predictive features” pose significant challenges for deep learning. Even if there is a small neural network that works very well for the distribution, gradient-based algorithms will not “get off the ground” in such cases. We will see a lower bound for learning parities that makes this intuition formal.

What do neural networks learn, and when do they learn it?

As a first example to showcase what is learned by neural networks, we’ll consider the following data distribution where we sample points (X, Y), with Y \in {1, -1} (Y = 1 corresponding to orange points and Y=-1 corresponding to blue points).

If we train a neural network to fit this distribution, we can see below that the neurons that are closest to the input data end up learning features that are highly correlated with the input (mostly linear subspaces at 45-degree angle, which correspond to one of the stripes). In the subsequent layers, the features learned are more sophisticated and have increased complexity.

Neural networks have simpler but useful features in lower layers

Some people have spent a lot of time trying to understand what is learned by different layers. In a recent work, Olah et al. dig deep into a particular architecture for computer vision, trying to interpret the features learned by neurons at different layers.

They found that earlier layers learn features that resemble edge detectors.

However, as we go deeper, the neurons at those layers start learning more convoluted (for example, these features from layer 3b resemble heads).

SGD learns simple (but still predictive) features earlier.

There is evidence that SGD learns simpler classifiers first. The following figure tracks how much of a learned classifier’s performance can be accounted for by a linear classifier. We see that up to a certain point in training all of the performance of the neural network learned by SGD (measured as mutual information with the label or as accuracy) can be ascribed to the linear classifier. They diverge only very near the point where the linear classifier “saturates,” in the sense that the classifier reachers the best possible accuracy for linear models. (We use the quantity I(f(x);y |L(x)) – the mutual information of f(x) and y conditioned on the prediction of the linear classifier L(x) – to measure how much of f‘s performance cannot be accounted for by L.)

The benefits and pitfalls of simplicity bias

In general, simplicity bias is a very good thing. For example, the most “complex” function is a random function. However, if given some observed data { (x_i,y_i)}_{i\in [n]}, SFD were to find a random function f that perfectly fits it, then it would never generalize (since for every fresh x, the value of f(x) would be random).

At the same time, simplicity bias means that our algorithms might focus too much on simple solutions and miss more complex ones. Sometimes the complex solutions actually do perform better. In the following cartoon a person could go to the low-hanging fruit tree on the right-hand side and miss the bigger rewards on the left-hand side.

This can actually happen in neural networks. We also saw a simple example in class:

The two datasets are equally easy to represent, but on the righthand side, there is a very strong “simple classifier” (the 45-degree halfspace) that SGD will “latch onto.” Once it gets stuck with that classifier, it is hard for SGD to get “unstuck.” As a result, SGD has a much harder time learning the righthand dataset than the lefthand dataset.

Analysing SGD for over-parameterized linear regression

So, what can we prove about the dynamics of gradient descent? Often we can gain insights by studying linear regression.

Formally, given (x_i, y_i)_{i=1}^n \in \mathbb{R}^{d+1} with d\gg n we would like to find a vector w\in\mathbb{R}^d such that \langle w, x_i\rangle\approx y_i.

In this setting, we can prove that running SGD (from zero or tiny initialization) on the loss \mathcal{L}(w) =\lVert Xw -y \rVert^2 will converge to solution w of minimum norm. To see whym note that SGD performs updates of the form
w_{t+1} = w_t - \eta x_i^T(\langle x_i, w\rangle - y_i).
However note that \eta(\langle x_i, w\rangle - y_i) is a scalar. Therefore all of the updates keep the updated vector w_{t+1} within \mathrm{span}(x_1^T, \ldots, x_n^T). This implies that the converging solution w_{\infty} will also lie in \mathrm{span}(x_1^T, \ldots, x_n^T).

Geometrically this translates into w_{\infty} being the projection of y onto the the subspace \mathrm{span}(x_1^T, \ldots, x_n^T) which results in the least norm solution.

Analyzing the dynamics of descent, we can write the distance between consecutive weight updates and the converging solution as
w_{t+1} - w_{\infty} = (I - \eta X^TX)(w_t - w_{\infty}).
We see that we are applying the linear operator (I - \eta X^TX) at every step we take. As long as this operator is contractive, we will continue to progress and converge to w_{\infty}. Formally, to make progress, we require
0 \prec I -\eta X^TX\prec 1.
This directly translates into \eta < \frac{1}{\lambda_1} and then the progress we make is approximately \frac{\lambda_d}{\lambda_1}=\frac{1}{\kappa}, where \kappa is the condition number of X.

What happens now if the matrix X is random? Then, results from random matrix theory (specifically the Marchenko-Pastur distribution) state that

  • if d < n, then the matrix X^\top X has \mathrm{rank}(X^\top X)=d and the eigenvalues are bounded away from 0. This means that the matrix is well conditioned.
  • if d \approx n, then the spectrum of X^\top X starts shifting towards 0, with some eigenvalues being equal to zero, resulting in an ill-conditioned matrix.
  • if d > n, then the spectrum has some zero eigenvalues, but is otherwise bounded away from zero. If we restrict to the subspace of positive eigenvalues, we achieve again a good condition number.

Deep linear networks

We now want to go beyond linear regression and talk about deep networks. As deep networks are very hard to understand, we will first start analyzing a depth 2 network. We will also consider a linear network and omit the nonlinearity. This might seem strange, as we could consider the corresponding linear model, which has exactly the same expressiveness. However, note that these two models have a different parameter space. This means that gradient-based algorithms will travel on different paths when optimizing these two models.

Specifically, we can see that the minimum loss attained by the two models will coincide, i.e., \min \mathcal{L}(A_1, A_2) = \min \mathcal{L}(B), but the SGD path and the solution will be different.

We will analyze the gradient flow on these two networks (which is gradient descent with the learning rate \eta \rightarrow 0). We will make the simplifying assumption that A_1 = A_2 and symmetric. Then, we can see that B = A^2 \Rightarrow A = \sqrt{B}. We will try and compare the gradient flow of two different loss functions: \mathcal{L}(B) (doing gradient flow on a linear model) and \tilde{\mathcal{L}}(A) = \mathcal{L}(A^2) (doing gradient flow on a depth linear model).

Gradient flow on the linear model simply gives \frac{dB(t)}{dt}=-\nabla\mathcal{L}(B(t)), whereas for the deep linear network we have (using the chain rule)
\frac{dA(t)}{dt}=-\nabla\tilde{\mathcal{L}}(A(t)) = \nabla\mathcal{L}(A^2)A = A\nabla\mathcal{L}(A^2),
since A is symmetric.

For simplicity, let’s denote \nabla\mathcal{L}(B) = \nabla\mathcal{L}(A^2) = \nabla and \nabla\tilde{\mathcal{L}}(A) = \tilde{\nabla}. We then have
\frac{dA^2(t)}{dt}=\frac{dA(t)}{dt}A=-\tilde{\nabla}A = -A \nabla A.

Another way to view the comparison between the models of interest, \frac{dA^2(t)}{dt} and \frac{dB(t)}{dt} is as follows: let B = A^2, then \frac{dB(t)}{dt} = -A\nabla A = -\sqrt{B}\nabla\mathcal{L}(B(t))\sqrt{B}.
We can view this as follows: when we multiply the gradient with \sqrt{B} we end up making the “big bigger and the small smaller”. Basically, this accenuates the differences between the eigenvalues and is biasing B to become a low-rank matrix.

To see why, you can think of a low rank matrix has one that has few large eigenvalues and the others small. If B is already close low rank, then replacing a gradient by \sqrt{B}\nabla\mathcal{L}(B(t))\sqrt{B} encourages the gradient steps to mostly happen in the top eigenspace of B. This result generalizes to networks of greater depth, and the gradient evolves as \frac{dB(t)}{dt} = -\psi_{B(t)}(\nabla\mathcal{L}(B(t))), with \psi_{B}(\nabla) = \sum B^{\alpha}\nabla B^{1-\alpha}.

This means that we end up doing gradient flow on a Riemannian manifold. An interesting result is that the flow induced by the operator \psi_{B} is provably not equivalent to a regularized minimization problem \min\mathcal{L} + \lambda R(B) for any R(\cdot).

What is learned at different layers?

Finally, let’s discuss what is learned by the different layers in a neural network. Some intuition people have is that learning proceeds roughly like the following cartoon:

We can think of our data as being “built up” as a sequence of choices from higher level to lower level features. For example, the data is generated by first deciding that it would be a photo of a dog, then that it would be on the beach, and finally low-level details such as the type of fur and light. This is also how a human would describe this photo. In contrast, a neural network builds up the features in the opposite direction. It starts from the simplest (lowest-level) features in the image (edges, textures, etc.) and gradually builds up complexity until it finally classifies the image.

How neural networks learn features?

To build a bit of intuition, consider an example of combining different simple features. We can see that if we try to combine two good edge detectors with different orientations, the end result will hardly be an edge detector.

So the intuition is that there is competitive/evolutionary pressure on neurons to “specialize” and recognize useful features. Initially, all the neurons are random features, which can be thought of as random linear combination of the various detectors. However, after training, the symmetry will break between the neurons, and they will specialize (in this simple example, they will either become vertical or horizontal edge detectors).

Raghu, Gilmer, Yosinski, and Sohl-Dickstein tracked the speed at which features learned by different layers reach their final learned state. In the figure below the diagonal elements denote the similarity of the current state of a layer to its final one, where lighter color means that the state is more similar. We can see that earlier layer (more to the left) reach their final state earlier (with th exception of the 2 layers closest to the output that also converge very early).

The “symmetry breaking” intuition is explored by a recent work of Frankle, Dziugaite, Roy, and Carbin. Intuitively, because the average of two good features is generally not a good feature, averaging the weights of two neural networks with small loss will likely result in a network with large loss. That is, if we start from two random initializations w_0, w'_0 and train two networks until we reach weights w_\infty and w'_\infty with small loss, then we expect the average of w_\infty and w'_\infty to result in a network with poor loss:

In contrast, Frankle et al showed that sometimes, when we start from the same initialization (especially after pruning) and use random SGD noise (obtained by randomly shuffling the training set) then we reach a “linear plateu” of the loss function in which averaging two networks yields a network with similar loss:

The contrapositive of simplicity: lower bounds for learning parities

If we believe that networks learn simple features first, and learn them in the early layers, then this has an interesting consequence. If the data has the form that simple features (e.g. linear or low degree) are completely uninformative (have no correlation with the label) then we may expect that learning cannot “get off the ground”. That is, even if there exists a small neural network that can learn the class, gradient based algorithms such as SGD will never find it. (In fact, it is possible that no efficient algorithm could find it.) There are some settings where we can prove such conjectures. (For gradient-based algorithms that is; proving this for all efficient algorithms would require settling the P vs NP question.)

We discuss one of the canonical “hard” examples for neural networks: parities. Formally, for I\subset [d], the distribution D_I is the distribution over (x,y) \in { \pm 1 }^{d+1} defined as follows: x\sim {\pm 1}^d and y = \prod_{i\in I}x_i. The “learning parity” problem is as follows: given n samples { (x_i,y_i) }_{i=1..n} drawn from D_I, either recover I or do the weaker task of finding a predictor f such that f(x)=y with high probability over future samples (x,y) \sim D_I.

It turns out that if we don’t restrict ourselves to deep learning, given 2d samples we can recover I. Consider the transformations Z_{i,j} = (1 - x_{i,j})/2 and b_i = (1 - y_i)/2. If we let s_i=1 if i\in I and 0 otherwise, we can write \sum_j Z_{i, j}s_j = b_i (\text{mod } 2). Basically, we transformed the problem of parity to a problem of counting if we have an odd or an even number of -1. In this setting, we can think of every sample (x,y) \in D_I as providing a linear equation moudlo 2 over the d unknown variables s_1,\ldots,s_d. When n>d, these linear equations will be very likely to be of full rank, and hence we can use Gaussian elimination to find s_1,\ldots,s_d and hence I.

Switching to the learning setting, we can express parities by using few ReLUs. In particular, we’ve shown that we can create a step function using 4 ReLUs. Therefore for every k \in {0,1,\ldots, d }, there is a combination of four ReLUs that computes the function f_k:\mathbb{R} :\rightarrow \mathbb{R} such that f_k(s) outputs 1 for s=k, and f_k(s) outputs 0 if |x-k|>0.5. We can then write the parity function (for example for I=[d]) as \sum_{k \text{ odd } \in [d]} f_k(\sum_{i=1}^d (1-x_i)/2 ). This will be a linear combination of at most 4d ReLUs.

Parities are an example of a case where simple feature are uninformative. For example, if |I|>1 then for every linear function L:\mathbb{R}^d \rightarrow \mathbb{R},

\mathbb{E}_{(x,y) \sim D_I}[ L(x)y] = 0

in other words, there is no correlation between the linear function and the label.
To see why this is true, write L(x) = \sum L_i x_i. By linearity of expectation, it suffices to show that $latex \mathbb{E}{(x,y) \sim D_I}[ L_ix_i y] = L_i \mathbb{E}{(x,y) \sim D_I}[ x_i y] = 0&bg=ffffff$. Both x_i and y_i are just values in {\pm 1 }. To evaluate the expectation \mathbb{E}[x_i y] we simply need to know the marginal distribution that D_I induces on { \pm 1 }^2 when we restrict it to these two coordinates. This distribution is just the uniform distribution. To see why this is the case, consider a coordinate j\in I \setminus { i } and let’s condition on the values of all coordinates other than i and j. After conditioning on these values, y = \sigma x_i x_j for some \sigma \in { \pm 1 } and x_i,x_j are chosen uniformly and independently from { \pm 1 }. For every choice of x_i, if we flip x_j then that would flip the value of y, and hence the marginal distribution on x_i and y will be uniform.

This lack of correlation turns out to be a real obstacle for gradient-based algorithms. While small neural networks for parities exist, and Gaussian elimination can find them, it turns out that gradient-based algorithms such as SGD will fail to do so. Parities are hard to learn, and even if the capacity of the network is such that it can memorize the input, it will still perform poorly in a test set. Indeed, we can prove that for every neural network architecture f_w(x), running SGD on \min\lVert f_w(x) -\prod_{i\in I} x_i\rVert^2 will require e^{\Omega(d)} steps. (Note that if we add noise to parities, then Gaussian elimination will fail and it is believed that no efficient algorithm can learn the distribution in this case. This is known as the learning parity with noise problem, which is also related to the learning with errors problem that is the foundation of modern lattice-based cryptography.)

We now sketch the proof that gradient-based algorithms require exponentially many steps to learn parities, following Theorem 1 of Shalev-Shwartz,Shamir and Shammah. We think of an idealized setting where we have an unlimited number of samples and only use a sample (x,y) \sim D_I only once (this should only make learning easier). We will show that we make very little progress in learning D_I, by showing that for any given w, the expected gradient over (x,y) will be exponentially small, and hence we make very little progress toward learning I. Specifically, using the notation \chi_I(x)=\prod_{i\in I}x_i, for any w,x,I,

\nabla_w \parallel f_w(x) - \chi_I(x) \parallel^2 = 2\sum_{i=1}^d \left [ f_w(x) \tfrac{d}{d x_i}f_{w}(x)- \chi_I(x)\tfrac{d}{d x_i}f_{w}(x) \right]

The term f_w(x) \tfrac{d}{d x_i}f_{w}(x) is independent of I and so does not contribute toward learning D_I. Hence intuitively to show we make exponentially small progress, it suffices to show that typically for every i, \left(\mathbb{E}_x[ \chi_I(x)\tfrac{d}{d x_i}f{w}(x) ] \right)^2 will be exponentially small. (That is, even if for a fixed x we make a large step, these all cancel out and give us exponentially small progress toward actually learning I.)

Formally, we will prove the following lemma:

Lemma: For every w, i

\mathbb{E}_I \left(\mathbb{E}_x[ \chi_I(x)\tfrac{d}{d x_i}f{w}(x) ] \right)^2 \leq \tfrac{poly(d,n)\mathbb{E}_x \frac{d}{d x_i}f{w}(x)^2}{2^d}

Proof: Let us fix i and define g(x) = \tfrac{d}{d x_i}f_{w}(x). The quantity \mathbb{E}_x[ \chi_I(x)\tfrac{d}{d x_i}f{w}(x)] can be written as \langle \chi_I,g \rangle with respect to the inner product \langle f,g \rangle = \mathbb{E}_x f(x)g(x). However, { \chi_I }{I\subseteq [d]} is an orhtonormal basis with respect to this inner product. To see this note that since \chi_I(x) \in { \pm 1 }, \mathbb{E}_x \chi_I(x)^2 = 1 for every I, and for I \neq J, \chi_I(x)\chi_J(x) = (\prod{i \in I} x_i)(\prod_{j \in J} x_j) = \prod_{k \in I \oplus H} x_k where I\oplus J is the symmetric difference of I and J. The reason is that x_i^2 =1 for all i and so elements that appear in both I and J “cancel out”. Since the coordinates of x are distributed independently and uniformly, the expectation of the product is the product of expectations. This means that as long as I \oplus J is not empty (i.e., I \neq J) this will be a product of one or more terms of the form (\mathbb{E} x_k). Since x_k is uniform over \{ \pm 1 \}, \mathbb{E} x_k = 0 and so we get that if I \neq J, \langle \chi_I,\chi_J \rangle =0.

Given the above

\mathbb{E}_x g(x)^2 = \langle g,g\rangle = \sum_{I \subseteq [d]} \langle g , \chi_I \rangle^2

which means that (since there are 2^d subsets of [d]) on average \langle g, \chi_I \rangle = \parallel g \parallel^2 / 2^d. In other words, \langle g,\chi_I \rangle is typically exponentially small which is what we wanted to prove.

2 thoughts on “What do deep networks learn and when do they learn it

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s