You don't need training data

I have a two year old. Like most two year olds she is in a stage of development where she picks up words like crazy. It's interesting as a person involved with machine learning to observe how she learns things. You can point at a sponge, an object she's never obviously never seen before and say "this is a sponge". Now ask her again in 5 minutes: "what is this?". "Spon Spon! Songe! Sponge! Spon! Dada it Sponge!" Spooooooooge!! DADAAAA!!". It doesn't stop there, show her a highly artistic illustration of a sponge and she'd just as well quickly identify that it is still a sponge. How could this be possible?

Machine learning (ML) practitioners come from a field where we are used to requiring thousands of training examples to teach a machine learning model how to identify simple objects in a scene. But this not the case for humans. From a purely machine learning perspective, this ability to learn from one example is incredibly interesting. Our own brains seem to be able to learn complex ontological information from just one example, and even to identify illustrations from just one example of the real thing. This is totally baffling! Clearly something is going on in the brain that we aren't capturing in our ML models.

Zero-shot learning is an area of study in ML to answer a simple question: how much knowledge can we gain from a single example of something? This zero-shot learning ability in the human mind ought to be possible in machines as well. Investigations in this area might even be an enabler for new insights into artificial general intelligence: machines can reason and think just as well as we do.

From an information theoretic perspective, zero-shot learning seems fundamentally impossible. Consider a picture of a sports car. As a 2D data source, there is no way to fully capture the information of what a sports car is. How can we tell that a sports car is a sports car from just one picture? Clearly there are many makes and models of sports cars. If the picture is taken from the side, it might not be possible to know if a sports car has two wheels or four wheels. There is clearly missing information when you only consider one example of a thing. This implies that this is an impossible problem. There is not enough information in one example to derive any meaningful generalizations.

But we are thinking of the example as a data point in isolation. But this is not true in the real world. We've seen cars before that were not sport cars. We've seen cars from many angles and could tell that cars typically have four wheels. We can tell that non-sports cars have distinctive body shapes that we haven't seen before in a sports car. Using past knowledge we can imagine features of a sport car that we might not have. We imagine based on other examples of cars we've seen in the world. Obviously a sports car will have four wheels, it's a car!

In fact we may generalize features of the object from all other experiences we've had in the world. We aren't truly learning from one example. We are learning from one example plus all other examples that came before it. That one example is conditioned on all the knowledge we have about how the world operates, and we can fill in the blanks with strong enough assumptions. This is what enables us to draw really powerful conclusions from what at first looks like tiny amounts of data. This general idea of learning from past experience is called transfer learning, and it is used today to reduce the amount of data needed to train new models by basing these new models on older models that were used to solve related problems.

Can we improve on transfer learning? The problem is transfer learning is a one directional process. The past model was trained to solve a specific problem. It's only a happy coincidence that it might be helpful to a future problem. The original problem does not know about the future problem, how could it? Could we predict the future? That seems like nonsense! Except that machine learning is all about prediction.

It seems to me that the best kind of transfer learning would be a learning algorithm that not only considers how to optimally solve the current task at hand, but how well it could be used to solve tasks that are yet unknown. How to accomplish this is still an open research question (or "universal generalization performance"). The problem is related to the problem of catastrophic forgetting, that is a problem of how to maintain model performance on past tasks. By building models that avoid catastrophic forgetting, we can improve universal generalization performance, its ability to be adapted to new and different tasks.

How can we solve catastrophic forgetting? The most obvious way is to never forget anything. One problem with neural networks is they are usually parametric models. What is a parametric model you say? Parameters are weights or other numbers or information that is learned from examples. These parameters are a big part of the definition of the model, but they can contain different information from the training examples that defined the parameters. In fact, you might end up with a different set of parameters every time you train a model with the same data. The goal of many parametric models is not to learn all the possible information from data, but to learn enough to solve whatever problem it is being optimized to solve.

If we decided to define "parametric models", there should be such a thing as "non-parametric model". And there is. One of the most canonical examples of a non-parametric model is called k-nearest neighbors or k-NN for short. k-NN is a learning algorithm that barely qualifies as a learning algorithm. In its naive formulation it doesn't have any learned parameters at all, it only considers the training data in its decisions.

Consider a dataset that is just a catalog of flowers of different species. All we know about the flowers is the length of a flower stem and the average length of its pedals. These features of the flowers are just numbers you can plot as points (vectors) on a graph. We can take this training data and plot it on a graph. We'll call this graph a feature space. That's all there is to the learning procedure for k-NN, if you can even call it that. Now how do you do inference? How do you figure out what species a flower is based on its features? You take the example's features and plot it on the same feature space as the training examples. Now you see what other flowers are nearby. That's it. All we are answering is: what is this new example's nearest neighbors in the feature space? The K is how many neighbors needed to vote on a correct answer, the most simple would be 1-NN, which would mean the single nearest flower's species is also the unknown flower's species. Simple enough?

The problem is the naive version of k-NN is a super simple algorithm where we make some pretty massive assumptions. Who is to say that the features-to-species connection has a direct "nearness" relationship that we can define beforehand? Maybe the between different classifications of flowers is more complex than what simple nearness of the features of the flowers. More complex parametric models can learn pseudo-similarity measures.

You also have the problem of the curse of dimensionality. Imagine your "features" being the pixels of an image or video. Each one of those features creates a dimension (axis) on the plot. When the dimension of the feature space increases it becomes exponentially difficult computationally to compare nearness of vectors. Worse yet, it's also just difficult to find vectors that are near each other in the first place. You can think of this problem by observing that in high dimensional space, there are so many axes to plot points on. This produces many degrees of freedom, many places for vectors to occupy. Everything is given the chance to be far away from everything else. When computing nearness, you lose a ton of information. Even though you have a rich high dimensional feature space, distance is a single number (a scalar), while these vectors could be 1000s of numbers. It's clear that naive k-NN has problems. But we like the non-parametric nature of the model for zero-shot learning, because we can always refer back to the data it was trained on. So how can we do better?

One solution to zero-shot learning is to somehow take the advantages of non-parametric models like k-NN and parametric models like most neural networks and combine them in a way that produces a model that has the strengths of both but the weaknesses of neither. It might be possible to do good zero-shot learning with a purely parametric model too. But it's not something a conventional neural network or other parametric model prioritizes. The only thing a conventional neural net priorities is the problem at hand. But can we increase this information? Better yet, can we explicitly not lose information that is beneficial, even though it might not be currently beneficial?


Matching Networks for One Shot Learning by Oriol Vinyals et. al. is an example of integrating the advantages of non-parametric models with the advantages of neural networks.

Overcoming catastrophic forgetting in neural networks by James Kirkpatrick et. al. makes the point that if you want to avoid catastrophic forgetting maybe you should optimize for something besides the current problem at hand.

Attalos by Lab41 is a "joint vector space" between words and images that is able to transfer knowledge contained in either modality. This enables use cases like unconstrained image search.

Binah by me takes the idea of a "joint vector space" further and allows arbitary image queries (eg: "surfers on a beach in Hawaii").