Some machine learning problems involve classifying an object into one of N classes. These are called multiclass classification problems, as opposed to binary classification, where there is just a positive and a negative class. Handwritten digit recognition and image classification are two well-known instances of multiclass classification problems.
Image source: https://towardsdatascience.com/multi-class-classification-one-vs-all-one-vs-one-94daed32a87b
In this post we will explain how a neural network can be used to solve this problem. This is the second of a series of posts on neural networks I’ve been writing. If you are new to the topic, you may want to have a look into the first one before.
Problem statement
Let’s say we are working in an application for a factory processing fruits. Trucks bring oranges, lemons and limes mixed together. As a first step, a robot separates the fruits guided by a camera. Our task is to develop a model that classifies a fruit into one of the three groups given an input image. This is a multiclass classification problem with
Each image can be translated into a set of input features
Given an image’s features
To train our model, we will have a collection of
In the previous post we presented the following neural network, suitable for binary classification:
This network needs a couple changes before it can be used for multiclass classification. We will go over them in the next sections.
Label representation: one-hot encoding
The label representation
Instead, we will use a one-hot representation. With this scheme, our labels
Using this representation, our training labels become:
Note that classes are mutually exclusive: a fruit may be an orange or a lemon, but not both at the same time! This implies that exactly one element in the label vector must be one, with the other ones being zero. That’s why the encoding scheme is called one-hot!
In the general case, we will use vectors
The output layer
In binary classification, the output layer produces a single value
In multiclass classification we are using vectors of
This means that our networks thinks there is a 70% chance that the input picture is an orange, a 20% chance that it’s a lemon, and a 10% chance that it’s a lime. The final prediction would be orange. Note that all the probabilities in the output vector sum 1. This is required because the output labels are mutually exclusive: one picture can’t be an orange and a lemon at the same time.
So what’s the actual change? We will just increase the number of units in the output layer to three!
The activation function
In binary classification we used the sigmoid activation function for the output layer, as it guarantees that the output is between zero and one. Now that we have increased our output layer size, the sigmoid may not be the better choice. We could apply it independently to each output unit, but we would have no guarantee that the output probabilities add up to one, thus breaking our interpretation.
Instead, we will use the softmax activation function for the output layer. Given a vector
Where
With these considerations, the output layer will do the following:
The matrix
The loss function
For binary classification we can use the log loss (also called the cross-entropy loss), with the following formulation:
Where
The form shown above is a particularization of the cross-entropy loss for two classes. For our example, we can use this generalization:
An example is worth a thousand words, so go through one. We will examine how the loss function behaves for three different predictions for the same image. Let’s say that we are given a picture of an orange. With our one-hot encoding strategy, the label would be:
Case 1. The network works well for the example: it is 70% sure that what is saw was an orange.
Case 2. The network is unsure and thinks every fruit is equally likely.
Case 3. The network makes a mistake: it predicts a lime.
The loss function’s behaviour seems coherent:
- Case 1. If the network is confident about a prediction and it’s right, the cost is low.
- Case 2. If the network is uncertain, the cost is higher.
- Case 3. If the network is confident but the prediction ends up being wrong, the cost is the highest.
Also notice that the loss function only pays attention to the predicted probability of the actual class. In this example, the actual class was the first one, so the loss function only cares about the first element of
For the general case with
If you want to dig deeper, this post explores the cross-entropy loss with multiple classes in depth.
Conclusion
That’s it! We now have all the elements in place to build a neural network that performs multiclass classification. We’ve covered the data format, the network architecture and the loss function. These are all the elements that we need to specify when creating a network using frameworks like Keras.
In the next post we will apply these concepts by using Keras to create and train a network to solve a classical handwritten digit recognition problem. See you soon!
References
- Deep Learning Specialization, Coursera courses by Andrew Ng: https://www.coursera.org/specializations/deep-learning
- Cross-entropy for classification, by Vlastimil Martinek: https://towardsdatascience.com/cross-entropy-for-classification-d98e7f974451
- Choosing the right Encoding method-Label vs OneHot Encoder, by Raheel Shaikh: https://towardsdatascience.com/choosing-the-right-encoding-method-label-vs-onehot-encoder-a4434493149b
- Multi-class Classification — One-vs-All & One-vs-One, by Amey Band: https://towardsdatascience.com/multi-class-classification-one-vs-all-one-vs-one-94daed32a87b