Life-Long Learning — Part c
This is part three of a series on life-long learning. One of the resources is the “Stanford CS330 Deep Multi-Task & Meta-Learning” course and mainly lecture 15. The content of this blog post is mainly from the lecture, some parts are from the lecture, some from me, and some may be from an AI writing assistant.
Let’s continue where we left off in the previous blog post.
We saw at the end of the previous post that negative backward transfer can happen. Here is a question now: can we design an algorithm that uses only a very small amount of memory in order to avoid that sort of negative backward transfer?
One approach is called Gradient Episodic Memory (GEM) which is as follows:
So the idea would be to have an episodic memory and we assume that we can only store a very small amount of data per task in memory. Then, when making updates for new tasks, we’re going to try to ensure that those updates don’t unlearn the previous tasks. The first step is easy but the second step is a little bit more difficult. We’re going to assume that we’re learning some predictor that takes as input the example and the task ID and predicts the label. Also, we’ll assume that we have some memory for each task. This memory might have as few as five examples, for example. And then at each time step, we’re going to minimize the loss of our predictor on the data from that time step for that task. But the key part is that we’re going to add a constraint, which is that we don’t want the new predictor that we’re optimizing over task k to perform worse on previous tasks compared to the previous versions of the predictor in the previous time step on previous tasks.
We’re going to try to apply this constraint for all of the previous tasks, and not just our current task. Basically, we want to find a new predictor that does well on the new task, or the new data point, such that the loss on the previous tasks doesn’t get worse. This constraint will use the memory that we had stored for the previous tasks. This is going to make the assumption that the memory is at least somewhat representative. We could imagine possibly overfitting to that memory and just memorizing those functions, although, in practice, if you’re making a small update to the model, you would expect that if it doesn’t make the predictions on some of the data points worse, it probably shouldn’t also make other things a lot worse, as well.
Now, there’s a question of, how we actually implement this constraint in practice. And what we’re going to do is we don’t want to make these worse. So we can add a local linearity assumption. Basically, instead of placing a constraint on the loss function, we can place a constraint on the gradients. We’ll basically try to make it such that we ensure that the gradient that we’re trying to apply to our predictor points in the same direction or is orthogonal to the gradient for the previous task.
And if their gradients are pointing in the same direction or orthogonal, then, assuming local linearity, that means that the loss function for the previous task won’t decrease. And so if they’re orthogonal, we would expect the same loss function. If they’re actually pointing in the same direction or at least somewhat in the same direction, then you may actually get a positive backward transfer, where, you might improve on the previous tasks. It’s worth noting that it’s possible that this constraint leads to an infeasible constraint optimization. If your gradients for your current task and your previous task are pointing opposite, it may be that it’s not possible to improve on the current task while also going in the same direction as all other past tasks. This may also become increasingly infeasible as you increase the number of past tasks. Because that’s going to basically increase the number of constraints on your optimization.
They evaluated this approach in a few different experiments. They looked at three different lifelong learning problems.
- The one is a sequence of MNIST tasks, digit classification, where each task has a different permutation of the pixels in the image.
- The second task is different rotations of MNIST digits.
- The third task is a CIFAR-100 image classification task where each task introduces five new image classes.
The total memory size that they assumed is 5,012 examples.
In the top left plot, we can see that the Gradient Episodic Memory has an average accuracy similar to or higher than training a single model, training independent neural networks for each task from scratch, and some of these other prior methods. We can also see that some of these prior methods have negative backward transfer, meaning that they forget the previous tasks, whereas the GEM approach is actually able to get a small amount of positive transfer. We also see that there’s very little forward transfer for all of them.
Then the top right plot actually evaluates the accuracy on task one after you train on each additional task. We see that the GEM approach is able to maintain a pretty high performance on task one, whereas, with these other methods, the performance starts to drop as you see more and more tasks.
We can similarly look at plots for the second problem, the middle row, which is the different MNIST rotation task. Again, we see a somewhat similar trend. One thing that’s different here is we actually see a lot more forward transfer. And that’s probably expected because it’s probably pretty difficult — there isn’t a lot shared between different permutations of the pixels, whereas if you’re rotating these images, there’s a lot more that might be shared across the different tasks.
Then for CIFAR-100, it’s also somewhat of a similar story. It’s a little bit kind of noisier in terms of the performance. But it’s also able to get a less negative backward transfer and higher accuracy than some of the other methods.
When working on lifelong learning, it is important to ensure that the experimental domains being studied reflect real-world problems that are of concern.
You could actually use meta-learning to acquire a learning procedure, an online learning procedure, that can avoid forgetting. There are a couple of works that have looked into this topic and fairly successfully developed update rules that don’t have a backward transfer. If you’re interested in learning more about that, you can take a look at these references.
Let’s talk about a slightly different variation on the online learning formulation. So far, we looked at a formulation where we’re basically evaluated on a sequence of data points as we receive it. This problem setting can make a lot of sense in certain scenarios, especially if we have a stream of data. But if you do actually have different tasks, this formulation may not necessarily make full sense from the standpoint of the evaluation. Because when you see a new task, you’re actually going to be also evaluating its zero-shot performance. And it may actually be very difficult to perform well zero-shot on a completely new task. And in some cases, kind of more realistically, you might be given a small amount of data for each new task that you’re looking at. So the picture might look something more like the below one in the above slide, where you are actually kind of learning each task and then being evaluated on that task rather than being evaluated on that task right from the start. And what we might hope for in a setting like this would be something where the first task we’re learning pretty slowly. And as we see more and more tasks, we’re able to learn more and more quickly over time. Basically, the thing that differs in terms of the evaluation is instead of measuring performance on every single data point that you see, you can evaluate the performance only after seeing a small amount of data for each new task. And this is really primarily a difference in the evaluation rather than the stream of data.
And so, in particular, what this looks like is for each task that you see over time, you observe a small data set for that task. You use some update procedure, like gradient descent or something else, to produce parameters for that task. And then you observe a data point. You’re asked to make a prediction on that data point. And then you observe the label. And so these last three steps are identical to the standard online learning setting. The thing that’s different is you’re given this initial period to try to actually learn the task with a small amount of data. In this setting, you can actually create the analog of regret from the online learning setting in this online meta-learning scenario where it’s exactly the same as before, except instead of looking at the loss of the predictor, the predictor first gets to apply this update procedure — which could be one step of gradient descent, which could be something else — applied to the training data set before it’s actually evaluated on each task. Again, the goal here would be to try to get sublinear regret rather than linear regret.
We can apply meta-learning algorithms to this kind of setting. And you can basically take the same follow the leader algorithm. And instead of follow the leader, you could have something like follow the meta-leader, where you store all the data that you’ve seen so far, you meta-train on that data that you’ve seen so far, and then you apply the update procedure that you’ve meta-learned on the current task, and repeat that process.
You can basically — also similar to follow the leader — you can warm start your meta-parameters with the meta-parameters from the previous time step.
If the tasks you’re seeing in sequence are non-stationary, then it can be useful to use optimization-based meta-learners for this. Because you would still expect the update procedure to do well on tasks that are possibly out of distribution.
If you measure the learning efficiency and the learning proficiency — so how fast it’s learning and the error or the performance it has as you increase the task index — you see that, in general, these algorithms are all able to decrease the number of examples they need and decrease the error as they see more and more tasks. But if you use a meta-learning algorithm (the green one) versus some of these other algorithms, it’s actually able to kind of better learn more efficiently over time and also do better and better over time.
So the takeaways from the lecture are, first, there are lots of different flavors of lifelong learning. Unfortunately, a lot of the work out there puts them under this same name. That means that if you look up a paper on lifelong learning, it might end up being a very different problem setting than a different paper that studies lifelong learning.
Defining the problem statement is often one of the hardest parts of this. And so hopefully, the exercise at the beginning got you thinking a little bit about how you might define problem statements for different problems.
And lastly, you can sort of view meta-learning as one slice of the lifelong learning problem where you have some previous experience, and your goal is to very quickly learn something new at the current time step. And it’s also a pretty open area of research.
That’s it for this series. I think meta-learning can be a concept that might be a bit unclear to some people. So, I will try to write about that too.
Thank you for taking the time to read my post. If you found it helpful or enjoyable, please consider giving it a like and sharing it with your friends. Your support means the world to me and helps me to continue creating valuable content for you.