Filter löschen
Filter löschen

Retrieving Layer Activations from bertDocumentClassifier (Text Analytics Tooblx)

3 Ansichten (letzte 30 Tage)
Hi,
started using the text analytics toolbox, and successfully trained a bertDocmentClassifier network on my dataset.
In the past I've used the 'activations' function successfully to extract layer activations from dlNetworks.
However, for a bertDocmentClassifier, I cannot get the activations function to work, as it is not like e.g. image DL network objects - it has a tokenizer first.
So for example out=activations(bertTrained,textstring,layername) does not work
I tried to apply the tokenizer first, as in e.g.:
[a,b]=encode(mdl.Tokenizer,textDataTrain(1,:))
and that gives the token codes and segments fine in a,b.
But how do i "feed" those to the dlNetwork itself from the bertDocumentClassifer object?
This for example does NOT work:
net=mdl.Network;
activations(net,a,b,'out_fc2')
and variations fail as well.
So to sum up - I have a trained BERT classifier object, I can use it to classify just fine, but I can't get the network's layer activations.
Thanks!

Akzeptierte Antwort

Malay Agarwal
Malay Agarwal am 28 Mär. 2024
Bearbeitet: Malay Agarwal am 29 Mär. 2024
Hi Tsvi,
I understand that you want to retrieve layer activations from a trained “bertDocumentClassifier” model.
There are two reasons why the “activations” function is not working as expected.
First, the “InputNames” property of the underlying network for the model shows that the model accepts three inputs instead of two. Namely, it expects the input IDs, an attention mask, and the segment IDs.
In your code, you are calling the function with only two inputs, the input IDs and the segment IDs.
Second, the “activations” function only works with networks represented as “DAGNetwork” objects or “SeriesNetwork” objects, as specified in the documentation: https://www.mathworks.com/help/releases/R2023b/deeplearning/ref/seriesnetwork.activations.html#d126e5157.
The underlying network for “bertDocumentClassifier” is a “dlnetwork” object: https://www.mathworks.com/help/releases/R2023b/textanalytics/ref/bertdocumentclassifier.html#mw_2480ef12-2a75-480d-aec3-eefb236d8afe. For such objects, you need to use the “predict” or the “forward” function, based on whether you want the model to output for inference or for training.
Please try the following code. I am assuming you want the model to output for inference and hence, using the “predict” function. If you want the model to output for training, change the “predict” call to a “forward” call. No other changes will be required:
% Extract the network
net = mdl.Network;
% Extract an example and encode it
example = textDataTrain(1, :);
[tokens, segments] = encode(mdl.Tokenizer, example);
% Since tokens and segments is a cell arrays with single vectors
% Extract the vectors
tokens = tokens{1};
segments = segments{1};
% Extract number of tokens
dims = size(tokens, 2);
% Convert the tokens and segments to dlarray
% BERT expects input in CTB format
tokens = dlarray(tokens, "CTB");
segments = dlarray(segments, "CTB");
% Create an attention mask of all zeros in CTB format
attentionMask = dlarray(zeros(1, dims), "CTB");
% Use predict function to get the output of layer 'out_fc2'
output = predict(net, tokens, attentionMask, segments, 'Outputs', 'out_fc2');
The code:
This has the following output:
Please refer to the following resources for more information:
Hope this helps!

Weitere Antworten (1)

tsvi lev
tsvi lev am 30 Mär. 2024
Excellent answer - logical, informed and with working code.
Thank you!

Kategorien

Mehr zu Modeling and Prediction finden Sie in Help Center und File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by