Focal Loss: An efficient way of handling class imbalance
Recently I participated in a Kaggle competition: SIIM-ISIC Melanoma Classification. In this competition one has to output the probability of the melanoma in images of skin lesions from the two classes of skin cancer. So it is a kind of binary image classification task. The Evaluation criterion is AUC (Area Under the Curve) metric. At first I worked on a model with cross-entropy as a loss function. Then after some searching over the internet I found this paper in which the team at Facebook AI research(FAIR) introduced a new loss function — Focal Loss.
I got a good AUC score (92+) with this loss function so I have decided to discuss something about this loss function.
Table of Contents
- Object Detectors
- Focal Loss (an Extension to Cross Entropy Loss)
- Definition of Focal Loss
- Focal Loss (alternate form)
- References
Object Detectors
Before moving on to the discussion of Focal Loss let me give a short overview over two types of object detectors i.e one stage and two stage detectors.
Two Stage Detectors
Two stages are required for this class of object detectors to detect an object. The first stage scans through the image and generate proposals and the second stage classifies those proposals and outputs bounding boxes and classes. The accuracy is quite good but the speed is slow than one-stage object detectors.
One Stage Detectors
Only a single stage is required for this class of object detectors to detect an object. The image is divided into grid of (n x n) where n can be any positive integer. This image is then passed through a convolutional neural network which detects object and outputs bounding box corresponding to object in the image. Note that all the grids are classified in a single iteration of this network. These object detectors are faster than two stage object detectors but are comparatively less accurate.
Focal Loss (an Extension to Cross Entropy loss):
Basically Focal loss is an extension to cross entropy loss. It is specific enough to deal with class imbalance issues. A cross entropy loss would be defined as
Here y = {-1,1} is ground truth label and p is the probability that the example to be classified belongs to positive class (y=1).
We can also define a variable \pt as
So now cross entropy loss can be re-written as
This loss function is somewhat unable to handle to importance of positive/negative examples and hence a new version of it is introduced with name: Balanced Cross entropy and is defined as
Here a weighting factor “α” is introduced whose range is [0,1] and it is α for positive class and 1 - α for negative class and both these definitions are merged under the name α_t which can be defined as
This loss function slightly solves the problem of class imbalance but still is quite unable to differentiate between easy and hard examples. To solve this issue, Focal Loss was defined.
Definition of Focal Loss:
Theoretical Definition: Focal loss can be considered as a loss function which down-weights the easily classified examples and gives much more importance to examples which are hard to classify.
Mathematical Definition: A Focal loss is — a modulating factor multiplied to the original cross entropy loss.
Formula for Focal loss would be:
Here γ ≥ 0 and known as focusing parameter.
Two properties of focal loss can be extracted from the above definition —
- When example is misclassified, pt would tend to become 0 and so modulating factor would tend to become 1 which makes makes loss function almost unaffected. On the other hand if example is correctly classified, pt would tend to become 1 and modulating factor would tend to become 0 making loss to be very near to 0 which down-weights that particular example.
- The focusing parameter (γ) smoothly adjusts the rate at which easily classified examples are down-weighted.
One comparison of FL wit CE can be:
When γ = 2 and example classified with probability 0.9 would have 100x lower loss compared with CE and with 0.968, it would have 1000x lower loss.
The figure at the top describes the FL on different γ values. At γ=0 FL will be equal to CE loss. Here we can see that for γ=0 (CE loss) even examples that are easily classified incur a loss with non-trivial magnitude.These losses on summation can overwhelm the rare class (class that are hard to classify).
Focal Loss ( alternate form ):
For the new definition of focal loss we can define a quantity xt as:
Here y = {-1,1} specifies the ground truth label. We can write pt as:
We can define focal loss in terms of xt:
γ: controls the steepness of the loss curve.
β: controls the shift of the loss curve.
Finally let’s end this discussion with a plot of loss curve for CE, FL and FL* (with two settings of β and γ) —
However focal loss was defined specifically for one-stage object detectors, it can also perform pretty well on image classification tasks.
References:
- https://arxiv.org/pdf/1708.02002.pdf
- https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc
- https://towardsdatascience.com/understanding-auc-roc-curve-68b2303cc9c5
- https://en.wikipedia.org/wiki/Object_detection
- https://medium.com/data-science-bootcamp/understand-cross-entropy-loss-in-minutes-9fb263caee9a