Skip to main content

Transformer decoding methods

Source

By default transformer models do not output words. They produce a probability distribution over the set of all possible tokens.

The code samples given here are not for production use! The model.generate function has parameters to do all of this so you don't need to write it yourself! For practical use, see How do you produce text from a text generation model

The greedy approch​

The easiest thing to do is to pick the token with the highest probability as the next token every time. This tends to produce highly repetitive text as shows below.

def generate_text(prompt):
""" A naive implementation of greedy text geneneration """
tmp_text = prompt
print(tmp_text, end="")
for i in range(50): # we generate 50 tokens
with torch.no_grad(): # no gradient needed for inference
inputs = tokenizer(tmp_text, return_tensors="pt").to("mps")
outputs = model(**inputs)
# Apply softmax to turn the output into a probability
# Softmax is a form of normalization that converts logits
# which are in log-space to probabilities.
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# argmax = pick the index with the highest probability
predicted_token = torch.argmax(input=probs, dim=-1)
new_tok = tokenizer.decode(predicted_token[0][-1], skip_special_tokens=True)
print(new_tok, end="")
tmp_text += new_tok
print("")
>>> generate_text("I have no special")
I have no special talent. I have no special skill. I have no special talent. I have no special skill.

Beam search

Sampling based on the distribution​

If deterministic answers do not matter to you, you can sample from the probability distribution instead of picking the most likely token.

You can implement sampling using the multinomial function from pytorch:

def generate_text(prompt, token_count = 50, temperature=1.0, device = "mps"):
""" A naive implementation of token sampling """
tmp_text = text
print(tmp_text, end="")
for i in range(token_count):
with torch.no_grad():
inputs = tokenizer(tmp_text, return_tensors="pt").to(device)
outputs = model(**inputs)

logits = outputs.logits / temperature
probs = torch.nn.functional.softmax(logits, dim=-1)

# Convert the probabilities to tokens
predicted_token = torch.multinomial(input=probs[0][-1], num_samples=1, replacement=False)[0]
new_tok = tokenizer.decode(predicted_token, skip_special_tokens=True)
print(new_tok, end="")
tmp_text += new_tok
print("")
>>> generate_text("I have no special")
I have no special talent."

**Maya Angelou** β€”This American poet and born in Maryland has been a

You can pick the temperature which controls the temperature of the softmax applied beforehand and allows you to increase the probability of less likely tokens. Lower temperatures mean lower probabilities of unlikely tokens. In the limit, when the temperature approches 0, the sampling is the same as the greedy approach. When the temperature approaches infinity, this is the same as picking the next token at random.

>>> generate_text("I have no special", temperature=0.7)
I have no special knowledge or skill. I just want to be able to do as much of it as possible.

>>> generate_text("I have no special", temperature=0.1)
I have no special talent. I have no special skill. I have no special talent. I have no special skill.

>>> generate_text("I have no special", temperature=10)
I have no special Power judiciary mes artificialcmp categor Metro horsepower

You can pick among how many of the most likely tokens you want to perform the sampling. This is called top_k sampling.

This leads to natural results and is used by most text generation models in production.