How can I leverage the Bayesian Optimization framework to find the optimal hyperparameters for a non-image training task?

1 Ansicht (letzte 30 Tage)
I would like to leverage the Bayesian Optimization framework described in the following documentation page,
to find the optimal hyperparameters for training a network in order to approximate a nonlinear function 'y = f(x)'.
How can I leverage the Bayesian Optimization framework to find the optimal hyperparameters for a non-image training task?

Akzeptierte Antwort

MathWorks Support Team
MathWorks Support Team am 30 Okt. 2020
It is provided below a script which leverages the Bayesian Optimization framework to find optimal hyperparameters, herein the number of epochs, the initial learning rate and the number of neurons of an intermediate layer for the approximation of nonlinear function 'y(x) = x^3':

%% Definition of the function that needs to be approximated
fnc = @(x) x.^3;
%% Definition of the training data
xTrain = linspace(-1, 1, 80)';
yTrain = fnc(xTrain);
%% Definition of the validation data
numRand = 20;
xValidation = sort(2.*rand(numRand, 1) - 1);
yValidation = fnc(xValidation);
%% Definition of the design variables
optimVars = [
    optimizableVariable('epochs', [100 10000], 'Type', 'integer')
    optimizableVariable('InitialLearnRate', [1e-4 1], 'Transform', 'log')
    optimizableVariable('numberOfNeurons', [1 100], 'Type', 'integer')];
%% Objective function for the Bayesian optimization
ObjFcn = makeObjFcn(xTrain, yTrain, xValidation, yValidation);
%% Perform bayesian optimization to find the optimal parameters
BayesObject = bayesopt(ObjFcn, optimVars, ...
    'MaxTime', 14*60*60, ...
    'IsObjectiveDeterministic', false, ...
    'UseParallel', false);
%% Definition of the objective function
function ObjFcn = makeObjFcn(XTrain, YTrain, XValidation, YValidation)
    %% Assign the output of the objective function
    ObjFcn = @valErrorFun;
    
    %% Definition of the objective function
    function [valError, cons, fileName] = valErrorFun(optVars)
        %% Definition of the layer architecture in dependence to the design variables
        layers = [ ...
            featureInputLayer(1, "Name", "myFeatureInputLayer", 'Normalization','rescale-symmetric')
            fullyConnectedLayer(optVars.numberOfNeurons, "Name", "myFullyConnectedLayer1")
            tanhLayer("Name", "myTanhLayer")
            fullyConnectedLayer(1, "Name", "myFullyConnectedLayer2")
            regressionLayer("Name", "myRegressionLayer")
        ];
        %% Definition of the training options in dependence to the design variables
        options = trainingOptions('adam', ...
            'MaxEpochs', optVars.epochs, ... % first design parameter
            'InitialLearnRate', optVars.InitialLearnRate,... % second design parameter
            'Shuffle', 'every-epoch', ...
            'MiniBatchSize', 128, ...
            'Verbose', false); % 'Plots', 'training-progress', ...
        
        %% Train the network for the actual optimization step
        [trainedNet, ~] = trainNetwork(XTrain, YTrain, layers, options);
        close(findall(groot, 'Tag', 'NNET_CNN_TRAININGPLOT_UIFIGURE'))
        
        %% Perform prediction on the provided validation data for the current optimization step
        YPredicted = predict(trainedNet, XValidation);
        
        %% Computation of the error between the expected and the predicted solution of the trained network using the validation data
        valError = norm(YPredicted - YValidation);
        
        %% Save the results of the current optimization step
        fileName = num2str(valError) + ".mat";
        save(fileName, 'trainedNet', 'valError', 'options')
        cons = [];
    end
end
This script should only serve as a proof of concept and should not be considered as the best possible setting of the corresponding Bayesian optimization task.

Weitere Antworten (0)

Kategorien

Mehr zu Sequence and Numeric Feature Data Workflows finden Sie in Help Center und File Exchange

Tags

Noch keine Tags eingegeben.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by