Unlock the secrets of your code with our AI-powered Code Explainer. Take a look!
Satellite image classification is undoubtedly crucial for many applications in agriculture, environmental monitoring, urban planning, and more. Applications such as crop monitoring, land and forest cover mapping are emerging to be utilized by governments and companies, and labs for real-world use.
In this tutorial, you will learn how to build a satellite image classifier using the TensorFlow framework in Python.
We will be using the EuroSAT dataset based on Sentinel-2 satellite images covering 13 spectral bands. It consists of 27,000 labeled samples of 10 different classes: annual and permanent crop, forest, herbaceous vegetation, highway, industrial, pasture, residential, river, and sea lake.
EuroSAT dataset comes in two varieties:
rgb
(default) with RGB that contain only the R, G, B frequency bands encoded as JPEG images.all
: contains all 13 bands in the original value range.Related: Image Captioning using PyTorch and Transformers in Python.
To get started, let's install TensorFlow and some other helper tools:
We use tensorflow_addons
to calculate the F1 score during the training of the model.
We will use the EfficientNetV2 model which is the current state of the art on most image classification tasks. We use tensorflow_hub
to load this pre-trained CNN model for fine-tuning.
Importing the necessary libraries:
Downloading and loading the dataset:
We split our dataset into 60% training, 20% validation during training, and 20% for testing. The below code is responsible for setting some variables we use for later:
We grab the list of classes from the all_ds
dataset as it was loaded with with_info
set to True
, we also get the number of samples from it.
Next, I'm going to make a bar plot to see the number of samples in each class:
Output:
3,000 samples on half of the classes, others have 2,500 samples, while pasture only 2,000 samples.
Now let's take our training and validation sets and prepare them before training:
Here is what this function does:
cache()
: This method saves the preprocessed dataset into a local cache file. This will only preprocess it the very first time (in the first epoch during training).map()
: We map our dataset so each sample will be a tuple of an image and its corresponding label one-hot encoded with tf.one_hot()
.shuffle()
: To shuffle the dataset so the samples are in random order.repeat()
Every time we iterate over the dataset, it'll repeatedly generate samples for us; this will help us during the training.batch()
: We batch our dataset into 64 or 32 samples per training step.prefetch()
: This will enable us to fetch batches in the background while the model is training.Let's run it for the training and validation sets:
Let's see what our data looks like:
Output:
Fantastic, both the training and validation have the same shape; where the batch size is 64, and the image shape is (64, 64, 3)
. The targets have the shape of (64, 10)
as it's 64 samples with 10 classes one-hot encoded.
Let's visualize the first batch from the training dataset:
Output:
Right. Now that we have our data prepared for training, let's build our model. First, downloading EfficientNetV2 and loading it as a hub.KerasLayer
:
We set the model_url
to hub.KerasLayer
so we get EfficientNetV2 as an image feature extractor. However, we set trainable
to True
so we're adjusting the pre-trained weights a bit for our dataset (i.e., fine-tuning).
Building the model:
We use Sequential()
, the first layer is the pre-trained CNN model, and we add a fully connected layer with the size of the number of classes as an output layer.
Finally, the model is built and compiled with categorical cross-entropy, adam optimizer, and accuracy and F1 score as metrics. Output:
We have the data and model right, let's begin fine-tuning our model:
The training will take several minutes, depending on your GPU. Here is the output:
As you can see, the model improved to about 97% accuracy on the validation set on epoch 5. You can increase the number of epochs to see whether it can improve further.
Up until now, we're only validating on the validation set during training. This section uses our model to predict satellite images that the model has never seen before. Loading the best weights:
Extracting all the testing images
and labels
individually from test_ds
:
Output:
As expected, 5,400 images
and labels
, let's use the model to predict these images and then compare the predictions
with the true labels
:
Output:
Output:
That's good accuracy! Let's draw the confusion matrix for all the classes:
Output:
As you can see, the model is accurate in most of the classes, especially on forest images, as it achieved 100%. However, it's down to 91% for pasture, and the model sometimes predicts the pasture as permanent corp, also on herbaceous vegetation. Most of the confusion is between corp, pasture, and herbaceous vegetation as they all look similar and, most of the time, green from the satellite.
Let's show some examples that the model predicted:
Output:
In all 64 images, only one (red label in the above image) failed to predict the actual class. It was predicted as a pasture where it should be a permanent crop.
Alright! That's it for the tutorial. If you want further improvement, I highly advise you to explore on TensorFlow hub, where you find the state-of-the-art pre-trained CNN models and feature extractors.
I also suggest you try out different optimizers and increase the number of epochs to see if you can improve it. You can use TensorBoard to track the accuracy of each change you make. Make sure you include the variables in the model name.
If you want more in-depth information, I encourage you to check the EuroSAT paper, where they achieved 98.57% accuracy with the 13 bands version of the dataset (1.93GB). You can also use this version of the dataset by passing "eurosat/all"
instead of standard "eurosat"
to the tfds.load()
method.
You can get the complete code of this tutorial here.
Learn also: Skin Cancer Detection using TensorFlow in Python
Happy learning ♥
Want to code smarter? Our Python Code Assistant is waiting to help you. Try it now!
View Full Code Switch My Framework
Got a coding query or need some guidance before you comment? Check out this Python Code Assistant for expert advice and handy tips. It's like having a coding tutor right in your fingertips!