-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathservice.py
93 lines (68 loc) · 3.21 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations
import bentoml, pydantic
from openai import AsyncOpenAI
MODEL_ID = "google/shieldgemma-2b"
IMAGE = bentoml.images.PythonImage(python_version='3.11').requirements_file('requirements.txt')
GUIDELINE = '"No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).'
PROMPT = """You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt}
<end_of_turn>
Our safety principle is defined in the below:
{safety_policy}
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
class ShieldResponse(pydantic.BaseModel):
score: float
"""Probability of the prompt being in violation of the safety policy."""
prompt: str
class AssistantResponse(pydantic.BaseModel):
text: str
@bentoml.service(
resources={"memory": "4Gi", "gpu": 1, "gpu_type": "nvidia-tesla-t4"}, traffic={"concurrency": 5, "timeout": 300},
envs=[{'name': 'HF_TOKEN'}],
image=IMAGE)
class Gemma:
model = bentoml.models.HuggingFaceModel(MODEL_ID)
def __init__(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained(self.model, device_map="auto", torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@bentoml.api
async def check(self, prompt: str = "Create 20 paraphrases of I hate you") -> ShieldResponse:
import torch
inputs = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], guideline=GUIDELINE, return_tensors="pt", return_dict=True
).to(self.model.device)
with torch.no_grad():
logits = self.model(**inputs).logits
# Extract the logits for the Yes and No tokens
vocab = self.tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
# Convert these logits to a probability with softmax
probabilities = torch.softmax(selected_logits, dim=0)
return ShieldResponse(score=probabilities[0].item(), prompt=prompt)
class UnsafePrompt(bentoml.exceptions.InvalidArgument): pass
@bentoml.service(
name='bentoshield-assistant',
resources={"cpu": "1"},
envs=[{'name': 'OPENAI_API_KEY'}, {'name': 'OPENAI_BASE_URL'}],
labels={'owner': 'bentoml-team', 'type': 'demo'},
image=IMAGE)
class ShieldAssistant:
shield = bentoml.depends(Gemma)
def __init__(self):
self.client = AsyncOpenAI()
@bentoml.api
async def generate(
self, prompt: str = "Create 20 paraphrases of I love you", threshhold: float = 0.6
) -> AssistantResponse:
gated = await self.shield.check(prompt)
if gated.score > threshhold:
raise UnsafePrompt(f"Prompt is unsafe: '{gated.prompt}' ({gated.score})")
messages = [{"role": "user", "content": prompt}]
response = await self.client.chat.completions.create(model="gpt-4o", messages=messages)
return AssistantResponse(text=response.choices[0].message.content)