Introduction
In short, an attention-based model “focuses” on each element of the input (a word in a sentence or a different position in an image, etc.). “Focusing” means projecting different levels of attention so that the input elements are treated differently and each element of the input is weighted differently to influence the result; a non-attention model treats each element “equally”.
Attention Mechanism
In natural language processing (NLP), the attention mechanism outperformed the encoder decoder-based neural machine translation system. As the name implies, the attention mechanism is essentially designed to mimic the way humans look at objects. For example, when looking at a picture, people will not only grasp the picture as a whole, but will also pay more attention to a particular part of the picture, such as the location of a table, or the category of product, etc. In the field of translation, whenever people translate a passage, they usually start with the sentence, but when reading the whole sentence, it is certainly necessary to focus on the information of the words themselves, as well as the information of the relationship between the words before and after and the information of the context. In NLP, if sentiment classification is to be performed, it will certainly involve words that express sentiment in a given sentence, including but not limited to “happy”, “frustrated”, “knackered” and so on. The other words in these sentences are contextual, not that they are useless, but that they do not play as big a role as the emotive keywords. Under the above description, the attention mechanism actually consists of two parts.
- The attention mechanism needs to decide which part of the whole input needs more attention.
- Feature extraction from key sections to get important information.
Let’s take a sentence for example. We have a context “you had me at hello”, and we already get the embedding vectors for each token. As you can see in figure, the sequence length is 5 (five tokens), and the embedding dimension is 3 (size of vector). The embedding vector for each token in the input sequence is fixed and does not contain any contextualised information. The purpose of self attention is to get a new embedding vector for each token by calculating the dependencies between representations considering the contextual information (self-attention model allows inputs to interact with each other). Assume we want to calculate the contextualised embedding for “hello”, so we use the formula below.
where
After computing all the attention weights for the token “hello”, we can get the weighted sum for the contextualised embedding vector. Since the weights
Non-Parametric version of Self Attention.
import torch
import torch.nn as nn
class NonparametricSelfAttention(nn.Module):
"""
Examples
--------
>>> context = torch.Tensor([
[
[0.6, 0.2, 0.8],
[0.2, 0.3, 0.1],
[0.9, 0.1, 0.8],
[0.4, 0.1, 0.4],
[0.4, 0.1, 0.6]
]
])
>>> context_, attention_weights = NonparametricSelfAttention(3)(context)
>>> print("Input: ", context_.shape)
Input: torch.Size([1, 5, 3])
>>> print("Output: ", context_.shape)
Output: torch.Size([1, 5, 3])
"""
def __init__(self, dimensions):
super(NonparametricSelfAttention, self).__init__()
self.dimensions = dimensions
self.softmax = nn.Softmax(dim=-1)
def forward(self, context, return_weights=True):
"""
context: [sequence_length, embedding_dimension]
"""
attention_scores = torch.bmm(context, context.transpose(1, 2))
attention_weights = self.softmax(attention_scores )
context_ = torch.bmm(attention_weights, context)
if return_weights:
return context_ , attention_weights
return context_
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
Parametric version of Self Attention.
import math
import torch
import torch.nn as nn
class ParametricSelfAttention(nn.Module):
"""
Examples
--------
>>> context = torch.Tensor([
[
[0.6, 0.2, 0.8],
[0.2, 0.3, 0.1],
[0.9, 0.1, 0.8],
[0.4, 0.1, 0.4],
[0.4, 0.1, 0.6]
]
])
>>> context_, attention_weights = ParametricSelfAttention(3)(context)
>>> print("Input: ", context_.shape)
Input: torch.Size([1, 5, 3])
>>> print("Output: ", context_.shape)
Output: torch.Size([1, 5, 3])
"""
def __init__(self, dimensions):
super(ParametricSelfAttention, self).__init__()
self.dimensions = dimensions
self.softmax = nn.Softmax(dim=-1)
self.tanh = nn.Tanh()
self.linear_q_in = nn.Linear(dimensions, dimensions, bias=False)
self.linear_k_in = nn.Linear(dimensions, dimensions, bias=False)
self.linear_v_in = nn.Linear(dimensions, dimensions, bias=False)
self.linear_out = nn.Linear(dimensions, dimensions, bias=False)
def forward(self, context, return_weights=True):
"""
context: [sequence_length, embedding_dimension]
"""
context_q = self.linear_q_in(context)
context_k = self.linear_k_in(context)
context_v = self.linear_v_in(context)
attention_scores = torch.bmm(context_q, context_k.transpose(1, 2))
attention_weights = self.softmax(attention_scores / math.sqrt(self.dimensions))
context_ = torch.bmm(attention_weights, context_v)
context_ = self.tanh(context_)
context_ = self.linear_out(context_)
if return_weights:
return context_ , attention_weights
return context_
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
Conclusions
In context-aware encoding for learning long-range dependencies, self-attention was utilised to replace RNN (Vaswani et al. 2017). The length of the paths along which the forward and backward signals move in the network affects the ability to learn long-range relationships.