Erste Schritte mit Deep Network Designer
Dieses Beispiel zeigt, wie ein einfaches rekurrentes neuronales Netz für die Deep-Learning-Sequenzklassifizierung mit Deep Network Designer erstellt wird.
Zum Trainieren eines tiefen neuronalen Netzes für die Klassifizierung von Sequenzdaten können Sie ein LSTM-Netz verwenden. Mit einem LSTM-Netz können Sie Sequenzdaten in ein Netz eingeben und Vorhersagen auf der Grundlage der einzelnen Zeitschritte der Sequenzdaten treffen.
Laden von Sequenzdaten
Laden Sie die Beispieldaten aus WaveformData
. Um auf diese Daten zuzugreifen, öffnen Sie das Beispiel als Live-Skript. Diese Daten enthalten Wellenformen aus vier Klassen: Sinus, Rechteck, Dreieck und Sägezahn. In diesem Beispiel wird ein neuronales LSTM-Netz trainiert, um die Art der Wellenform anhand von Zeitreihendaten zu erkennen. Jede Sequenz hat drei Kanäle und variiert in der Länge.
load WaveformData
Visualisieren Sie einige Sequenzen in einem Diagramm.
numChannels = size(data{1},2); classNames = categories(labels); figure tiledlayout(2,2) for i = 1:4 nexttile stackedplot(data{i},DisplayLabels="Channel "+string(1:numChannels)) xlabel("Time Step") title("Class: " + string(labels(i))) end
Unterteilen Sie die Daten in einen Trainingssatz, der 80 % der Daten enthält, und einen Validierungs- und einen Testsatz, die jeweils 10 % der Daten enthalten. Um die Daten zu partitionieren, verwenden Sie die trainingPartitions
-Funktion. Um auf diese Funktion zuzugreifen, öffnen Sie das Beispiel als Live-Skript.
numObservations = numel(data); [idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]); XTrain = data(idxTrain); TTrain = labels(idxTrain); XValidation = data(idxValidation); TValidation = labels(idxValidation); XTest = data(idxTest); TTest = labels(idxTest);
Definieren der Netzarchitektur
Verwenden Sie zum Erstellen des Netzes die App Deep Network Designer.
deepNetworkDesigner
Um ein Sequenznetz zu erstellen, fahren Sie im Abschnitt Sequenznetze über Sequence to Label (Zu bezeichnende Sequenz) und klicken Sie auf Open (Öffnen). Auf diese Weise wird ein vorgefertigtes Netz geöffnet, das für Sequenz-zu-Label-Klassifizierungsprobleme geeignet ist.
Deep Network Designer zeigt das vorgefertigte Netz an.
Sie können dieses Sequenznetz leicht für den Wellenform-Datensatz anpassen.
Wählen Sie die Sequenz-Eingabeschicht input
und setzen Sie InputSize auf 3, um die Anzahl der Kanäle anzupassen.
Wählen Sie die vollständig verbundene Schicht fc
und setzen Sie OutputSize auf 4, was der Anzahl der Klassen entspricht.
Klicken Sie auf Analyze (Analysieren), um zu prüfen, ob das Netz für das Training bereit ist. Der Deep Learning Network Analyzer meldet keine Fehler oder Warnungen, sodass das Netz für das Training bereit ist. Um das Netz zu exportieren, klicken Sie auf Export (Exportieren). Die App speichert das Netz in der Variablen net_1
.
Festlegen von Trainingsoptionen
Legen Sie die Trainingsoptionen fest. Die Auswahl aus diesen Optionen erfordert eine empirische Analyse.
options = trainingOptions("adam", ... MaxEpochs=500, ... InitialLearnRate=0.0005, ... GradientThreshold=1, ... ValidationData={XValidation,TValidation}, ... Shuffle = "every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Trainieren von neuronalen Netzen
Trainieren Sie das neuronale Netz mit der Funktion trainnet
. Da das Ziel die Klassifizierung ist, geben Sie den Querentropieverlust an.
net = trainnet(XTrain,TTrain,net_1,"crossentropy",options);
Testen von neuronalen Netzen
Um das neuronale Netz zu testen, klassifizieren Sie die Testdaten und berechnen die Klassifizierungsgenauigkeit.
Machen Sie Vorhersagen mit der Funktion minibatchpredict
und wandeln Sie die Ergebnisse mit der Funktion scores2label
in Bezeichnungen um.
scores = minibatchpredict(net,XTest); YTest = scores2label(scores,classNames);
Berechnen Sie die Klassifizierungsgenauigkeit. Die Genauigkeit ist der Prozentsatz der richtig vorhergesagten Bezeichnungen.
acc = mean(YTest == TTest)
acc = 0.8300
Visualisieren Sie die Vorhersagen in einer Konfusionstabelle.
figure confusionchart(TTest,YTest)