Error defining categorical matrix for U-Net segmentation

15 Ansichten (letzte 30 Tage)
DB
DB am 12 Jun. 2023
Beantwortet: Garmit Pant am 11 Sep. 2023
Hello,
I am trying to train a UNet segmentation algorithm in Matlab. I have inputData (MRI images) and targetData (categorical labeled images). I run the following code and get the error:
Error using trainNetwork (line 184). Invalid training data. For classification tasks, responses must be a vector of categorical responses. For regression tasks, responses must be a vector, a matrix, or a 4-D array of numeric responses which must not contain NaNs.
I have validated that targetData is categorical, has only two classes ('Background' and 'Object'), and is the same size as inputData.
% Load and preprocess the 3D MRI data
data = load('mri_data.mat'); % Load your MRI data here
inputData = data.inputData; % Input MRI volumes
targetData = data.targetData; % Target segmentation volumes
% Split the data into training and validation sets
splitRatio = 0.8; % Split ratio for training/validation (80% training, 20% validation)
splitIdx = round(splitRatio*size(inputData, 4));
inputDataTrain = inputData(:, :, :, 1:splitIdx);
targetDataTrain = targetData(:, :, :, 1:splitIdx);
inputDataVal = inputData(:, :, :, splitIdx+1:end);
targetDataVal = targetData(:, :, :, splitIdx+1:end);
% Define the U-Net architecture
inputSize = size(inputData, 1:3);
numClasses = 2;
layers = [
image3dInputLayer(inputSize, 'Normalization', 'none')
% Encoder
convolution3dLayer(3, 64, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling3dLayer(2, 'Stride', 2)
convolution3dLayer(3, 128, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling3dLayer(2, 'Stride', 2)
% Bridge
convolution3dLayer(3, 256, 'Padding', 'same')
batchNormalizationLayer
reluLayer
% Decoder
transposedConv3dLayer(2, 128, 'Stride', 2)
convolution3dLayer(3, 128, 'Padding', 'same')
batchNormalizationLayer
reluLayer
transposedConv3dLayer(2, 64, 'Stride', 2)
convolution3dLayer(3, 64, 'Padding', 'same')
batchNormalizationLayer
reluLayer
% Output
convolution3dLayer(1, numClasses)
softmaxLayer
pixelClassificationLayer
];
% Set the training options
options = trainingOptions('adam', ...
'MaxEpochs', 20, ...
'MiniBatchSize', 4, ...
'Shuffle', 'every-epoch', ...
'Plots', 'training-progress', ...
'ValidationData', {inputDataVal, targetDataVal});
% Train the U-Net model
net = trainNetwork(inputDataTrain, targetDataTrain, layers, options);
Any help would be greatly appreciated!
  3 Kommentare
mohd akmal masud
mohd akmal masud am 14 Jun. 2023
Your data is in 4D, while need change to 3D matrix.
I have face some problem before. can try change first.

Melden Sie sich an, um zu kommentieren.

Antworten (1)

Garmit Pant
Garmit Pant am 11 Sep. 2023
Hello DB
It is my understanding that you are trying to train a U-Net for segmentation of 3D MRI image data and you are encountering an error as you try to train the network.
I investigated the code with the data that you have provided. According to my investigation, the input training dataset (inputTrainData) that you are passing to the ‘trainNetwork’ function as a 4D numerical array has the size 96x57x52x6.
When numerical arrays are passed as inputs to the ‘trainNetwork’ function, data of the size format h-by-w-by-c-by-Nis used for 2-D images, where h, w, and c are the height, width, and number of channels of the images, respectively, and N is the number of images. Thus, the function throws an error because there is a mismatch of sizes.
For 3D image segmentation, you can use ‘imageDataStore’ and ‘pixelLabelDataStore’ to pass as the input to function ‘trainNetwork’.
Please follow the following example to see how to preprocess 4-D volumes and train 3D U-Net for image segmentation:
For more information on this, you can refer the following MathWorks Documentation:
  1. Refer to the “images” in Input Arguments ” section of the MATLAB function “trainNetwork. https://in.mathworks.com/help/deeplearning/ref/trainnetwork.html#bu6sn4c-2
  2. MATLAB function for 3-D U-Net layers for semantic segmentation of volumetric images https://in.mathworks.com/help/vision/ref/unet3dlayers.html?s_tid=doc_ta
I hope this helps!

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by