Create Simple Sequence Classification Network

This example shows how to create a simple long short-term memory (LSTM) classification network.

To train a deep neural network to classify sequence data, you can use an LSTM network. An LSTM network is a type of recurrent neural network (RNN) that learns long-term dependencies between time steps of sequence data.

The example demonstrates how to:

  • Load sequence data.

  • Define the network architecture.

  • Specify training options.

  • Train the network.

  • Predict the labels of new data and calculate the classification accuracy.

Load Data

Load the Japanese Vowels data set as described in [1] and [2]. The predictors are a cell arrays containing sequences of varying length with a feature dimension of 12. The labels are categorical vectors of labels 1,2,...,9.

[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,YValidation] = japaneseVowelsTestData;

View the sizes of the first few training sequences. The sequences are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).

ans=5×1 cell
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

Define Network Architecture

Define the LSTM network architecture. Specify the number of features in the input layer and the number of classes in the fully connected layer.

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...

Train Network

Specify the training options and train the network.

Because the mini-batches are small with short sequences, the CPU is better suited for training. Set 'ExecutionEnvironment' to 'cpu'. To train on a GPU, if available, set 'ExecutionEnvironment' to 'auto' (the default value).

miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',{XValidation,YValidation}, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'Verbose',false, ...

net = trainNetwork(XTrain,YTrain,layers,options);

For more information about specifying training options, see Set Up Parameters and Train Convolutional Neural Network.

Test Network

Classify the test data and calculate the classification accuracy. Specify the same mini-batch size used for training.

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9405

For next steps, you can try improving the accuracy by using bidirectional LSTM (BiLSTM) layers or by creating a deeper network. For more information, see Long Short-Term Memory Networks.

For an example showing how to use convolutional networks to classify sequence data, see Speech Command Recognition Using Deep Learning.


  1. M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

  2. UCI Machine Learning Repository: Japanese Vowels Dataset.

See Also

| |

Related Topics