|
6 | 6 | # GitHub history for details.
|
7 | 7 |
|
8 | 8 | import json
|
| 9 | +import math |
9 | 10 | import os
|
10 | 11 | import pickle
|
11 | 12 | import platform
|
|
22 | 23 | import numpy as np
|
23 | 24 | import pandas as pd
|
24 | 25 | import torch
|
| 26 | +import torch.nn as nn |
25 | 27 | import yaml
|
26 | 28 | from accelerate import Accelerator, notebook_launcher
|
27 | 29 | from mdutils.fileutils import MarkDownFile
|
|
31 | 33 | from tqdm import tqdm
|
32 | 34 | from transformers import TrainingArguments, get_linear_schedule_with_warmup
|
33 | 35 | from transformers.convert_graph_to_onnx import convert
|
| 36 | +from transformers.models.distilbert.modeling_distilbert import ( |
| 37 | + DistilBertSdpaAttention, |
| 38 | + MultiHeadSelfAttention, |
| 39 | +) |
34 | 40 |
|
35 | 41 | from opensearch_py_ml.ml_commons.ml_common_utils import (
|
36 | 42 | _generate_model_content_hash_value,
|
@@ -368,6 +374,119 @@ def load_training_data(self, query_df) -> List[List[str]]:
|
368 | 374 | train_examples.append([queries[i], passages[i]])
|
369 | 375 | return train_examples
|
370 | 376 |
|
| 377 | + def _get_parent_and_attr(self, model, module_name): |
| 378 | + """Retrieve the parent module and the attribute name for a given module.""" |
| 379 | + parts = module_name.split(".") |
| 380 | + parent = model |
| 381 | + for part in parts[:-1]: # Traverse until the second last part |
| 382 | + parent = getattr(parent, part) |
| 383 | + return parent, parts[-1] |
| 384 | + |
| 385 | + def patch_model_weights(self, model): |
| 386 | + """Replace DistilBertSdpaAttention with MultiHeadSelfAttention in the given model.""" |
| 387 | + # Collect the layers to replace in a separate list to avoid modifying dictionary while iterating |
| 388 | + modules_to_replace = [] |
| 389 | + |
| 390 | + for name, module in model.named_modules(): |
| 391 | + if isinstance(module, DistilBertSdpaAttention): |
| 392 | + modules_to_replace.append((name, module)) |
| 393 | + |
| 394 | + # Now replace the modules |
| 395 | + for name, module in modules_to_replace: |
| 396 | + # Retrieve the original config |
| 397 | + config = getattr(module, "config", None) |
| 398 | + if config is None: |
| 399 | + raise ValueError(f"Module {name} does not have a 'config' attribute.") |
| 400 | + |
| 401 | + # Create new MultiHeadSelfAttention with same config |
| 402 | + new_module = MultiHeadSelfAttention(config) |
| 403 | + |
| 404 | + # Copy weights into new module |
| 405 | + new_module.q_lin.weight.data = module.q_lin.weight.data.clone() |
| 406 | + new_module.q_lin.bias.data = module.q_lin.bias.data.clone() |
| 407 | + new_module.k_lin.weight.data = module.k_lin.weight.data.clone() |
| 408 | + new_module.k_lin.bias.data = module.k_lin.bias.data.clone() |
| 409 | + new_module.v_lin.weight.data = module.v_lin.weight.data.clone() |
| 410 | + new_module.v_lin.bias.data = module.v_lin.bias.data.clone() |
| 411 | + new_module.out_lin.weight.data = module.out_lin.weight.data.clone() |
| 412 | + new_module.out_lin.bias.data = module.out_lin.bias.data.clone() |
| 413 | + |
| 414 | + # Modify the forward method to fix tuple return issue |
| 415 | + def new_forward( |
| 416 | + self, query, key, value, mask, head_mask, output_attentions |
| 417 | + ): |
| 418 | + """New forward function to fix tuple return issue""" |
| 419 | + batch_size, seq_length, _ = query.shape |
| 420 | + dim_per_head = self.dim // self.n_heads |
| 421 | + |
| 422 | + # Ensure the mask is the correct shape |
| 423 | + if mask.dim() == 2: # [batch_size, seq_length] |
| 424 | + mask = mask[ |
| 425 | + :, None, None, : |
| 426 | + ] # Convert to [batch_size, 1, 1, seq_length] |
| 427 | + elif mask.dim() == 3: # [batch_size, seq_length, seq_length] |
| 428 | + mask = mask[ |
| 429 | + :, None, :, : |
| 430 | + ] # Convert to [batch_size, 1, seq_length, seq_length] |
| 431 | + |
| 432 | + # Validate the new mask shape before applying expansion |
| 433 | + if mask.shape[-1] != seq_length: |
| 434 | + raise ValueError( |
| 435 | + f"Mask shape {mask.shape} does not match sequence length {seq_length}" |
| 436 | + ) |
| 437 | + |
| 438 | + # Apply mask expansion |
| 439 | + mask = (mask == 0).expand( |
| 440 | + batch_size, self.n_heads, seq_length, seq_length |
| 441 | + ) |
| 442 | + |
| 443 | + q = ( |
| 444 | + self.q_lin(query) |
| 445 | + .view(batch_size, seq_length, self.n_heads, dim_per_head) |
| 446 | + .transpose(1, 2) |
| 447 | + ) |
| 448 | + k = ( |
| 449 | + self.k_lin(key) |
| 450 | + .view(batch_size, seq_length, self.n_heads, dim_per_head) |
| 451 | + .transpose(1, 2) |
| 452 | + ) |
| 453 | + v = ( |
| 454 | + self.v_lin(value) |
| 455 | + .view(batch_size, seq_length, self.n_heads, dim_per_head) |
| 456 | + .transpose(1, 2) |
| 457 | + ) |
| 458 | + |
| 459 | + q = q / math.sqrt(dim_per_head) |
| 460 | + scores = torch.matmul(q, k.transpose(-2, -1)) |
| 461 | + |
| 462 | + # Apply the correctly shaped mask |
| 463 | + scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min) |
| 464 | + |
| 465 | + weights = nn.functional.softmax(scores, dim=-1) |
| 466 | + weights = nn.functional.dropout( |
| 467 | + weights, p=self.dropout.p, training=self.training |
| 468 | + ) |
| 469 | + |
| 470 | + context = torch.matmul(weights, v) |
| 471 | + context = ( |
| 472 | + context.transpose(1, 2) |
| 473 | + .contiguous() |
| 474 | + .view(batch_size, seq_length, self.dim) |
| 475 | + ) |
| 476 | + output = self.out_lin(context) |
| 477 | + |
| 478 | + # ✅ Ensure return is always a tuple, as expected by DistilBERT |
| 479 | + return (output, weights) if output_attentions else (output,) |
| 480 | + |
| 481 | + # Replace forward method with the new function |
| 482 | + new_module.forward = new_forward.__get__(new_module, MultiHeadSelfAttention) |
| 483 | + |
| 484 | + # Replace module in the model |
| 485 | + parent_module, attr_name = self._get_parent_and_attr(model, name) |
| 486 | + setattr(parent_module, attr_name, new_module) |
| 487 | + |
| 488 | + return model |
| 489 | + |
371 | 490 | def train_model(
|
372 | 491 | self,
|
373 | 492 | train_examples: List[List[str]],
|
@@ -616,6 +735,7 @@ def train_model(
|
616 | 735 | plt.plot(loss[::100])
|
617 | 736 | plt.show()
|
618 | 737 |
|
| 738 | + model = self.patch_model_weights(model) |
619 | 739 | # saving the pytorch model and the tokenizers.json file is saving at this step
|
620 | 740 | model.save(self.folder_path)
|
621 | 741 | device = "cpu"
|
|
0 commit comments