Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silent failure in generation parameters #33690

Open
4 tasks
Manalelaidouni opened this issue Sep 25, 2024 · 7 comments · May be fixed by #34726
Open
4 tasks

Silent failure in generation parameters #33690

Manalelaidouni opened this issue Sep 25, 2024 · 7 comments · May be fixed by #34726

Comments

@Manalelaidouni
Copy link
Contributor

Manalelaidouni commented Sep 25, 2024

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.7
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (False)
  • Tensorflow version (GPU?): 2.17.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Using distributed or parallel set-up in script?: No

Who can help?

@zucchini-nlp @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hey!

I noticed that top_p was silenctly failing so I tested the rest of the generation parameters and found that no_repeat_ngram_size also silently fails for the same reason: the condition checks inside of the _get_logits_processor() method prevent their respective wrapper classes from executing, which is where the ValueError are being raised.

For instance, raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") error is never reached when we set no_repeat_ngram_size <= 0.

Here is a simple example with the invalid values where generation proceeds without notifying the user. Ideally, those should raise errors or warnings.

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "EleutherAI/pythia-14m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "hey there!"
inputs = tokenizer(prompt, return_tensors="pt")
generation_config = dict(do_sample=True, top_p=5, no_repeat_ngram_size=-1)

outputs = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], **generation_config)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response

Expected behavior

  • To not let things fail silently and proceed with a default value, instead raise a ValueError or issue a warning to the user.

  • It would be great if the generate method could fail early when invalid values are passed, maybe by checking for them upfront in _get_logits_processor before applying the generation parameters one by one and going through the entire process, this will help avoid wasting compute resources.

I would be happy to open a PR to help address this issue if that’s possible, thank you for all your work!

@LysandreJik
Copy link
Member

cc @gante

@zucchini-nlp
Copy link
Member

Hey! Actually these are validated when preparing generation config in
https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/generation/utils.py#L1355C28-L1355C45

But seems like we don't check all generation parameters and perform general check on whether there isn't any clash between kwargs. @gante will say if we need to perform check on everything or not :)

@Manalelaidouni
Copy link
Contributor Author

Thank you @zucchini-nlp for the clarification, from what I can see it looks like the _prepare_generation_config method you pointed out doesn’t actually validate the parameters; it just updates the config with the provided values.

It would be great to hear @gante's thoughts on whether we should validate all parameters upfront, perhaps, this could happen in _prepare_generation_config.

I'm happy to assist with a PR if needed!

@zucchini-nlp
Copy link
Member

@Manalelaidouni the config validates that incoming params are correct when updating, but the validation is done on a general level. So that we don;t have clashing arguments as generation without sampling but with a top-k param

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented Oct 29, 2024

Hi @Manalelaidouni 👋

Yes, more validation is super welcome! :D GenerationConfig.validate would be the right place to add more validation -- this function is run both when we define a GenerationConfig object and when we update it in generate.

We should, in fact, create bulk validation structures, such as:

  • define which variables are integers, check that they are indeed integers
  • same, but for floats
  • same, but for positive variables
  • same, but for [0, 1] floats

(note: we don't want to add pydantic as a dependency :) )

@Manalelaidouni Manalelaidouni linked a pull request Nov 14, 2024 that will close this issue
5 tasks
@Manalelaidouni
Copy link
Contributor Author

Got it, I’ve opened a PR with your suggestions @gante, if there’s any advanced validation testing you have in mind I’m happy to implement it! 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants