Custom DQN Algorithm Not Learning or Converging

4 Ansichten (letzte 30 Tage)
roham farhadi
roham farhadi am 9 Jan. 2024
Hello MathWorks community,
I am currently working on implementing a custom Deep Q-Network (DQN) algorithm for a specific problem, but I am facing difficulties as the algorithm doesn't seem to learn or converge.
I have attached my code below for reference. I would appreciate it if someone could take a look and provide insights into why the algorithm might not be performing as expected. Additionally, if there are any improvements or modifications that could enhance its learning capabilities, I am open to suggestions.
clc, clear, close all
% environment:
rngSeed = 1;
rng(rngSeed);
env = rlPredefinedEnv("CartPole-Discrete");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
numObservations = obsInfo.Dimension(1);
% build network:
QNetwork = [
featureInputLayer(obsInfo.Dimension(1))
fullyConnectedLayer(20)
reluLayer
fullyConnectedLayer(24)
reluLayer
fullyConnectedLayer(length(actInfo.Elements))];
QNetwork = dlnetwork(QNetwork);
% buffer
myBuffer.bufferSize = 1e5;
myBuffer.bufferIndex = 0;
myBuffer.currentBufferSize = 0;
myBuffer.observation = zeros(numObservations, myBuffer.bufferSize);
myBuffer.nextObservation = zeros(numObservations, myBuffer.bufferSize);
myBuffer.action = zeros(1, myBuffer.bufferSize);
myBuffer.reward = zeros(1, myBuffer.bufferSize);
myBuffer.isDone = zeros(1, myBuffer.bufferSize);
% parameters
num_episodes = 100;
max_steps = 500;
batch_size = 256;
discountFactor = 0.99;
epsilon = 1;
epsilonMin = 0.01;
epsilonDecay = 0.005;
totalSteps = 0;
numGradientSteps = 5;
targetUpdateFrequency = 4;
target_QNetwork = QNetwork;
iteration = 0;
% Plot
monitor = trainingProgressMonitor(Metrics="Loss",Info="Episode",XLabel="Iteration");
[trainingPlot,lineReward,lineAveReward, ax] = hBuildFigure;
set(trainingPlot,Visible = "on");
episodeCumulativeRewardVector = [];
aveWindowSize = 10;
% training loop
for episode = 1:num_episodes
observation = reset(env);
episodeReward = zeros(max_steps,1);
for stepCt = 1:max_steps
totalSteps = totalSteps + 1;
action = policy(QNetwork, observation', actInfo, epsilon);
if totalSteps > batch_size
epsilon = max(epsilon*(1-epsilonDecay), epsilonMin);
end
[nextObservation, reward, isDone] = step(env, action);
myBuffer = storeData(myBuffer, observation, action, nextObservation, reward, isDone);
episodeReward(stepCt) = reward;
observation = nextObservation;
for gradientCt = 1:numGradientSteps
if myBuffer.currentBufferSize >= batch_size
iteration = iteration + 1;
[sampledObservation, sampledAction, sampledNextObservation, sampledReward, sampledIsDone] = ...
sampleBatch(myBuffer, batch_size);
target_Q = zeros(1,batch_size);
Y = zeros(1,batch_size);
for i=1:batch_size
Y(i) = target_predict(target_QNetwork, dlarray(sampledNextObservation(:,i), 'CB'), actInfo);
if myBuffer.isDone(i)
target_Q(i) = myBuffer.reward(i);
else
target_Q(i) = myBuffer.reward(i) + discountFactor * Y(i);
end
end
lossData.batchSize = batch_size;
lossData.actInfo = actInfo;
lossData.actionBatch = sampledAction;
lossData.targetValues = target_Q;
% calculating gradient
[loss, gradients] = dlfeval(@QNetworkLoss, QNetwork, sampledObservation, lossData.targetValues,...
lossData);
% performing gradient descent
params = QNetwork.Learnables;
for i=1:6
params(i,3).Value{1} = params(i,3).Value{1} - 1e-3 .* gradients(i,3).Value{1};
end
QNetwork.Learnables = params;
recordMetrics(monitor,iteration,Loss=loss);
end
end
if mod(totalSteps, targetUpdateFrequency) == 0
target_QNetwork = QNetwork;
end
if isDone
break
end
end
episodeCumulativeReward = sum(episodeReward);
episodeCumulativeRewardVector = cat(2, episodeCumulativeRewardVector,episodeCumulativeReward);
movingAveReward = movmean(episodeCumulativeRewardVector, aveWindowSize,2);
addpoints(lineReward,episode,episodeCumulativeReward);
addpoints(lineAveReward,episode,movingAveReward(end));
title(ax, "Training Progress - Episode: " + episode + ", Total Step: " + string(totalSteps) + ", epsilon:" + ...
string(epsilon))
drawnow;
updateInfo(monitor,Episode=episode);
end
and here is the code for @QNetworkLoss:
function [loss, gradients] = QNetworkLoss(net, X, T, lossData)
batchSize = lossData.batchSize;
Z = repmat(lossData.actInfo.Elements', 1, batchSize);
actionIndicationMatrix = lossData.actionBatch(:,:) == Z;
Y = forward(net, X);
Y = Y(actionIndicationMatrix);
T = reshape(T,size(Y));
loss = mse(Y, T, 'DataFormat', 'CB');
gradients = dlgradient(loss, net.Learnables);
end
I have thoroughly reviewed my code and attempted various adjustments, but the desired convergence remains elusive. Any guidance, tips, or insights from experienced members would be highly appreciated.

Antworten (0)

Produkte


Version

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by