Skip to content

Commit d98e87c

Browse files
wellenzhengzhengweijun
andauthored
fix: client close graceful (#22)
Co-authored-by: zhengweijun <[email protected]>
1 parent 42d6cc3 commit d98e87c

File tree

9 files changed

+942
-4
lines changed

9 files changed

+942
-4
lines changed

examples/example_types.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, Dict
3+
4+
Message = dict[str, str] # keys role, content
5+
MessageList = list[Message]
6+
7+
__all__ = ['Message', 'MessageList', 'Templates', 'SamplerBase', 'EvalResult', 'SingleEvalResult', 'Eval']
8+
9+
Templates = {
10+
'base': "{task_template}",
11+
12+
'meta-chat': "[INST] {task_template} [/INST]",
13+
14+
'vicuna-chat': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {task_template} ASSISTANT:",
15+
16+
'lwm-chat': "You are a helpful assistant. USER: {task_template} ASSISTANT: ",
17+
18+
'command-r-chat': "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{task_template}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
19+
20+
'chatglm-chat': "[gMASK]sop<|user|> \n {task_template}<|assistant|> \n ",
21+
22+
'glm-4-chat': "[gMASK]<sop><|user|>\n{task_template}<|assistant|>",
23+
24+
'tgi-glm-4-chat': "<|user|>\n{task_template}<|assistant|>",
25+
26+
'RWKV': "User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it\n\nUser: {task_template}\n\nAssistant:",
27+
}
28+
29+
class SamplerBase:
30+
"""
31+
Base class for defining a sampling model, which can be evaluated,
32+
or used as part of the grading process.
33+
"""
34+
35+
def __call__(self, message_list: MessageList) -> str:
36+
raise NotImplementedError
37+
38+
39+
@dataclass
40+
class EvalResult:
41+
"""
42+
Result of running an evaluation (usually consisting of many samples)
43+
"""
44+
45+
score: Optional[float] = None # top-line metric
46+
metrics: Optional[Dict[str, float]] = None # other metrics
47+
48+
49+
@dataclass
50+
class SingleEvalResult:
51+
"""
52+
Result of evaluating a single sample
53+
"""
54+
55+
score: Optional[float] = None # top-line metric
56+
metrics: Dict[str, float] = field(default_factory=dict) # other metrics with default empty dict
57+
58+
59+
class Eval:
60+
"""
61+
Base class for defining an evaluation.
62+
"""
63+
64+
def __call__(self, sampler: SamplerBase) -> EvalResult:
65+
raise NotImplementedError

examples/glm4_5_thinking.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import sys
3+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4+
import time
5+
import traceback
6+
from typing import Optional
7+
8+
from example_types import MessageList, SamplerBase
9+
from zai import ZhipuAiClient
10+
11+
12+
class ZaiSampler(SamplerBase):
13+
"""
14+
Sample from TGI's completion API
15+
"""
16+
17+
def __init__(
18+
self,
19+
model: str = "glm-4.5",
20+
api_key: str = '',
21+
system_message: Optional[str] = None,
22+
temperature: float = 0.0,
23+
max_tokens: int = 4096,
24+
stream: bool = False,
25+
):
26+
self.system_message = system_message
27+
self.temperature = temperature
28+
self.max_tokens = max_tokens
29+
self.model = model
30+
self.client = ZhipuAiClient(api_key=api_key)
31+
self.stream = stream
32+
33+
def get_resp(self, message_list):
34+
for _ in range(3):
35+
try:
36+
chat_completion = self.client.chat.completions.create(
37+
messages=message_list,
38+
model=self.model,
39+
temperature=self.temperature,
40+
top_p=self.top_p,
41+
max_tokens=self.max_tokens
42+
)
43+
output = chat_completion.choices[0].message.content
44+
return output
45+
except Exception as e:
46+
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
47+
time.sleep(1)
48+
continue
49+
print(f"failed, last exception: {e if 'e' in locals() else ''}")
50+
return ''
51+
52+
53+
def get_resp_stream(self, message_list, top_p=-1, temperature=-1):
54+
temperature = temperature if temperature > 0 else self.temperature
55+
top_p = top_p if top_p > 0 else 0.95
56+
final = ''
57+
for _ in range(200):
58+
try:
59+
chat_completion_res = self.client.chat.completions.create(
60+
model=self.model,
61+
messages=message_list,
62+
thinking={
63+
"type": "enabled",
64+
},
65+
stream=True,
66+
max_tokens=self.max_tokens,
67+
temperature=temperature
68+
)
69+
for chunk in chat_completion_res:
70+
if chunk.choices[0].delta.content:
71+
final += chunk.choices[0].delta.content
72+
break
73+
except Exception as e:
74+
final = ""
75+
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
76+
time.sleep(5)
77+
continue
78+
79+
if final == '':
80+
print(f"failed in get_resp for 50 times, last exception: {e if 'e' in locals() else ''}")
81+
return ''
82+
83+
content = ''
84+
if '</think>' in final:
85+
content = final.split("</think>")[-1].strip()
86+
if not content:
87+
content = final[-512:].strip()
88+
else:
89+
content = final[-512:].strip()
90+
91+
return content
92+
93+
def __call__(self, message_list: MessageList, top_p=0.95, temperature=0.6) -> str:
94+
if self.system_message:
95+
message_list = [
96+
{
97+
"role": "system", "content": self.system_message
98+
}
99+
] + message_list
100+
101+
if not self.stream:
102+
return self.get_resp(message_list, top_p, temperature)
103+
else:
104+
return self.get_resp_stream(message_list, top_p, temperature)
105+
106+
107+
if __name__ == "__main__":
108+
client = ZaiSampler(model="glm-4.5", api_key=os.getenv("ZAI_API_KEY"), stream=True)
109+
messages = [
110+
{"role": "user", "content": "Hi?"},
111+
]
112+
response = client(messages)
113+
print(response)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zai-sdk"
3-
version = "0.0.1"
3+
version = "0.0.2"
44
description = "A SDK library for accessing big model apis from Z.ai"
55
authors = ["Z.ai"]
66
readme = "README.md"

src/zai/_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,13 @@ def __del__(self) -> None:
204204
if self._has_custom_http_client:
205205
return
206206

207-
self.close()
207+
try:
208+
# Check if client is still valid before closing
209+
if hasattr(self, '_client') and self._client is not None:
210+
self.close()
211+
except Exception:
212+
# Ignore any exceptions during cleanup to avoid masking the original error
213+
pass
208214

209215

210216
class ZaiClient(BaseClient):

src/zai/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
__title__ = 'Z.ai'
2-
__version__ = '0.0.1'
2+
__version__ = '0.0.2'

src/zai/core/_http_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,12 @@ def is_closed(self) -> bool:
481481
return self._client.is_closed
482482

483483
def close(self):
484-
self._client.close()
484+
try:
485+
if hasattr(self, '_client') and self._client is not None and not self._client.is_closed:
486+
self._client.close()
487+
except Exception:
488+
# Ignore any exceptions during cleanup to avoid masking the original error
489+
pass
485490

486491
def __enter__(self):
487492
return self

0 commit comments

Comments
 (0)