RLToolboxのGridWorldについて
1 Ansicht (letzte 30 Tage)
Ältere Kommentare anzeigen
shoki kobayashi
am 14 Jun. 2020
Kommentiert: shoki kobayashi
am 28 Jul. 2020
GridWorldをQ学習で解くのに困っています。
GridWorldを解くプログラミングを作ったのですが、Agentが上手に学習してくれないです
どのように改善すればよろしいでしょうか
%迷路の作成
GW = createGridWorld(8,8);
GW.CurrentState = '[2,1]';
GW.TerminalStates = '[8,8]';
GW.ObstacleStates = ["[3,3]";"[3,4]";"[3,5]";"[3,6]";"[3,7]";"[4,3]";"[7,3]";"[6,3]";"[5,3]"];
updateStateTranstionForObstacles(GW)
GW.T(state2idx(GW,"[2,4]"),:,:) = 0;
GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;
nS = numel(GW.States);
nA = numel(GW.Actions);
GW.R = -1*ones(nS,nS,nA);
GW.R(state2idx(GW,"[4,2]"),state2idx(GW,"[5,2]"),:) = 5;
GW.R(state2idx(GW,"[8,3]"),state2idx(GW,"[8,4]"),:) = 5;
GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;
%環境の読み込み
env = rlMDPEnv(GW)
env.ResetFcn = @() 2;
rng(0)
%Q学習
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 101;
trainOpts.ScoreAveragingWindowLength = 30;
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
end
%結果の描画
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(qAgent,env)
0 Kommentare
Akzeptierte Antwort
Kazuaki Yamada
am 28 Jul. 2020
次の通り変更すると学習しました.
12-13行目をコメントアウト
32行目のfalseをtrueに変更
%迷路の作成
GW = createGridWorld(8,8);
GW.CurrentState = '[2,1]';
GW.TerminalStates = '[8,8]';
GW.ObstacleStates = ["[3,3]";"[3,4]";"[3,5]";"[3,6]";"[3,7]";"[4,3]";"[7,3]";"[6,3]";"[5,3]"];
updateStateTranstionForObstacles(GW)
GW.T(state2idx(GW,"[2,4]"),:,:) = 0;
GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;
nS = numel(GW.States);
nA = numel(GW.Actions);
GW.R = -1*ones(nS,nS,nA);
%GW.R(state2idx(GW,"[4,2]"),state2idx(GW,"[5,2]"),:) = 5; %--- ?
%GW.R(state2idx(GW,"[8,3]"),state2idx(GW,"[8,4]"),:) = 5; %--- ?
GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;
%環境の読み込み
env = rlMDPEnv(GW)
env.ResetFcn = @() 2;
rng(0)
%Q学習
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 101;
trainOpts.ScoreAveragingWindowLength = 30;
doTraining = true; %--- trueにしないと以下のif文に入らない
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
end
%結果の描画
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(qAgent,env)
Weitere Antworten (0)
Siehe auch
Kategorien
Mehr zu 行列計算 finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!