Skip to content

Commit e8c3c40

Browse files
Change torch version to 2.5.1, replaced DistilBertSdpaAttention with MultiHeadSelfAttention
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 134806f commit e8c3c40

File tree

3 files changed

+122
-2
lines changed

3 files changed

+122
-2
lines changed

opensearch_py_ml/ml_models/sentencetransformermodel.py

+120
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# GitHub history for details.
77

88
import json
9+
import math
910
import os
1011
import pickle
1112
import platform
@@ -22,6 +23,7 @@
2223
import numpy as np
2324
import pandas as pd
2425
import torch
26+
import torch.nn as nn
2527
import yaml
2628
from accelerate import Accelerator, notebook_launcher
2729
from mdutils.fileutils import MarkDownFile
@@ -31,6 +33,10 @@
3133
from tqdm import tqdm
3234
from transformers import TrainingArguments, get_linear_schedule_with_warmup
3335
from transformers.convert_graph_to_onnx import convert
36+
from transformers.models.distilbert.modeling_distilbert import (
37+
DistilBertSdpaAttention,
38+
MultiHeadSelfAttention,
39+
)
3440

3541
from opensearch_py_ml.ml_commons.ml_common_utils import (
3642
_generate_model_content_hash_value,
@@ -368,6 +374,119 @@ def load_training_data(self, query_df) -> List[List[str]]:
368374
train_examples.append([queries[i], passages[i]])
369375
return train_examples
370376

377+
def _get_parent_and_attr(self, model, module_name):
378+
"""Retrieve the parent module and the attribute name for a given module."""
379+
parts = module_name.split(".")
380+
parent = model
381+
for part in parts[:-1]: # Traverse until the second last part
382+
parent = getattr(parent, part)
383+
return parent, parts[-1]
384+
385+
def patch_model_weights(self, model):
386+
"""Replace DistilBertSdpaAttention with MultiHeadSelfAttention in the given model."""
387+
# Collect the layers to replace in a separate list to avoid modifying dictionary while iterating
388+
modules_to_replace = []
389+
390+
for name, module in model.named_modules():
391+
if isinstance(module, DistilBertSdpaAttention):
392+
modules_to_replace.append((name, module))
393+
394+
# Now replace the modules
395+
for name, module in modules_to_replace:
396+
# Retrieve the original config
397+
config = getattr(module, "config", None)
398+
if config is None:
399+
raise ValueError(f"Module {name} does not have a 'config' attribute.")
400+
401+
# Create new MultiHeadSelfAttention with same config
402+
new_module = MultiHeadSelfAttention(config)
403+
404+
# Copy weights into new module
405+
new_module.q_lin.weight.data = module.q_lin.weight.data.clone()
406+
new_module.q_lin.bias.data = module.q_lin.bias.data.clone()
407+
new_module.k_lin.weight.data = module.k_lin.weight.data.clone()
408+
new_module.k_lin.bias.data = module.k_lin.bias.data.clone()
409+
new_module.v_lin.weight.data = module.v_lin.weight.data.clone()
410+
new_module.v_lin.bias.data = module.v_lin.bias.data.clone()
411+
new_module.out_lin.weight.data = module.out_lin.weight.data.clone()
412+
new_module.out_lin.bias.data = module.out_lin.bias.data.clone()
413+
414+
# Modify the forward method to fix tuple return issue
415+
def new_forward(
416+
self, query, key, value, mask, head_mask, output_attentions
417+
):
418+
"""New forward function to fix tuple return issue"""
419+
batch_size, seq_length, _ = query.shape
420+
dim_per_head = self.dim // self.n_heads
421+
422+
# Ensure the mask is the correct shape
423+
if mask.dim() == 2: # [batch_size, seq_length]
424+
mask = mask[
425+
:, None, None, :
426+
] # Convert to [batch_size, 1, 1, seq_length]
427+
elif mask.dim() == 3: # [batch_size, seq_length, seq_length]
428+
mask = mask[
429+
:, None, :, :
430+
] # Convert to [batch_size, 1, seq_length, seq_length]
431+
432+
# Validate the new mask shape before applying expansion
433+
if mask.shape[-1] != seq_length:
434+
raise ValueError(
435+
f"Mask shape {mask.shape} does not match sequence length {seq_length}"
436+
)
437+
438+
# Apply mask expansion
439+
mask = (mask == 0).expand(
440+
batch_size, self.n_heads, seq_length, seq_length
441+
)
442+
443+
q = (
444+
self.q_lin(query)
445+
.view(batch_size, seq_length, self.n_heads, dim_per_head)
446+
.transpose(1, 2)
447+
)
448+
k = (
449+
self.k_lin(key)
450+
.view(batch_size, seq_length, self.n_heads, dim_per_head)
451+
.transpose(1, 2)
452+
)
453+
v = (
454+
self.v_lin(value)
455+
.view(batch_size, seq_length, self.n_heads, dim_per_head)
456+
.transpose(1, 2)
457+
)
458+
459+
q = q / math.sqrt(dim_per_head)
460+
scores = torch.matmul(q, k.transpose(-2, -1))
461+
462+
# Apply the correctly shaped mask
463+
scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min)
464+
465+
weights = nn.functional.softmax(scores, dim=-1)
466+
weights = nn.functional.dropout(
467+
weights, p=self.dropout.p, training=self.training
468+
)
469+
470+
context = torch.matmul(weights, v)
471+
context = (
472+
context.transpose(1, 2)
473+
.contiguous()
474+
.view(batch_size, seq_length, self.dim)
475+
)
476+
output = self.out_lin(context)
477+
478+
# ✅ Ensure return is always a tuple, as expected by DistilBERT
479+
return (output, weights) if output_attentions else (output,)
480+
481+
# Replace forward method with the new function
482+
new_module.forward = new_forward.__get__(new_module, MultiHeadSelfAttention)
483+
484+
# Replace module in the model
485+
parent_module, attr_name = self._get_parent_and_attr(model, name)
486+
setattr(parent_module, attr_name, new_module)
487+
488+
return model
489+
371490
def train_model(
372491
self,
373492
train_examples: List[List[str]],
@@ -616,6 +735,7 @@ def train_model(
616735
plt.plot(loss[::100])
617736
plt.show()
618737

738+
model = self.patch_model_weights(model)
619739
# saving the pytorch model and the tokenizers.json file is saving at this step
620740
model.save(self.folder_path)
621741
device = "cpu"

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pandas>=1.5.2,<2.3,!=2.1.0
55
matplotlib>=3.6.2,<4
66
numpy>=1.24.0,<2
77
opensearch-py>=2.2.0
8-
torch>=2.5.0
8+
torch>=2.5.0,<2.6
99
onnx>=1.15.0
1010
accelerate>=0.27
1111
sentence_transformers>=2.5.0,<2.6

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pandas>=1.5.2,<2.3,!=2.1.0
55
matplotlib>=3.6.2,<4
66
numpy>=1.24.0,<2
77
opensearch-py>=2.2.0
8-
torch>=2.5.0
8+
torch>=2.5.0,<2.6
99
onnx>=1.15.0
1010
accelerate>=0.27
1111
sentence_transformers>=2.5.0,<2.6

0 commit comments

Comments
 (0)