What are the differences between fitrnet and trainnet?

9 Ansichten (letzte 30 Tage)
Matthew
Matthew am 17 Jan. 2025
Kommentiert: Matthew am 30 Jan. 2025
I am trying to perform regression on a set of data X (size 52048 x 4) with responses Y (size 52048 x 1).
If I perform a linear regression using mvregress(X,Y), I get a model with a coefficient of determination of 0.70.
If I train a neural network using fitrnet(X,Y), I get a model with a coefficient of determination of about 0.75.
mdl2 = fitrnet(X,Y);
I would like more customizability than what fitrnet offers, so I am trying to use trainnet. To make sure that I am using it correctly, I wanted to recreate the result from fitrnet using trainnet, so I set up the trainnet input parameters to match the defaults from fitrnet:
layers = [
featureInputLayer(size(X,2))
fullyConnectedLayer(10)
reluLayer
fullyConnectedLayer(1)];
options = trainingOptions('lbfgs', GradientTolerance=1e-6, StepTolerance=1e-6);
mdl3 = trainnet(X,Y,layers,'mse',options);
When I run this, after a few iterations I get a "Training stopped" message and the output model has a coefficient of determination less than 0.30. Sometimes the stoppage message says "Step tolerance reached" and sometimes it says "Suitable learning rate not found".
However, if I initialize the weights and biases for trainnet using the weights and biases learned by fitrnet, trainnet immediately recognizes this as an optimal solution and outputs "Training stopped: Suitable learning rate not found" without making any modification to the weights or biases.
layers = [
featureInputLayer(size(X,2))
fullyConnectedLayer(10, 'Weights', mdl2.LayerWeights{1}, 'Bias', mdl2.LayerBiases{1})
reluLayer
fullyConnectedLayer(1, 'Weights', mdl2.LayerWeights{2}, 'Bias', mdl2.LayerBiases{2})];
mdl4 = trainnet(X,Y,layers,'mse',options);
How is it that trainnet can recognize the solution from fitrnet as optimal, but cannot recreate a similar result using what seems to be identical input parameters?

Antworten (2)

Jaimin
Jaimin am 30 Jan. 2025
The situation you are encountering is common when working with neural networks and different training functions. Here are a few reasons why "trainnet" might not be reproducing the results of "fitrnet" despite having seemingly identical configurations:
  1. Neural networks are sensitive to the initial weights and biases. "fitrnet" might be using a different initialization strategy that happens to work well for your data, while "trainnet" might be initializing weights in a way that leads to poor optimization. When you use the weights from "fitrnet", "trainnet" recognizes them as effective because they are already well-tuned.
  2. Even if you set the optimizer to "lbfgs" in both cases, there might be subtle differences in the implementation or default settings (e.g., learning rate schedules, regularization techniques) between "fitrnet" and "trainnet". These differences can significantly affect the convergence behavior.
Despite being beyond reason, I conducted an experiment to demonstrate that 'trainnet' and 'firrnet' yield the approximatly same results. Kindly refer to the following code snippet for clarification.
% Generate synthetic data
rng(0); % For reproducibility
nSamples = 52048;
nFeatures = 4;
X = rand(nSamples, nFeatures); % Features
Y = X * [0.5; -0.3; 0.7; 0.2] + 0.1 * randn(nSamples, 1); % Linear combination with noise
% Normalize the data
X = normalize(X);
Y = normalize(Y);
% Train using fitrnet
mdl1 = fitrnet(X, Y, 'LayerSizes', 10, 'Activations', 'relu', 'Verbose', 1);
% Evaluate fitrnet model
YPred1 = predict(mdl1, X);
R2_fitrnet = 1 - sum((Y - YPred1).^2) / sum((Y - mean(Y)).^2);
fprintf('fitrnet R^2: %.2f\n', R2_fitrnet);
% Extract weights and biases from fitrnet
fitrnetWeights1 = mdl1.LayerWeights{1};
fitrnetBiases1 = mdl1.LayerBiases{1};
fitrnetWeights2 = mdl1.LayerWeights{2};
fitrnetBiases2 = mdl1.LayerBiases{2};
% Define the network architecture for trainnet
layers = [
featureInputLayer(nFeatures)
fullyConnectedLayer(10, 'Weights', fitrnetWeights1, 'Bias', fitrnetBiases1)
reluLayer
fullyConnectedLayer(1, 'Weights', fitrnetWeights2, 'Bias', fitrnetBiases2)];
% Define training options for trainnet
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 100, ...
'MiniBatchSize', 512, ...
'Verbose', true, ...
'Plots', 'training-progress');
% Train using trainnet
mdl2 = trainnet(X, Y, layers, 'mse', options);
% Evaluate trainnet model
YPred2 = predict(mdl2, X);
R2_trainnet = 1 - sum((Y - YPred2).^2) / sum((Y - mean(Y)).^2);
fprintf('trainnet R^2: %.2f\n', R2_trainnet);
% Define the network architecture for trainnet
layers = [
featureInputLayer(nFeatures)
fullyConnectedLayer(10)
reluLayer
fullyConnectedLayer(1)];
% Define training options for trainnet
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 100, ...
'MiniBatchSize', 512, ...
'Verbose', true, ...
'Plots', 'training-progress');
% Train using trainnet
mdl3 = trainnet(X, Y, layers, 'mse', options);
% Evaluate trainnet model
YPred3 = predict(mdl3, X);
R3_trainnet = 1 - sum((Y - YPred3).^2) / sum((Y - mean(Y)).^2);
fprintf('trainnet R^2: %.2f\n', R3_trainnet);
For more information kindly refer following MathWorks documentation.
Thanks!
  1 Kommentar
Matthew
Matthew am 30 Jan. 2025
Thanks for providing this nice example!
I ran your script, then substituted in my own data and played around with the parameters a bit. It seems that I can get fitrnet and trainnet to provide similar solutions if I apply the z-score normalization to the data as you do here, but not otherwise.
It is still not clear to me precisely why this is the case, but I can at least continue what I was attempting to do before now that I can obtain consistent behavior between the two algorithms.

Melden Sie sich an, um zu kommentieren.


Matt J
Matt J am 17 Jan. 2025
Unlike fitnrnet, the L-BFGS algorithm used by trainnet does not seem to use a line search. That might have had something to do with it...
  1 Kommentar
Matthew
Matthew am 17 Jan. 2025
The trainingOptions function for trainnet allows you to explicitly specify the line search method using the LineSearchMethod option. The default is "weak-wolfe".
It does not seem possible to specify the line search method when using fitrnet, and I have not been able to trace what is happening through the source code, but a quick 'grep' of the stats toolbox source code shows that both classreg.learning.fsutils.Solver.m and classreg.learning.fsutils.fminlbfgs.m use "weak-wolfe" as the default line search method.
So I believe both fitrnet and trainnet are using the same line search method for the L-BFGS by default.

Melden Sie sich an, um zu kommentieren.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by