Skip to content

Commit ffdee4e

Browse files
committed
Add PyTorch Exception Catch
1 parent a40fa43 commit ffdee4e

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
def get_process_prompt_response(request, validated_data):
2+
try:
3+
output = generate_sequences_from_prompt(**validated_data)
4+
except RuntimeError as exc:
5+
if "out of memory" in str(exc):
6+
logger.exception(
7+
f"Ran Out of Memory When Running {validated_data}. Clearing Cache."
8+
)
9+
torch.cuda.empty_cache()
10+
11+
oom_response = get_oom_response(validated_data)
12+
return oom_response
13+
14+
response = serialize_sequences_to_response(
15+
output,
16+
validated_data["prompt"],
17+
validated_data["cache_key"],
18+
WebsocketMessageTypes.COMPLETED_RESPONSE,
19+
completed=validated_data["length"],
20+
length=validated_data["length"],
21+
)
22+
23+
# clear cache on all responses (maybe this is overkill)
24+
torch.cuda.empty_cache()
25+
return response

0 commit comments

Comments
 (0)