predict
Syntax
Description
Some deep learning layers behave differently during training and inference (prediction). For example, during training, dropout layers randomly set input elements to zero to help prevent overfitting, but during inference, dropout layers do not change the input.
To compute network outputs for inference, use the predict
function.
To compute network outputs for training, use the forward
function. For
prediction with SeriesNetwork
and DAGNetwork
objects, see
predict
.
Tip
For prediction with SeriesNetwork
and DAGNetwork
objects, see predict
.
[Y1,...,YN] = predict(___)
returns the
N
outputs Y1
, …, YN
during
inference for networks that have N
outputs using any of the previous
syntaxes.
[Y1,...,YK] = predict(___,'Outputs',
returns the outputs layerNames
)Y1
, …, YK
during inference for the
specified layers using any of the previous syntaxes.
[___] = predict(___,'Acceleration',
also specifies performance optimization to use during inference, in addition to the input
arguments in previous syntaxes. acceleration
)
[___,
also returns the updated network state.state
] = predict(___)
Examples
Make Predictions Using dlnetwork
Object
This example shows how to make predictions using a dlnetwork
object by splitting data into mini-batches.
For large data sets, or when predicting on hardware with limited memory, make predictions by splitting the data into mini-batches. When making predictions with SeriesNetwork
or DAGNetwork
objects, the predict
function automatically splits the input data into mini-batches. For dlnetwork
objects, you must split the data into mini-batches manually.
Load dlnetwork
Object
Load a trained dlnetwork
object and the corresponding classes.
s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;
Load Data for Prediction
Load the digits data for prediction.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true);
Make Predictions
Loop over the mini-batches of the test data and make predictions using a custom prediction loop.
Use minibatchqueue
to process and manage the mini-batches of images. Specify a mini-batch size of 128. Set the read size property of the image datastore to the mini-batch size.
For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to concatenate the data into a batch and normalize the images.Format the images with the dimensions
'SSCB'
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
.Make predictions on a GPU if one is available. By default, the
minibatchqueue
object converts the output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
Loop over the minibatches of data and make predictions using the predict
function. Use the onehotdecode
function to determine the class labels. Store the predicted class labels.
numObservations = numel(imds.Files); YPred = strings(1,numObservations); predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. dlYPred = predict(dlnet,dlX); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
Visualize some of the predictions.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); label = predictions(idx(i)); imshow(I) title("Label: " + string(label)) end
Mini-Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses the data using the following steps:
Extract the data from the incoming cell array and concatenate into a numeric array. Concatenating over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Normalize the pixel values between
0
and1
.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
Input Arguments
net
— Network for custom training loops or custom pruning loops
dlnetwork
object | TaylorPrunableNetwork
object
This argument can represent either of these:
Network for custom training loops, specified as a
dlnetwork
object.Network for custom pruning loops, specified as a
TaylorPrunableNetwork
object.
To prune a deep neural network, you require the Deep Learning Toolbox™ Model Quantization Library support package. This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, see Deep Learning Toolbox Model Quantization Library.
layerNames
— Layers to extract outputs from
string array | cell array of character vectors
Layers to extract outputs from, specified as a string array or a cell array of character vectors containing the layer names.
If
layerNames(i)
corresponds to a layer with a single output, thenlayerNames(i)
is the name of the layer.If
layerNames(i)
corresponds to a layer with multiple outputs, thenlayerNames(i)
is the layer name followed by the character "/
" and the name of the layer output:'layerName/outputName'
.
acceleration
— Performance optimization
'auto'
(default) | 'mex'
| 'none'
Performance optimization, specified as the comma-separated pair consisting of
'Acceleration'
and one of the following:
'auto'
— Automatically apply a number of optimizations suitable for the input network and hardware resources.'mex'
— Compile and execute a MEX function. This option is available when using a GPU only. The input data or the network learnable parameters must be stored asgpuArray
objects. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.'none'
— Disable all acceleration.
The default option is
'auto'
. If 'auto'
is specified, MATLAB® will apply a number of compatible optimizations. If you use the
'auto'
option, MATLAB does not ever generate a MEX function.
Using the 'Acceleration'
options 'auto'
and
'mex'
can offer performance benefits, but at the expense of an
increased initial run time. Subsequent calls with compatible parameters are faster. Use
performance optimization when you plan to call the function multiple times using new
input data.
The 'mex'
option generates and executes a MEX function based on the
network and parameters used in the function call. You can have several MEX functions
associated with a single network at one time. Clearing the network variable also clears
any MEX functions associated with that network.
The 'mex'
option is only available when you are using a GPU. You
must have a C/C++ compiler installed and the GPU Coder™ Interface for Deep Learning support package. Install the support package
using the Add-On Explorer in MATLAB. For setup instructions, see MEX Setup (GPU Coder). GPU Coder is not required.
The 'mex'
option has the following limitations:
The
state
output argument is not supported.Only
single
precision is supported. The input data or the network learnable parameters must have underlying typesingle
.Networks with inputs that are not connected to an input layer are not supported.
Traced
dlarray
objects are not supported. This means that the'mex'
option is not supported inside a call todlfeval
.Not all layers are supported. For a list of supported layers, see Supported Layers (GPU Coder).
You cannot use MATLAB Compiler™ to deploy your network when using the
'mex'
option.
For quantized networks, the 'mex'
option requires a CUDA® enabled NVIDIA® GPU with compute capability 6.1, 6.3, or higher.
Example: 'Acceleration','mex'
Output Arguments
state
— Updated network state
table
Updated network state, returned as a table.
The network state is a table with three columns:
Layer
– Layer name, specified as a string scalar.Parameter
– State parameter name, specified as a string scalar.Value
– Value of state parameter, specified as adlarray
object.
Layer states contain information calculated during the layer operation to be retained for use in subsequent forward passes of the layer. For example, the cell state and hidden state of LSTM layers, or running statistics in batch normalization layers.
For recurrent layers, such as LSTM layers, with the HasStateInputs
property set to 1
(true), the state table does not contain
entries for the states of that layer.
Algorithms
Reproducibility
To provide the best performance, deep learning using a GPU in MATLAB is not guaranteed to be deterministic. Depending on your network architecture, under some conditions you might get different results when using a GPU to train two identical networks or make two predictions using the same network and data.
Extended Capabilities
C/C++ Code Generation
Generate C and C++ code using MATLAB® Coder™.
Usage notes and limitations:
C++ code generation supports the following syntaxes:
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
The input data
X
must not have variable size. The size must be fixed at code generation time.The
dlarray
input to thepredict
method must be asingle
datatype.
GPU Code Generation
Generate CUDA® code for NVIDIA® GPUs using GPU Coder™.
Usage notes and limitations:
GPU code generation supports the following syntaxes:
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
The input data
X
must not have variable size. The size must be fixed at code generation time.Code generation for TensorRT library does not support marking an input layer as an output by using the
[Y1,...,YK] = predict(__,'Outputs',layerNames)
syntax.The
dlarray
input to thepredict
method must be asingle
datatype.
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
This function runs on the GPU if either or both of the following conditions are met:
Any of the values of the network learnable parameters inside
net.Learnables.Value
aredlarray
objects with underlying data of typegpuArray
The input argument
X
is adlarray
with underlying data of typegpuArray
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019bR2021a: predict
returns state values as dlarray
objects
For dlnetwork
objects, the state
output argument returned by the predict
function is
a table containing the state parameter names and values for each layer in the network.
Starting in R2021a, the state values are dlarray
objects.
This change enables better support when using AcceleratedFunction
objects. To accelerate deep learning functions that have frequently changing input values,
for example, an input containing the network state, the frequently changing values must be
specified as dlarray
objects.
In previous versions, the state values are numeric arrays.
In most cases, you will not need to update your code. If you have code that requires the
state values to be numeric arrays, then to reproduce the previous behavior, extract the data
from the state values manually using the extractdata
function with the dlupdate
function.
state = dlupdate(@extractdata,net.State);
Beispiel öffnen
Sie haben eine geänderte Version dieses Beispiels. Möchten Sie dieses Beispiel mit Ihren Änderungen öffnen?
MATLAB-Befehl
Sie haben auf einen Link geklickt, der diesem MATLAB-Befehl entspricht:
Führen Sie den Befehl durch Eingabe in das MATLAB-Befehlsfenster aus. Webbrowser unterstützen keine MATLAB-Befehle.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)