Was versteht man unter einer Überanpassung?
Der Begriff „Überanpassung“ bezeichnet ein Verhalten beim Machine Learning, das auftritt, wenn das Modell so eng an die Trainingsdaten angepasst ist, dass es nicht weiß, wie es auf neue Daten reagieren soll. Eine Überanpassung kann aus folgenden Gründen auftreten:
- Das Machine-Learning-Modell ist zu komplex. Es merkt sich sehr subtile Muster in den Trainingsdaten, die sich nicht ohne Weiteres verallgemeinern lassen.
- Der Umfang der Trainingsdaten ist zu klein für die Komplexität des Modells und/oder enthält eine große Menge irrelevanter Informationen.
Eine Überanpassung lässt sich verhindern, indem die Modellkomplexität gesteuert und der Trainingsdatensatz verbessert wird.
Überanpassung vs. Unteranpassung
Die Unteranpassung ist das Gegenteil der Überanpassung. Das entsprechende Modell passt sich dabei nicht optimal an die Trainingsdaten an oder lässt sich unzureichend auf neue Daten übertragen. Sowohl bei Klassifikations- als auch bei Regressionsmodellen kann es zu einer Über- oder Unteranpassung kommen. Die folgende Abbildung veranschaulicht, wie die Klassifizierungsentscheidungsgrenze und die Regressionslinie den Trainingsdaten bei einem überangepassten Modell zu eng und bei einem unterangepassten Modell nicht eng genug folgen.
Wenn man nur den berechneten Fehler eines Machine-Learning-Modells für die Trainingsdaten betrachtet, ist eine Überanpassung schwieriger zu erkennen als eine Unteranpassung. Um also eine Überanpassung zu vermeiden, ist es wichtig, ein Machine-Learning-Modell zu validieren, bevor man es auf Testdaten anwendet.
Fehler |
Überanpassung |
Idealanpassung |
Unteranpassung |
Training |
Gering |
Gering |
Hoch |
Test |
Hoch |
Gering |
Hoch |
Mithilfe von MATLAB® zusammen mit der Statistics and Machine Learning Toolbox™ und der Deep Learning Toolbox™ können Sie eine Überanpassung von Machine-Learning- und Deep-Learning-Modellen verhindern. So bietet MATLAB verschiedene Funktionen und Methoden, die speziell zur Vermeidung einer Überanpassung von Modellen entwickelt wurden. Diese Tools können Sie verwenden, wenn Sie Ihr Modell trainieren oder abstimmen, um es vor Überanpassung zu schützen.
Vermeiden von Überanpassung durch Verringerung der Modellkomplexität
Mit MATLAB können Sie Machine-Learning-Modelle und Deep-Learning-Modelle (wie z. B. CNNs) von Grund auf trainieren oder die Vorteile von vortrainierten Deep-Learning-Modellen nutzen. Zur Vermeidung einer Überanpassung sollten Sie eine Modellvalidierung durchführen, mit der Sie sicherstellen, dass Sie ein Modell mit dem für Ihre Daten geeigneten Komplexitätsgrad wählen. Oder Sie reduzieren die Komplexität des Modells durch Regularisierung.
Modellvalidierung
Der Fehler eines überangepassten Modells ist gering, wenn er für die Trainingsdaten berechnet wird. Es empfiehlt sich, vor der Einführung neuer Daten eine Validierung des Modells anhand eines separaten Datensatzes (d. h. eines Validierungsdatensatzes) vorzunehmen. Bei MATLAB-Modellen zum Machine Learning können Sie die Funktion cvpartition
verwenden, um einen Datensatz nach dem Zufallsprinzip in Trainings- und Validierungssätze zu unterteilen. Für Deep-Learning-Modelle können Sie die Validierungsgenauigkeit während des Trainings überwachen. Die Verbesserung der sorgfältig validierten Genauigkeitsmessung für Ihre Modelle durch die Modellauswahl und die Abstimmung der Hyperparameter sollte sich in einer höheren Genauigkeit niederschlagen, wenn das Modell neue Daten erhält.
Die Kreuzvalidierung ist ein Verfahren zur Bewertung von Modellen, mit dem die Leistung eines Algorithmus für das Machine Learning bei Vorhersagen für Datensätze bewertet wird, für die er nicht trainiert wurde. Dabei ermöglicht die Kreuzvalidierung die Auswahl eines nicht zu komplexen Algorithmus, der zu einer Überanpassung führen würde. Verwenden Sie die Funktion crossval
, um die Schätzung des Kreuzvalidierungsfehlers für Machine-Learning-Modelle zu berechnen, indem Sie gängige Kreuzvalidierungstechniken verwenden, wie z. B. k-fold (unterteilt Daten in k zufällig gewählte Teilmengen von ungefähr gleicher Größe) und Holdout (unterteilt Daten zufällig in genau zwei Teilmengen mit einem bestimmten Verhältnis).
Regularisierung
Die Regularisierung ist eine Technik, die dazu dient, eine statistische Überanpassung in einem Machine-Learning-Modell zu verhindern. Regularisierungsalgorithmen arbeiten in der Regel mit einem Malus für Komplexität oder Unregelmäßigkeit. Durch die Einführung zusätzlicher Informationen in das Modell können Regularisierungsalgorithmen die Multikollinearität und redundante Prädiktoren bewältigen, indem sie das Modell schlanker und genauer machen.
Für das Machine Learning können Sie zwischen drei weit verbreiteten Regularisierungstechniken wählen: Lasso (L1-Norm), Ridge (L2-Norm) und Elastic Net, bei denen es mehrere Arten von linearen Machine-Learning-Modellen gibt. Zur Vermeidung von Überanpassungen beim Deep Learning können Sie den L2-Regularisierungsfaktor in den angegebenen Trainingsoptionen erhöhen oder Dropout-Schichten in Ihrem Netzwerk verwenden.
Beispiele und Erläuterungen
Vermeiden einer Überanpassung durch Optimierung des Trainingsdatensatzes
Indem die Modellkomplexität gesteuert wird, verhindern die Kreuzvalidierung und die Regularisierung eine Überanpassung. Ein weiterer Ansatz ist die Verbesserung des Datensatzes. Insbesondere Deep-Learning-Modelle benötigen große Datenmengen, um eine Überanpassung zu vermeiden.
Datenaugmentierung
Wenn die Datenverfügbarkeit begrenzt ist, ist die Datenaugmentierung eine Methode, um die Datenpunkte des Trainingsdatensatzes künstlich zu erweitern, indem zufällige Versionen der vorhandenen Daten zum Datensatz hinzugefügt werden. Mit MATLAB lassen sich hierfür Bild-, Audio- und andere Datentypen augmentieren. Sie können beispielsweise Bilddaten erweitern, indem Sie den Maßstab und die Ausrichtung vorhandener Bilder zufällig verändern.
Datengenerierung
Die Erzeugung synthetischer Daten ist eine ebenfalls mögliche Methode zur Erweiterung eines Datensatzes. Mit MATLAB können Sie synthetische Daten mithilfe von Generative Adversarial Networks (GANs) oder digitalen Zwillingen (Datengenerierung durch Simulation) erzeugen.
Bereinigung von Daten
Zur Überanpassung trägt darüber hinaus auch das Datenrauschen bei. Ein gängiger Ansatz zur Reduzierung unerwünschter Datenpunkte ist die Entfernung von Datenausreißern mithilfe der Funktion rmoutliers
.