Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral past_key_values and output_router_logits incompatible #30731

Open
2 of 4 tasks
sorgfresser opened this issue May 9, 2024 · 4 comments · May be fixed by #34707
Open
2 of 4 tasks

Mixtral past_key_values and output_router_logits incompatible #30731

sorgfresser opened this issue May 9, 2024 · 4 comments · May be fixed by #34707
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@sorgfresser
Copy link
Contributor

System Info

transformers==4.40.2
Python 3.11.8

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import MixtralConfig, MixtralForCausalLM, AutoTokenizer
import torch
# Initializing a smaller version of Mixtral for faster execution
configuration = MixtralConfig(
    hidden_size=256,
    intermediate_size=896,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_key_value_heads=8,
    num_local_experts=4,
    num_experts_per_tok=1,
)

model = MixtralForCausalLM(configuration)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt = "This is a test"
tokenized = tokenizer(prompt, return_tensors="pt")
output = model(**tokenized, output_router_logits=True)
key_values = output.past_key_values
logits = output.logits
next_token_logits = logits[..., -1, :]
# Softmax
softmaxed = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Sample
sampled = torch.multinomial(softmaxed.squeeze(), num_samples=1)
ids = sampled.item()

attention_mask = torch.cat([tokenized["attention_mask"], torch.tensor([[1]])], dim=-1)
next_output = model(
    torch.tensor([[ids]]),
    attention_mask=attention_mask,
    past_key_values=key_values,
    output_router_logits=True
)

Expected behavior

It seems that this is the same underlying issue as in #29087 - I would expect past_key_values to work with output_router_logits.
So what happens?

  1. Without past key values (and with multiple input ids) the all_router_logits has the proper sequence length, thus in load_balancing_loss_func this num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) correctly evaluates the number of hidden layers.
  2. If past key values are used, all_router_logits has a sequence length of 1, but since the attention mask is still the whole sequence (from which the sequence_length is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in Mixtral inference breaks when output_router_logits=True #29087

Instead, I would like the load_balancing_loss_func to be able to deal with a case where the gate_logits passed are of shape [batch_size X 1, num_experts] instead of [batch_size X sequence_length, num_experts].

@ArthurZucker
Copy link
Collaborator

Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation).
Do you want to open a PR for a fix?

@ArthurZucker ArthurZucker added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label May 15, 2024
@csking101
Copy link

Hi, could I take this up? @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Sur feel free to open a PR!

@csking101 csking101 linked a pull request Nov 12, 2024 that will close this issue
5 tasks
@csking101
Copy link

I am having trouble figuring out how to go about the [batch_size X 1, num_experts] case. I am thinking of creating a condition for this in the load_balancing_loss_func, such as, is_single_token = gate_logits[0].shape[0] == 1. After that, I will modify the code in the following way, so that the shapes during the multiplication are valid:

        if not is_single_token:
            expert_attention_mask = (
                attention_mask[None, :, :, None, None]
                .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
                .reshape(-1, top_k, num_experts)
                .to(compute_device)
            )
        else:
            expert_attention_mask = (
                attention_mask[None, :, -1:, None, None]
                .expand((len(gate_logits), 1,1, top_k, num_experts))
                .reshape(-1, top_k, num_experts)
                .to(compute_device)
            )

As you can see here, I think one approach is to retain the attention score of that particular token itself (the one that we are adding), and get rid of the other scores in the sequence. This is because, if we do retain the other ones, then expert_attention_mask's shape will not match that of expert_mask. As a consequence, for this case, we will not need to calculate the num_hidden_layers (as it is a tiny value as mentioned by @sorgfresser) , because we will just get the batch size directly from gate_logits. Similarly, we will need a modification for router_per_expert_attention_mask:

if not is_single_token:
            router_per_expert_attention_mask = (
                attention_mask[None, :, :, None]
                .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
                .reshape(-1, num_experts)
                .to(compute_device)
            )
        else:
            router_per_expert_attention_mask = (
                attention_mask[None, :, -1:, None]
                .expand((len(gate_logits), 1, 1, num_experts))
                .reshape(-1, num_experts)
                .to(compute_device)
            )

This approach is what I have come up with, I hope you all can verify the correctness of the same, and point out if I have made any mistakes. The PR is here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants