It is often said that "we don’t understand deep learning" but it is not as often clarified what is it exactly that we don’t understand. In this post I try to list some of the "puzzles" of modern machine learning, from a theoretical perspective. This list is neither comprehensive nor authoritative. Indeed, I only started looking at these issues last year, and am very much in the position of not yet fully understanding the questions, let alone potential answers. On the other hand, at the rate ML research is going, a calendar year corresponds to about 10 "ML years"…
Machine learning offers many opportunities for theorists; there are many more questions than answers, and it is clear that a better theoretical understanding of what makes certain training procedures work or fail is desperately needed. Moreover, recent advances in software frameworks made it much easier to test out intuitions and conjectures. While in the past running training procedures might have required a Ph.D in machine learning, recently the "barrier to entry" was reduced to first to undergraduates, then to high school students, and these days it’s so easy that even theoretical computer scientists can do it 🙂
To set the context for this discussion, I focus on the task of supervised learning. In this setting we are given a training set of examples of the form where is some vector (think of it as the pixels of an image) and is some label (think of as equaling if is the image of a dog and if is the image of a cat). The goal in supervised learning is to find a classifier such that will hold for many future samples .
The standard approach is to consider some parameterized family of classifiers, where for every vector of parameters, we associate a classifier . For example, we can fix a certain neural network architecture (depth, connections, activation functions, etc.) and let be the vector of weights that characterizes every network in this architecture. People then run some optimizing algorithm such as stochastic gradient descent with the objective function set as finding the vector that minimizes a loss function . This loss function can be the fraction of labels that gets wrong on the set or a more continuous loss that takes into account the confidence level or other parameters of as well. By now this general approach has been successfully applied to a many classification tasks, in many cases achieving near-human to super-human performance. In the rest of this post I want to discuss some of the questions that arise when trying to obtain a theoretical understanding of both the powers and the limitations of the above approach. I focus on deep learning, though there are still some open questions even for over-parameterized linear regression.
The generalization puzzle
The approach outlined above has been well known and analyzed for many decades in the statistical learning literature. There are many cases where we can prove that a classifier obtained in this case has a small generalization gap, in the sense that if the training set was obtained by sampling independent and identical samples from a distribution , then the performance of a classifier on new samples from will be close to its performance on the training set.
Ultimately, these results all boil down to the Chernoff bound. Think of the random variables where if the classifier makes an error on the -th training example. The Chernoff bound tells us that probability that that deviates by more than from its expectation is something like and so as long as the total number of classifiers is less than for , we can use a union bound over all possible classifiers to argue that if we make a fraction of errors on the training set, the probability we make an error on a new example is at most . We can of course "bunch together" classifiers that behave similarly on our distribution, and so it is enough if there are at most of these equivalence classes. Another approach is to add a "regularizing term" to the objective function, which amounts to restricting attention to the set of all classifiers such that for some parameter . Again, as long as the number of equivalence classes in this set is less than , we can use this bound.
To a first approximation, the number of classifiers (even after "bunching together") is roughly exponential in the number of parameters, and so these results tell us that as long as the number of of parameters is smaller than the number of examples, we can expect to have a small generalization gap and can infer future performance (known as "test performance") from the performance on the set (known as "train performance"). Once the number of parameters becomes close to or even bigger than the number of samples , we are in danger of "overfitting" where we could have excellent train performance but terrible test performance. Thus according to the classical statistical learning theory, the ideal number of parameters would be some number between and the number of samples , with the precise value governed by the so called "bias variance tradeoff".
This is a beautiful theory, but unfortunately the classical theorems yield vacous results in the realm of modern machine learning, where we often train networks with millions of parameters on a mere tens of thousands of examples. Moreover, Zhang et al showed that this is not just a question of counting parameters better. They showed that modern deep networks can in fact "overfit" and achieve 100% success on the training set even if you gave them random or arbitrary labels.
The results above in particular show that we can find classifiers that perform great on the training set but perform terribly on the future tests, as well as classifiers that perform terrible on the training set but pretty good on future test. Specifically, consider an architecture that has the capacity to fit arbitrary labels, and suppose that we train it on a set of examples. Then we can find a setting of parameters that both fits the training set exactly (i.e., satisfies for all ) but also satisfies that the additional constraint that (i.e., the negation of the label ) for every in some additional set of pairs. (The set is not part of the actual training set, but rather an "auxiliary set" that we simply use for the sake of constructing this counterexample; note that we can use as means to generate the initial network which can then be fed into standard stochastic gradient descent on the set .) The network fits its training set perfectly, but since it effectively corresponds to training with 95% label noise, it will perform worse than even a coin toss.
In an analogous way, we can find parameters that completely fail on the training set, but fit correctly the additional "auxiliary set" . This will correspond to the case of standard training with 5% label noise, which typically yields about 95% of the performance on the noiseless distribution.
The above insights break the separation of concerns or separation of computational problems from algorithms which we theorists like so much. Ideally, we would like to phrase the "machine learning problem" as a well defined optimization objective, such as finding, given a set , the vector that mimimizes . Once phrased in this way, we can try to find with an algorithm that achieves this goal as efficiently as possible.
Unfortunately, modern machine learning does not currently lend itself to such a clean partition. In particular, since not all optima are equally good, we don’t actually want to solve the task of minimizing the loss function in a "black box" way. In fact, many of the ideas that make optimization faster such as accelaration, lower learning rate, second order methods and others, yield worse generalization performance. Thus, while the objective function is somewhat correlated with generalization performance, it is neither necessary nor sufficient for it. This is a clear sign that we don’t really understand what makes machine learning work, and there is still much left to discover. I don’t know what machine learning textbooks in the 2030’s will contain, but my guess is that they would not prescribe running stochastic gradient descent on one of these loss functions. (Moritz Hardt counters that what we teach in ML today is not that far from the 1973 book of Duda and Hart, and that by some measures ML moved slower than other areas of CS.)
The generalization puzzle of machine learning can be phrased as the question of understanding what properties of procedures that map a training set into a classifier lead to good generalization performance with respect to certain distributions. In particular we would like to understand what are the properties of natural natural distributions and stochastic gradient descent that make the latter into such a map.
The computational puzzle
Yet another puzzle in modern machine learning arises from the fact that we are able to find the minimum of in the first place. A priori this is surprising since, apart from very special cases (e.g., linear regression with a square loss), the function is in general non convex. Indeed, for almost any natural loss function, the problem of finding that minimizes is NP hard. However, if we look at the computational question in the context of the generalization puzzle above, it might not be as mysterious. As we have seen, the fact that the we output is a global minimizer (or close to minimizer) of is in some sense accidental and by far not the the most important property of . There are many minima of the loss function that generalize badly, and many non minima that generalize well.
So perhaps the right way to phrase the computational puzzle is as
"How come that we are able to use stochastic gradient descent to find the vector that is output by stochastic gradient descent."
which when phrased like that, doesn’t seem like much of a puzzle after all.
The off-distribution performance puzzle
In the supervised learning problem, the training samples are drawn from the same distribution as the final test sample. But in any applications of machine learning, classifiers are expected to perform on samples that arise from very different settings. The image that the camera of a self-driving car observes is not drawn from ImageNet, and yet it still needs to (and often can) detect whether not it is seeing a dog or a cat (at which point it will break or accelerate, depending on whether the programmer was a dog or cat lover). Another insight to this question comes from a recent work of Recht et al. They generated a new set of images that is very similar to the original ImageNet test set, but not identical to it. One can think of it as generated from a distribution that is close but not the same as the original distribution of ImageNet. They then checked how well do neural networks that were trained on the original ImageNet distribution perform on . They saw that while these networks performed significantly worse on than they did on , their performance on was highly correlated with their performance on . Hence doing better on did correspond to being better in a way that carried over to the (very closely related) . (However, the networks did perform worse on $D’$ so off-distribution performance is by no means a full success story.)
Coming up with a theory that can supply some predictions for learning in a way that is not as tied to the particular distribution is still very much open. I see it as somewhat akin to finding a theory for the performance of algorithms that is somewhere between average-case complexity (which is highly dependant on the distribution) and worst-case complexity (which does not depend on the distribution at all, but is not always achievable).
The robustness puzzle
If the previous puzzles were about understanding why deep networks are surprisingly good, the next one is about understanding why they are surprisingly bad. Images of physical objects have the property that if we modify them in some ways, such as perturbing them in a small number of pixels or by few shades or rotating by an angle, they still correspond to the same object. Deep neural networks do not seem to "pick up" on this property. Indeed, there are many examples of how tiny perturbations can cause a neural net to think that one image is another, and people have even printed a 3D turtle that most modern systems recognize as a rifle. (See this excellent tutorial, though note an "ML decade" has already passed since it was published). This "brittleness" of neural networks can be a significant concern when we deploy them in the wild. (Though perhaps mixing up turtles and rifles is not so bad: I can imagine some people that would normally resist regulations to protect the environment but would support them if they confused turtles with guns..) Perhaps one reason for this brittleness is that neural networks can be thought of as a way of embedding a set of examples in dimension into dimension (where is the number of neurons in the penultimate layer) in a way that will make the positive examples be linearly separable from the negative examples. Amplifying small differences can help in achieving such a separation, even if it hurts robustness.
Recent works have attempted to rectify this, by using a variants of the loss function where corresponds to the maximum error under all possible such perturbations of the data. A priori you would think that while robust training might come at a computational cost, statistically it would be a "win win" with the resulting classifiers not only being more robust but also overall better at classifying. After all, we are providing the training procedure with the additional information (i.e., "updating its prior") that the label should be unchanged by certain transformations, which should be equivalent to supplying it with more data. Surprisingly, the robust classifiers currently perform worse than standard trained classifiers on unperturbed data. Ilyas et al argued that this may be because even if humans ignore information encoded in, for example, whether the intensity level of a pixel is odd or even, it does not mean that this information is not predictive of the label. Suppose that (with no basis whatsoever – just as an example) cat owners are wealthier than dog owners and hence cat pictures tend to be taken with higher quality lenses. One could imagine that a neural network would pick up on that, and use some of the fine grained information in the pixels to help in classification. When we force such a network to be robust it would perform worse. Distill journal published six discussion pieces on the Ilyas et al paper. I like the idea of such "paper discussions" very much and hope it catches on in machine learning and beyond.
The interpretability puzzle
Deep neural networks are inspired by our brain, and it is tempting to try to understand their internal structure just like we try to understand the brain and see if it has a "grandmother neuron". For example, we could try to see if there is a certain neuron (i.e., gate) in a neural network that "fires" only when it is fed images with certain high level features (or more generally find vectors that have large correlation with the state at a certain layer only when the image has some features). This also of practical importance, as we increasingly use classifiers to make decisions such as whether to approve or deny bail, whether to prescribe to a patient treatment A or B, or whether a car should steer left or right, and would like to understand what is the basis for such decisions. There are beautiful visualizations of neural networks’ decisions and internal structures , but given the robustness puzzle above, it is unclear if these really capture the decision process. After all, if we could change the classification from a cat to a dog by perturbing a tiny number of pixels, in what sense can we explain why the network made this decision or the other.
The natural distributions puzzle
Yet another puzzle (pointed out to me by Ilya Sutskever) is to understand what is it about "natural" distributions such as images, texts, etc.. that makes them so amenable to learning via neural networks, even though such networks can have a very hard time with learning even simple concepts such as parities. Perhaps this is related to the "noise robustness" of natural concepts which is related to being correlated with low degree polynomials. Another suggestion could be that at least for text etc.., human languages are implicitly designed to fit neural network. Perhaps on some other planets there are languages where the meaning of a sentence completely changes depending on whether it has an odd or an even number of letters…
The above are just a few puzzles that modern machine learning offers us. Not all of those might have answers in the form of mathematical theorems, or even well stated conjectures, but it is clear that there is still much to be discovered, and plenty of research opportunities for theoretical computer scientists. In this blog I focused on supervised learning, where at least the problem is well defined, but there are other areas of machine learning, such as transfer learning and generative modeling, where we don’t even yet know how to phrase the computational task, let alone prove that any particular procedure solves it. In several ways, the state of machine learning today seems to me as similar to the state of cryptography in the late 1970’s. After the discovery of public key cryptography, researchers has highly promising techniques and great intuitions, but still did not really understand even what security means, let alone how to achieve it. In the decades since, cryptography has turned from an art to a science, and I hope and believe the same will happen to machine learning.
Acknowledgements: Thanks to Preetum Nakkiran, Aleksander Mądry, Ilya Sutskever and Moritz Hardt for helpful comments. (In particular, I dropped an interpretability experiment suggested in an earlier version of this post since Moritz informed me that several similar experiments have been done.) Needless to say, none of them is responsible for any of the speculations and/or errors above.