diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index db288d3..7b09ea2 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -48,14 +48,21 @@ def __call__( if ( decoded.count(".") == 1 - and len(decoded.strip().split(".")[1]) > self.precision + and len(decoded.replace(" ", "").split(".")[1]) > self.precision + ): + return True + + if ( + len(decoded) > 1 + and "," in decoded + and any(c.isdigit() for c in decoded.split(",")[0]) ): return True if ( len(decoded) > 1 and any(c.isdigit() for c in decoded) - and decoded[-1] in [" ", "\n"] + and ("," in decoded or decoded[-1] in (" ", "\n")) ): return True @@ -71,9 +78,78 @@ def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): for _, token_id in tokenizer.get_vocab().items(): token_str = tokenizer.decode(token_id).strip() - if token_str == "" or ( - all(c.isdigit() or c == "." for c in token_str) - and token_str.count(".") <= 1 + if ( + token_str == "" + or ( + all(c.isdigit() or c == "." for c in token_str) + and token_str.count(".") <= 1 + ) or ( + "," in token_str + and all(c.isdigit() or c == "." for c in token_str.split(",")[0]) + and token_str.count(".") <= 1 + ) + ): + self.allowed_mask[token_id] = True + + def __call__(self, _, scores): + mask = self.allowed_mask.expand_as(scores) + scores[~mask] = -float("inf") + + return scores + +class IntegerStoppingCriteria(StoppingCriteria): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + prompt_length: int, + max_digits: int = 15, + ): + self.tokenizer = tokenizer + self.prompt_length = prompt_length + self.max_digits = max_digits + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + ) -> bool: + decoded = self.tokenizer.decode( + input_ids[0][self.prompt_length :], skip_special_tokens=True + ) + + if len(decoded.strip()) > self.max_digits: + return True + + if ( + len(decoded) > 1 + and "," in decoded + and any(c.isdigit() for c in decoded.split(",")[0]) + ): + return True + + if ( + len(decoded) > 1 + and any(c.isdigit() for c in decoded) + and decoded[-1] in (" ", "\n") + ): + return True + + return False + +class OutputIntegersTokens(LogitsWarper): + def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): + self.tokenizer = tokenizer + self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") + vocab_size = len(tokenizer) + self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) + + for _, token_id in tokenizer.get_vocab().items(): + token_str = tokenizer.decode(token_id).strip() + + if ( + token_str == "" + or all(c.isdigit() for c in token_str) + or "," in token_str and all(c.isdigit() for c in token_str.split(",")[0]) ): self.allowed_mask[token_id] = True diff --git a/jsonformer/main.py b/jsonformer/main.py index dd867d4..4bb2d5d 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -1,13 +1,16 @@ -from typing import List, Union, Dict, Any +from typing import List, Set, Union, Dict, Any from jsonformer.logits_processors import ( NumberStoppingCriteria, OutputNumbersTokens, + IntegerStoppingCriteria, + OutputIntegersTokens, StringStoppingCriteria, ) from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer import json +import torch GENERATION_MARKER = "|GENERATION|" @@ -34,6 +37,7 @@ def __init__( self.prompt = prompt self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt) + self.integer_logit_processor = OutputIntegersTokens(self.tokenizer, self.prompt) self.generation_marker = "|GENERATION|" self.debug_on = debug @@ -72,7 +76,9 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0): response = self.tokenizer.decode(response[0], skip_special_tokens=True) response = response[len(prompt) :] - response = response.strip().rstrip(".") + if "," in response: + response = response.split(",")[0] + response = response.replace(" ", "").rstrip(".") self.debug("[generate_number]", response) try: return float(response) @@ -82,6 +88,38 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0): return self.generate_number(temperature=self.temperature * 1.3) + def generate_integer(self, temperature: Union[float, None] = None, iterations=0): + prompt = self.get_prompt() + self.debug("[generate_number]", prompt, is_prompt=True) + input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + response = self.model.generate( + input_tokens, + max_new_tokens=self.max_number_tokens, + num_return_sequences=1, + logits_processor=[self.integer_logit_processor], + stopping_criteria=[ + IntegerStoppingCriteria(self.tokenizer, len(input_tokens[0])) + ], + temperature=temperature or self.temperature, + pad_token_id=self.tokenizer.eos_token_id, + ) + response = self.tokenizer.decode(response[0], skip_special_tokens=True) + + response = response[len(prompt) :] + if "," in response: + response = response.split(",")[0] + response = response.replace(" ", "") + self.debug("[generate_integer]", response) + try: + return int(response) + except ValueError: + if iterations > 3: + raise ValueError("Failed to generate a valid integer") + + return self.generate_integer(temperature=self.temperature * 1.3) + def generate_boolean(self) -> bool: prompt = self.get_prompt() self.debug("[generate_boolean]", prompt, is_prompt=True) @@ -90,11 +128,8 @@ def generate_boolean(self) -> bool: output = self.model.forward(input_tensor.to(self.model.device)) logits = output.logits[0, -1] - # todo: this assumes that "true" and "false" are both tokenized to a single token - # this is probably not true for all tokenizers - # this can be fixed by looking at only the first token of both "true" and "false" - true_token_id = self.tokenizer.convert_tokens_to_ids("true") - false_token_id = self.tokenizer.convert_tokens_to_ids("false") + true_token_id = self.tokenizer.encode("true", return_tensors="pt")[0, 0] + false_token_id = self.tokenizer.encode("false", return_tensors="pt")[0, 0] result = logits[true_token_id] > logits[false_token_id] @@ -139,6 +174,38 @@ def generate_string(self) -> str: return response.split('"')[0].strip() + def generate_enum(self, enum_values: Set[str]) -> str: + prompt = self.get_prompt() + self.debug("[generate_enum]", prompt, is_prompt=True) + + # These are necessary because we don't know if we're at the end or middle of an object/array + terminal_tokens = torch.concat([ + self.tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")[:, 0] + for s in ('", "', '"}', '"]') + ]) + + highest_probability = 0.0 + best_option = None + for option in enum_values: + n_option_tokens = self.tokenizer.encode(f'"{option}', add_special_tokens=False, return_tensors="pt").shape[1] + prompt_tokens = self.tokenizer.encode(prompt + f'"{option}', return_tensors="pt") + option_tokens = prompt_tokens[0, -n_option_tokens:] + + with torch.no_grad(): + logits = self.model.forward(prompt_tokens.to(self.model.device)).logits[0, -n_option_tokens-1:] + probabilities = torch.softmax(logits, dim=1) + option_token_probabilities = probabilities[:-1][torch.arange(probabilities.shape[0]-1), option_tokens] + termination_probability = torch.max(probabilities[-1, terminal_tokens]) + option_probability = torch.prod(option_token_probabilities) * termination_probability + + if option_probability > highest_probability: + best_option = option + highest_probability = option_probability + + self.debug("[generate_enum]", best_option) + + return best_option + def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] ) -> Dict[str, Any]: @@ -160,6 +227,12 @@ def generate_value( else: obj.append(self.generation_marker) return self.generate_number() + elif schema_type == "integer": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_integer() elif schema_type == "boolean": if key: obj[key] = self.generation_marker @@ -172,6 +245,12 @@ def generate_value( else: obj.append(self.generation_marker) return self.generate_string() + elif schema_type == "enum": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_enum(set(schema["values"])) elif schema_type == "array": new_array = [] obj[key] = new_array