Filter löschen
Filter löschen

Can you provide me suggestions/critique my approach to this Neural Network fitting?

1 Ansicht (letzte 30 Tage)
I am hopping I can get some constructive suggestions on how to improve my code or if you guys think there is a better way to approach what I am trying to do or if there is any other area I should investigate. Much appreciated in advance.
So what I am doing is creating a neural network to fit a variable as a function of some other 7 variables. I have a gigantic tabulation of those 8 variables:
ANN=fitnet([50],'trainrp');
ANN=train(ANN,Input,Output);
Where Input is the tabulation of the 7 variables and Output is the tabulation of the variable. The end goal here is to give the ANN any 7 random combination of the Input and it will give me an accurate estimation of the Output (linear interpolation). However, I am doing this process iteratively. What I mean is that after training this neural network on fitting the data, I then generate a new and different table and use the neural network to verify if it can predict the output values within a 10% percent error. If it cannot, I take the rows of data where it didn't do well and I add them to the original table, and repeat the command:
ANN=train(ANN,Input,Output);
But now Input and Output are the original table plus the data from the new table where the neural network didn't do very well. And I keep repeating this process over and over and over (automated process, not manual).

Akzeptierte Antwort

Mrutyunjaya Hiremath
Mrutyunjaya Hiremath am 17 Aug. 2023
Try this:
% Load your data
% Example: load('your_data.mat');
% Assuming you have Input and Output as your data matrices
% Normalize Inputs
meanInput = mean(Input,2);
stdInput = std(Input,0,2);
Input = (Input - meanInput) ./ stdInput;
% Split data into training and validation sets
[trainInd,~,valInd] = dividerand(size(Input,2),0.7,0,0.3);
trainInput = Input(:,trainInd);
trainOutput = Output(:,trainInd);
valInput = Input(:,valInd);
valOutput = Output(:,valInd);
% Initialize the neural network with multiple hidden layers
% For instance, one with 100 neurons and another with 50
hiddenLayers = [100, 50];
ANN = fitnet(hiddenLayers,'trainrp');
% Set early stopping parameters
ANN.divideParam.trainRatio = 0.7;
ANN.divideParam.valRatio = 0.3;
ANN.divideParam.testRatio = 0;
% Regularization (to prevent overfitting)
ANN.performParam.regularization = 0.1;
% Convergence criteria
threshold = 0.02; % stop if mean relative error is below this
max_iterations = 20; % maximum number of training iterations
prevError = inf; % initialize with a high value
tolerance = 0.001; % minimal change to consider convergence
for i = 1:max_iterations
% Training the neural network
[ANN, tr] = train(ANN, trainInput, trainOutput);
% Validate
predictions = ANN(valInput);
error = abs(predictions - valOutput) ./ valOutput; % relative error
meanError = mean(error);
% Check convergence
if meanError < threshold || abs(prevError - meanError) < tolerance
break; % convergence achieved
end
% If the error on validation data is high, add it to training set
highErrorIndices = find(error > 0.1);
trainInput = [trainInput, valInput(:,highErrorIndices)];
trainOutput = [trainOutput, valOutput(:,highErrorIndices)];
prevError = meanError;
end
Change according to your input and output data.
  1 Kommentar
Ali Almakhmari
Ali Almakhmari am 17 Aug. 2023
Bearbeitet: Ali Almakhmari am 17 Aug. 2023
I am already doing something similar, except few things. First, I never normalized my data, which is a good idea. But the thing is, I am not sure if I should do it because my trasnferFcn of the first hidden layer is the "tansig", which I believe already does the normalization (although not 100% sure). Second, I never considered multiple layers, so thats a good suggestion to investigate. Third, I never considered "ANN.performParam.regularization", I will try to read more about it. Also, why did you not scale the output?

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange

Produkte


Version

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by