Error while creating custom critic function

2 Ansichten (letzte 30 Tage)
Harsh
Harsh am 26 Feb. 2024
Beantwortet: Amal Raj am 14 Mär. 2024
customCriticNetwork = [
imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'observation')
fullyConnectedLayer(400, 'Name', 'CriticFC1', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(300, 'Name', 'CriticFC2', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu2')
fullyConnectedLayer(1, 'Name', 'output', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)]; % Change the output size to 1
% Create the custom critic dlnetwork
dlnet = dlnetwork(customCriticNetwork);
% Create the custom critic
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);
Error using rlQValueFunction
Number of input layers for state-action-value function deep neural network must equal the number of
observation and action specifications.
Error in RL_Agent (line 37)
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);

Antworten (1)

Amal Raj
Amal Raj am 14 Mär. 2024
Hey Harsh.
The error message suggests that the number of input layers in your custom critic network does not match the number of observation and action specifications. To resolve this issue, you need to ensure that the number of input layers in your custom critic network matches the number of observation and action specifications.
Here's an example of how you can modify your custom critic network to match the observation and action specifications:
% Define observation and action specifications
obsInfo = rlNumericSpec([1 1 1]);
actInfo = rlFiniteSetSpec([1 2 3]);
% Create custom critic network
customCriticNetwork = [
imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'observation')
fullyConnectedLayer(400, 'Name', 'CriticFC1', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(300, 'Name', 'CriticFC2', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu2')
fullyConnectedLayer(1, 'Name', 'output', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)];
% Check the number of input layers
numInputLayers = numel(customCriticNetwork(1).InputSize);
% Adjust the number of input layers to match the observation and action specifications
customCriticNetwork(1).InputSize = obsInfo.Dimension;
customCriticNetwork(1).Name = obsInfo.Name;
customCriticNetwork(numInputLayers).InputSize = actInfo.Dimension;
customCriticNetwork(numInputLayers).Name = actInfo.Name;
% Create the custom critic dlnetwork
dlnet = dlnetwork(customCriticNetwork);
% Create the custom critic
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);
This code snippet ensures that the number of input layers in the custom critic network matches the observation and action specifications provided by obsInfo and actInfo.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by