Generating Human-level Text with Contrastive Search in Transformers 🤗

Back to Articles

Tian Lan's avatar

This article is also available in Chinese 简体中文.


Open In Colab

1. Introduction:

Natural language generation (i.e. text generation) is one of the core tasks in natural language processing (NLP). In this blog, we introduce the current state-of-the-art decoding method, Contrastive Search, for neural text generation. Contrastive search is originally proposed in "A Contrastive Framework for Neural Text Generation" [1] ([Paper][Official Implementation]) at NeurIPS 2022. Moreover, in this follow-up work, "Contrastive Search Is What You Need For Neural Text Generation" [2] ([Paper] [Official Implementation]), the authors further demonstrate that contrastive search can generate human-level text using off-the-shelf language models across 16 languages.

[Remark] For users who are not familiar with text generation, please refer more details to this blog post.


2. Hugging Face 🤗 Demo of Contrastive Search:

Contrastive Search is now available on 🤗 transformers, both on PyTorch and TensorFlow. You can interact with the examples shown in this blog post using your framework of choice in this Colab notebook, which is linked at the top. We have also built this awesome demo which directly compares contrastive search with other popular decoding methods (e.g. beam search, top-k sampling [3], and nucleus sampling [4]).


3. Environment Installation:

Before running the experiments in the following sections, please install the update-to-date version of transformers as

pip install torch
pip install "transformers==4.24.0"

4. Problems of Existing Decoding Methods:

Decoding methods can be divided into two categories: (i) deterministic methods and (ii) stochastic methods. Let's discuss both!

4.1. Deterministic Methods:

Deterministic methods, e.g. greedy search and beam search, generate text by selecting the text continuation with the highest likelihood measured by the language model. However, as widely discussed in previous studies [3][4], deterministic methods often lead to the problem of model degeneration, i.e., the generated text is unnatural and contains undesirable repetitions.

Below, let's see an example of generated text from greedy search using GPT-2 model.

from transformers import AutoTokenizer, GPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
input_ids = tokenizer('DeepMind Company is', return_tensors='pt').input_ids
model = GPT2LMHeadModel.from_pretrained('gpt2-large')

output = model.generate(input_ids, max_length=128)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[0], skip_special_tokens=True))
print("" + 100 * '-')

Model Output:

Output:
----------------------------------------------------------------------------------------------------
DeepMind Company is a leading AI research company, with a focus on deep learning and deep
learning-based systems.

The company's research is focused on the development of deep learning-based systems that
can learn from large amounts of data, and that can be used to solve real-world problems.

DeepMind's research is also used by the UK government to develop new technologies for the
UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the
UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies
----------------------------------------------------------------------------------------------------

[Remark] From the result generated by greedy search, we can see obvious pattern of repetitions.