I am trying to implement curriculum learning in a matlab class
    7 Ansichten (letzte 30 Tage)
  
       Ältere Kommentare anzeigen
    
I am trying to implement curriculum learning in MATLAB. 
So, I have varying levels of difficulty in my resetEnvironmentForLevel() function. 
The issue is that the agent explores the level 1 and once it reaches the goal within the desired condition (threshold=0.6 and history size=3), it should ideally move to the next level. 
However, I believe that mainly due to the resetImpl function the environmnet resets to level 1. 
What I ideally want is that the agent moves to the next level only when it reaches the goal within the desired conditions.
P.S. I also tried using global but that didnt work as well. 
classdef GridWorld < matlab.System
    properties (Nontunable)
        GridSize = [8, 8]          % Size of the grid
        Actions = {'up', 'down', 'left', 'right'} % Possible actions (no diagonals)
        MaxSteps = 100             % Maximum steps per episode
        LevelThreshold = 0.6        % Success rate threshold to advance to next level
        HistorySize = 3           % Number of episodes to consider for level advancement
    end
    properties (Access = protected)
        CurrentPosition            % Agent's current position
        StartPosition              % Agent's starting position
        GoalPosition               % Agent's goal position
        Obstacles                  % Obstacle positions
        Explored                   % Explored cells
        TotalReward                % Accumulated reward
        Steps                      % Counter for the number of steps
        CurrentStep                % Current step within an episode
        PreviousActions            % History of recent actions
        StartLevel = 1             % Start at the easiest level
        GoalReached                % Flag to indicate if goal is reached
        ShortestPathLength         % Store the shortest path length
        %
        SuccessHistory             % Array to store recent episode results
        %
        EpisodesPerLevel           % Track episodes completed at each level
        %
        CurrLvl
        TotalEpisodes = 0
    end
    methods (Access = protected)
        function setupImpl(obj)
            persistent currentLevel
            if isempty(currentLevel)
                currentLevel = 1;
            end
            obj.CurrLvl =1;
            obj.Obstacles = obj.setObstacles();
            obj.PreviousActions = zeros(1, 5);
            obj.SuccessHistory = zeros(1, obj.HistorySize);
            obj.TotalEpisodes = 0;
            obj.resetEnvironmentForLevel(currentLevel);
            obj.CurrentPosition = obj.StartPosition;
            disp(['SetupImpl']);
        end
        function initializeEpisodeTracking(obj)
            obj.SuccessHistory = zeros(1, obj.HistorySize);
            obj.EpisodesPerLevel = zeros(1, 5);  % Assuming 5 levels
        end
        % function globalvar = GridWorld
        %     Global Curr;
        %     globalvar.CurrLvl = Curr;
        % end
        %%
        function resetImpl(obj)
            % persistent currentLevel
            % if isempty(currentLevel)
            %     currentLevel = obj.CurrLvl;  % Default starting level
            % end
            currentLevel = obj.CurrLvl;
            disp(obj.CurrLvl);
            % Pass the current level to reset the environment correctly
            obj.resetEnvironmentForLevel(currentLevel);
            obj.setCurrentPosition(obj.StartPosition);
            obj.Explored = false(obj.GridSize);
            obj.Explored(obj.CurrentPosition(1), obj.CurrentPosition(2)) = true;
            obj.TotalReward = 0;
            obj.Steps = 0;
            obj.CurrentStep = 0;
            obj.PreviousActions = zeros(1, 5);
            obj.GoalReached = false;
            disp(['resetImpl - Current Level: ', num2str(currentLevel)]);
        end
        function [observation, reward, isDone] = stepImpl(obj, action)
            persistent totalEpisodes successHistory currentLevel
            if isempty(totalEpisodes)
                totalEpisodes = 0;
                successHistory = zeros(1, obj.HistorySize);
                currentLevel = 1;
            end 
            action = double(action);
            obj.CurrentStep = obj.CurrentStep + 1;
            [newPos, reward, isDone] = obj.takeAction(action, obj.Steps);
            obj.setCurrentPosition(newPos);
            obj.TotalReward = obj.TotalReward + reward;
            obj.Steps = obj.Steps + 1;
            obj.PreviousActions = [action, obj.PreviousActions(1:end-1)];
            observation = zeros(obj.GridSize(1), obj.GridSize(2), 11, 'double');
            observation(:,:,1) = double(obj.Obstacles);
            observation(:,:,2) = double(obj.Explored);
            observation(obj.CurrentPosition(1), obj.CurrentPosition(2), 3) = 1;
            observation(obj.GoalPosition(1), obj.GoalPosition(2), 4) = 1;
            relativeGoalPos = obj.GoalPosition - obj.CurrentPosition;
            observation(:,:,5) = relativeGoalPos(1);
            observation(:,:,6) = relativeGoalPos(2);
            observation(:,:,7) = obj.distanceToNearestObstacle();
            observation(:,:,8) = atan2(relativeGoalPos(2), relativeGoalPos(1));
            observation(:,:,9:11) = repmat(reshape(obj.PreviousActions(1:3), 1, 1, []), obj.GridSize(1), obj.GridSize(2));
            if obj.Steps >= obj.MaxSteps
                isDone = true;
            end
            if isDone
                totalEpisodes = totalEpisodes + 1;
                successHistory = [isequal(newPos, obj.GoalPosition), successHistory(1:end-1)];
                successRate = mean(successHistory);
                disp(['Total Episodes: ', num2str(totalEpisodes)]);
                disp(['Current Level: ', num2str(currentLevel)]);
                disp(['Success Rate: ', num2str(successRate)]);
                disp(['Success History: ', num2str(successHistory)]);
                % Only advance if success history is [1 1 1] and success rate is >= 0.6
                if totalEpisodes >= obj.HistorySize && ...
                        successRate >= obj.LevelThreshold && ...
                        all(successHistory == 1) && currentLevel < 5
                    currentLevel = currentLevel + 1;
                    obj.CurrLvl = currentLevel;
                    disp(['objCurrLvl in if:',num2str(obj.CurrLvl)]);
                    successHistory = zeros(1, obj.HistorySize);
                    disp(['Advanced to level ', num2str(currentLevel)]);
                    obj.resetEnvironmentForLevel(currentLevel);
                else
                    % If level exceeds max (level 5), stay at max level
                    if currentLevel >= 5
                        currentLevel = 5;
                    end
                    disp('Did not advance level. Reasons:');
                    if totalEpisodes < obj.HistorySize
                        disp([' - Not enough episodes. Current: ', num2str(totalEpisodes), ', Required: ', num2str(obj.HistorySize)]);
                    end
                    if successRate < obj.LevelThreshold
                        disp([' - Success rate too low. Current: ', num2str(successRate), ', Required: ', num2str(obj.LevelThreshold)]);
                    end
                end
                obj.plot(obj.Steps)
            end
        end
        methods (Access = public)
            function setCurrentPosition(obj, newPosition)
                obj.CurrentPosition = newPosition;
            end
            function resetEnvironmentForLevel(obj, currentLevel)
                switch currentLevel;
                    % if currentLevel == 1
                    case 1
                        obj.Obstacles = false(obj.GridSize);
                        obj.Obstacles(3, 3) = true;
                        obj.StartPosition = [5, 5];
                        obj.GoalPosition = [1, 5];
                        % elseif currentLevel == 2
                    case 2
                        obj.Obstacles = false(obj.GridSize);
                        obj.Obstacles(3, 3) = true;
                        obj.StartPosition = [2, 5];
                        obj.GoalPosition = [8, 1];
                        % elseif currentLevel == 3
                    case 3
                        obj.Obstacles = false(obj.GridSize);
                        obj.Obstacles(3, 3:5) = true;
                        obj.StartPosition = [4, 5];
                        obj.GoalPosition = [7, 7];
                        % elseif currentLevel == 4
                    case 4
                        obj.Obstacles = obj.setObstacles();  % Use existing obstacle setup
                        obj.StartPosition = [4, 5];
                        obj.setRandomGoal();
                        % else
                    case 5
                        obj.Obstacles = obj.setObstacles();
                        obj.setRandomStart();
                        obj.setRandomGoal();
                end
                obj.CurrentPosition = obj.StartPosition;
                obj.calculateShortestPath();  % Ensure shortest path is calculated for each level
            end
0 Kommentare
Antworten (1)
  Anagha Mittal
 am 17 Okt. 2024
        The encountered issue is diue to the "resetImpl" function. By the implemented logic, environment is getting set to "level1" at each call and is not getting set to the correct level("CurrLvl").
I have made a few modifications to the code with the changes mentioned as comments:
classdef MLAns < matlab.System
    properties (Nontunable)
        GridSize = [8, 8]          % Size of the grid
        Actions = {'up', 'down', 'left', 'right'} % Possible actions
        MaxSteps = 100             % Maximum steps per episode
        LevelThreshold = 0.6       % Success rate threshold to advance to next level
        HistorySize = 3            % Number of episodes to consider for level advancement
    end
    properties (Access = protected)
        CurrentPosition            % Agent's current position
        StartPosition              % Agent's starting position
        GoalPosition               % Agent's goal position
        Obstacles                  % Obstacle positions
        Explored                   % Explored cells
        TotalReward                % Accumulated reward
        Steps                      % Counter for the number of steps
        CurrentStep                % Current step within an episode
        PreviousActions            % History of recent actions
        CurrLvl = 1                % Start at the easiest level (default: 1)
        GoalReached                % Flag to indicate if goal is reached
        ShortestPathLength         % Store the shortest path length
        SuccessHistory             % Array to store recent episode results
        TotalEpisodes = 0          % Track total episodes across all levels
    end
    methods (Access = protected)
        function setupImpl(obj)
            obj.initializeEnvironment();  % Initialize once during setup
            obj.initializeEpisodeTracking();
            disp('Environment setup complete.');
        end
        function resetImpl(obj)
            obj.resetEnvironmentForLevel(obj.CurrLvl);  % Ensure correct level is set
            obj.setCurrentPosition(obj.StartPosition);
            obj.Explored = false(obj.GridSize);
            obj.Explored(obj.CurrentPosition(1), obj.CurrentPosition(2)) = true;
            obj.TotalReward = 0;
            obj.Steps = 0;
            obj.CurrentStep = 0;
            obj.PreviousActions = zeros(1, 5);
            obj.GoalReached = false;
            disp(['Environment reset to Level: ', num2str(obj.CurrLvl)]);
        end
        function [observation, reward, isDone] = stepImpl(obj, action)
            action = double(action);  % Convert action to double if necessary
            obj.CurrentStep = obj.CurrentStep + 1;
            [newPos, reward, isDone] = obj.takeAction(action, obj.Steps);
            obj.setCurrentPosition(newPos);
            obj.TotalReward = obj.TotalReward + reward;
            obj.Steps = obj.Steps + 1;
            % Observation generation (similar to your original code)
            observation = obj.generateObservation();
            % Check if episode is done
            if isDone || obj.Steps >= obj.MaxSteps
                obj.TotalEpisodes = obj.TotalEpisodes + 1;
                obj.updateSuccessHistory(newPos);  % Update success based on goal reaching
                obj.handleLevelProgression();      % Check for level advancement
                obj.plot(obj.Steps);               % Plot environment after episode ends
            end
        end
        function handleLevelProgression(obj)
            % Check if agent should advance to the next level
            successRate = mean(obj.SuccessHistory);
            disp(['Success Rate: ', num2str(successRate)]);
            if obj.TotalEpisodes >= obj.HistorySize && ...
                    successRate >= obj.LevelThreshold && ...
                    all(obj.SuccessHistory == 1) && obj.CurrLvl < 5
                obj.CurrLvl = obj.CurrLvl + 1;
                disp(['Advancing to Level ', num2str(obj.CurrLvl)]);
                obj.resetEnvironmentForLevel(obj.CurrLvl);
                obj.SuccessHistory = zeros(1, obj.HistorySize);  % Reset history for new level
            else
                disp('Did not advance level.');
                if obj.CurrLvl >= 5
                    disp('Max level reached.');
                end
            end
        end
        function initializeEnvironment(obj)
            obj.CurrentPosition = [5, 5];  % Default start position
            obj.resetEnvironmentForLevel(obj.CurrLvl);
        end
        function initializeEpisodeTracking(obj)
            obj.SuccessHistory = zeros(1, obj.HistorySize);
        end
        function updateSuccessHistory(obj, newPos)
            % Update success history after each episode
            obj.SuccessHistory = [isequal(newPos, obj.GoalPosition), obj.SuccessHistory(1:end-1)];
            disp(['Success History: ', num2str(obj.SuccessHistory)]);
        end
        function resetEnvironmentForLevel(obj, currentLevel)
            switch currentLevel
                case 1
                    obj.setSimpleLevel();
                case 2
                    obj.setMediumLevel();
                case 3
                    obj.setHardLevel();
                case 4
                    obj.setVeryHardLevel();
                case 5
                    obj.setExtremeLevel();
            end
            obj.calculateShortestPath();  % Calculate the optimal path for each level
        end
        function observation = generateObservation(obj)
            % Create observation for the current state (use your original code here)
            observation = zeros(obj.GridSize(1), obj.GridSize(2), 11, 'double');
            observation(:,:,1) = double(obj.Obstacles);
            observation(:,:,2) = double(obj.Explored);
            observation(obj.CurrentPosition(1), obj.CurrentPosition(2), 3) = 1;
            observation(obj.GoalPosition(1), obj.GoalPosition(2), 4) = 1;
            % Additional state-related calculations can go here...
        end
        % Define your environment levels (for resetEnvironmentForLevel)
        function setSimpleLevel(obj)
            obj.Obstacles = false(obj.GridSize);
            obj.Obstacles(3, 3) = true;
            obj.StartPosition = [5, 5];
            obj.GoalPosition = [1, 5];
        end
        function setMediumLevel(obj)
            obj.Obstacles = false(obj.GridSize);
            obj.Obstacles(3, 3) = true;
            obj.StartPosition = [2, 5];
            obj.GoalPosition = [8, 1];
        end
        % Define further levels similarly...
    end
    methods (Access = public)
        function setCurrentPosition(obj, newPosition)
            obj.CurrentPosition = newPosition;
        end
    end
end
Hope this helps!
0 Kommentare
Siehe auch
Kategorien
				Mehr zu Training and Simulation finden Sie in Help Center und File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

