Data Augmentation: Part 1
One of the key components to a successful machine learning product is having sufficient, good quality data to train the classifier. The data samples should be representative of the entire population distribution. Increasing the number of samples reduces the risk of your model over fitting the data. That is, the model is too complex for the data set. The best way to get more samples is to simply go out and collect them. This might mean expensive and time consuming experimental data collection, along with manually labelling thousands or even million of samples with the correct class label.
However, in many instances the cost and/or time required to collect the additional samples is prohibitive. In this article, I will outline two methods to synthetically increase the number of samples available for your machine learning task.
How much is enough?
The first question to ask is do you have enough samples? The simplest way to answer this is to divide your entire data set into two groups: a training set, and a test set. (A better approach is to create three groups: training, validation, and test set). You then train the system using only part of the training data, but test the model using the complete test set. Then incrementally increase the amount of training data used and retrain the system. By graphing the error rate (of both the training and test set) against the number of training samples, you will be able to evaluate the if you have sufficient data. The plot you created should should show the gap in error rate between the training set and the test set decrease. A large gap between the training error and the test error indicates overfitting of the model, which can usually be remedied by training with more data.
Lets do this for the Mixed National Institute of Standards and Technology (MNIST) handwritten digit dataset, which has 10 classes the numbers 0 to 9. The data is split into a 50,000 sample training set (i.e., 5000 per class), and a 10,000 sample test set. We will use three different classifiers: (i) a convolutional neural network (CNN), (ii) a convolutional support vector machine (CSVM), and (iii) a convolutional extreme learning machine (CELM). As we increase the number of samples, the training error percentage will generally increase, but the test error percentage will decrease. A rough rule of thumb to prevent over fitting the model is to ensure that the gap between the training error and test error is within 0.5%. It is also important to confirm that the test error percentage good enough for your application!
Great! So 50,000 samples (i.e., 5,000 per class) provides enough data to prevent our three classifiers from overfitting.
But what if the gap was bigger?
Let’s say we only had a total of 5,000 handwritten digits in the training class (500 per class). Here, (far left of the plot) the gap between training and testing is over 1%. So we could reduce the gap by making our classifiers simpler, (e.g. less neurons), but this would also increase the overall test error%.
Instead, we are going to artificially increase the number of samples through:
- Data Warping
- Synthetic Over Sampling (SMOTE). This will be covered in a subsequent blog post.
Data Warping with Elastic Deformations.
The basic idea with data warping is that we are going to transform the images of the handwritten digits, while still preserving the label information. This means, warp it a bit, but make sure it still looks like the original number!
To do this we are going to create a random displacement field. This is a matrix that causes pixels-values in each digit to be moved (a little) to new locations. So we can use a 2D matrix with uniformly distributed random numbers. But we also want this movement to be smooth, so we will convolve the matrix with a Gaussian.
The code to do this looks like:
function [X_warped, morelabels] = DataWarpingDiffusion(X, labels, K, N, alpha) % [Y, morelabels, L] = DataWarping(X, labels, k, n, alpha) % % A function to increase the number of training data vectors, % by creating N warped duplicates of each of the K vectors. % Uses pseudo-elastic warping, see % (Simard, 2003) Best Practices in Convolutional Neural Networks % % X - the data in row vector form % labels - the labels for each of the vectors % K - the number of vectors % N - the number of duplicates % alpha - warp-strength [in pixels] % % Sebastien Wong, 5 Jan 2014 % Xim = reshape(X',28,28,K); % assuming 28 by 28 input image L = K*N; Y = zeros(28,28,L); l = 1; for n=1:N, for k = 1:K I = Xim(:,:,k); C = rand(28,28,2)*2 - 1; blur = d2gauss(28,20,28,20,0); C(:,:,1) = conv2(C(:,:,1),blur,'same'); C(:,:,2) = conv2(C(:,:,2),blur,'same'); E = sqrt ( C(:,:,1).^2 + C(:,:,2).^2 ) + 1e-7; C(:,:,1) = C(:,:,1) ./ E; C(:,:,2) = C(:,:,2) ./ E; Y(:,:,l) = imwarp(I, C * alpha); l = l+1; end end X_warped = reshape(Y,28*28,)'; morelabels = repmat(labels,[N,1]); end
So what do these warped digits look like (see below)? Importantly how strong should the displacement alpha be? I found that alpha = 1.2 pixels worked well for this data set. I was pretty sure that all the warped digits still looked liked digits. If we make alpha large, say alpha = 8, we can cause some of digits to look like other numbers!