- 'Train Network with Complex-Valued Data' example - https://www.mathworks.com/help/deeplearning/ug/train-network-with-complex-valued-data.html
- "SplitComplexInputs" argument - https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.sequenceinputlayer.html#mw_1f1ef68c-244a-4374-a846-2e24b71d384f_sep_mw_19bc7780-8e05-482a-b309-d24e230ab466
- function handles - https://www.mathworks.com/help/matlab/function-handles.html
- "trainnet" function - https://www.mathworks.com/help/deeplearning/ref/trainnet.html
netTrained = trainnet(sequences,targets,net,lossFcn,options),sequences包含复数无法使用此函数
3 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
问题:
应用函数netTrained = trainnet(sequences,targets,net,lossFcn,options),
sequences包含复数时如何使用此函数?
函数说明里有提示可使用复数输入:This argument supports complex-valued predictors and targets.
代码:
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
numChannels = betalen;
layers = [
sequenceInputLayer(numChannels)
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
net = trainnet(XTrain,TTrain,layers,"mse",options);
报错结果:
错误使用 trainnet (第 46 行)
在层 'lstm' 期间执行失败。
出错 HDL (第 66 行)
net = trainnet(XTrain,TTrain,layers,"mse",options);
原因:
错误使用 dlarray/lstm (第 105 行)
位置 1 处的参数无效。 值必须为实数。
0 Kommentare
Antworten (1)
Paras Gupta
am 18 Jul. 2024
Bearbeitet: Paras Gupta
am 18 Jul. 2024
Hi Alexander,
I understand that you are trying to use the "trainnet" function on complex-valued sequences and complex-valued targets.
You are correct in noting that the documentation indicates that the "trainnet" function can support complex-valued predictors and targets. However, the built-in loss functions provided by "trainnet" do not inherently support complex-valued targets. To address this, you will need to define a custom loss function that can handle complex values for targets.
Moreover, the "sequenceInputLayer" in your model should be configured to handle complex-valued inputs. This can be done by setting the "SplitComplexInputs" argument to true.
Below is an example of a custom loss function for complex inputs, which you can use in your training loop:
% dummy data
numSamples = 100;
numTimesteps = 10;
numChannels = 2;
realPart = randn(numSamples, numTimesteps, numChannels);
imagPart = randn(numSamples, numTimesteps, numChannels);
dataTrain = realPart + 1i * imagPart;
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
% complex target
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
% real target
% TTrain = rand(numSamples, numChannels, numTimesteps-1);
numChannels = 2;
layers = [
sequenceInputLayer(numChannels, SplitComplexInputs=true) % split Complex Inputs
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
% net = trainnet(XTrain, TTrain, layers, "mse", options);
% custom loss function passed as function handle
net = trainnet(XTrain, TTrain, layers, @complexLoss, options);
function loss = complexLoss(Y, T)
difference = Y - T;
squaredMagnitude = real(difference).^2;
loss = mean(squaredMagnitude, 'all');
end
You can refer to the following documentation links for more information on the code above:
Hope this helps with your work.
0 Kommentare
Siehe auch
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!