Transformer (machine learning model)
The Transformer is a deep learning model introduced in 2017, used primarily in the field of natural language processing (NLP).[1]
Part of a series on |
Machine learning and data mining |
---|
Theory |
Machine-learning venues |
Like recurrent neural networks (RNNs), Transformers are designed to handle sequential data, such as natural language, for tasks such as translation and text summarization. However, unlike RNNs, Transformers do not require that the sequential data be processed in order. For example, if the input data is a natural language sentence, the Transformer does not need to process the beginning of it before the end. Due to this feature, the Transformer allows for much more parallelization than RNNs and therefore reduced training times.[1]
Since their introduction, Transformers have become the model of choice for tackling many problems in NLP, replacing older recurrent neural network models such as the long short-term memory (LSTM). Since the Transformer model facilitates more parallelization during training, it has enabled training on larger datasets than was possible before it was introduced. This has led to the development of pretrained systems such as BERT (Bidirectional Encoder Representations from Transformers) and GPT (Generative Pre-trained Transformer), which have been trained with huge general language datasets, and can be fine-tuned to specific language tasks.[2][3]
Background
Before the introduction of Transformers, most state-of-the-art NLP systems relied on gated recurrent neural networks (RNNs), such as LSTMs and Gated recurrent units (GRUs), with added attention mechanisms. The Transformer built upon these attention technologies without using an RNN structure, highlighting the fact that the attention mechanisms alone, without recurrent sequential processing, are powerful enough to achieve the performance of RNNs with attention.
Gated RNNs process tokens sequentially, maintaining a state vector that contains a representation of the data seen after every token. To process the token, the model combines the state representing the sentence up to token with the information of the new token to create a new state, representing the sentence up to token . Theoretically, the information from one token can propagate arbitrarily far down the sequence, if at every point the state continues to encode information about the token. But in practice this mechanism is imperfect: due in part to the vanishing gradient problem, the model's state at the end of a long sentence often does not contain precise, extractable information about early tokens.
This problem was addressed by the introduction of attention mechanisms. Attention mechanisms let a model directly look at, and draw from, the state at any earlier point in the sentence. The attention layer can access all previous states and weighs them according to some learned measure of relevancy to the current token, providing sharper information about far-away relevant tokens. A clear example of the utility of attention is in translation. In an English-to-French translation system, the first word of the French output most probably depends heavily on the beginning of the English input. However, in a classic encoder-decoder LSTM model, in order to produce the first word of the French output the model is only given the state vector of the last English word. Theoretically, this vector can encode information about the whole English sentence, giving the model all necessary knowledge, but in practice this information is often not well preserved. If an attention mechanism is introduced, the model can instead learn to attend to the states of early English tokens when producing the beginning of the French output, giving it a much better concept of what it is translating.
When added to RNNs, attention mechanisms led to large gains in performance. The introduction of the Transformer brought to light the fact that attention mechanisms were powerful in themselves, and that sequential recurrent processing of data was not necessary for achieving the performance gains of RNNs with attention. The Transformer uses an attention mechanism without being an RNN, processing all tokens at the same time and calculating attention weights between them. The fact that Transformers do not rely on sequential processing, and lend themselves very easily to parallelization, allows Transformers to be trained more efficiently on larger datasets.
Architecture
Like the models invented before it, the Transformer is an encoder-decoder architecture. The encoder consists of a set of encoding layers that processes the input iteratively one layer after another and the decoder consists of a set of decoding layers that does the same thing to the output of the encoder.
The function of each encoder layer is to process its input to generate encodings, containing information about which parts of the inputs are relevant to each other. It passes its set of encodings to the next encoder layer as inputs. Each decoder layer does the opposite, taking all the encodings and processes them, using their incorporated contextual information to generate an output sequence.[4] To achieve this, each encoder and decoder layer makes use of an attention mechanism, which for each input, weighs the relevance of every other input and draws information from them accordingly to produce the output.[5] Each layer decoder also has an additional attention mechanism which draws information from the outputs of previous decoders, before the decoder layer draws information from the encodings. Both the encoder and decoder layers have a feed-forward neural network for additional processing of the outputs, and contain residual connections and layer normalization steps.[5]
Scaled dot-product attention
The basic building blocks of the Transformer are scaled dot-product attention units. When a sentence is passed into a Transformer model, attention weights are calculated between every token simultaneously. The attention unit produces embeddings for every token in context that contain information not only about the token itself, but also a weighted combination of other relevant tokens weighted by the attention weights.
Concretely, for each attention unit the Transformer model learns three weight matrices; the query weights , the key weights , and the value weights . For each token , the input word embedding is multiplied with each of the three weight matrices to produce a query vector , a key vector , and a value vector . Attention weights are calculated using the query and key vectors: the attention weight from token to token is the dot product between and . The attention weights are divided by the square root of the dimension of the key vectors, , which stabilizes gradients during training, and passed through a softmax which normalizes the weights to sum to . The fact that and are different matrices allows attention to be non-symmetric: if token attends to token (i.e. is large), this does not necessarily mean that token will attend to token (i.e. is large). The output of the attention unit for token is the weighted sum of the value vectors of all tokens, weighted by , the attention from to each token.
The attention calculation for all tokens can be expressed as one large matrix calculation, which is useful for training due to computational matrix operation optimizations which make matrix operations fast to compute. The matrices , and are defined as the matrices where the th rows are vectors , , and respectively.
Multi-head attention
One set of matrices is called an attention head, and each layer in a Transformer model has multiple attention heads. While one attention head attends to the tokens that are relevant to each token, with multiple attention heads the model can learn to do this for different definitions of "relevance". Research has shown that many attention heads in Transformers encode relevance relations that are transparent to humans. For example there are attention heads that, for every token, attend mostly to the next word, or attention heads that mainly attend from verbs to their direct objects.[6] Since Transformer models have multiple attention heads, they have the possibility of capturing many levels and types of relevance relations, from surface-level to semantic. The multiple outputs for the multi-head attention layer are concatenated to pass into the feed-forward neural network layers.
Encoder
Each encoder consists of two major components: a self-attention mechanism and a feed-forward neural network. The self-attention mechanism takes in a set of input encodings from the previous encoder and weighs their relevance to each other to generate a set of output encodings. The feed-forward neural network then further processes each output encoding individually. These output encodings are finally passed to the next encoder as its input, as well as the decoders.
The first encoder takes positional information and embeddings of the input sequence as its input, rather than encodings. The positional information is necessary for the Transformer to make use of the order of the sequence, because no other part of the Transformer makes use of this.[1]
Decoder
Each decoder consists of three major components: a self-attention mechanism, an attention mechanism over the encodings, and a feed-forward neural network. The decoder functions in a similar fashion to the encoder, but an additional attention mechanism is inserted which instead draws relevant information from the encodings generated by the encoders.[1][5]
Like the first encoder, the first decoder takes positional information and embeddings of the output sequence as its input, rather than encodings. Since the transformer should not use the current or future output to predict an output though, the output sequence must be partially masked to prevent this reverse information flow.[1] The last decoder is followed by a final linear transformation and softmax layer, to produce the output probabilities over the vocabulary.
Training
Transformers typically undergo semi-supervised learning involving unsupervised pretraining followed by supervised fine-tuning. Pretraining is typically done on a much larger dataset than fine-tuning, due to the restricted availability of labeled training data. Tasks for pretraining and fine-tuning commonly include:
- next-sentence prediction[2]
- question answering[3]
- reading comprehension
- sentiment analysis[7]
- paraphrasing[7]
Implementations
The Transformer model has been implemented in major deep learning frameworks such as TensorFlow and PyTorch. Below is pseudo code for an implementation of the Transformer variant known as the "vanilla" transformer:
def vanilla_transformer(enc_inp, dec_inp):
"""Transformer variant known as the "vanilla" transformer."""
x = embedding(enc_inp) * sqrt(d_m)
x = x + pos_encoding(x)
x = dropout(x)
for _ in range(n_enc_layers):
attn = multi_head_attention(x, x, x, None)
attn = dropout(attn)
attn = layer_normalization(x + attn)
x = point_wise_ff(attn)
x = layer_normalization(x + attn)
# x is at this point the output of the encoder
enc_out = x
x = embedding(dec_inp) * sqrt(d_m)
x = x + pos_encoding(x)
x = dropout(x)
mask = causal_mask(x)
for _ in range(n_dec_layers):
attn1 = multi_head_attention(x, x, x, mask)
attn1 = layer_normalization(attn1 + x)
attn2 = multi_head_attention(attn1, enc_out, enc_out, None)
attn2 = dropout(attn2)
attn2 = layer_normalization(attn1 + attn2)
x = point_wise_ff(attn2)
x = layer_normalization(attn2 + x)
return dense(x)
Applications
The Transformer finds most of its applications in the field of natural language processing (NLP), for example the tasks of machine translation and time series prediction.[8] Many pretrained models such as GPT-3, GPT-2, BERT, XLNet, and RoBERTa demonstrate the ability of Transformers to perform a wide variety of such NLP-related tasks, and have the potential to find real-world applications.[2][3][9] These may include:
- machine translation
- document summarization
- document generation
- named entity recognition (NER)[10]
- speech recognition[10]
- biological sequence analysis [11]
References
- Polosukhin, Illia; Kaiser, Lukasz; Gomez, Aidan N.; Jones, Llion; Uszkoreit, Jakob; Parmar, Niki; Shazeer, Noam; Vaswani, Ashish (2017-06-12). "Attention Is All You Need". arXiv:1706.03762 [cs.CL].
- "Open Sourcing BERT: State-of-the-Art Pre-training for Natural Language Processing". Google AI Blog. Retrieved 2019-08-25.
- "Better Language Models and Their Implications". OpenAI. 2019-02-14. Retrieved 2019-08-25.
- "Sequence Modeling with Neural Networks (Part 2): Attention Models". Indico. 2016-04-18. Retrieved 2019-10-15.
- Alammar, Jay. "The Illustrated Transformer". jalammar.github.io. Retrieved 2019-10-15.
- Clark, Kevin; Khandelwal, Urvashi; Levy, Omer; Manning, Christopher D. (August 2019). "What Does BERT Look at? An Analysis of BERT's Attention". Proceedings of the 2019 ACL Workshop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP. Florence, Italy: Association for Computational Linguistics: 276–286. doi:10.18653/v1/W19-4828.
- Wang, Alex; Singh, Amanpreet; Michael, Julian; Hill, Felix; Levy, Omer; Bowman, Samuel (2018). "GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding". Proceedings of the 2018 EMNLP Workshop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP. Stroudsburg, PA, USA: Association for Computational Linguistics: 353–355. arXiv:1804.07461. Bibcode:2018arXiv180407461W. doi:10.18653/v1/w18-5446.
- Allard, Maxime (2019-07-01). "What is a Transformer?". Medium. Retrieved 2019-10-21.
- Yang, Zhilin Dai, Zihang Yang, Yiming Carbonell, Jaime Salakhutdinov, Ruslan Le, Quoc V. (2019-06-19). XLNet: Generalized Autoregressive Pretraining for Language Understanding. OCLC 1106350082.CS1 maint: multiple names: authors list (link)
- Monsters, Data (2017-09-26). "10 Applications of Artificial Neural Networks in Natural Language Processing". Medium. Retrieved 2019-10-21.
- Rives, Alexander; Goyal, Siddharth; Meier, Joshua; Guo, Demi; Ott, Myle; Zitnick, C. Lawrence; Ma, Jerry; Fergus, Rob (2019). "Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences". doi:10.1101/622803. Cite journal requires
|journal=
(help)