How to get convergence in Gradient descent algorithm?

5 Ansichten (letzte 30 Tage)
Robert101
Robert101 am 24 Sep. 2020
Kommentiert: Robert101 am 24 Sep. 2020
I am trying to implement Gradient Descent Algorithm for linear regression. The X axis representes the year and Y axis the housing price. The Y_prediction is supposed to converge with the number of iterations. However it shows matrix dimension error and i do not know how to fix it. I am following the python code and is taken from this link.https://towardsdatascience.com/linear-regression-using-gradient-descent-97a6c8700931
Theta_0 = 0
Theta_1 = 0
learning_rate = 0.001
X = [2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013] % Year
Y = [2.00 2.500 2.900 3.147 4.515 4.903 5.365 5.704 6.853 7.971 8.561 10.00 11.280 12.900] % Price
n = length(X)
for i = 1:100
Y_prediction = Theta_1.*X + Theta_0 ! Y = mx + c
Derivative_Theta_0 = (1/n)*sum(Y_prediction - Y)
Derivative_Theta_1 = (1/n)*sum(X.*(Y_prediction - Y))
Theta_0(i+1) = Theta_0(i) - learning_rate*Derivative_Theta_0
Theta_1(i+1) = Theta_1(i) - learning_rate*Derivative_Theta_1
end

Akzeptierte Antwort

Michael Croucher
Michael Croucher am 24 Sep. 2020
Bearbeitet: Michael Croucher am 24 Sep. 2020
Here's my MATLAB port of that code
data=dlmread('data.csv');
X = data(:,1);
Y = data(:,2);
%Building the model
m = 0;
c = 0;
L = 0.0001; % The learning Rate
epochs = 1000; % The number of iterations to perform gradient descent
n = numel(X); % Number of elements in X
%Performing Gradient Descent
for i=1:epochs
Y_pred = m.*X + c; % The current predicted value of Y
D_m = (-2/n) * sum(X .* (Y - Y_pred)); % Derivative wrt m
D_c = (-2/n) * sum(Y - Y_pred); % Derivative wrt c
m = m - L * D_m; % Update m
c = c - L * D_c; % Update c
end
sprintf('m=%f c=%f', m,c)
This gives the following result
ans =
'm=1.477744 c=0.088937'
You might notice that this is different from the result given in the blog post and that's because there is a bug in the original Python code. The pandas read_csv command that was used interprets the first x,y pair as the row and column names of the data frame. Hence this first data point is not included in the Python analysis.
  4 Kommentare
Michael Croucher
Michael Croucher am 24 Sep. 2020
Bearbeitet: Michael Croucher am 24 Sep. 2020
If you didn't want to scale your X values then you could scale the learning rates. I think that the convergence issue is related to the fact that since your X and Y values have different scales by 1000 times, a small change in one direction is a huge change in another. So we adapt our learning rates
X = [2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013] % Year
Y = [2.00 2.500 2.900 3.147 4.515 4.903 5.365 5.704 6.853 7.971 8.561 10.00 11.280 12.900] % Price
L = 0.1; % learning rate 1
L2 = 0.0000001; % learning rate 2
epochs = 20000000; % The number of iterations to perform gradient descent
n = numel(X); % Number of elements in X
%Performing Gradient Descent
for i=1:epochs
Y_pred = m.*X + c; % The current predicted value of Y
D_m = (-2/n) * sum(X .* (Y - Y_pred)); % Derivative wrt m
D_c = (-2/n) * sum(Y - Y_pred); % Derivative wrt c
m = m - L2 * D_m; % Update m
c = c - L * D_c; % Update c
end
This gives.
'm=0.799002 c=-1596.869956'
The m is pretty much the same as before as we might expect and the c is different because it now refers to 2000 years in the past. If you play with the numbers, you'll see that the two models are compatible.
However, if your aim is to do linear regression and not to play with gradient descent then you can follow the notes at https://www.mathworks.com/help/matlab/data_analysis/linear-regression.html
X = [2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013]';% Year
Y = [2.00 2.500 2.900 3.147 4.515 4.903 5.365 5.704 6.853 7.971 8.561 10.00 11.280 12.900]'; % Price
Xp = [ones(length(X),1) X];
Xp\Y
This will give
ans =
1.0e+03 *
-1.596873819780291
0.000799004395604
Robert101
Robert101 am 24 Sep. 2020
Yeah. Scaling is the reason there. You spotted it right.

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu MATLAB finden Sie in Help Center und File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by