Audio Transfer Learning Using Experiment Manager
This example shows how to configure an experiment that compares the performance of multiple pretrained networks when applied to a speech command recognition task using transfer learning. It highlights Experiment Manager's capability to tune hyperparameters and easily compare results between the different pretrained networks using both built-in and user-defined metrics.
Audio Toolbox™ provides a variety of pretrained networks for audio processing, and each consists of a different architecture that requires different data pre-processing. These differences result in tradeoffs between the accuracy, speed, and size of the various networks. Experiment Manager organizes the results of training experiments to highlight the strengths and weaknesses of each individual network so you can select the network that best fits your constraints.
The example compares the performance of the YAMNet (Audio Toolbox) and VGGish (Audio Toolbox) pretrained networks, as well as a custom-designed network that is trained from scratch. See Deep Network Designer to explore other pretrained network options supported by Audio Toolbox™.
In this example you will download the Google Speech Commands Dataset  and the pretrained networks and store them in your temp directory if they are not already present. The dataset takes up 1.96 GB of disk space and the networks in total take up 470 MB.
Open Experiment Manager
Load the example by clicking the Open Example button. This opens the project in Experiment Manager in your MATLAB editor.
Built-in training experiments consist of a description, a table of hyperparameters, a setup function, and a collection of metric functions to evaluate the results of the experiment. For more information, see Configure Built-In Training Experiment.
The Description field contains a textual description of the experiment.
The Hyperparameters section specifies the strategy (Exhaustive Sweep) and hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. This example demonstrates how to test the different network types. Define one hyperparameter, Network, to represent the network names stored as strings.
The Setup Function field contains the name of the main function that configures the training data, network architecture, and training options for the experiment. The input to the setup function is a structure with fields from the hyperparameter table. The setup function returns the training data, network architecture, and training parameters as outputs. This has already been implemented for you.
The Metrics list enables you to define your own custom metrics to compare across different trials of the training experiment. A couple of example custom metric functions are defined for you later in this example. Experiment Manager runs each of the listed metrics against the networks trained in each trial. The metrics defined for you in this example are listed here. Any additional custom metric you intend to use must be listed in this section.
Define Setup Function
In this example, the Setup Function downloads the dataset, selects the desired network, performs the requisite data pre-processing, and sets the network training options. The input to this function is a structure with fields for each of the hyperparameters defined in the Experiment Manager interface. In the Setup Function for this example the input variable is named
params and the output variables are named
options representing the training data, network structure, and training parameters, respectively. The key steps of the Setup Function for this example are explained below. Open the example in MATLAB to see the full definition of
compareNetSetup, the name of the Setup Function used in this example.
Download and Extract Data
To speed up the example, open
compareNetSetup and toggle the
speedUp flag to
true. This reduces the size of the dataset to quickly test the basic functionality of the experiment.
speedUp = false;
The helper function
setupDatastores downloads the Google Speech Commands Dataset , selects the commands for networks to recognize, and randomly partitions the data into training and validation datastores.
[adsTrain,adsValidation] = setupDatastores(speedUp);
Select the Desired Network and Preprocess Data
Initially transform the datastores based on the preprocessing required by the network type defined in the hyperparameter table, which is accessed as
params.Network. The helper function
extractSpectrogram processes the input data to the format expected by each respective network type. The helper function
getLayers returns a
layerGraph object that represents the architecture of the desired network.
tdsTrain = transform(adsTrain,@(x)extractSpectrogram(x,params.Network)); tdsValidation = transform(adsValidation,@(x)extractSpectrogram(x,params.Network));
layers = getLayers(classes,classWeights,numClasses,netName);
Now that the datastores are properly set up, read the data into the
trainingData = readall(tdsTrain,UseParallel=canUseParallelPool); validationData = readall(tdsValidation,UseParallel=canUseParallelPool);
validationData = table(validationData(:,1),adsValidation.Labels); trainingData = table(trainingData(:,1),adsTrain.Labels);
Set the Training Options
Set the training parameters by assigning a
trainingOptions object into the
options output variable. Train the networks for a maximum of 30 epochs with a patience of 8 epochs using the Adam optimizer. Set the
ExecutionEnvironment field to "auto" to use a GPU if available. Without using a GPU, training may be very time consuming.
maxEpochs = 30; miniBatchSize = 256; validationFrequency = floor(numel(TTrain)/miniBatchSize); options = trainingOptions("adam", ... GradientDecayFactor=0.7, ... InitialLearnRate=params.LearnRate, ... MaxEpochs=maxEpochs, ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... Plots="training-progress", ... Verbose=false, ... ValidationData=validationData, ... ValidationFrequency=validationFrequency, ... ValidationPatience=10, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.2, ... LearnRateDropPeriod=round(maxEpochs/3), ... ExecutionEnvironment="auto");
Define Custom Metrics
Experiment Manager enables you to define custom metric functions to evaluate the performance of the networks trained in each trial. Basic metrics like accuracy and loss are computed by default. In this example you compare the size of each of the models as memory usage is an important metric when deploying deep neural networks to real-world applications.
Custom metric functions must take one input argument
trialInfo which is a structure containing the fields
DAGNetworkobject returned by the
trainingInfois a struct containing the training information returned by the
parametersis a struct with fields from the hyperparameter table
The metric functions must return a scalar number, logical output, or string which gets displayed in the results table. The custom metrics defined for you in this experiment are listed below:
sizeMBcomputes the memory allocated to store the networks in megabytes
numLearnableParamscounts the number of learnable parameters within each model
numIterscomputes the number of mini-batches each network trained on before hitting either
MaxEpochsor violating the
ValidationPatienceparameter in the
Press 'Run' in the top pane of the Experiment Manager app to run the experiment. You can select to either run each trial sequentially, simultaneously, or in batches by toggling the mode option. For this experiment, the trials were run sequentially.
When the experiment finishes, the results for each trial appear and the metrics are displayed in tabular format. The progress bar shows how many epochs each network trained for before violating the patience parameter in terms of the percentage of
The table can be sorted by entries in each column by hovering over the right side of the column name cell and clicking the arrow that appears. Click the table icon on the top right to select which columns to show or hide. To first compare the networks by accuracy, sort the table over the Validation Accuracy in descending order.
In terms of accuracy, the
Yamnet network performs the best followed by
VGGish, and lastly the custom network. However, the Elapsed Time column shows that
Yamnet takes the longest to train. To compare the size of these networks, sort the table by the sizeMB column.
The custom network is the smallest,
Yamnet is a few orders of magnitude larger, and
VGGish is the largest.
These results highlight the tradeoffs between the different network designs. The
Yamnet network performs the best at the classification task at the cost of more training time and a moderately large memory consumption. The
VGGish network performs slightly worse in terms of accuracy but requires over 20 times more memory than
YAMNet. Lastly, the custom network has the worst accuracy by a small margin but also uses the least memory.
Notice that even though
VGGish are pretrained networks, the custom network converges the fastest. Looking at the NumIters column, the custom network takes the most batch iterations to converge because it is learning from scratch. But, since the custom network is much smaller and shallower than the deep pretrained models, each of these batch updates are processed much faster so the overall training time is reduced.
To save one of the trained networks from any of the trials, right click on the corresponding row in the results table and select Export Trained Network.
To further analyze any of the individual trials, single click on the corresponding row, and under the Review Results tab in the top pane, you can choose to bring up a plot of the training progress or a confusion matrix of the resulting trained model. Below shows the confusion matrix for the
Yamnet model from trial 2 of the experiment.
The model struggles most at differentiating between the pair of commands "off" and "up" as well as the pair "no" and "go", although the accuracy is generally uniform across all classes. Further, the model is very confident in predicting the "yes" command as the false positive rate for that class is only .4%.
 Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. The Speech Commands Dataset is licensed under the Creative Commons Attribution 4.0 license, available here: https://creativecommons.org/licenses/by/4.0/legalcode.