Filter löschen
Filter löschen

Passing additional minibatchable quantities to a trainnet() loss function

5 Ansichten (letzte 30 Tage)
I am calling trainnet() with the syntax,
netTrained = trainnet(cds,net,lossFcn,options)
where lossFcn=f(Y,T) is a handle to a custom loss function. Here the variable Y is the network prediction based on input X and T is the training target. Both Y and T are S1xS2xC images. During training, the usual operation of trainnet() is to fetch minibatched pairs (X,T) pointed to by the CombinedDataStore cds and to give the pairs (Y(X),T) to the lossFcn.
I would now like to modify the training to have a loss function of the form lossFcn=f(Y,T,W) where W is an additional minibatchable data set containing known, constant weights, and is of the same dimensions as Y and T. My question is if there is a way to combine 3 datastores instead of 2 datastores to make this happen. In other words, is it possible to have cds read in minibatched triplets (X,T,W) and to give (Y(X),T,W) to the loss function?
And if so, how do I tell trainnet() that X is the input to be used for network prediction and that (T,W) are constant data? Does trainnet() always assume only the first datastore in cds are the predictors?
  1 Kommentar
Matt J
Matt J am 15 Mai 2024
Bearbeitet: Matt J am 15 Mai 2024
My current workaround is to have cds fetch pairs (X,TW) where TW=[T,W] is the concatenation of the actual targets T and the desired weights W. I can then decompose TW into T and W inside the lossFcn(). However, I am wondering if there is a more graceful way.

Melden Sie sich an, um zu kommentieren.

Akzeptierte Antwort

Matt J
Matt J am 29 Mai 2024
Bearbeitet: Matt J am 29 Mai 2024
Tech support has told me that the way to do this is to supply N strings in the TargetDataFormats training options parameter. trainnet() will then assume the final N input datastores to be training target data. Example:
net=dlnetwork( [ imageInputLayer([1,1],Name='In'), fullyConnectedLayer( 1, Name='FCLayer' ) ] );
e=(1:30)'; %6 batches/iterations with MiniBatchSize=5
netTrained = trainnet(combine(X,T,W) ,net,@lossFcn,options);
Iteration Epoch TimeElapsed LearnRate TrainingLoss _________ _____ ___________ _________ ____________ 1 1 00:00:01 0.001 7943 6 1 00:00:02 0.001 8141.1 Training stopped: Max epochs completed
function out=lossFcn(Y,T,W)

Weitere Antworten (0)




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by