Overfitting In Machine LearningDo you get what overfitting means in machine learning? If you don't, then you better learn about it if you want to use or leverage machine learning. Why? Because overfitting can ruin the effectiveness of machine learning. I wrote this blog because I found existing explanations of overfitting to be too technical. I hope this one is more consumable by non specialists. The Machine Learning WorkflowMachine learning involves a fairly complex workflow, see Machine Learning Algorithm != Learning Machine for a detailed discussion. Overfitting can occur in one specific part of the workflow, which is the part where machine learning algorithms are used to create models. This part can be summarized with the picture below. It starts with a set of labelled examples (yes this is supervised learning, see What Is Machine Learning? for other types of machine learning). For each examples we have data, and a target value we want to predict. The input examples are called training examples. The target value is sometime called the label of the example. I will use in this entry data taken from intr
Training data table. We have two other examples, namely Batgirl and Riddler that we keep for testing purpose. These two examples are called test examples to differentiate them from the training examples. Test examples must not be looked at until we have completed the construction of our machine learning model for reasons that will be clearer later. The goal of machine learning algorithms is to find how to compute the target of each training example from the data of that example. The target can be a class name (in which case we have a classification problem), a numerical value (in which case we have a regression problem), or more complex data such as a sentence (in which case we have a data generation problem). Our superheroes use case is a classification problem because Target can take only two values. Whatever the machine learning problem, the overall process is the same. We use machine learning algorithms to find a way to compute the target from the rest of the data associated with each training example. The 'way' of computing labels is called a machine learning model, or a model when there is no confusion. In our isuperheroes use case, the model will compute a value for Target.. When we have obtained a model from training examples, we move it to an operational system where it is used to compute label for new, unforeseen, examples. For instance, we will use our model to classify Batgirl and Riddler. as Good or Bad. Computing labels for new, unforeseen, examples is often called classification, prediction, or scoring. The Assumption Behind Machine LearningIt should then be clear that the value of a machine learning model comes from the labels it computes on new examples. The more accurate these labels, the better. The goal of machine learning is therefore to find the model that will yield the best possible accuracy of prediction on new example. This is harder than it look. Do you see the catch? Issue is this: how can we evaluate the accuracy of predictions on new examples without looking at these examples in the first place? It seems like a self contradicting task, hence an impossible task. To make it clearer perhaps: in our superheroes use case, we want to find a model that yields the best possible predictions for Batgirl and Riddler without looking at Batgirl and Riddler data. This seems siilly at best. The solution to this paradox is to make an assumption at the core of machine learning: we assume that the training examples are representative (are close to) the new examples that will be encountered in the operational system. Therefore, if a model yields good prediction on the training examples, then it will yield good predictions on the new examples. Then, we can focus on finding models that yield good predictions on the training examples. And this is doable as we have access to these training examples and their labels. In our superheroes use case, we assume that if a model classifies all training examples correctly then it is likely to classify Batgirl and Riddler correctly too. I cannot stress enough the importance of the above assumption. It has profound implications. If your training examples bear no resemblance whatsoever with the new instances that will be found in the operational system, then machine learning is not going to help. For instance, suppose our learning problem is to learn how to recognize objects in pictures. if all our training examples are images of animals, then we may find a model that yields accurate labels for new images of animals. But the same model will just produce random guesses if presented images of cars. I may look like stating the obvious, but this is something that newcomers to machine learning do not necessarily understand right away. Machine Learning can only be as good as the training examples. If your training examples do not contain a given information, then machine learning will not be able to discover that information. Selecting A ModelLet's assume therefore that we have training examples that are representative of the new examples that we will get in the operational system. Are we safe, and can we blindly trust machine learning? As a matter of fact we can't, unless some precautions are taken. Let me show you why there still can be an issue. We will use on class of machine learning models, namely decision trees. A decision tree is a set of yes/no questions that lead to a conclusion. I won't describe here how to construct decision trees from training examples, and will just show possible results. A first model can be the following tree. Note that the tree is upside down, with the trunk at the top, and leaves at the bottom. Model 1. Let us see how to use that model to make a prediction. For each example, we start at the top of the tree. We start with the Batman example. The first question is: has Batman a tie? The answer is no, as we can see from the training data table. We therefore go to the left branch of the tree. Second question is: has Batman a cape? Answer is yes, therefore we go down on the right branch. We then reach a leave labelled Good. This is the prediction of Batman. It so happen that the prediction is the same as the truth label, which is Good. Let us do the same computation for all our training examples. Results are summarized below.
Training data table with predictions for model 1.
We see that the predictions are the same as the true label for all examples! This looks like a great model. It is not the only possible model. We can construct tress that have more nodes, for instance model 2 below. Model 2. We can also construct trees that have less node, like model 3 below. Model 3. Given we can construct many trees, how do we select one? As suggested below, we can look at which tree yields the best predictions on the training examples. The table below provides the predictions for the three models.
Training data table with predictions from 3 models.
We see that model 1 and model 2 have perfect predictions for the training examples. Model 3 misses two predictions: it predicts Penguin and Joker are Good while the truth label is Bad. Model 3 isn't a good model therefore. We can discard it probably. What about model 1 versus model 2? How can we chose? The rule of thumb is to select the simplest model in general. For decision trees, it means choosing the tree with the smaller number of leaves, i.e. model 1 in our case. Overfitting And UnderfittingOne way to evaluate the performance of our model on new examples is to use the test examples. Here is the data of the test examples we have.
Test data table. It is important to treat these as new examples, exactly like the examples we will encounter in the operational system. In particular, test examples must not be used for selecting the model. All they can be used for is to give a hint as to how our models will behave with new examples. Indeed, if we use them for selecting the model, then they become part of the training data, and they no longer usable to assess performance on new, unforeseen, examples. Using some examples to select model is still a good idea, and it is the core of the cross validation technique. We will cover that technique in another post. For now we use test data to evaluate model performance after we decided to select one of the model as the best model. We have selected model 1, but let's use our three models to make predictions on test data. Results are included in the table below.
Test data table with predictions from 3 models. We see that model 3 predictions are completely wrong, which is in line with training data predictions. This model is not a good model. We also see that model 2 isn't great there, with only one correct prediction out of two, while model1 makes perfect predictions. The reason model 2 does not perform well is because it is too complex. In a way, it captures all the training examples as they are; it performs rote learning. Model 1 is different, it has abstracted from the training examples to capture the essence of what is a good superhero and what is an evil superhero. Model 3 abstracted too much, it is overly simplistic and it fails at capturing what makes a good vs evil superhero. We say that model 2 is overfitting to the training examples, while model 3 is underfitting to the training examples. Underfitting is quite easy to spot: predictions on train data aren't great. Overfitting is way harder to spot on training data, because it yields great predictions on training data. In our use case, it is not even possible to distinguish between model 1 and model 2 using training data alone. It is only when we try the models on new, unforeseen data like the test data that we can spot overfitting. However, as we said before, we cannot access the test data when building the model. This is why overfitting is a fundamental issue in machine learning: it is hard to spot with training data alone. How To Avoid OverfittingA number of techniques have been developed to avoid overfitting, like cross validation. But the one technique that seems most powerful is to favor simpler models over more complex ones. For instance, we favored model 1 over model 2 because model 1 is simpler. Wait a minute, model 3 is simpler than model 1: shouldn't we prefer model 3 over model 1? We don't want to do that because model 3 yields less accurate predictions on training data than model 1. As a matter of fact we must balance two conflicting goals:
Most state of the art machine learning algorithms embody these two goals as follow.
Then the algorithm looks for a model that minimizes the sum of the loss function and the regularization function. This way, the resulting model is both predictive and simple. Use of a regularization function helps, but it may not be sufficient. Covering all the techniques that can help combat overfitting is beyond this post (I'll discuss cross validation in another post). But I hope I managed to explain why it can be a problem. Let me conclude with a simple trick: I wrote above that overfitting is hard to spot, but there is a case where overfitting is easy to spot. It is when the predictions on the test data are too good to be true. If you are working on a machine learning problem and you find a model that has 98% accuracy when all other techniques you tried had 60% accuracy then you are most probably overfitting: the improvement in accuracy is just too good to be true. In that case, try to look for simpler models.
