Classify error: requires 3 arguments
3 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
Hello All I have trained an LSTM model to classify EMG signals ( one dimensional time series) to produce a class prediction. Now when trying to test the trained LSTM on a test signal , classify produces error of requiring 3 arguments. No matter how I changed the shape of the test signal nothing helped. Also predict produced error results. Could you please help.
% Training code:
% LSTM-1D classification using raw EMG signal
%
% Data path
path = '/home/ubuntu/Desktop/EMG data analysis/EMG signal Matlab'
parameters
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now prepare training/lables dataset for LSTM training
% Assuming sorted_emg_data is your sorted array with data and labels
% Extract the data for training
XTrain = cellfun(@(c) c.signal, sorted_emg_data(:, 1), 'UniformOutput', false);
% Extract the labels for training
TTrain = sorted_emg_data(:, 2);
% Convert the labels to a categorical array
TTrain = categorical(TTrain);
% Now XTrain contains all the EMG signals and TTrain contains the corresponding labels
% Now training the LSTM model
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now define your layers with the correct number of output classes
layers = [ ...
sequenceInputLayer(numChannels)
bilstmLayer(numHiddenUnits, 'OutputMode', 'last')
fullyConnectedLayer(numClasses) % Make sure this matches the number of unique classes in TTrain
softmaxLayer
classificationLayer];
% Define your training options (make sure MiniBatchSize is appropriate for your dataset)
options = trainingOptions('adam', ...
'MaxEpochs', 50, ...
'MiniBatchSize', 3, ... % Adjust based on your hardware capabilities
'InitialLearnRate', 0.01, ...
'GradientThreshold', 1, ...
'Verbose', 0, ...
'Plots', 'training-progress');
% Train the network
net = trainNetwork(XTrain, TTrain, layers, options);
% Testing code
%
% Loading the network
net = load ("lstm_trained_model.mat")
% Loading data
test = load ('emg_signal_3.mat')
net.layers
length (test)
length (signal) % signal directly loaded
% Classification
pred = classify(net, test);
Antworten (1)
Cris LaPierre
am 2 Jan. 2024
I cannot duplicate your error. I used this example to create a sample data set. I then trained that data using your code, and then tested it using the code in the pdfs. My conclusion is there is nothing wrong with your code. Without more details, I don't know what more we can do to help.
Here are the results I obtained when running the model on test data using the code from your pdfs.

0 Kommentare
Siehe auch
Kategorien
Mehr zu Measurements and Feature Extraction finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!