Skip to content

Commit 6784e1f

Browse files
authored
Merge pull request #326 from urchade/fix/training
Fix training issues and compatability with transformers v4.x and v5
2 parents 972b002 + 8aff893 commit 6784e1f

File tree

6 files changed

+235
-270
lines changed

6 files changed

+235
-270
lines changed

gliner/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
def model_type(self):
325325
"""Auto-detect model type based on configuration."""
326326
if self.labels_decoder:
327-
if self.span_mode == 'token-level':
327+
if self.span_mode == "token-level":
328328
return "gliner_uni_encoder_token_decoder"
329329
else:
330330
return "gliner_uni_encoder_span_decoder"
@@ -357,4 +357,4 @@ def model_type(self):
357357
"gliner_bi_encoder_span": BiEncoderSpanConfig,
358358
"gliner_bi_encoder_token": BiEncoderTokenConfig,
359359
}
360-
)
360+
)

gliner/data_processing/processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,8 @@ def create_labels(self, batch):
807807

808808
for i, sentence_entities in enumerate(batch["entities"]):
809809
for st, ed, sp_label in sentence_entities:
810+
if sp_label not in batch["classes_to_id"][i]:
811+
continue
810812
lbl = batch["classes_to_id"][i][sp_label]
811813
class_idx = lbl - 1 # Convert to 0-indexed
812814

gliner/model.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import torch
1111
import onnxruntime as ort
12+
import transformers
1213
from tqdm import tqdm
1314
from torch import nn
15+
from packaging import version
1416
from safetensors import safe_open
1517
from transformers import AutoTokenizer
1618
from huggingface_hub import PyTorchModelHubMixin, snapshot_download
@@ -338,6 +340,14 @@ def _load_config(cls, config_file: Path, **config_overrides) -> object:
338340

339341
return config
340342

343+
@staticmethod
344+
def _set_tokenizer_spec_tokens(tokenizer):
345+
if hasattr(tokenizer, "add_bos_token"):
346+
tokenizer.add_bos_token = tokenizer.bos_token_id is not None
347+
if hasattr(tokenizer, "add_eos_token"):
348+
tokenizer.add_eos_token = tokenizer.eos_token_id is not None
349+
return tokenizer
350+
341351
@classmethod
342352
def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optional[Path] = None):
343353
"""
@@ -351,11 +361,14 @@ def _load_tokenizer(cls, config: GLiNERConfig, model_dir: Path, cache_dir: Optio
351361
Returns:
352362
Tokenizer instance or None
353363
"""
354-
if os.path.exists(model_dir / "tokenizer_config.json"):
355-
return AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
364+
tokenizer_config_path = model_dir / "tokenizer_config.json"
365+
366+
if tokenizer_config_path.is_file():
367+
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
356368
else:
357-
return AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
358-
return None
369+
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
370+
371+
return cls._set_tokenizer_spec_tokens(tokenizer)
359372

360373
@classmethod
361374
def _load_state_dict(cls, model_file: Path, map_location: str = "cpu"):
@@ -514,7 +527,7 @@ def load_from_config(
514527
tokenizer = None
515528
if load_tokenizer:
516529
tokenizer = AutoTokenizer.from_pretrained(config_instance.model_name, cache_dir=cache_dir)
517-
530+
cls._set_tokenizer_spec_tokens(tokenizer)
518531
# Create model instance from scratch
519532
instance = cls(
520533
config_instance,
@@ -1110,15 +1123,22 @@ def train_model(
11101123
# Create data collator
11111124
data_collator = self._create_data_collator()
11121125

1113-
# Create trainer
1114-
trainer = Trainer(
1115-
model=self,
1116-
args=training_args,
1117-
train_dataset=train_dataset,
1118-
eval_dataset=eval_dataset,
1119-
tokenizer=self.data_processor.transformer_tokenizer,
1120-
data_collator=data_collator,
1121-
)
1126+
# Create trainer with version-conditional tokenizer argument
1127+
# transformers < 5.0 requires tokenizer, >= 5.0 does not
1128+
trainer_kwargs = {
1129+
"model": self,
1130+
"args": training_args,
1131+
"train_dataset": train_dataset,
1132+
"eval_dataset": eval_dataset,
1133+
"data_collator": data_collator,
1134+
}
1135+
1136+
if version.parse(transformers.__version__) < version.parse("5.0.0"):
1137+
trainer_kwargs["tokenizer"] = self.data_processor.transformer_tokenizer
1138+
else:
1139+
trainer_kwargs["processing_class"] = self.data_processor.transformer_tokenizer
1140+
1141+
trainer = Trainer(**trainer_kwargs)
11221142

11231143
# Train
11241144
trainer.train()
@@ -1134,6 +1154,7 @@ def _create_model(self, config, backbone_from_pretrained, cache_dir, **kwargs):
11341154
def _create_data_processor(self, config, cache_dir, tokenizer=None, words_splitter=None, **kwargs):
11351155
if tokenizer is None:
11361156
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
1157+
self._set_tokenizer_spec_tokens(tokenizer)
11371158
self.data_processor = self.data_processor_class(config, tokenizer, words_splitter)
11381159
return self.data_processor
11391160

@@ -1520,6 +1541,7 @@ def _create_data_processor(self, config, cache_dir, tokenizer=None, words_splitt
15201541
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder, cache_dir=cache_dir)
15211542
if tokenizer is None:
15221543
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
1544+
self._set_tokenizer_spec_tokens(tokenizer)
15231545

15241546
self.data_processor = self.data_processor_class(
15251547
config, tokenizer, words_splitter, labels_tokenizer=labels_tokenizer
@@ -1981,6 +2003,7 @@ def _create_data_processor(self, config, cache_dir, tokenizer=None, words_splitt
19812003
"""Create data processor with decoder tokenizer."""
19822004
if tokenizer is None:
19832005
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
2006+
self._set_tokenizer_spec_tokens(tokenizer)
19842007

19852008
if words_splitter is None:
19862009
words_splitter = WordsSplitter(config.words_splitter_type)
@@ -2242,6 +2265,7 @@ def _create_data_processor(self, config, cache_dir, tokenizer=None, words_splitt
22422265
"""Create relation extraction data processor."""
22432266
if tokenizer is None:
22442267
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
2268+
self._set_tokenizer_spec_tokens(tokenizer)
22452269

22462270
if words_splitter is None:
22472271
words_splitter = WordsSplitter(config.words_splitter_type)
@@ -2271,7 +2295,7 @@ def inference(
22712295
self,
22722296
texts: Union[str, List[str]],
22732297
labels: List[str],
2274-
relations: List[str],
2298+
relations: List[str] = [],
22752299
flat_ner: bool = True,
22762300
threshold: float = 0.5,
22772301
adjacency_threshold: Optional[float] = None,

0 commit comments

Comments
 (0)