# attention

Dot-product attention

## Syntax

``Y = attention(queries,keys,values,numHeads)``
``[Y,weights] = attention(queries,keys,values,numHeads)``
``[Y,weights] = attention(queries,keys,values,numHeads,DataFormat=FMT)``
``[Y,weights] = attention(queries,keys,values,numHeads,Name=Value)``

## Description

The attention operation focuses on parts of the input using weighted multiplication operations.

example

````Y = attention(queries,keys,values,numHeads)` applies the dot-product attention operation to the specified queries, keys, and values using the number of attention heads `numHeads`. The input argument `queries` must be a formatted `dlarray` object.```

example

````[Y,weights] = attention(queries,keys,values,numHeads)` applies the dot-product attention operation and also returns the attention weights..```

example

````[Y,weights] = attention(queries,keys,values,numHeads,DataFormat=FMT)` applies the dot-product attention operation to the unformatted `dlarray` object `queries` with format specified by `FMT`. For example, `DataFormat="CBT"` specifies data with format `"CBT"` (channel, batch, time).```

example

````[Y,weights] = attention(queries,keys,values,numHeads,Name=Value)` specifies additional options using one or more name-value arguments. For example, `DropoutProbability=0.01` specifies a dropout probability of 0.01.```

## Examples

collapse all

Specify the sizes of the queries, keys, and values.

```querySize = 100; valueSize = 120; numQueries = 64; numValues = 80; numObservations = 32;```

Create random arrays containing the queries, keys, and values. For the queries, specify the `dlarray` format `"CBT"` (channel, batch, time).

```queries = dlarray(rand(querySize,numObservations, numQueries),"CBT"); keys = dlarray(rand(querySize,numObservations, numValues)); values = dlarray(rand(valueSize,numObservations, numValues));```

Specify the number of attention heads.

`numHeads = 5;`

Apply the attention operation.

`[Y,weights] = attention(queries,keys,values,numHeads);`

View the sizes and format of the output.

`size(Y)`
```ans = 1×3 120 32 64 ```
`dims(Y)`
```ans = 'CBT' ```

View the sizes and format of the weights.

`size(weights)`
```ans = 1×4 80 64 5 32 ```
`dims(weights)`
```ans = 0×0 empty char array ```

You can use the `attention` function to implement the multihead self attention operation [1] that focuses on parts of the input.

Create the `multiheadSelfAttention` function, listed in the Multihead Self Attention Function section of the example. The `multiheadSelfAttention` function takes as input the input data `X`, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.

The input `X` must be an unformatted `dlarray` object, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.

Create an array of sequence data.

```numChannels = 10; numObservations = 128; numTimeSteps = 100; X = rand(numChannels,numObservations,numTimeSteps); X = dlarray(X); size(X)```
```ans = 1×3 10 128 100 ```

`numHeads = 8;`

Initialize the learnable parameters for multihead attention.

• The learnable query, key, and value weights must be `(numChannels*numHeads)`-by-`numChannels` arrays.

• The learnable output weights must be a `(numChannels*numHeads)`-by-`(numChannels*numHeads)` array.

```outputSize = numChannels*numHeads; WQ = rand(outputSize,numChannels); WK = rand(outputSize,numChannels); WV = rand(outputSize,numChannels); WO = rand(outputSize,outputSize);```

Apply the multihead self attention operation.

`Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO);`

View the size of the output. The output has size `(numChannels*numHeads)`-by-`numObservations`-by-`(numTimeSteps)`.

`size(Y)`
```ans = 1×3 80 128 100 ```

The `multiheadSelfAttention` function takes as input the input data `X`, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.

• The input `X` must be an unformatted `dlarray` object, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.

• The learnable query, key, and value weight matrices are `(numChannels*numHeads)`-by-`numChannels` matrices.

• The learnable output weights matrix is a `(numChannels*numHeads)`-by-`(numChannels*numHeads)` matrix.

```function Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO) queries = pagemtimes(WQ,X); keys = pagemtimes(WK,X); values = pagemtimes(WV,X); A = attention(queries,keys,values,numHeads,DataFormat="CTB"); Y = pagemtimes(WO,A); end```

You can use the `attention` function to create a function that applies the Luong attention operation to its input. Create the `luongAttention` function, listed at the end of the example, that applies the Luong attention operation.

Specify the array sizes.

```numHiddenUnits = 100; latentSize = 16;```

Create random arrays containing the input data.

```hiddenState = dlarray(rand(numHiddenUnits,1)); Z = dlarray(rand(latentSize,1)); weights = dlarray(rand(numHiddenUnits,latentSize));```

Apply the `luongAttention` function.

`[context,attentionScores] = luongAttention(hiddenState,Z,weights);`

View the sizes of the outputs.

`size(context)`
```ans = 1×2 16 1 ```
`size(attentionScores)`
```ans = 1×2 1 1 ```

Luong Attention Function

The `luongAttention` function returns the context vector and attention scores according to the Luong "general" scoring [2]. This is equivalent to dot-product attention with queries, keys, and values specified as the hidden state, the weighted latent representation, and the latent representation, respectively.

```function [context,attentionScores] = luongAttention(hiddenState,Z,weights) numHeads = 1; queries = hiddenState; keys = pagemtimes(weights,Z); values = Z; [context,attentionScores] = attention(queries,keys,values,numHeads,Scale=1,DataFormat="CBT"); end```

## Input Arguments

collapse all

Queries, specified as a `dlarray` object.

`queries` can have at most one `"S"` (spatial) or `"T"` (time) dimension. Any dimensions in `queries` labeled `"U"` (unspecified) must be singleton. If `queries` is an unformatted `dlarray` object, then specify the data format using the `DataFormat` option.

The size of the `"C"` (channel) dimension in `keys` must match the size of the corresponding dimension in `queries`.

The size of the `"B"` (batch) dimension in `queries`, `keys`, and `values` must match.

Keys, specified as a `dlarray` object or a numeric array.

If `keys` is a formatted `dlarray` object, then its format must match the format of `queries`. If `keys` is not a formatted `dlarray`, then the function uses the same format as `queries`.

The size of any `"S"` (spatial) or `"T"` (time) dimensions in `keys` must match the size of the corresponding dimension in `values`.

The size of the `"C"` (channel) dimension in `keys` must match the size of the corresponding dimension in `queries`.

The size of the `"B"` (batch) dimension in `queries`, `keys`, and `values` must match.

Values, specified as a `dlarray` object or a numeric array.

If `values` is a formatted `dlarray` object, then its format must match the format of `queries`. Otherwise, the function uses the same format as `queries`.

The size of any `"S"` (spatial) or `"T"` (time) dimensions in `keys` must match the size of the corresponding dimension in `values`.

The size of the `"B"` (batch) dimension in `queries`, `keys`, and `values` must match.

Number of heads, specified as a positive integer. The value of `numHeads` must evenly divide the size of the `"C"` (channel) dimension of `queries`, `keys`, and `values`.

### Name-Value Arguments

Specify optional pairs of arguments as `Name1=Value1,...,NameN=ValueN`, where `Name` is the argument name and `Value` is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Before R2021a, use commas to separate each name and value, and enclose `Name` in quotes.

Example: `attention(queries,keys,values,numHeads,DataFormat="CBT")` applies the attention operation for unformatted data and specifies the data format `"CBT"` (channel, batch, time).

Dimension order of unformatted input data, specified as a character vector or string scalar `FMT` that provides a label for each dimension of the data.

When you specify the format of a `dlarray` object, each character provides a label for each dimension of the data and must be one of the following:

• `"S"` — Spatial

• `"C"` — Channel

• `"B"` — Batch (for example, samples and observations)

• `"T"` — Time (for example, time steps of sequences)

• `"U"` — Unspecified

You can use the labels `"C"` and `"B"` at most once and one dimension labeled either `"S"` or `"T"`.

You must specify `DataFormat` when the input data is not a formatted `dlarray`.

Data Types: `char` | `string`

Multiplicative factor for scaled dot-product attention [1], specified as one of these values:

• `"auto"` — Multiply the dot-product by $\lambda =\frac{1}{\sqrt{{d}_{k}}}$, where dk denotes the number of channels in the keys divided by the number of heads.

• Numeric scalar — Multiply the dot-product by the specified scale factor.

Data Types: `single` | `double` | `char` | `string`

Mask indicating which elements of the input correspond to padding values, specified as a `dlarray` object, a logical array, or a numeric array consisting of 0 and 1 values.

The function prevents and allows attention to elements of input data key-value pairs when the corresponding element in `PaddingMask` is 0 and 1, respectively.

If `PaddingMask` is a formatted `dlarray`, then its format must match that of `keys`. If `PaddingMask` is not a formatted `dlarray`, then the function uses the same format as `keys`. The size of the `"S"` (spatial), `"T"` (time), and `"B"` (batch) dimensions in `PaddingMask` must match the size of the corresponding dimensions in `keys` and `values`.

The default value is a logical array of ones with the same size as `keys`.

Attention mask indicating which elements to include when applying the attention operation, specified as one of these values:

• `"none"` — Do not prevent attention to elements with respect to their positions. If `AttentionMask` is `"none"`, then the software prevents attention using `PaddingMask` only.

• `"causal"` — Prevent elements in position `M` in the `"S"` (spatial) or `"T"` (time) dimension of `queries` from providing attention to the elements in positions `n>M` in the corresponding dimension of `keys` and `values`. Use this option for auto-regressive models.

• Logical or numeric array — Prevent attention to elements of `keys` and `values` when the corresponding element in the array is `0`, where `AttentionMask` is a `Nk`-by-`Nq` matrix or a `Nk`-by-`Nq`-by-`numObservations` array, `Nk` is the size of the `"S"` (spatial) or `"T"` (time) dimension of `keys`, `Nq` is the size of the corresponding dimension in `queries`, and `numObservations` is the size of the `"B"` dimension in `queries`.

Data Types: `single` | `double` | `int8` | `int16` | `int32` | `int64` | `uint8` | `uint16` | `uint32` | `uint64` | `logical` | `char` | `string`

Dropout probability for the attention weights, specified as a nonnegative scalar less than 1.

Data Types: `single` | `double`

## Output Arguments

collapse all

Output data, returned as a `dlarray` object.

If `queries` is a formatted `dlarray` object, then `Y` is a formatted `dlarray` object with the same dimension labels as `queries`. The size of the `"C"` (channel) dimension of `Y` is the same as the size of the corresponding dimension in `values`. The size of the `"S"` (spatial)

or `"T"` dimension of `Y` is the same size as the corresponding dimension in `queries`.

If `queries` is not a formatted `dlarray` object, the `Y` is an unformatted `dlarray` object.

Attention weights, returned as an unformatted `dlarray` object.

`weights` is a `Nk`-by-`Nq`-by-`numHeads`-by-`numObservations`, where `Nk` is the size of the `"S"` (spatial) or `"T"` (time) dimension of `keys`, `Nq` is the size of the corresponding dimension in `queries`, and `numObservations` is the size of the `"B"` (batch) dimension in `queries`.

## Algorithms

collapse all

### Dot-Product Attention

The attention operation focuses on parts of the input using weighted multiplication operations.

The single-head dot-product attention operation is given by

`$\text{attention}\left(Q,K,V\right)=\text{dropout}\left(\text{softmax}\left(\text{mask}\left(\lambda Q{K}^{\top },M\right)\right),p\right)V,$`

where Q, K, and V correspond to the queries, keys, and values, respectively, $\lambda$ denotes the scaling factor, M is a mask array of ones and zeros, and p is the dropout probability. The mask operation includes and excludes the values of the matrix multiplication setting values of the input to $-\infty$ for zero-valued mask elements. The mask is the union of the padding and attention masks. The dropout operation sets elements to zero with probability p.

The multihead self attention operation for the input X is given by

`$\text{multiheadSelfAttention}\left(X,h,{W}^{Q},{W}^{K},{W}^{V},{W}^{O}\right)=\text{concatenate}\left({\text{head}}_{1},\dots ,{\text{head}}_{h}\right){W}^{O},$`

where h is the number of heads, WQ, WK, WV, and WO are learnable projection matrices for the queries, keys, values, and output, respectively. Each weight matrix is composed of concatenated weight matrices Wi for each head. Each ${\text{head}}_{i}$ denotes the output of the head operation given by

`${\text{head}}_{i}=\text{attention}\left(X{W}_{i}^{Q},X{W}_{i}^{K},X{W}_{i}^{V}\right).$`

## References

[1] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[2] Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015).

## Version History

Introduced in R2022b