Hauptinhalt

Klassifizieren einer Zeitreihe mithilfe von Wavelet-Analyse und Deep Learning

Dieses Beispiel veranschaulicht, wie die Signale eines menschlichen Elektrokardiogramms (EKG) mithilfe der Continuous-Wavelet-Transformation (CWT) und einem tiefen Convolutional Neural Network (CNN) klassifiziert werden können.

Das Training eines tiefen CNNs von Grund auf ist rechnerisch aufwendig und benötigt eine große Menge Trainingsdaten. In zahlreichen Anwendungen ist keine ausreichende Menge an Trainingsdaten verfügbar und eine Synthese neuer, realistischer Trainingsbeispiele nicht machbar. In diesen Fällen ist es hilfreich, bestehende neuronale Netze zu nutzen, die anhand von großen Datensätzen für konzeptionell ähnliche Aufgaben trainiert wurden. Diese Nutzung bestehender neuronaler Netze wird als Transfer Learning bezeichnet. In diesem Beispiel adaptieren wir zwei für Bilderkennung vortrainierte tiefe CNNs, GoogLeNet und SqueezeNet, um EKG-Wellenformen basierend auf einer Zeit-Frequenz-Darstellung zu klassifizieren.

Bei GoogLeNet und SqueezeNet handelt es sich um tiefe CNNs, die ursprünglich für die Klassifizierung von Bildern in 1000 Kategorien entwickelt wurden. Wir verwenden die Netzarchitektur des CNN wieder, um EKG-Signale auf Basis von Bildern aus der CWT der Zeitreihendaten zu klassifizieren. Die in diesem Beispiel verwendeten Daten sind öffentlich in PhysioNet verfügbar.

Beschreibung der Daten

In diesem Beispiel verwenden Sie EKG-Daten, die von drei Personengruppen erhoben wurden: Personen mit Herzrhythmusstörungen (ARR), Personen mit Stauungsinsuffizienz (CHF) und Personen mit normalem Sinusrhythmus (NSR). Insgesamt verwenden Sie 162 EKG-Aufzeichnungen aus drei PhysioNet-Datenbanken: MIT-BIH Arrhythmia Database [3][7], MIT-BIH Normal Sinus Rhythm Database [3], und The BIDMC Congestive Heart Failure Database [1][3]. Die Daten umfassen 96 Aufzeichnungen von Personen mit Herzrhythmusstörungen, 30 Aufzeichnungen von Personen mit Stauungsinsuffizienz und 36 Aufzeichnungen von Personen mit normalem Sinusrhythmus. Das Ziel besteht darin, einen Classifier zu trainieren, um zwischen ARR, CHF und NSR zu unterscheiden.

Herunterladen der Daten

Der erste Schritt besteht darin, die Daten aus dem GitHub®-Repository herunterzuladen. Um die Daten von der Webseite herunterzuladen, klicken Sie auf Code und wählen Sie Download ZIP aus. Speichern Sie die Datei physionet_ECG_data-main.zip in einem Ordner, auf den Sie Schreibzugriff haben. Bei der Anleitung für dieses Beispiel wird davon ausgegangen, dass Sie die Datei in Ihr temporäres Verzeichnis tempdir in MATLAB® heruntergeladen haben. Wenn Sie die Daten in einen anderen Ordner als tempdir herunterladen möchten, ändern Sie die folgende Anleitung zum Entpacken und Laden der Daten entsprechend.

Haben Sie die Daten von GitHub heruntergeladen, entpacken Sie die Datei in Ihrem temporären Verzeichnis.

unzip(fullfile(tempdir,"physionet_ECG_data-main.zip"),tempdir)

Durch das Entpacken wird der Ordner physionet-ECG_data-main in Ihrem temporären Verzeichnis erstellt. Dieser Ordner enthält die Textdatei README.md und ECGData.zip. Die Datei ECGData.zip enthält

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat enthält die in diesem Beispiel verwendeten Daten. Die Textdatei Modified_physionet_data.txt ist gemäß der Kopierrichtlinien von PhysioNet erforderlich und umfasst die Quellenzuordnungen der Daten sowie eine Beschreibung der Vorverarbeitungsschritte, die bei den EKG-Aufzeichnungen durchgeführt wurden.

Entpacken Sie ECGData.zip in physionet-ECG_data-main. Laden Sie die Datendatei in Ihren MATLAB-Workspace.

unzip(fullfile(tempdir,"physionet_ECG_data-main","ECGData.zip"), ...
    fullfile(tempdir,"physionet_ECG_data-main"))
load(fullfile(tempdir,"physionet_ECG_data-main","ECGData.mat"))

ECGData ist ein Strukturarray mit zwei Feldern: Data und Labels. Das Data-Feld ist eine 162-mal-65536-Matrix, bei der jeder Zeile eine mit 128 Hertz abgetastete EKG-Aufzeichnung darstellt. Labels ist ein 162-mal-1-Zellenarray mit diagnostischen Kennzeichnungen, eine pro Zeile von Data. Die drei Diagnosekategorien sind: 'ARR', 'CHF' und 'NSR'.

Um die vorverarbeiteten Daten jeder Kategorie zu speichern, erstellen Sie zunächst in tempdir ein EKG-Datenverzeichnis dataDir. Erstellen Sie daraufhin in 'data' drei Unterverzeichnisse, die nach den EKG-Kategorien benannt sind. Hierfür können Sie die Hilfsfunktion helperCreateECGDirectories verwenden. helperCreateECGDirectories akzeptiert als Eingangsargumente ECGData, den Namen eines EKG-Datenverzeichnis und den Namen eines übergeordneten Verzeichnisses. Sie können tempdir durch ein anderes Verzeichnis ersetzen, auf das Sie Schreibzugriff haben. Den Quellcode für diese Hilfsfunktion finden Sie im Abschnitt Unterstützungsfunktionen am Ende dieses Beispiels.

parentDir = tempdir;
dataDir = "data";
helperCreateECGDirectories(ECGData,parentDir,dataDir)

Stellen Sie ein Beispiel jeder EKG-Kategorie grafisch dar. Hierfür können Sie die Hilfsfunktion helperPlotReps verwenden. helperPlotReps akzeptiert ECGData als Eingabe. Den Quellcode für diese Hilfsfunktion finden Sie im Abschnitt Unterstützungsfunktionen am Ende dieses Beispiels.

helperPlotReps(ECGData)

Erstellen von Zeit-Frequenz-Darstellungen

Haben Sie die Ordner angelegt, erstellen Sie Zeit-Frequenz-Darstellungen der EKG-Signale. Diese Darstellungen werden als Skalogramme bezeichnet. Ein Skalogramm ist der Absolutwert der CWT-Koeffizienten eines Signals.

Um die Skalogramme zu erstellen, berechnen Sie eine CWT-Filterbank vor. Die Vorberechnung der CWT-Filterbank ist die bevorzugte Methode, wenn die CWT vieler Signale mit denselben Parametern errechnet werden soll.

Untersuchen Sie eines dieser Elemente, bevor Sie die Skalogramme generieren. Erstellen Sie eine CWT-Filterbank mithilfe von cwtfilterbank (Wavelet Toolbox) für ein Signal mit 1000 Samples. Verwenden Sie die Filterbank, um die CWT der ersten 1000 Samples des Signals zu errechnen und das Skalogramm aus den Koeffizienten zu erstellen.

Fs = 128;
fb = cwtfilterbank(SignalLength=1000, ...
    SamplingFrequency=Fs, ...
    VoicesPerOctave=12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;
figure
pcolor(t,frq,abs(cfs))
set(gca,"yscale","log")
shading interp
axis tight
title("Scalogram")
xlabel("Time (s)")
ylabel("Frequency (Hz)")

Verwenden Sie die Hilfsfunktion helperCreateRGBfromTF, um die Skalogramme als RGB-Bilder zu erstellen und im entsprechenden Unterverzeichnis in dataDir zu speichern. Den Quellcode für diese Hilfsfunktion finden Sie im Abschnitt Unterstützungsfunktionen am Ende dieses Beispiels. Zur Kompatibilität mit der GoogLeNet-Architektur ist jedes RGB-Bild ein Array mit der Größe 224-mal-224-mal-3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Unterteilen in Trainings- und Validierungsdaten

Laden Sie die Skalogrammbilder als Bild-Datastore. Die Funktion imageDatastore benennt die Bilder automatisch basierend auf den Ordnernamen und speichert die Daten als ImageDatastore-Objekt. Mit einem Bild-Datastore können Sie große Sammlungen von Bilddaten speichern, einschließlich Daten, die nicht in den Speicher passen, und beim Training eines CNN effizient Bildstapel lesen.

allImages = imageDatastore(fullfile(parentDir,dataDir), ...
    "IncludeSubfolders",true, ...
    "LabelSource","foldernames");

Teilen Sie die Bilder zufällig auf zwei Gruppen auf, eine für Training und eine für Validierung. Verwenden Sie 80 % der Bilder für das Training, den Rest für die Validierung.

[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,"randomized");
disp("Number of training images: "+num2str(numel(imgsTrain.Files)))
Number of training images: 130
disp("Number of validation images: "+num2str(numel(imgsValidation.Files)))
Number of validation images: 32

GoogLeNet

Laden

Laden Sie das vortrainierte neuronale Netz GoogLeNet. Wenn das Supportpaket Deep Learning Toolbox™ Model for GoogLeNet Network nicht installiert ist, finden Sie in der Software einen Link zum erforderlichen Supportpaket im Add-On Explorer. Um das Supportpaket zu installieren, klicken Sie auf den Link und daraufhin auf Install (Installieren).

net = imagePretrainedNetwork("googlenet");

Extrahieren Sie das Schichtdiagramm aus dem Netz und zeigen Sie es an.

numberOfLayers = numel(net.Layers);
figure("Units","normalized","Position",[0.1 0.1 0.8 0.8])
plot(net)
title("GoogLeNet Layer Graph: "+num2str(numberOfLayers)+" Layers")

Inspizieren Sie das erste Elemen der Eigenschaft Layers des Netzes. Bestätigen Sie, dass GoogLeNet RGB-Bilder der Größe 224-mal-224-mal-3 benötigt.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

Modifizieren der Netzparameter von GoogLeNet

Jede Schicht in der Netzarchitektur kann als Filter angesehen werden. Die früheren Schichten identifizieren häufigere Merkmale der Bilder, wie Flächen, Ränder und Farben. Die darauffolgenden Schichten konzentrieren sich auf spezifischere Merkmale, um Kategorien zu differenzieren. GoogLeNet ist auf das Klassifizieren von Bildern in 1000 Objektkategorien vortrainiert. Für unser EKG-Klassifizierungsproblem müssen Sie GoogLeNet neu trainieren.

Inspizieren Sie die letzten vier Schichten des Netzes.

net.Layers(end-3:end)
ans = 
  4×1 Layer array with layers:

     1   'pool5-7x7_s1'        2-D Global Average Pooling   2-D global average pooling
     2   'pool5-drop_7x7_s1'   Dropout                      40% dropout
     3   'loss3-classifier'    Fully Connected              1000 fully connected layer
     4   'prob'                Softmax                      softmax

Um eine Überanpassung zu vermeiden, wird eine Dropout-Schicht verwendet. Eine Dropout-Schicht setzt mit einer bestimmten Wahrscheinlichkeit zufällige Eingangselemente auf null. Weitere Informationen finden Sie unter dropoutLayer. Die Standardwahrscheinlichkeit ist 0,5. Ersetzen Sie die letzte Dropout-Schicht im Netz, pool5-drop_7x7_s1, durch eine Dropout-Schicht mit der Wahrscheinlichkeit 0,6.

newDropoutLayer = dropoutLayer(0.6,"Name","new_Dropout");
net = replaceLayer(net,"pool5-drop_7x7_s1",newDropoutLayer);

Die Faltungsschichten des Netzes extrahieren Bildmerkmale. Die letzte lernbare Schicht loss3-classifier in GoogLeNet enthält Informationen darüber, wie die vom Netz extrahierten Merkmale zu Klassenwahrscheinlichkeiten kombiniert werden. Um GoogLeNet zur Klassifizierung der RGB-Bilder neu zu trainieren, ersetzen Sie diese Schicht durch eine neue, an die Daten angepasste Schicht.

Ersetzen Sie die vollständig verknüpfte Schicht loss3-classifier durch eine neue vollständig verknüpfte Schicht, deren Anzahl Filter der Anzahl Klassen entspricht. Um bei den neuen Schichten schneller zu lernen als bei den transferierten Schichten, erhöhen Sie die Lerngeschwindigkeit-Faktoren der vollständig verknüpften Schicht.

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,"Name","new_fc", ...
    "WeightLearnRateFactor",5,"BiasLearnRateFactor",5);
net = replaceLayer(net,"loss3-classifier",newConnectedLayer);

Inspizieren Sie die letzten fünf Schichten. Überprüfen Sie, ob die Dropout- und Faltungsschichten und die vollständig verknüpfte Schicht korrekt ersetzt wurden.

net.Layers(end-3:end)
ans = 
  4×1 Layer array with layers:

     1   'pool5-7x7_s1'   2-D Global Average Pooling   2-D global average pooling
     2   'new_Dropout'    Dropout                      60% dropout
     3   'new_fc'         Fully Connected              3 fully connected layer
     4   'prob'           Softmax                      softmax

Einstellen der Trainingsoptionen und Training von GoogLeNet

Bei dem Training eines neuronalen Netzes handelt es sich um einen iterativen Prozess, bei dem eine Verlustfunktion minimiert wird. Um die Verlustfunktion zu minimieren, wird ein Gradientenabstieg-Algorithmus verwendet. Bei jeder Iteration wird der Gradient der Verlustfunktion ausgewertet und die Gewichtungen des Abstiegsalgorithmus werden aktualisiert.

Das Training kann mithilfe verschiedener Optionen abgestimmt werden. InitialLearnRate gibt die anfängliche Schrittgröße des negativen Gradienten der Verlustfunktion an. MiniBatchSize gibt die Größe der bei jeder Iteration verwendeten Untermenge des Trainingsdatensatzes an. Eine Epoche ist ein vollständiger Durchlauf des Trainingsalgorithmus über den gesamten Trainingsdatensatz hinweg. MaxEpochs gibt die Höchstanzahl der für das Training zu verwendenden Epochen an. Die Auswahl der korrekten Anzahl Epochen ist keine triviale Aufgabe. Wenn Sie die Anzahl Epochen verringern, kommt es zu einer Unteranpassung des Modells, wenn Sie die Anzahl Epochen erhöhen, zu einer Überanpassung.

Verwenden Sie die Funktion trainingOptions, um die Trainingsoptionen anzugeben. Setzen Sie MiniBatchSize auf 15, MaxEpochs auf 20 und InitialLearnRate auf 0,0001. Visualisieren Sie den Trainingsfortschritt, indem Sie Plots auf training-progress setzen. Verwenden Sie den stochastischen Gradientenabstieg mit Momentum-Optimizer. Standardmäßig verwendet das Training eine Grafikkarte, sofern vorhanden. Die Verwendung einer Grafikkarte erfordert die Parallel Computing Toolbox™. Unter GPU Computing Requirements (Parallel Computing Toolbox) können Sie herausfinden, welche Grafikkarten unterstützt werden.

options = trainingOptions("sgdm", ...
    MiniBatchSize=15, ...
    MaxEpochs=20, ...
    InitialLearnRate=1e-4, ...
    ValidationData=imgsValidation, ...
    ValidationFrequency=10, ...
    Verbose=true, ...
    Plots="training-progress", ...
    Metrics="accuracy");

Trainieren Sie das Netz. Der Trainingsprozess nimmt bei einer Desktop-CPU üblicherweise 1–5 Minuten in Anspruch. Kann eine Grafikkarte verwendet werden, sind die Laufzeiten schneller. Im Befehlsfenster werden während dem Lauf Trainingsdaten angezeigt. Die Ergebnisse umfassen die Epochenzahl, die Iterationszahl, die vergangene Zeit, die Minibatch-Genauigkeit, die Validierungsgenauigkeit und den Verlustfunktionswert für die Validierungsdaten.

trainedGN = trainnet(imgsTrain,net,"crossentropy",options);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    ValidationLoss    TrainingAccuracy    ValidationAccuracy
    _________    _____    ___________    _________    ____________    ______________    ________________    __________________
            0        0       00:00:07       0.0001                            1.3444                                    46.875
            1        1       00:00:07       0.0001          1.7438                                    40                      
           10        2       00:00:37       0.0001          1.7555            1.1047                  40                  62.5
           20        3       00:01:03       0.0001         0.75169           0.68252              66.667                 68.75
           30        4       00:01:27       0.0001         0.74739           0.52126              73.333                78.125
           40        5       00:01:50       0.0001         0.49647           0.43025                  80                84.375
           50        7       00:02:13       0.0001         0.27949           0.36374              93.333                  87.5
           60        8       00:02:33       0.0001         0.15129           0.36825              93.333                84.375
           70        9       00:02:50       0.0001         0.15792           0.29109                 100                  87.5
           80       10       00:03:07       0.0001          0.3697           0.30388              93.333                90.625
           90       12       00:03:28       0.0001           0.159           0.25558                 100                90.625
          100       13       00:03:47       0.0001         0.02107           0.25558                 100                90.625
          110       14       00:04:06       0.0001         0.17743            0.2531              93.333                90.625
          120       15       00:04:27       0.0001        0.086914           0.23932                 100                90.625
          130       17       00:04:48       0.0001         0.13208           0.24259              93.333                90.625
          140       18       00:05:12       0.0001        0.025648           0.20339                 100                 93.75
          150       19       00:05:36       0.0001         0.17878           0.19556              93.333                 93.75
          160       20       00:06:01       0.0001        0.050998           0.21189                 100                 93.75
Training stopped: Max epochs completed

Beurteilen der Genauigkeit von GoogLeNet

Beurteilen Sie das Netz mithilfe der Validierungsdaten.

classNames = categories(imgsTrain.Labels);
scores = minibatchpredict(trainedGN,imgsValidation);
YPred = scores2label(scores,classNames);
accuracy = mean(YPred==imgsValidation.Labels);
disp("GoogLeNet Accuracy: "+num2str(100*accuracy)+"%")
GoogLeNet Accuracy: 93.75%

Die Genauigkeit entspricht der im Trainings-Visualisierungsbild gemeldeten Validierungsgenauigkeit. Die Skalogramme wurden in Trainings- und Validierungsdatensätze aufgeteilt. Beide Datensätze wurden zum Training von GoogLeNet verwendet. Die ideale Methode zur Beurteilung des Trainingsergebnisses ist es, das Netz unbekannte Daten klassifizieren zu lassen. Da nicht genügend Daten vorliegen, um diese auf Training, Validierung und Test aufzuteilen, behandeln wir die berechnete Validierungsgenauigkeit als Netzgenauigkeit.

Erkunden der GoogLeNet-Aktivierungen

Jede Schicht eines CNN erzeugt für ein Eingangsbild eine Reaktion oder Aktivierung. Nur wenige Schichten in einem CNN eignen sich jedoch für die Extraktion von Bildmerkmalen. Inspizieren Sie die ersten fünf Schichten des trainierten Netzes.

trainedGN.Layers(1:5)
ans = 
  5×1 Layer array with layers:

     1   'data'             Image Input                   224×224×3 images with 'zerocenter' normalization
     2   'conv1-7x7_s2'     2-D Convolution               64 7×7×3 convolutions with stride [2  2] and padding [3  3  3  3]
     3   'conv1-relu_7x7'   ReLU                          ReLU
     4   'pool1-3x3_s2'     2-D Max Pooling               3×3 max pooling with stride [2  2] and padding [0  1  0  1]
     5   'pool1-norm1'      Cross Channel Normalization   cross channel normalization with 5 channels per element

Die Schichten am Anfang des Netzes erfassen grundlegende Bildmerkmale wie Ränder und Flächen. Um dies zu betrachten, visualisieren Sie die Netzfilter-Gewichtungen der ersten Faltungsschicht. Die erste Schicht umfasst 64 Sätze Gewichtungen.

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,8);
figure
I = imtile(wghts,GridSize=[8 8]);
imshow(I)
title("First Convolutional Layer Weights")

Sie könnend die Aktivierungen untersuchen und erkunden, welche Merkmale GoogLeNet lernt, indem Sie die Aktivierungsbereiche mit dem Ursprungsbild vergleichen. Weitere Informationen finden Sie unter Visualize Activations of a Convolutional Neural Network und Visualize Features of a Convolutional Neural Network.

Untersuchen Sie anhand eines Bildes aus der Klasse ARR, welche Bereiche in den Faltungsschichten aktiviert werden. Vergleichen Sie dies mit dem entsprechenden Bereichen im Ursprungsbild. Jede Schicht eines Convolutional Neural Network besteht aus vielen 2D-Arrays, die als Kanäle bezeichnet werden. Lassen Sie das Bild vom Netz verarbeiten und untersuchen Sie die Ausgangsaktivierungen der ersten Faltungsschicht conv1-7x7_s2.

convLayer = "conv1-7x7_s2";

imgClass = "ARR";
imgName = "ARR_10.jpg";
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = predict(trainedGN,single(imarr),Outputs=convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
I = imtile(rescale(trainingFeaturesARR),GridSize=[8 8]);
imshow(I)
title(imgClass+" Activations")

Finden Sie den stärksten Kanal für dieses Bild. Vergleichen Sie den stärksten Kanal mit dem Ursprungsbild.

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure
I = imtile({imarr,arrMax});
imshow(I)
title("Strongest "+imgClass+" Channel: "+num2str(maxValueIndex))

SqueezeNet

SqueezeNet ist ein tiefes CNN, dessen Architektur Bilder mit der Größe 227-mal-227-mal-3 unterstützt. Obwohl die Abmessungen des Bildes sich bei GoogLeNet unterscheiden, müssen Sie keine neuen RGB-Bilder mit den SqueezeNet-Abmessungen generieren. Sie können die ursprünglichen RGB-Bilder verwenden.

Laden

Laden Sie das vortrainierte neuronale Netz SqueezeNet.

netsqz = imagePretrainedNetwork("squeezenet");

Extrahieren Sie das Schichtdiagramm aus dem Netz. Bestätigen Sie, dass SqueezeNet weniger Schichten als GoogLeNet aufweist. Überprüfen Sie zudem, ob SqueezeNet für Bilder der Größe 227-mal-227-mal-3 konfiguriert ist.

disp("Number of Layers: "+num2str(numel(netsqz.Layers)))
Number of Layers: 68
netsqz.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [227 227 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [1×1×3 single]

Modifizieren der Netzparameter von SqueezeNet

Um SqueezeNet für die Klassifizierung neuer Bilder neu zu trainieren, nehmen Sie ähnliche Veränderungen wie bei GoogLeNet vor.

Inspizieren Sie die letzten fünf Netzschichten.

netsqz.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'conv10'         2-D Convolution              1000 1×1×512 convolutions with stride [1  1] and padding [0  0  0  0]
     2   'relu_conv10'    ReLU                         ReLU
     3   'pool10'         2-D Global Average Pooling   2-D global average pooling
     4   'prob'           Softmax                      softmax
     5   'prob_flatten'   Flatten                      Flatten

Ersetzen Sie die letzte Dropout-Schicht im Netz durch eine Dropout-Schicht mit der Wahrscheinlichkeit 0,6.

tmpLayer = netsqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,"Name","new_dropout");
netsqz = replaceLayer(netsqz,tmpLayer.Name,newDropoutLayer);

Im Gegensatz zu GoogLeNet ist die letzte lernbare Schicht in SqueezeNet eine 1-mal-1-Faltungsschicht conv10 und keine vollständig verknüpfte Schicht. Ersetzen Sie die Schicht durch eine neue Faltungsschicht, deren Anzahl Filter der Anzahl Klassen entspricht. Erhöhen Sie wie bei GoogLeNet die Lerngeschwindigkeits-Faktoren der neuen Schicht.

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = netsqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        "Name","new_conv", ...
        "WeightLearnRateFactor",10, ...
        "BiasLearnRateFactor",10);
netsqz = replaceLayer(netsqz,tmpLayer.Name,newLearnableLayer);

Inspizieren Sie die letzten fünf Schichten des Netzes. Überprüfen Sie, ob sich die Dropout- und Faltungsschichten geändert haben.

netsqz.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'new_conv'       2-D Convolution              3 1×1 convolutions with stride [1  1] and padding [0  0  0  0]
     2   'relu_conv10'    ReLU                         ReLU
     3   'pool10'         2-D Global Average Pooling   2-D global average pooling
     4   'prob'           Softmax                      softmax
     5   'prob_flatten'   Flatten                      Flatten

Vorbereiten von RGB-Daten für SqueezeNet

Die RGB-Bilder weisen Abmessungen auf, die für die GoogLeNet-Architektur geeignet sind. Erstellen Sie erweiterte Bild-Datastores, die die Größe der bestehenden RGB-Bilder automatisch an die SqueezeNet-Architektur anpassen. Weitere Informationen finden Sie unter augmentedImageDatastore.

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Einstellen der Trainingsoptionen und Training von SqueezeNet

Erstellen Sie einen neuen Satz Trainingsoptionen zur Verwendung mit SqueezeNet und trainieren Sie das Netz.

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ...
    MaxEpochs=maxEpochs, ...
    InitialLearnRate=ilr, ...
    ValidationData=augimgsValidation, ...
    ValidationFrequency=valFreq, ...
    Verbose=1, ...
    Plots="training-progress", ...
    Metrics="accuracy");

trainedSN = trainnet(augimgsTrain,netsqz,"crossentropy",opts);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    ValidationLoss    TrainingAccuracy    ValidationAccuracy
    _________    _____    ___________    _________    ____________    ______________    ________________    __________________
            0        0       00:00:01       0.0003                            2.7267                                        25
            1        1       00:00:01       0.0003          3.0502                                    30                      
           13        1       00:00:05       0.0003         0.93269           0.81717                  60                78.125
           26        2       00:00:10       0.0003          0.6929           0.62475                  70                 81.25
           39        3       00:00:15       0.0003         0.55664           0.54038                  70                84.375
           50        4       00:00:19       0.0003        0.075004                                   100                      
           52        4       00:00:20       0.0003         0.27402           0.51236                  90                 81.25
           65        5       00:00:25       0.0003         0.15558           0.72845                  90                 81.25
           78        6       00:00:27       0.0003         0.29531           0.58038                  90                 81.25
           91        7       00:00:30       0.0003        0.053372           0.53191                 100                 81.25
          100        8       00:00:32       0.0003        0.019003                                   100                      
          104        8       00:00:33       0.0003         0.23475           0.22768                  80                 93.75
          117        9       00:00:37       0.0003        0.059982           0.15849                 100                96.875
          130       10       00:00:43       0.0003        0.038729           0.20219                 100                90.625
          143       11       00:00:46       0.0003       0.0059834           0.26095                 100                90.625
          150       12       00:00:47       0.0003        0.002025                                   100                      
          156       12       00:00:48       0.0003       0.0067973           0.16036                 100                96.875
          169       13       00:00:50       0.0003       0.0086382           0.17935                 100                96.875
          182       14       00:00:52       0.0003       0.0020118           0.21593                 100                 93.75
          195       15       00:00:54       0.0003       0.0061499           0.22566                 100                 93.75
Training stopped: Max epochs completed

Beurteilen der Genauigkeit von SqueezeNet

Beurteilen Sie das Netz mithilfe der Validierungsdaten.

scores = minibatchpredict(trainedSN,augimgsValidation);
YPred = scores2label(scores,classNames);
accuracy = mean(YPred==imgsValidation.Labels);
disp("SqueezeNet Accuracy: "+num2str(100*accuracy)+"%")
SqueezeNet Accuracy: 96.875%

Abschluss

Dieses Beispiel demonstriert die Verwendung von Transfer Learning und Continuous-Wavelet-Analyse zur Klassifizierung von drei Klassen von EKG-Signalen durch Verwendung der vortrainierten CNNs GoogLeNet und SqueezeNet. Wavelet-basierte Zeit-Frequenz-Darstellungen von EKG-Signalen werden zur Erstellung von Skalogrammen verwendet. Es werden RGB-Bilder der Skalogramme generiert. Die Bilder werden zur Feinabstimmung beider tiefer CNNs verwendet. Ebenfalls wurde die Aktivierung unterschiedlicher Netzschichten erkundet.

Bei diesem Beispiel wird ein möglicher Workflow veranschaulicht, den Sie zur Klassifizierung von Signalen mithilfe vortrainierter CNN-Modelle verwenden können. Andere Workflows sind ebenfalls möglich. Deploy Signal Classifier on NVIDIA Jetson Using Wavelet Analysis and Deep Learning (Wavelet Toolbox) und Deploy Signal Classifier Using Wavelets and Deep Learning on Raspberry Pi (Wavelet Toolbox) zeigen, wie Sie Code zur Signalklassifizierung auf Hardware bereitstellen können. GoogLeNet und SqueezeNet sind Modelle, die anhand einer Untermenge der ImageNet-Datenbank [10] vortrainiert wurden. Diese Datenbank wird für die ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) [8] verwendet. Die ImageNet-Datenbank umfasst Bilder realer Objekte wie Fische, Vögel, Haushaltsgeräte und Pilze. Skalogramme gehören nicht zur Klasse realer Objekte. Die Skalogramme wurden zudem zur Kompatibilität mi der GoogLeNet- und SqueezeNet-Architektur einer Datenreduktion unterzogen. Statt vortrainierte CNNs abzustimmen, um unterschiedliche Klassen von Skalogrammen zu differenzieren, ist es auch möglich, ein CNN von Grund auf für die ursprünglichen Skalogramm-Abmessungen zu trainieren.

Referenzen

  1. Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman, and E. Braunwald. "Survival of patients with severe congestive heart failure treated with oral milrinone." Journal of the American College of Cardiology. Vol. 7, Number 3, 1986, S. 661–670.

  2. Engin, M. "ECG beat classification using neuro-fuzzy network." Pattern Recognition Letters. Vol. 25, Number 15, 2004, S.1715–1722.

  3. Goldberger A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals." Circulation. Vol. 101, Number 23: e215–e220. [Circulation Electronic Pages; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (June 13). doi: 10.1161/01.CIR.101.23.e215.

  4. Leonarduzzi, R. F., G. Schlotthauer, and M. E. Torres. "Wavelet leader based multifractal analysis of heart rate variability during myocardial ischaemia." In Engineering in Medicine and Biology Society (EMBC), Annual International Conference of the IEEE, 110–113. Buenos Aires, Argentina: IEEE, 2010.

  5. Li, T., and M. Zhou. "ECG classification using wavelet packet entropy and random forests." Entropy. Vol. 18, Number 8, 2016, S.285.

  6. Maharaj, E. A., and A. M. Alonso. "Discriminant analysis of multivariate time series: Application to diagnosis based on ECG signals." Computational Statistics and Data Analysis. Vol. 70, 2014, S. 67–87.

  7. Moody, G. B., and R. G. Mark. "The impact of the MIT-BIH Arrhythmia Database." IEEE Engineering in Medicine and Biology Magazine. Vol. 20. Number 3, May-June 2001, S. 45–50. (PMID: 11446209)

  8. Russakovsky, O., J. Deng, and H. Su et al. "ImageNet Large Scale Visual Recognition Challenge." International Journal of Computer Vision. Vol. 115, Number 3, 2015, S. 211–252.

  9. Zhao, Q., and L. Zhang. "ECG feature extraction and classification using wavelet transform and support vector machines." In IEEE International Conference on Neural Networks and Brain, 1089–1092. Beijing, China: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

Unterstützungsfunktionen

helperCreateECGDataDirectories erstellt ein Datenverzeichnis in einem übergeordneten Verzeichnis und daraufhin drei Unterverzeichnisse im Datenverzeichnis. Die Unterverzeichnisse sind nach den EKG-Signalklassen in ECGData benannt.

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps stellt die ersten tausend Samples eines Beispiels jeder EKG-Signalklasse in ECGData grafisch dar.

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTF verwendet cwtfilterbank (Wavelet Toolbox), um die Continuous-Wavelet-Transformation der EKG-Signale zu berechnen und generiert die Skalogramme aus den Wavelet-Koeffizienten. Die Hilfsfunktion passt die Größe der Skalogramme an und speichert sie als JPEG-Bilder auf der Festplatte.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank(SignalLength=signalLength,VoicesPerOctave=12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(round(rescale(cfs,0,255)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = char(labels(ii))+"_"+num2str(ii)+".jpg";
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end

Siehe auch

(Wavelet Toolbox) | | | | | |

Themen