Neural networks do evolve and change their predictions as they train and there’s been a recent effort to use those changes to understand the underlying training data better. One of the more notable attempts at understanding these changes is an ICLR ’19 paper by Microsoft Research Montreal titled: An Empirical Study Of Example Forgetting During Deep Neural Network Learning. We recommend reading the whole thing, but we want to use it as a jumping off point for our research, so let’s summarize a few of its key takeaways:
- A forgetting event is when a model that previously understood an example later misclassifies that example.
- A model’s total forgetfulness is the number of times that event occurs.
- The paper demonstrates that examples are forgotten by chance very rarely and that the vast majority are attributed to the content of the data itself.
- When you remove so-called “unforgettable examples” (examples where the model’s prediction does not change) and then retrain the model, its accuracy suffers.
- They also posit that the quality of unforgettable examples is such that they transfer well between architectures, meaning that computing forgetfulness using a smaller network would be a good approximation to computing the same using a larger network.
What we’re most interested in here is thinking of ways we can leverage forgetfulness to improve the training process. Though that leads us to a tricky question: How can we use forgetfulness if its computed based on training set predictions?
This is based on real-world concerns. Training data, after all, is completely labeled and we’re interested in saving time and money on labeling costs here. While computing forgetfulness training data is interesting, computing it on unseen, unlabeled data has more dramatic implications.
Here’s how we’re going to compute that:
Say in the course of our training process, we perform 5 different validation steps on an unseen data point. Let’s say the predictions for a particular data points are: class 3, class 5, class 5, class 5, class 4. We’ll be computing that as a “forgetfulness” of 2, as the prediction changed twice, once from class 3 to 5, then again from class 5 to 4.
Our experiment is as follows:
- We’ll be using CIFAR-10
- We’ll be dividing the traning dataset in two: a 40K set and a 10K set
- We’ll be using a ResNet-20 to train on the second set (the smaller, 10K set) for 30 epochs and we’ll estimate forgetfulness” on the 40K example after each of those epochs
- Then, we’ll analyze the impact of removing the most and the least forgotten examples.
- For our purposes, we will not know the labels of the 40K record set while assigning them a forgetfulness score.
Let’s dig in:
First, we needed to figure out the forgetfulness of our 40K dataset. On the graph below, you’ll see the forgetfulness score on the X-axis and the amount of examples with that score on Y. Essentially, examples closer to the left are less forgettable. Examples on the right with high scores means our model changed its prediction more frequently:
Here’s a quick look at some examples that were least forgotten:
And now the opposite:
So now we have a forgetfulness scores for each piece of data in our 40K set. What we want to do next is train a model. Here, we’re interested in the model’s accuracy when it’s trained on the most forgotten examples versus the least forgotten examples. We’ll also look at what happens when we pull samples from more of a middle ground.
Below, you’ll see those results. The red line shows what happens when we train on the least forgotten data whereas the green shows what happens when we train on the most forgotten examples. Our purple line shows a less extreme approach:
Here, you can see that leaning on those least forgotten examples is much more successful than choosing the most forgotten examples as our training set. This suggests that, given limited datasets or limited training budgets, focusing on the least forgotten examples is likely a smart strategy. In fact, you can see that at lower epochs, our “only remembered” strategy excels until about epoch 10. The most forgotten examples took about a long time to simply meet the worst accuracy of more successful approaches.
For a little extra information, let’s look at training loss and test accuracy on those last 10K examples. First: training loss:
As expected, training on forgotten examples takes a very long time to converge — it only does so when the learning rate is reduced, in fact.
Now, let’s look at the test accuracy on those last 10K images:
This leads us to the following: when dealing with a small number of examples, it’s best to select the most representative ones and/or the ones with medium complexity. The most forgotten end of the spectrum contains the most complicated, noisy examples and relying on them solely doesn’t work well.
Next, we want to look at what happens if instead of training on the most and least forgotten examples, we remove them from training?
Here, the difference is a bit less stark but interesting nonetheless. Removing the most remembered examples (about 11 thousand in all) gave us a training accuracy of 82.4%. Removing the most forgotten (5,000 in total) gives us a lower accuracy (of about 79.7%).
In other words, removing 11,000 of those most remembered examples doesn’t have as much of an effect on accuracy than removing 5,000 of the most forgotten examples.
With that baseline, let’s look at how we can apply what we’ve discovered thus far to active learning. We want to start by asking the following: what exactly is the difference between the most forgetfulness metric above and using an uncertainty-based querying strategy for active learning? (We’ve written a lot about querying strategies and AL before, but a brief refresher if you aren’t up to date: active learning samples subsets of training set during training. Which samples it ingests are determined by its querying strategy. A “least confidence” strategy, then, means the model will be looking for examples it’s least confident about predicting.)
Interestingly, there was far less less overlap than you might expect. Here we’re using the same setup and same splits in the dataset as above. After training on 10k images, the chart below shows the overlap between the top 1000 forgotten examples and the top 1000 least confidence examples, least margin examples, and highest entropy examples.
But when we get a Venn diagram of where those strategies overlap, we get this:
In other words, there’s a great deal more overlap here than there was with our forgetfulness metric. This means we likely can’t use these querying strategies as proxies for forgetfulness.
So, what happens if we use forgetfulness as our querying strategy? About what you’d expect:
Training on the “hardest” examples only gives us less accuracy than other more commonly-used querying strategies. So let’s take it a different direction. If it’s not the best querying strategy, can forgetfulness help us uncover outliers?
To approximate this, we’re taking 10% of that larger portion of CIFAR (read: 4,000 records) and corrupting them. Namely, we’re going to erase random parts of the image or flip the vertically. We get stuff like this:
Now, what are the top 1000 examples for each querying strategy? How likely are they to find to our corrupted outliers? Traditional, simple uncertainty samples find modest amounts. 85 of 1000 high entropy examples are noisy, where least confidence and least margin sampling consisted of 133 and 134 of the 1000 examples. For forgetfulness?
The results are extremely supportive of our hypothesis that forgetfulness is a great measure for finding outliers and potentially filtering them out before training. There is one problem however. One most carefully manipulate a forgetfulness hyper-parameter over their dataset to identify at what stage the examples start to become less useful.
Lastly, we want to look at how some of this research might inform non-classification tasks. We’re looking to test whether forgetfulness beats basic querying strategies on an instance segmentation task — detecting pedestrians in the Penn-Fudan dataset.
Now, Penn-Fudan is a bit small at only 170 images so we did a little pre-training here. Namely, we used the model & hyper-parameter setup given by this informative PyTorch tutorial. They use a pre-trained Faster-RCNN backbone and replace the head to do segmentation instead of detection. Here’s how we broke things down:
- 50 examples for training the network and calculating forgetfulness
- 30 examples as the final test set
- Use forgetfulness metrics to choose and add 30 images (out of a possible 90) to the training dataset and measure performance against random.
Our most forgotten examples often had multiple people, occlusions, and often people far in the distance.
Results: Here are the metrics generated by the original PyTorch code in two cases. First, when the examples were selected using our strategy and second, when they were chosen randomly.
As you can observe, we pretty much beat random on all fronts. The most important metric to me is AP (area=medium) since it outlines the effectiveness of our technique. The strategy picked examples where people were occluded or a little far away (hence, medium size bounding boxes) and it was rewarded for it on test set: 0.361 with random vs 0.497 with forgetfulness.
In the end, we’re really intrigued about forgetfulness and its ability to detect outliers as well as its potential to be an additional querying strategy (though there’s plenty of research left to do on that front). We’ll keep you updated on what we find in the coming months. Thanks for reading!