Skip to content

Commit 171df6c

Browse files
author
shubham
committed
feat: Add source
1 parent 5a13e9a commit 171df6c

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

azure/azure_openai_manifold_pipeline.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ async def on_valves_updated(self):
3939
self.set_pipelines()
4040

4141
async def on_startup(self):
42+
# This function is called when the server is started.
4243
print(f"on_startup:{__name__}")
4344

4445
async def on_shutdown(self):
46+
# This function is called when the server is stopped.
4547
print(f"on_shutdown:{__name__}")
4648

4749
def pipe(
@@ -55,32 +57,21 @@ def pipe(
5557
print(messages)
5658
print(user_message)
5759

60+
user_email = body.get("user", {}).get("email")
61+
5862
headers = {
5963
"api-key": self.valves.AZURE_OPENAI_API_KEY,
6064
"Content-Type": "application/json",
6165
}
6266

63-
# 1) Read the 'user' email from body
64-
# user_id = body.get("user", {})
65-
# user_name = user_id.get("email", "").split("@")[0]
66-
user_name = body.get("name", {})
67-
68-
# 2) Build the base URL *manually* to preserve `@` in the `source`
69-
# This ensures the server sees `[email protected]` literally
70-
# instead of `source=you%40company.com`
71-
if user_name:
72-
full_url = (
73-
f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/{model_id}/chat/completions"
74-
f"?api-version={self.valves.AZURE_OPENAI_API_VERSION}&source={user_name}"
75-
)
76-
else:
77-
# If we have no email, just omit the source from the query string
78-
full_url = (
79-
f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/{model_id}/chat/completions"
80-
f"?api-version={self.valves.AZURE_OPENAI_API_VERSION}"
81-
)
67+
# URL for Chat Completions in Azure OpenAI
68+
url = (
69+
f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/"
70+
f"{model_id}/chat/completions?api-version={self.valves.AZURE_OPENAI_API_VERSION}&source={user_email}"
71+
)
8272

8373
# --- Define the allowed parameter sets ---
74+
# (1) Default allowed params (non-o1)
8475
allowed_params_default = {
8576
"messages",
8677
"temperature",
@@ -97,7 +88,7 @@ def pipe(
9788
"presence_penalty",
9889
"frequency_penalty",
9990
"logit_bias",
100-
"user",
91+
"user",
10192
"function_call",
10293
"funcions",
10394
"tools",
@@ -109,6 +100,7 @@ def pipe(
109100
"seed",
110101
}
111102

103+
# (2) o1 models allowed params
112104
allowed_params_o1 = {
113105
"model",
114106
"messages",
@@ -118,31 +110,45 @@ def pipe(
118110
"presence_penalty",
119111
"frequency_penalty",
120112
"logit_bias",
121-
"user", # <--- still here too
113+
"user",
122114
}
123115

116+
# Simple helper to detect if it's an o1 model
124117
def is_o1_model(m: str) -> bool:
125-
return "o1" in m or m.endswith("o")
118+
# Adjust this check to your naming pattern for o1 models
119+
return "o1" in m or m.startswith("o1")
120+
121+
# Ensure user is a string
122+
if "user" in body and not isinstance(body["user"], str):
123+
body["user"] = body["user"].get("id", str(body["user"]))
126124

127125
# If it's an o1 model, do a "fake streaming" approach
128126
if is_o1_model(model_id):
129-
body.pop("stream", None) # only remove 'stream' if present
127+
# We'll remove "stream" from the body if present (since we'll do manual streaming),
128+
# then filter to the allowed params for o1 models.
129+
body.pop("stream", None)
130130
filtered_body = {k: v for k, v in body.items() if k in allowed_params_o1}
131131

132+
# Log which fields were dropped
132133
if len(body) != len(filtered_body):
133134
dropped_keys = set(body.keys()) - set(filtered_body.keys())
134135
print(f"Dropped params: {', '.join(dropped_keys)}")
135136

136137
try:
138+
# We make a normal request (non-streaming)
137139
r = requests.post(
138-
url=full_url,
140+
url=url,
139141
json=filtered_body,
140142
headers=headers,
141143
stream=False,
142144
)
143145
r.raise_for_status()
144146

147+
# Parse the full JSON response
145148
data = r.json()
149+
150+
# Typically, the text content is in data["choices"][0]["message"]["content"]
151+
# This may vary depending on your actual response shape.
146152
content = ""
147153
if (
148154
isinstance(data, dict)
@@ -157,6 +163,7 @@ def is_o1_model(m: str) -> bool:
157163
content = str(data)
158164

159165
def chunk_text(text: str, chunk_size: int = 30) -> Generator[str, None, None]:
166+
"""Yield text in fixed-size chunks."""
160167
for i in range(0, len(text), chunk_size):
161168
yield text[i : i + chunk_size]
162169

@@ -173,23 +180,26 @@ def fake_stream() -> Generator[str, None, None]:
173180
return f"Error: {e}"
174181

175182
else:
183+
# Normal pipeline for non-o1 models:
176184
filtered_body = {k: v for k, v in body.items() if k in allowed_params_default}
177185
if len(body) != len(filtered_body):
178186
dropped_keys = set(body.keys()) - set(filtered_body.keys())
179187
print(f"Dropped params: {', '.join(dropped_keys)}")
180188

181189
try:
182190
r = requests.post(
183-
url=full_url,
191+
url=url,
184192
json=filtered_body,
185193
headers=headers,
186194
stream=True,
187195
)
188196
r.raise_for_status()
189197

190198
if filtered_body.get("stream"):
199+
# Real streaming
191200
return r.iter_lines()
192201
else:
202+
# Just return the JSON
193203
return r.json()
194204

195205
except Exception as e:

0 commit comments

Comments
 (0)