lstm
Long short-term memory
Syntax
Description
The long short-term memory (LSTM) operation allows a network to learn long-term dependencies between time steps in time series and sequence data.
applies a long short-term memory (LSTM) calculation to input Y
= lstm(X
,H0
,C0
,weights
,recurrentWeights
,bias
)X
using the
initial hidden state H0
, initial cell state C0
, and
parameters weights
, recurrentWeights
, and
bias
. The input X
must be a formatted
dlarray
. The output Y
is a formatted
dlarray
with the same dimension format as X
, except
for any "S"
dimensions.
The lstm
function updates the cell and hidden states using the
hyperbolic tangent function (tanh) as the state activation function. The
lstm
function uses the sigmoid function given by as the gate activation function.
[
also returns the hidden state and cell state after the LSTM operation.Y
,hiddenState
,cellState
] = lstm(X
,H0
,C0
,weights
,recurrentWeights
,bias
)
___ = lstm(___,
specifies additional options using one or more name-value arguments.Name=Value
)
Examples
Apply LSTM Operation to Sequence Data
Perform an LSTM operation using three hidden units.
Create the input sequence data as 32 observations with 10 channels and a sequence length of 64
numFeatures = 10;
numObservations = 32;
sequenceLength = 64;
X = randn(numFeatures,numObservations,sequenceLength);
X = dlarray(X,"CBT");
Create the initial hidden and cell states with three hidden units. Use the same initial hidden state and cell state for all observations.
numHiddenUnits = 3; H0 = zeros(numHiddenUnits,1); C0 = zeros(numHiddenUnits,1);
Create the learnable parameters for the LSTM operation.
weights = dlarray(randn(4*numHiddenUnits,numFeatures),"CU"); recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),"CU"); bias = dlarray(randn(4*numHiddenUnits,1),"C");
Perform the LSTM calculation
[Y,hiddenState,cellState] = lstm(X,H0,C0,weights,recurrentWeights,bias);
View the size and dimensions of the output.
size(Y)
ans = 1×3
3 32 64
dims(Y)
ans = 'CBT'
View the size of the hidden and cell states.
size(hiddenState)
ans = 1×2
3 32
size(cellState)
ans = 1×2
3 32
Input Arguments
X
— Input data
dlarray
| numeric array
Input data, specified as a formatted dlarray
, an unformatted
dlarray
, or a numeric array. When X
is not a
formatted dlarray
, you must specify the dimension label format using
the DataFormat
option. If X
is a numeric array,
at least one of H0
, C0
,
weights
, recurrentWeights
, or
bias
must be a dlarray
.
X
must contain a sequence dimension labeled "T"
. If
X
has any spatial dimensions labeled "S"
, they
are flattened into the "C"
channel dimension. If X
does not have a channel dimension, then one is added. If X
has any
unspecified dimensions labeled "U"
, they must be singleton.
H0
— Initial hidden state vector
dlarray
| numeric array
Initial hidden state vector, specified as a formatted dlarray
, an
unformatted dlarray
, or a numeric array.
If H0
is a formatted dlarray
, it must contain a
channel dimension labeled "C"
and optionally a batch dimension
labeled "B"
with the same size as the "B"
dimension of X
. If H0
does not have a
"B"
dimension, the function uses the same hidden state vector for
each observation in X
.
The size of the "C"
dimension determines the number of hidden
units. The size of the "C"
dimension of H0
must be
equal to the size of the "C"
dimensions of C0
.
If H0
is a not a formatted dlarray
, the size
of the first dimension determines the number of hidden units and must be the same size
as the first dimension or the "C"
dimension of
C0
.
C0
— Initial cell state vector
dlarray
| numeric array
Initial cell state vector, specified as a formatted dlarray
, an
unformatted dlarray
, or a numeric array.
If C0
is a formatted dlarray
, it must contain
a channel dimension labeled 'C'
and optionally a batch dimension
labeled 'B'
with the same size as the 'B'
dimension of X
. If C0
does not have a
'B'
dimension, the function uses the same cell state vector for
each observation in X
.
The size of the 'C'
dimension determines the number of hidden
units. The size of the 'C'
dimension of C0
must be
equal to the size of the 'C'
dimensions of H0
.
If C0
is a not a formatted dlarray
, the size
of the first dimension determines the number of hidden units and must be the same size
as the first dimension or the 'C'
dimension of
H0
.
weights
— Weights
dlarray
| numeric array
Weights, specified as a formatted dlarray
, an unformatted
dlarray
, or a numeric array.
Specify weights
as a matrix of size
4*NumHiddenUnits
-by-InputSize
, where
NumHiddenUnits
is the size of the "C"
dimension
of both C0
and H0
, and
InputSize
is the size of the "C"
dimension of
X
multiplied by the size of each "S"
dimension
of X
, where present.
If weights
is a formatted dlarray
, it must
contain a "C"
dimension of size 4*NumHiddenUnits
and a "U"
dimension of size InputSize
.
recurrentWeights
— Recurrent weights
dlarray
| numeric array
Recurrent weights, specified as a formatted dlarray
, an
unformatted dlarray
, or a numeric array.
Specify recurrentWeights
as a matrix of size
4*NumHiddenUnits
-by-NumHiddenUnits
, where
NumHiddenUnits
is the size of the "C"
dimension
of both C0
and H0
.
If recurrentWeights
is a formatted dlarray
, it
must contain a "C"
dimension of size
4*NumHiddenUnits
and a "U"
dimension of size
NumHiddenUnits
.
bias
— Bias
dlarray
vector | numeric vector
Bias, specified as a formatted dlarray
, an unformatted
dlarray
, or a numeric array.
Specify bias
as a vector of length
4*NumHiddenUnits
, where NumHiddenUnits
is the
size of the "C"
dimension of both C0
and
H0
.
If bias
is a formatted dlarray
, the
nonsingleton dimension must be labeled with "C"
.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Example: Y =
lstm(X,H0,C0,weights,recurrentWeights,bias,DataFormat="CTB")
applies the LSTM
operation and specifies that the data has format "CTB"
(channel, time,
batch).
DataFormat
— Description of data dimensions
character vector | string scalar
Description of the data dimensions, specified as a character vector or string scalar.
A data format is a string of characters, where each character describes the type of the corresponding data dimension.
The characters are:
"S"
— Spatial"C"
— Channel"B"
— Batch"T"
— Time"U"
— Unspecified
For example, consider an array containing a batch of sequences where the first, second,
and third dimensions correspond to channels, observations, and time steps, respectively. You
can specify that this array has the format "CBT"
(channel, batch,
time).
You can specify multiple dimensions labeled "S"
or "U"
.
You can use the labels "C"
, "B"
, and
"T"
once each, at most. The software ignores singleton trailing
"U"
dimensions after the second dimension.
If the input data is not a formatted dlarray
object, then you must
specify the DataFormat
option.
For more information, see Deep Learning Data Formats.
Data Types: char
| string
StateActivationFunction
— State activation function
"tanh"
(default) | "softsign"
| "relu"
Since R2024a
Activation function to update the cell and hidden state, specified as one of these values:
"tanh"
— Use the hyperbolic tangent function (tanh)."softsign"
— Use the softsign function ."relu"
— Use the rectified linear unit (ReLU) function .
The software uses this option as the function in the calculations to update the cell and hidden state.
For more information, see the definition of Long Short-Term Memory Layer on the lstmLayer
reference page.
GateActivationFunction
— Gate activation function
"sigmoid"
(default) | "hard-sigmoid"
Since R2024a
Activation function to apply to the gates, specified as one of these values:
"sigmoid"
— Use the sigmoid function, ."hard-sigmoid"
— Use the hard sigmoid function,
The software uses this option as the function in the calculations for the layer gates.
For more information, see the definition of Long Short-Term Memory Layer on the lstmLayer
reference
page.
Output Arguments
Y
— LSTM output
dlarray
LSTM output, returned as a dlarray
. The output
Y
has the same underlying data type as the input
X
.
If the input data X
is a formatted dlarray
,
Y
has the same dimension format as X
, except for
any "S"
dimensions. If the input data is not a formatted
dlarray
, Y
is an unformatted
dlarray
with the same dimension order as the input data.
The size of the "C"
dimension of Y
is the same
as the number of hidden units, specified by the size of the "C"
dimension of H0
or C0
.
hiddenState
— Hidden state vector
dlarray
| numeric array
Hidden state vector for each observation, returned as a dlarray
or a numeric
array with the same data type as H0
.
If the input H0
is a formatted dlarray
, then the output
hiddenState
is a formatted dlarray
with the
format "CB"
.
cellState
— Cell state vector
dlarray
| numeric array
Cell state vector for each observation, returned as a dlarray
or
a numeric array. cellState
is returned with the same data type as
C0
.
If the input C0
is a formatted dlarray
, the
output cellState
is returned as a formatted
dlarray
with the format 'CB'
.
Algorithms
Long Short-Term Memory
The LSTM operation allows a network to learn long-term dependencies
between time steps in time series and sequence data. For more information, see the
definition of Long Short-Term Memory Layer on the lstmLayer
reference
page.
Deep Learning Array Formats
Most deep learning networks and functions operate on different dimensions of the input data in different ways.
For example, an LSTM operation iterates over the time dimension of the input data, and a batch normalization operation normalizes over the batch dimension of the input data.
To provide input data with labeled dimensions or input data with additional layout information, you can use data formats.
A data format is a string of characters, where each character describes the type of the corresponding data dimension.
The characters are:
"S"
— Spatial"C"
— Channel"B"
— Batch"T"
— Time"U"
— Unspecified
For example, consider an array containing a batch of sequences where the first, second,
and third dimensions correspond to channels, observations, and time steps, respectively. You
can specify that this array has the format "CBT"
(channel, batch,
time).
To create formatted input data, create a dlarray
object and specify the format using the second argument.
To provide additional layout information with unformatted data, specify the format using the FMT
argument.
For more information, see Deep Learning Data Formats.
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
The lstm
function
supports GPU array input with these usage notes and limitations:
When at least one of the following input arguments is a
gpuArray
or adlarray
with underlying data of typegpuArray
, this function runs on the GPU:X
H0
C0
weights
recurrentWeights
bias
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019bR2024a: Specify state and gate activation functions
Specify the state and gate activation functions using the StateActivationFunction
and GateActivationFunction
name-value arguments, respectively.
See Also
dlarray
| fullyconnect
| softmax
| dlgradient
| dlfeval
| gru
| attention
MATLAB-Befehl
Sie haben auf einen Link geklickt, der diesem MATLAB-Befehl entspricht:
Führen Sie den Befehl durch Eingabe in das MATLAB-Befehlsfenster aus. Webbrowser unterstützen keine MATLAB-Befehle.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)