Over-fitting

From Deep Learning Course Wiki
Jump to: navigation, search

Overview

Over-fitting is a problem in machine learning in general, not just in neural networks. The problem is inherent in the way machine learning models are developed: A set of "training data" is used to "train" the model. The goal is to have a model that can then be used on data that hasn't been seen before. Over-fitting refers to the problem of having the model trained to work so well on the training data that it starts to work more poorly on data it hasn't seen before. There are a number of techniques to mitigate or prevent over-fitting.

How to recognise over-fitting

Over-fitting can be recognised by comparing the accuracy the model achieves on training data compared with the accuracy achieved on the validation data set. If the accuracy is higher for the training set, that suggests the model has been over-fit to the training data. It's not a problem for a small amount of over-fitting.

Over-fitting is useful

When we are designing a DNN, we want to start by over-fitting - that way we know our model is complex enough to solve the problem. From there we can implement solutions that reduce the over-fitting.

What to do about over-fitting

There area variety of ways to reduce over-fitting. These should be applied to the training set but not to the validation set (no messing with the validation set!). Some of these techniques are almost compulsory, and are considered good practice from the get-go.

  1. add more data
  2. use data augmentation
  3. use batch normalisation
  4. use architectures that generalise well
  5. add regularisation (L1, L2, dropout)
  6. reduce architecture complexity

Add more data

No point leaving any data out. If you have plenty of data in your training set, then put it in your validation set.

Data augmentation

Data augmentation is the process of using the data you currently have and modifying it in a realistic but randomised way, to increase the variety of data seen during training. As an example for images, slightly rotating, zooming, and/or translating the image will result in the same content, but with a different framing. This is representative of the real-world scenario, so will improve the training. It's worth double-checking that the output of the data augmentation is still realistic.

To determine what types of augmentation to use, and how much of it, do some trial and error. Try each augmentation type on a sample set, with a variety of settings (e.g. 1% translation, 5% translation, 10% translation) and see what performs best on the sample set. Once you know the best setting for each augmentation type, try adding them all at the same time.

Batch normalisation

When feeding data into a machine learning model, the data should usually be "normalised". This means scaling the data so that it has a mean and standard deviation within "reasonable" limits. This is to ensure the objective functions in the machine learning model will work as expected and not focus on a specific feature of the input data. Without normalising inputs the model may be extremely fragile.

Batch normalisation is an extension of this concept. Instead of just normalising the data at the input to the neural network, batch normalisation adds layers to allow normalisation to occur at the input to each convolutional layer.

The benefits of using batch normalisation are:

  • improves gradient flow through the network
  • Allows higher learning rates
  • Reduces the strong dependence on initialisation
  • acts as a form of regularisation

Batch normalisation has two elements:

  1. Normalise the inputs to the layer. This is the same as regular feature scaling or input normalisation.
  2. Add two more trainable parameters. One for a gradient and one for an offset that apply to each of the activations. by adding these parameters, the normalisation can effectively be completely undone, using the gradient and offset. This allows the back propagation process to completely ignore the back normalisation layer if it wants to.

Training vs testing

The batch normalisation must be handled differently during training and test time. During training the mean and standard deviation are calculated for each batch. During testing (i.e. once the model has been trained and is in use), the parameters for the batch normalisation layers must be held constant. The mean of the activations from the training are used instead.

L1/L2 Regularisation

Dropout

References

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

"Why does batch normalization help?" - Quora