Main Content

Code Generation for Deep Learning Networks

This example shows how to generate CUDA code for an image classification application that uses deep learning. It uses the codegen command to generate a MEX function that runs prediction by using image classification network, ResNet.

Third-Party Prerequisites

This example generates CUDA® MEX and has the following third-party requirements.

  • CUDA-enabled NVIDIA® GPU and compatible driver.

For non-MEX builds such as static, dynamic libraries or executables, this example has the following additional requirements.

Verify GPU Environment

Use the coder.checkGpuInstall function to verify that the compilers and libraries necessary for running this example are set up correctly.

envCfg = coder.gpuEnvConfig('host');
envCfg.DeepLibTarget = 'none';
envCfg.DeepCodegen = 1;
envCfg.Quiet = 1;

Classification of Images by Using ResNet-50 network

ResNet-50 is a convolutional neural network that is 50 layers deep and can classify images into 1000 object categories. A pretrained ResNet-50 model for MATLAB® is available in the Deep Learning Toolbox™ model for ResNet-50 Network support package. Use the Add-On Explorer to download and install the support package.

[net, classNames] = imagePretrainedNetwork('resnet50');
  dlnetwork with properties:

         Layers: [176×1 nnet.cnn.layer.Layer]
    Connections: [191×2 table]
     Learnables: [214×3 table]
          State: [106×3 table]
     InputNames: {'input_1'}
    OutputNames: {'fc1000_softmax'}
    Initialized: 1

  View summary with summary.

resnet_predict Entry-Point Function

The resnet_predict.m entry-point function takes an image input and runs prediction on the image using the pretrained resnet50 deep learning network. The function uses a persistent object dlnet to load the dlnetwork object and reuses the persistent object for prediction on subsequent calls. This entry-point function uses the imagePretrainedNetwork (Deep Learning Toolbox) to load the dlnetwork object and perform prediction on the input image. A dlarray object is created within the entry-point function. The input and output to the entry-point function are of primitive datatypes. For more information, see Code Generation for dlarray.

function out = resnet_predict(in) %#codegen
% Copyright 2020-2024 The MathWorks, Inc.

persistent dlnet;

dlIn = dlarray(in, 'SSC');
if isempty(dlnet)
    % Call the function resnet50 that returns a dlnetwork object
    % for ResNet-50 model.
    dlnet = imagePretrainedNetwork('resnet50');

dlOut = predict(dlnet, dlIn);
out = extractdata(dlOut);


Run MEX Code Generation

To generate CUDA code for the resnet_predict.m entry-point function, create a GPU code configuration object for a MEX target. Use the coder.DeepLearningConfig function to create a deep learning code configuration object and assign it to the DeepLearningConfig property of the GPU code configuration object. Run the codegen command and specify an input size of 224-by-224-by-3, which is the value corresponds to the input layer size of the network.

cfg = coder.gpuConfig('mex');
dlcfg = coder.DeepLearningConfig(TargetLibrary = "none");
cfg.DeepLearningConfig = dlcfg;
codegen -config cfg resnet_predict -args {ones(224,224,3,'single')} -report
Code generation successful: View report

Run Genearted MEX

Call resnet_predict_mex on the input image.

im = imread('peppers.png');
im = imresize(im, [224,224]);
predict_scores = resnet_predict_mex(single(im));

Map the Prediction Scores to Labels and Display Output

Get the top five prediction scores and their labels.

[scores,indx] = sort(predict_scores, 'descend');
classNamesTop = classNames(indx(1:5));

h = figure;
h.Position(3) = 2*h.Position(3);
ax1 = subplot(1,2,1);
ax2 = subplot(1,2,2);

ax2.YAxisLocation = 'right';
sgtitle('Top Five Predictions That Use ResNet-50')

Figure contains 2 axes objects and another object of type subplottext. Axes object 1 contains an object of type image. Axes object 2 with xlabel Probability contains an object of type bar.

Clear the static network object that was loaded in memory.

clear resnet_predict_mex;

See Also



Related Topics