Matlab implementation of Hebbian learning

26 Ansichten (letzte 30 Tage)
Matteo Barisione
Matteo Barisione am 18 Jun. 2020
I'm trying to complete a task where I'm asked to implement basic Hebbian learning to a single neuron (linear firing rate model) with two inputs; I've been given the training set, a 2x100 input pattern, which is shuffled at each epoch. The main request is to plot the final weight vector against the main eigenvector of Q, the input correlation matrix. And here it comes my issue: Q is defined as where the angle notation means the average over the input patterns. I'm stucked in correctly computing this matrix because I can't understand properly how this "average over the input patterns" is made...
I splitted the code as follows: the function that actually executes learning
function [w_out,w_t,w_norm] = hebbian(xtrain,eta)
train_size = size(xtrain,2); % Checking training set dimension
w = -1 + 2.*rand(2,1); % Weights random initialization
epochs = 1000; % Maximum number of iterations
w_t = zeros(2,epochs);
w_norm = zeros(1,epochs);
for i = 1:epochs
w_old = w;
randp = randperm(train_size); % Generating random index
for k=1:train_size
xtrain_k = xtrain(:,randp(k));% Shuffling training set
% Computing output via linear firing rate model
v = w'*xtrain_k;
w = w_old + eta*v*xtrain_k; % Updating weights
end
w_norm(i) = norm(w);
w_out = w/norm(w);
w_t(:,i) = w_out;
end
end
While this is the main .m file: I think the learning algorithm actually works but I can't say I'm comparing the correct objects.
% Basic Hebbian learning rule implementation
T = readtable('lab2_1_data.csv'); % Importing data as table
u = table2array(T);% Converting table into input array
eta = 10e-3; % Learning rate
[w, w_t, w_norm] = hebbian(u,eta);
Q = u*u'; % Input correlation matrix
[vec, D] = eig(Q); % Computing eigenvalues and eigenvectors of Q
% Plotting data points and comparison between final weight vector and main
% eigenvector of Q
figure('Name','P1): Dataset, final weight vector and main eigenvector of Q','NumberTitle','off')
scatter(u(1,:),u(2,:))
hold on
plotv(vec(:,end));
set(findall(gca,'Type', 'Line'),'LineWidth',1.75);
plotv(w)
hold off
legend('Dataset','Dominant eigenvector of Q','Final weight vector','Location','best')
% Weight evolution, first component
figure('Name','P2.1): Weight vector time evolution (1st component)','NumberTitle','off')
plot(w_t(1,:))
% Weight evolution, second component
figure('Name','P2.2): Weight vector time evolution (2nd component)','NumberTitle','off')
plot(w_t(2,:))
% Weight norm evolution
figure('Name','P2.3): Weight vector norm time evolution','NumberTitle','off')
plot(w_norm)
Thanks in advance for your help and attention!

Antworten (0)

Produkte


Version

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by