Main Content

Automatic Target Recognition (ATR) in SAR Images

This example shows how to train a region-based convolutional neural network (R-CNN) for target recognition in large-scene synthetic aperture radar (SAR) images using Deep Learning Toolbox™ and Parallel Computing Toolbox™.

Deep Learning Toolbox provides a framework for designing and implementing deep neural networks with algorithms, pretrained models, and apps.

Parallel Computing Toolbox lets you solve computationally and data-intensive problems using multicore processors, GPUs, and computer clusters. It enables you to use GPUs directly from MATLAB® and accelerate the computation capabilities needed in deep learning algorithms.

Neural network based algorithms have shown remarkable achievement in diverse areas ranging from natural scene detection to medical imaging. They have shown huge improvement over the standard detection algorithms. Inspired by these advancements, researchers have put efforts to apply deep learning based solutions to the field of SAR imaging. In this example, the solution has been applied to solve the problem of target detection and recognition. The R-CNN network employed here not only solves problem of integrating detection and recognition but also provides an effective and efficient performance solution that scales to large scene SAR images as well.

This example demonstrates how to:

  • Download the dataset and the pretrained model

  • Load and analyze the image data

  • Define the network architecture

  • Specify training options

  • Train the network

  • Evaluate the network

To illustrate this workflow, the example uses the Moving and Stationary Target Acquisition and Recognition (MSTAR) clutter dataset published by the Air Force Research Laboratory. The dataset is available for download here. Alternatively, the example also includes a subset of the data used to showcase the workflow. The goal is to develop a model that can detect and recognize the targets.

Download the Dataset

This example uses a subset of the MSTAR clutter dataset that contains 300 training and 50 testing clutter images with five different targets. The data was collected using an X-band sensor in the spotlight mode with a one-foot resolution. The data contains rural and urban types of clutters. The types of targets used are BTR-60 (armoured car), BRDM-2 (fighting vehicle), ZSU-23/4 (tank), T62 (tank), and SLICY (multiple simple geometric shaped static target). The images were captured at a depression angle of 15 degrees. The clutter data is stored in the PNG image format and the corresponding ground truth data is stored in the groundTruthMSTARClutterDataset.mat file. The file contains 2-D bounding box information for five classes, which are SLICY, BTR-60, BRDM-2, ZSU-23/4, and T62 for training and testing data. The size of the dataset is 1.6 GB.

Download the dataset using the helperDownloadMSTARClutterData helper function, defined at the end of this example.

outputFolder = pwd;
dataURL = ('https://ssd.mathworks.com/supportfiles/radar/data/MSTAR_ClutterDataset.tar.gz');
helperDownloadMSTARClutterData(outputFolder,dataURL);

Depending on your Internet connection, the download process can take some time. The code suspends MATLAB® execution until the download process is complete. Alternatively, download the dataset to a local disk using your web browser and extract the file. When using this approach, change the <outputFolder> variable in the example to the location of the downloaded file.

Download the Pretrained Network

Download the pretrained network from the link here using the helperDownloadPretrainedSARDetectorNet helper function, defined at the end of this example. The pretrained model allows you to run the entire example without having to wait for the training to complete. To train the network, set the doTrain variable to true.

pretrainedNetURL = ('https://ssd.mathworks.com/supportfiles/radar/data/TrainedSARDetectorNet.tar.gz');
doTrain = false;
if ~doTrain
    helperDownloadPretrainedSARDetectorNet(outputFolder,pretrainedNetURL);
end

Load the Dataset

Load the ground truth data (training set and test set). These images are generated in such a way that it places target chips at random locations on a background clutter image. The clutter image is constructed from the downloaded raw data. The generated target will be used as ground truth targets to train and test the network.

load('groundTruthMSTARClutterDataset.mat', "trainingData", "testData");

The ground truth data is stored in a six-column table, where the first column contains the image file paths and the second to the sixth columns contain the different target bounding boxes.

% Display the first few rows of the data set
trainingData(1:4,:)
ans=4×6 table
            imageFilename                   SLICY                 BTR_60                BRDM_2               ZSU_23_4                  T62        
    ______________________________    __________________    __________________    __________________    ___________________    ___________________

    "./TrainingImages/Img0001.png"    {[ 285 468 28 28]}    {[ 135 331 65 65]}    {[ 597 739 65 65]}    {[ 810 1107 80 80]}    {[1228 1089 87 87]}
    "./TrainingImages/Img0002.png"    {[595 1585 28 28]}    {[ 880 162 65 65]}    {[308 1683 65 65]}    {[1275 1098 80 80]}    {[1274 1099 87 87]}
    "./TrainingImages/Img0003.png"    {[200 1140 28 28]}    {[961 1055 65 65]}    {[306 1256 65 65]}    {[ 661 1412 80 80]}    {[  699 886 87 87]}
    "./TrainingImages/Img0004.png"    {[ 623 186 28 28]}    {[ 536 946 65 65]}    {[ 131 245 65 65]}    {[1030 1266 80 80]}    {[  151 924 87 87]}

Display one of the training images and box labels to visualize the data.

img = imread(trainingData.imageFilename(1));
bbox = reshape(cell2mat(trainingData{1,2:end}),[4,5])';
labels = {'SLICY', 'BTR_60', 'BRDM_2',  'ZSU_23_4', 'T62'};
annotatedImage = insertObjectAnnotation(img,'rectangle',bbox,labels,...
    'TextBoxOpacity',0.9,'FontSize',50);
figure
imshow(annotatedImage);
title('Sample Training Image With Bounding Boxes and Labels')

Figure contains an axes object. The axes object with title Sample Training Image With Bounding Boxes and Labels contains an object of type image.

Define the Network Architecture

Create an R-CNN object detector for five targets: SLICY, BTR_60, BRDM_2, ZSU_23_4, T62.

objectClasses = {'SLICY', 'BTR_60', 'BRDM_2', 'ZSU_23_4', 'T62'};

The network must be able to classify the five targets and a background class in order to be trained using the trainRCNNObjectDetector function available in Deep Learning Toolbox™. 1 is added in the code below to include the background class.

numClassesPlusBackground = numel(objectClasses) + 1;

The final fully connected layer of the network defines the number of classes that it can classify. Set the final fully connected layer to have an output size equal to numClassesPlusBackground.

% Define input size 
inputSize = [128,128,1];

% Define network
layers = createNetwork(inputSize,numClassesPlusBackground);

Now, these network layers can be used to train an R-CNN based five-class object detector.

Train Faster R-CNN

Use trainingOptions (Deep Learning Toolbox) to specify network training options. trainingOptions by default uses a GPU if one is available (requires Parallel Computing Toolbox™ and a CUDA® enabled GPU with compute capability 3.0 or higher). Otherwise, it uses a CPU. You can also specify the execution environment by using the ExecutionEnvironment name-value argument of trainingOptions. To detect automatically if you have a GPU available, set ExecutionEnvironment to auto. If you do not have a GPU, or do not want to use one for training, set ExecutionEnvironment to cpu. To ensure the use of a GPU for training, set ExecutionEnvironment to gpu.

% Set training options
options = trainingOptions('sgdm', ...
    'MiniBatchSize', 128, ...
    'InitialLearnRate', 1e-3, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.1, ...
    'LearnRateDropPeriod', 100, ...
    'MaxEpochs', 10, ...
    'Verbose', true, ...
    'CheckpointPath',tempdir,...
    'ExecutionEnvironment','auto');

Use trainRCNNObjectDetector to train R-CNN object detector if doTrain is true. Otherwise, load the pretrained network. If training, adjust NegativeOverlapRange and PositiveOverlapRange to ensure that training samples tightly overlap with ground truth.

if doTrain
    % Train an R-CNN object detector. This will take several minutes
    detector = trainRCNNObjectDetector(trainingData, layers, options,'PositiveOverlapRange',[0.5 1], 'NegativeOverlapRange', [0.1 0.5]);   
else
    % Load a previously trained detector
    preTrainedMATFile = fullfile(outputFolder,'TrainedSARDetectorNet.mat');
    load(preTrainedMATFile);
end

Evaluate Detector on a Test Image

To get a qualitative idea of the functioning of the detector, pick a random image from the test set and run it through the detector. The detector is expected to return a collection of bounding boxes where it thinks the detected targets are, along with scores indicating confidence in each detection.

% Read test image
imgIdx = randi(height(testData));
testImage = imread(testData.imageFilename(imgIdx));

% Detect SAR targets in the test image
[bboxes,score,label] = detect(detector,testImage,'MiniBatchSize',16);

To understand the results achieved, overlay the results with the test image. A key parameter is the detection threshold, the score above which the detector detected a target. A higher threshold will result in fewer false positives; however, it also results in more false negatives.

scoreThreshold = 0.8;

% Display the detection results
outputImage = testImage;
for idx = 1:length(score)
    bbox = bboxes(idx, :);
    thisScore = score(idx);
    
    if thisScore > scoreThreshold
        annotation = sprintf('%s: (Confidence = %0.2f)', label(idx),...
            round(thisScore,2));
        outputImage = insertObjectAnnotation(outputImage, 'rectangle', bbox,...
            annotation,'TextBoxOpacity',0.9,'FontSize',45,'LineWidth',2);
    end
end
f = figure;
f.Position(3:4) = [860,740];
imshow(outputImage)
title('Predicted Boxes and Labels on Test Image')

Figure contains an axes object. The axes object with title Predicted Boxes and Labels on Test Image contains an object of type image.

Evaluate Model

By looking at the images sequentially, you can understand the detector performance. To perform more rigorous analysis using the entire test set, run the test set through the detector.

% Create a table to hold the bounding boxes, scores and labels output by the detector
numImages = height(testData);
results = table('Size',[numImages 3],...
    'VariableTypes',{'cell','cell','cell'},...
    'VariableNames',{'Boxes','Scores','Labels'});

% Run detector on each image in the test set and collect results
for i = 1:numImages
    imgFilename = testData.imageFilename{i};
    
    % Read the image
    I = imread(imgFilename);
    
    % Run the detector
    [bboxes, scores, labels] = detect(detector, I,'MiniBatchSize',16);
    
    % Collect the results
    results.Boxes{i} = bboxes;
    results.Scores{i} = scores;
    results.Labels{i} = labels;
end

The possible detections and their bounding boxes for all images in the test set can be used to calculate the detector's average precision (AP) for each class. The AP is the average of the detector's precision at different levels of recall, so let us define precision and recall.

  • Precision=tptp+fp

  • Recall=tptp+fn

where

  • tp - Number of true positives (the detector predicts a target when it is present)

  • fp - Number of false positives (the detector predicts a target when it is not present)

  • fn - Number of false negatives (the detector fails to detect a target when it is present)

A detector with a precision of 1 is considered good at detecting targets that are present, while a detector with a recall of 1 is good at avoiding false detections. Precision and recall have an inverse relationship.

Plot the relationship between precision and recall for each class. The average value of each curve is the AP. Plot curves for detection thresholds with the value of 0.5.

For more details, see evaluateObjectDetection (Computer Vision Toolbox).

% Format test data as a combined datastore 
imds = imageDatastore(testData.imageFilename);
blds = boxLabelDatastore(testData(:,2:end));
cds  = combine(imds,blds); % CombinedDatastore

% Evaluate the object detector using average precision metric
metrics   = evaluateObjectDetection(results,cds);
ap        = metrics.ClassMetrics.AP;
precision = metrics.ClassMetrics.Precision;
recall    = metrics.ClassMetrics.Recall;

% Plot precision recall curve
f = figure; ax = gca; f.Position(3:4) = [860,740];
xlabel('Recall')
ylabel('Precision')
grid on; hold on; legend('Location', 'southeast');
title('Precision Versus Recall');    
for i = 1:length(ap)
    plot(ax,recall{i},precision{i},'DisplayName',['Average Precision for Class ' trainingData.Properties.VariableNames{i+1} ' is ' num2str(round(ap{i},3))])
end

Figure contains an axes object. The axes object with title Precision Versus Recall, xlabel Recall, ylabel Precision contains 5 objects of type line. These objects represent Average Precision for Class SLICY is 0.719, Average Precision for Class BTR_60 is 0.978, Average Precision for Class BRDM_2 is 0.855, Average Precision for Class ZSU_23_4 is 0.793, Average Precision for Class T62 is 0.98.

The AP for most of the classes is excellent and is generally about 0.9 or better. Out of these, the trained model appears to struggle the most in detecting the SLICY targets. However, it is still able to achieve an AP of about 0.7 for the class.

Helper Function

The function createNetwork takes as input the image size inputSize and number of classes numClassesPlusBackground. The function returns a CNN.

function layers = createNetwork(inputSize,numClassesPlusBackground)
    layers = [
        imageInputLayer(inputSize)                      % Input Layer
        convolution2dLayer(3,32,'Padding','same')       % Convolution Layer
        reluLayer                                       % Relu Layer
        convolution2dLayer(3,32,'Padding','same')
        batchNormalizationLayer                         % Batch normalization Layer
        reluLayer
        maxPooling2dLayer(2,'Stride',2)                 % Max Pooling Layer
        
        convolution2dLayer(3,64,'Padding','same')
        reluLayer
        convolution2dLayer(3,64,'Padding','same')
        batchNormalizationLayer
        reluLayer
        maxPooling2dLayer(2,'Stride',2)
        
        convolution2dLayer(3,128,'Padding','same')
        reluLayer
        convolution2dLayer(3,128,'Padding','same')
        batchNormalizationLayer
        reluLayer
        maxPooling2dLayer(2,'Stride',2)

        convolution2dLayer(3,256,'Padding','same')
        reluLayer
        convolution2dLayer(3,256,'Padding','same')
        batchNormalizationLayer
        reluLayer
        maxPooling2dLayer(2,'Stride',2)
    
        convolution2dLayer(6,512)
        reluLayer
        
        dropoutLayer(0.5)                               % Dropout Layer
        fullyConnectedLayer(512)                        % Fully connected Layer.
        reluLayer
        fullyConnectedLayer(numClassesPlusBackground)
        softmaxLayer                                    % Softmax Layer
        classificationLayer                             % Classification Layer
        ];

end

function helperDownloadMSTARClutterData(outputFolder,DataURL)
% Download the data set from the given URL to the output folder.

    radarDataTarFile = fullfile(outputFolder,'MSTAR_ClutterDataset.tar.gz');
    
    if ~exist(radarDataTarFile,'file')
        
        disp('Downloading MSTAR Clutter data (1.6 GB)...');
        websave(radarDataTarFile,DataURL);
        untar(radarDataTarFile,outputFolder);
    end
end

function helperDownloadPretrainedSARDetectorNet(outputFolder,pretrainedNetURL)
% Download the pretrained network.

    preTrainedMATFile = fullfile(outputFolder,'TrainedSARDetectorNet.mat');
    preTrainedZipFile = fullfile(outputFolder,'TrainedSARDetectorNet.tar.gz');
    
    if ~exist(preTrainedMATFile,'file')
        if ~exist(preTrainedZipFile,'file')
            disp('Downloading pretrained detector (29.4 MB)...');
            websave(preTrainedZipFile,pretrainedNetURL);
        end
        untar(preTrainedZipFile,outputFolder);   
    end       
end

Summary

This example shows how to train an R-CNN for target recognition in SAR images. The pretrained network attained an accuracy of more than 0.9.

References