Filter löschen
Filter löschen

how to define custom loss function like tf.train.A​​damOptimi​z​er().min​im​ize(los​s) ?

3 Ansichten (letzte 30 Tage)
I put X=[4,100] into LSTM and received the last hidden state[4,1]. I use this state vector to estimate Q. I want to use Adam Optimizer to minimize a custom loss function. But options offered by official tool cannot do it. So I want to define a loss function and use Adam Optimizer to min the loss. How can I do it like python:train = tf.train.AdamOptimizer().minimize(loss)?
z=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.99];
z_normal=norminv(z);
numFeatures = 4;
A=4;
X = reshape(X_test.',4,100,[]);
Y = reshape(Y_test.',1,[]);
dlX = dlarray(X,'CBT');
dlY = dlarray(Y,'BT');
numHiddenUnits = 16;
outputdim=4;
H0 = zeros(numHiddenUnits,1);
C0 = zeros(numHiddenUnits,1);
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
[outputs,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);
state_h=hiddenState(:,end);
w_out=normrnd(0,1,[outputdim,numHiddenUnits]);
b_out=normrnd(0,1,[outputdim,1]);
out=w_out*state_h+b_out;
params=out+[0;1;1;1];
mu=params(1,:);
sig=params(2,:);
utail=params(3,:);
vtail=params(4,:);
factor1=exp(z_normal.'*utail)/A+1;
factor2=exp(z_normal.'*utail)/A+1;
factor=factor1.*factor2;
Q=factor.*z_normal.'.*sig+mu;
error=dlY-Q;
error1=z.'.*error;
error2=(z.'-1).*error;
loss=mean(max(error1,error2));

Antworten (1)

Sai Bhargav Avula
Sai Bhargav Avula am 18 Feb. 2020
Hi,
You can create custom loss function by creating a function of the form loss = myLoss(Y,T), where Y is the network predictions, T are the targets. The loss can be used to update the gradients in the modelGradient function.
This link explains how to create custom layers and loss functions in matlab
Hope this helps!

Kategorien

Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange

Produkte


Version

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by