Understanding generalization requires rethinking deep learning?

Yamini Bansal, Gal Kaplun, and Boaz Barak

(See also paper on arxiv, code on gitlab, upcoming talk by Yamini&Boaz, video of past talk)

A central puzzle of deep learning is the question of generalization. In other words, what can we deduce from the training performance of a neural network about its test performance on fresh unseen examples. An influential paper of Zhang, Bengio, Hardt, Recht, and Vinyals showed that the answer could be “nothing at all.”

Zhang et al. gave examples where modern deep neural networks achieve 100% accuracy on classifying their training data, but their performance on unseen data may be no better than chance. Therefore we cannot give meaningful guarantees for deep learning using traditional “generalization bounds” that bound the difference between test and train performance by some quantity that tends to zero as the number of datapoints n increases. This is why (to quote their title), Zhang et al. claimed that “understanding deep learning requires rethinking generalization”.

But what if the issue isn’t that we’ve been doing generalization bounds wrong, but rather that we’ve been doing deep learning (or more accurately, supervised deep learning) wrong?

Self Supervised + Simple fit (SSS) learning

To explain what we mean, let’s take a small detour to contrast “traditional” or “end-to-end” supervised learning with a different approach to supervised learning, which we’ll call here “Self-Supervised + Simple fit” or “SSS algorithms.” (While the name “SSS algorithms” is new, the approach itself has a long history and has recently been used with great success in practice; our work gives no new methods—only new analysis.)

The classical or “end-to-end” approach for supervised learning can be phrased as “ask and you shall receive”. Given labeled data, you ask (i.e., run an optimizer) for a complex classifier (e.g., a deep neural net) that fits the data (i.e., outputs the given labels on the given data points) and hope that it will be successful on future, unseen, data points as well. End-to-end supervised learning achieves state-of-art results for many classification problems, particularly for computer vision datasets ImageNet and CIFAR-10.

However, end-to-end learning does not directly correspond to the way humans learn to recognize objects (see also this talk of LeCun). A baby may see millions of images in the first year of her life, but most of them do not come with explicit labels. After seeing those images, a baby can make future classifications using very few labeled examples. For example, it might be enough to show her once what is a dog and what is a cat for her to correctly classify future dogs and cats, even if they look quite different from these examples.

End-to-end learning vs SSS algorithms.
Figure 1: Cartoon of end-to-end vs SSS learning

In recent years, practitioners have proposed algorithms that are more similar to human learning than supervised learning. Such methods separate the process into two stages. In the first stage, we do representation learning whereby we use unlabeled data to learn a representation: a complex map (e.g., a deep neural net) mapping the inputs into some “representation space.” In the second stage, we fit a simple classifier (e.g., a linear threshold function) to the representation of the datapoints and the given labels. We call such algorithms “Self-Supervision + Simple fit” or SSS algorithms. (Note that, unlike other representation-learning based classifiers, the complex representation is “frozen” and not “fine-tuned” in the second stage, where only a simple classifier is used on top of it.)

While we don’t have a formal definition, a “good representation” should make downstream tasks easier, in the sense of allowing for fewer examples or simpler classifiers. We typically learn a representation via self supervision , whereby one finds a representation minimizing an objective function that intuitively requires some “insight” into the data. Approaches for self-supervision include reconstruction, where the objective involves recovering data points from partial information (e.g., recover missing words or pixels), and contrastive learning, where the objective is to find a representation that make similar points close and dissimilar points far (e.g., in Euclidean space).

SSS algorithms have been traditionally used in natural language processing, where unlabeled data is plentiful, but labeled data for a particular task is often scarce. But recently SSS algorithms were also used with great success even for vision tasks such as ImageNet and CIFAR10 where all data is labeled! While SSS algorithms do not yet beat the state-of-art supervised learning algorithms, they do get pretty close. SSS algorithms also have other practical advantages over “end-to-end supervised learning”: they can make use of unlabeled data, the representation could be useful for non-classification tasks, and may have improved out of distribution performance. There has also been recent theoretical analysis of contrastive and reconstruction learning under certain statistical assumptions (see Arora et al and Lee et al).

The generalization gap of SSS algorithms

In a recent paper, we show that SSS algorithms not only work in practice, but work in theory too.

Specifically, we show that such algorithms have (1) small generalization gap and (2) we can prove (under reasonable assumptions) that their generalization gap tends to zero with the number of samples, with bounds that are meaningful for many modern classifiers on the CIFAR-10 and ImageNet datasets. We consider the setting where all data is labeled, and the same dataset is used for both learning the representation and fitting a simple classifier. The resulting classifier includes the overparameterized representation, and so we cannot simply apply “off the shelf” generalization bounds. Indeed, a priori it’s not at all clear that the generalization gap for SSS algorithms should be small.

To get some intuition for the generalization gap of SSS algorithms, consider the experiment where we inject some label noise into our distribution. That is, we corrupt an \eta fraction of the labels in both the train and test set, replacing them with random labels. Already in the noiseless case (\eta=0), the generalization gap of SSS algorithms is noticeably smaller than that of end-to-end supervised learning. As we increase the noise, the difference becomes starker. End-to-end supervised learning algorithms can always achieve 100% training accuracy, even as the test accuracy deteriorates, since they can “memorize” all the training labels they are given. In contrast, for SSS algorithms, both training and testing accuracy decrease together as we increase the noise, with training accuracy correlating with test performance.

Figure 2: Generalization gap of end-to-end and SSS algorithms on CIFAR 10 as a function of noise (since there are 10 classes, 90% noisy samples corresponds to the Zhang et al experiment). See also interactive version.

Our main theoretical result is a formal proof of the above statement. To do so, we consider training with a small amount of label noise (say \eta=5\%) and define the following quantities:

  • The robustness gap is the amount by which training accuracy degrades between the “clean” (\eta=0) experiment and the noisy one. (In this and all other quantities, the training accuracy is measured with respect to the original uncorrupted labels.)
  • The memorization gap considers the noisy experiment (\eta=5\%) and measures the amount by which performance on the corrupted data samples (where we received the wrong label) is worse than performance on the overall training set. If the algorithm can memorize all given labels, it will be perfectly wrong on the corrupted data samples, leading to a large memorization gap.
  • The rationality gap is the difference between the performance on the corrupted data samples and performance on unseen test examples. For example, if x is an image of a dog, then it measures the difference between the probability that f(x)=\text{"dog"} when (x,\text{"cat"}) is in the training set and the probability that f(x)=\text{"dog"} when x is not in the training set at all. Since intuitively, getting the wrong label should be worse than getting no label at all, we typically expect the rationality gap to be around zero or negative. Formally we define the rationality gap to the maximum between 0 and the difference above, so it is always non-negative. We think of an algorithm with a significant positive rationality gap as “irrational.”

By summing up the quantities above, we get the following inequality, which we call the RRM bound

generalization gap \leq robustness gap + rationality gap + memorization gap

In practice, the robustness and rationality gaps are always small, both for end-to-end supervised algorithms (which have a large generalization gap), and for SSS algorithms (which have a small generalization gap). Thus the main contribution to the generalization gap comes from the memorization gap. Roughly speaking, our main result is the following:

If the complexity of the second-stage classifier of an SSS algorithm is smaller than the number of samples then the generalization gap is small.

See the paper for the precise definition of “complexity,” but it is bounded by the number of bits that it takes to describe the simple classifier (no matter how complex is the representation used in the first stage). Our bound yields non-vacuous results in various practical settings; see the figures below or their interactive version.

Figure 3: Empirical study of the generalization gap of a variety of of SSS algorithms on CIFAR-10. Each vertical line corresponds to one model, sorted by generalization gap. The RRM bound is typically near-tight, and our complexity upper bound is often non vacuous. Use this webpage to interact with figures 3 and 4.
Figure 4: Empirical study of gaps for the ImageNet dataset. Because of limited computational resources, we only evaluated the theoretical bound for two models in this dataset.

What’s next

There are still many open questions. Can we prove rigorous bounds on robustness and rationality? We have some preliminary results in the paper, but there is much room for improvement. Similarly, our complexity-based upper bound is far from tight at the moment, though the RRM bound itself is often surprisingly tight. Our work only applies to SSS algorithms, but people have the intuition that even end-to-end supervised learning algorithms implicitly learn a representation. So perhaps these tools can apply to such algorithms as well. As mentioned, we don’t yet have formal definitions for “good representations,” and the choice of the self-supervision task is still somewhat of a “black art” – can we find a more principled approach?