Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CVE issues #447

Merged
merged 4 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- fix lint issues on main by @rawwar in ([#374](https://github.com/opensearch-project/opensearch-py-ml/pull/374))
- fix CVE vulnerability by @rawwar in ([#383](https://github.com/opensearch-project/opensearch-py-ml/pull/383))
- refactor: replace 'payload' with 'body' in `create_standalone_connector` by @yerzhaisang ([#424](https://github.com/opensearch-project/opensearch-py-ml/pull/424))
- Fix CVE vulnerability by @nathaliellenaa in ([#447](https://github.com/opensearch-project/opensearch-py-ml/pull/447))

## [1.1.0]

Expand Down
120 changes: 120 additions & 0 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# GitHub history for details.

import json
import math
import os
import pickle
import platform
Expand All @@ -22,6 +23,7 @@
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import yaml
from accelerate import Accelerator, notebook_launcher
from mdutils.fileutils import MarkDownFile
Expand All @@ -31,6 +33,10 @@
from tqdm import tqdm
from transformers import TrainingArguments, get_linear_schedule_with_warmup
from transformers.convert_graph_to_onnx import convert
from transformers.models.distilbert.modeling_distilbert import (
DistilBertSdpaAttention,
MultiHeadSelfAttention,
)

from opensearch_py_ml.ml_commons.ml_common_utils import (
_generate_model_content_hash_value,
Expand Down Expand Up @@ -368,6 +374,119 @@
train_examples.append([queries[i], passages[i]])
return train_examples

def _get_parent_and_attr(self, model, module_name):
"""Retrieve the parent module and the attribute name for a given module."""
parts = module_name.split(".")
parent = model
for part in parts[:-1]: # Traverse until the second last part
parent = getattr(parent, part)
return parent, parts[-1]

def patch_model_weights(self, model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add more comments to explain what are we doing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

"""Replace DistilBertSdpaAttention with MultiHeadSelfAttention in the given model."""
# Collect the layers to replace in a separate list to avoid modifying dictionary while iterating
modules_to_replace = []

for name, module in model.named_modules():
if isinstance(module, DistilBertSdpaAttention):
modules_to_replace.append((name, module))

# Now replace the modules
for name, module in modules_to_replace:
# Retrieve the original config
config = getattr(module, "config", None)
if config is None:
raise ValueError(f"Module {name} does not have a 'config' attribute.")

Check warning on line 399 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L399

Added line #L399 was not covered by tests

# Create new MultiHeadSelfAttention with same config
new_module = MultiHeadSelfAttention(config)

# Copy weights into new module
new_module.q_lin.weight.data = module.q_lin.weight.data.clone()
new_module.q_lin.bias.data = module.q_lin.bias.data.clone()
new_module.k_lin.weight.data = module.k_lin.weight.data.clone()
new_module.k_lin.bias.data = module.k_lin.bias.data.clone()
new_module.v_lin.weight.data = module.v_lin.weight.data.clone()
new_module.v_lin.bias.data = module.v_lin.bias.data.clone()
new_module.out_lin.weight.data = module.out_lin.weight.data.clone()
new_module.out_lin.bias.data = module.out_lin.bias.data.clone()

# Modify the forward method to fix tuple return issue
def new_forward(
self, query, key, value, mask, head_mask, output_attentions
):
"""New forward function to fix tuple return issue"""
batch_size, seq_length, _ = query.shape
dim_per_head = self.dim // self.n_heads

# Ensure the mask is the correct shape
if mask.dim() == 2: # [batch_size, seq_length]
mask = mask[

Check warning on line 424 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L424

Added line #L424 was not covered by tests
:, None, None, :
] # Convert to [batch_size, 1, 1, seq_length]
elif mask.dim() == 3: # [batch_size, seq_length, seq_length]
mask = mask[

Check warning on line 428 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L428

Added line #L428 was not covered by tests
:, None, :, :
] # Convert to [batch_size, 1, seq_length, seq_length]

# Validate the new mask shape before applying expansion
if mask.shape[-1] != seq_length:
raise ValueError(

Check warning on line 434 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L434

Added line #L434 was not covered by tests
f"Mask shape {mask.shape} does not match sequence length {seq_length}"
)

# Apply mask expansion
mask = (mask == 0).expand(
batch_size, self.n_heads, seq_length, seq_length
)

q = (
self.q_lin(query)
.view(batch_size, seq_length, self.n_heads, dim_per_head)
.transpose(1, 2)
)
k = (
self.k_lin(key)
.view(batch_size, seq_length, self.n_heads, dim_per_head)
.transpose(1, 2)
)
v = (
self.v_lin(value)
.view(batch_size, seq_length, self.n_heads, dim_per_head)
.transpose(1, 2)
)

q = q / math.sqrt(dim_per_head)
scores = torch.matmul(q, k.transpose(-2, -1))

# Apply the correctly shaped mask
scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min)

weights = nn.functional.softmax(scores, dim=-1)
weights = nn.functional.dropout(
weights, p=self.dropout.p, training=self.training
)

context = torch.matmul(weights, v)
context = (
context.transpose(1, 2)
.contiguous()
.view(batch_size, seq_length, self.dim)
)
output = self.out_lin(context)

# ✅ Ensure return is always a tuple, as expected by DistilBERT
return (output, weights) if output_attentions else (output,)

# Replace forward method with the new function
new_module.forward = new_forward.__get__(new_module, MultiHeadSelfAttention)

# Replace module in the model
parent_module, attr_name = self._get_parent_and_attr(model, name)
setattr(parent_module, attr_name, new_module)

return model

def train_model(
self,
train_examples: List[List[str]],
Expand Down Expand Up @@ -616,6 +735,7 @@
plt.plot(loss[::100])
plt.show()

model = self.patch_model_weights(model)
# saving the pytorch model and the tokenizers.json file is saving at this step
model.save(self.folder_path)
device = "cpu"
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ pandas>=1.5.2,<2.3,!=2.1.0
matplotlib>=3.6.2,<4
numpy>=1.24.0,<2
opensearch-py>=2.2.0
torch>=2.0.1,<2.1.0
torch>=2.5.0,<2.6
onnx>=1.15.0
accelerate>=0.27
sentence_transformers>=2.5.0,<2.6
tqdm>4.66.0,<5
transformers>=4.36.0,<5
transformers>=4.47.0,<5
deprecated>=1.2.14,<2
mdutils>=1.6.0,<2
pillow>10.0.0,<11
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ pandas>=1.5.2,<2.3,!=2.1.0
matplotlib>=3.6.2,<4
numpy>=1.24.0,<2
opensearch-py>=2.2.0
torch>=2.0.1,<2.1.0
torch>=2.5.0,<2.6
onnx>=1.15.0
accelerate>=0.27
sentence_transformers>=2.5.0,<2.6
tqdm>4.66.0,<5
transformers>=4.36.0,<5
transformers>=4.47.0,<5
deprecated>=1.2.14,<2
Loading