implementation of mini-batch stochastic gradient descent

8 Ansichten (letzte 30 Tage)
konoha
konoha am 28 Mär. 2021
Beantwortet: Mohamed Salem am 22 Dez. 2022
I implemented a mini-batch stochastic gradien descent but counldn't find the bug in my code.
I used this implement to do a classification problem but all my final predictions are 0.
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
end
end
%% cost function
function loss = LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,x,y)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
loss = norm(a5-y,2)^2;
end
%% prediction
function pred = predict(W2,W3,W4,W5,b2,b3,b4,b5,x)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
pred = round(a5);
end
%% activation function
function y = activate(x,W,b)
y = 1./(1+exp(-(W*x+b)));
end

Antworten (2)

Mahesh Taparia
Mahesh Taparia am 2 Apr. 2021
Hi
You mentioned that you are implementing a classification network. In your code, you are using square of L2 norm to calculate the loss and loss derivative is also not correct while doing back propagation. Moreover, since it is a classification network, use the classification loss like cross entropy loss, focalcrossentropy, etc instead of norm. May be this is the reason you are getting 0 everytime.
Also, you can use MATLAB inbuilt function to perform back propagation. For this, you can refer the link given below:
Hope it will help!
  1 Kommentar
konoha
konoha am 2 Apr. 2021
Bearbeitet: konoha am 2 Apr. 2021
the derivative of mes is -(y-f(x))f'(x). I don't follow your suggestions.
Thank you.

Melden Sie sich an, um zu kommentieren.


Mohamed Salem
Mohamed Salem am 22 Dez. 2022
Write a MATLAB code, that implement Dalta learning rule with mini-batch.
Compare (with graph) your mini-batch algorithm with SGD, Batch algorithm in terms of mean square error.

Kategorien

Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange

Produkte


Version

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by