Skip to content

Commit 4e141df

Browse files
(feat) Team level model-specific tpm/rpm limits + working key-level validation of tpm/rpm limit when assigned to team (#15513)
* fix(support-model-specific-tpm/rpm-limits): Allows setting rate limits by tpm/rpm for models by team * fix(key_management_endpoints.py): enforce guaranteed throughput with key-level model tpm/rpm limits, when team-level tpm/rpm limits are set * test: add unit testing * fix: fix minor linting errors * fix: refactor
1 parent 46d754a commit 4e141df

File tree

6 files changed

+299
-37
lines changed

6 files changed

+299
-37
lines changed

litellm/proxy/_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,8 @@ class NewTeamRequest(TeamBase):
12841284
prompts: Optional[List[str]] = None
12851285
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
12861286
allowed_passthrough_routes: Optional[list] = None
1287+
model_rpm_limit: Optional[Dict[str, int]] = None
1288+
model_tpm_limit: Optional[Dict[str, int]] = None
12871289
team_member_budget: Optional[float] = (
12881290
None # allow user to set a budget for all team members
12891291
)
@@ -1340,6 +1342,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
13401342
team_member_tpm_limit: Optional[int] = None
13411343
team_member_key_duration: Optional[str] = None
13421344
allowed_passthrough_routes: Optional[list] = None
1345+
model_rpm_limit: Optional[Dict[str, int]] = None
1346+
model_tpm_limit: Optional[Dict[str, int]] = None
13431347

13441348

13451349
class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase):

litellm/proxy/auth/auth_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,22 @@ def get_key_model_tpm_limit(
453453
return None
454454

455455

456+
def get_team_model_rpm_limit(
457+
user_api_key_dict: UserAPIKeyAuth,
458+
) -> Optional[Dict[str, int]]:
459+
if user_api_key_dict.team_metadata:
460+
return user_api_key_dict.team_metadata.get("model_rpm_limit")
461+
return None
462+
463+
464+
def get_team_model_tpm_limit(
465+
user_api_key_dict: UserAPIKeyAuth,
466+
) -> Optional[Dict[str, int]]:
467+
if user_api_key_dict.team_metadata:
468+
return user_api_key_dict.team_metadata.get("model_tpm_limit")
469+
return None
470+
471+
456472
def is_pass_through_provider_route(route: str) -> bool:
457473
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
458474
"vertex-ai",

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 117 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
REDIS_CLUSTER_SLOTS = 16384
104104
REDIS_NODE_HASHTAG_NAME = "all_keys"
105105

106+
106107
class RateLimitDescriptorRateLimitObject(TypedDict, total=False):
107108
requests_per_unit: Optional[int]
108109
tokens_per_unit: Optional[int]
@@ -157,15 +158,17 @@ def __init__(self, internal_usage_cache: InternalUsageCache):
157158
def _is_redis_cluster(self) -> bool:
158159
"""
159160
Check if the dual cache is using Redis cluster.
160-
161+
161162
Returns:
162163
bool: True if using Redis cluster, False otherwise.
163164
"""
164165
from litellm.caching.redis_cluster_cache import RedisClusterCache
165-
166+
166167
return (
167168
self.internal_usage_cache.dual_cache.redis_cache is not None
168-
and isinstance(self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache)
169+
and isinstance(
170+
self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache
171+
)
169172
)
170173

171174
async def in_memory_cache_sliding_window(
@@ -310,7 +313,7 @@ def is_cache_list_over_limit(
310313
)
311314

312315
return RateLimitResponse(overall_code=overall_code, statuses=statuses)
313-
316+
314317
def keyslot_for_redis_cluster(self, key: str) -> int:
315318
"""
316319
Compute the Redis Cluster slot for a given key.
@@ -325,34 +328,34 @@ def keyslot_for_redis_cluster(self, key: str) -> int:
325328
Returns:
326329
int: The slot number (0-16383).
327330
328-
331+
329332
"""
330333
# Handle hash tags: use substring between { and }
331-
start = key.find('{')
334+
start = key.find("{")
332335
if start != -1:
333-
end = key.find('}', start + 1)
336+
end = key.find("}", start + 1)
334337
if end != -1 and end != start + 1:
335-
key = key[start + 1:end]
338+
key = key[start + 1 : end]
336339

337340
# Compute CRC16 and mod 16384
338-
crc = binascii.crc_hqx(key.encode('utf-8'), 0)
341+
crc = binascii.crc_hqx(key.encode("utf-8"), 0)
339342
return crc % REDIS_CLUSTER_SLOTS
340343

341344
def _group_keys_by_hash_tag(self, keys: List[str]) -> Dict[str, List[str]]:
342345
"""
343346
Group keys by their Redis hash tag to ensure cluster compatibility.
344-
347+
345348
For Redis clusters, uses slot calculation to group keys that belong to the same slot.
346349
For regular Redis, no grouping is needed - all keys can be processed together.
347350
"""
348351
groups: Dict[str, List[str]] = {}
349-
352+
350353
# Use slot calculation for Redis clusters only
351354
if self._is_redis_cluster():
352355
for key in keys:
353356
slot = self.keyslot_for_redis_cluster(key)
354357
slot_key = f"slot_{slot}"
355-
358+
356359
if slot_key not in groups:
357360
groups[slot_key] = []
358361
groups[slot_key].append(key)
@@ -414,7 +417,7 @@ async def should_rate_limit(
414417
Check if any of the rate limit descriptors should be rate limited.
415418
Returns a RateLimitResponse with the overall code and status for each descriptor.
416419
Uses batch operations for Redis to improve performance.
417-
420+
418421
Args:
419422
descriptors: List of rate limit descriptors to check
420423
parent_otel_span: Optional OpenTelemetry span for tracing
@@ -499,7 +502,7 @@ async def should_rate_limit(
499502
parent_otel_span=parent_otel_span,
500503
local_only=False, # Check Redis too
501504
)
502-
505+
503506
# For keys that don't exist yet, set them to 0
504507
if cache_values is None:
505508
cache_values = []
@@ -546,6 +549,66 @@ async def should_rate_limit(
546549
)
547550
return rate_limit_response
548551

552+
def _add_model_per_key_rate_limit_descriptor(
553+
self,
554+
user_api_key_dict: UserAPIKeyAuth,
555+
requested_model: Optional[str],
556+
descriptors: List[RateLimitDescriptor],
557+
) -> None:
558+
"""
559+
Add model-specific rate limit descriptor for API key if applicable.
560+
561+
Args:
562+
user_api_key_dict: User API key authentication dictionary
563+
requested_model: The model being requested
564+
descriptors: List of rate limit descriptors to append to
565+
"""
566+
from litellm.proxy.auth.auth_utils import (
567+
get_key_model_rpm_limit,
568+
get_key_model_tpm_limit,
569+
)
570+
571+
if not requested_model:
572+
return
573+
574+
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
575+
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
576+
577+
if _tpm_limit_for_key_model is None and _rpm_limit_for_key_model is None:
578+
return
579+
580+
_tpm_limit_for_key_model = _tpm_limit_for_key_model or {}
581+
_rpm_limit_for_key_model = _rpm_limit_for_key_model or {}
582+
583+
# Check if model has any rate limits configured
584+
should_check_rate_limit = (
585+
requested_model in _tpm_limit_for_key_model
586+
or requested_model in _rpm_limit_for_key_model
587+
)
588+
589+
if not should_check_rate_limit:
590+
return
591+
592+
# Get model-specific limits
593+
model_specific_tpm_limit: Optional[int] = _tpm_limit_for_key_model.get(
594+
requested_model
595+
)
596+
model_specific_rpm_limit: Optional[int] = _rpm_limit_for_key_model.get(
597+
requested_model
598+
)
599+
600+
descriptors.append(
601+
RateLimitDescriptor(
602+
key="model_per_key",
603+
value=f"{user_api_key_dict.api_key}:{requested_model}",
604+
rate_limit={
605+
"requests_per_unit": model_specific_rpm_limit,
606+
"tokens_per_unit": model_specific_tpm_limit,
607+
"window_size": self.window_size,
608+
},
609+
)
610+
)
611+
549612
def _should_enforce_rate_limit(
550613
self,
551614
limit_type: Optional[str],
@@ -626,8 +689,8 @@ def _create_rate_limit_descriptors(
626689
Returns list of descriptors for API key, user, team, team member, end user, and model-specific limits.
627690
"""
628691
from litellm.proxy.auth.auth_utils import (
629-
get_key_model_rpm_limit,
630-
get_key_model_tpm_limit,
692+
get_team_model_rpm_limit,
693+
get_team_model_tpm_limit,
631694
)
632695

633696
descriptors = []
@@ -732,29 +795,43 @@ def _create_rate_limit_descriptors(
732795

733796
# Model rate limits
734797
requested_model = data.get("model", None)
735-
if requested_model and (
736-
get_key_model_tpm_limit(user_api_key_dict) is not None
737-
or get_key_model_rpm_limit(user_api_key_dict) is not None
798+
self._add_model_per_key_rate_limit_descriptor(
799+
user_api_key_dict=user_api_key_dict,
800+
requested_model=requested_model,
801+
descriptors=descriptors,
802+
)
803+
804+
if (
805+
get_team_model_rpm_limit(user_api_key_dict) is not None
806+
or get_team_model_tpm_limit(user_api_key_dict) is not None
738807
):
739-
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) or {}
740-
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) or {}
808+
_tpm_limit_for_team_model = (
809+
get_team_model_tpm_limit(user_api_key_dict) or {}
810+
)
811+
_rpm_limit_for_team_model = (
812+
get_team_model_rpm_limit(user_api_key_dict) or {}
813+
)
741814
should_check_rate_limit = False
742-
if requested_model in _tpm_limit_for_key_model:
815+
if requested_model in _tpm_limit_for_team_model:
743816
should_check_rate_limit = True
744-
elif requested_model in _rpm_limit_for_key_model:
817+
elif requested_model in _rpm_limit_for_team_model:
745818
should_check_rate_limit = True
746819

747820
if should_check_rate_limit:
748-
model_specific_tpm_limit: Optional[int] = None
749-
model_specific_rpm_limit: Optional[int] = None
750-
if requested_model in _tpm_limit_for_key_model:
751-
model_specific_tpm_limit = _tpm_limit_for_key_model[requested_model]
752-
if requested_model in _rpm_limit_for_key_model:
753-
model_specific_rpm_limit = _rpm_limit_for_key_model[requested_model]
821+
model_specific_tpm_limit = None
822+
model_specific_rpm_limit = None
823+
if requested_model in _tpm_limit_for_team_model:
824+
model_specific_tpm_limit = _tpm_limit_for_team_model[
825+
requested_model
826+
]
827+
if requested_model in _rpm_limit_for_team_model:
828+
model_specific_rpm_limit = _rpm_limit_for_team_model[
829+
requested_model
830+
]
754831
descriptors.append(
755832
RateLimitDescriptor(
756-
key="model_per_key",
757-
value=f"{user_api_key_dict.api_key}:{requested_model}",
833+
key="model_per_team",
834+
value=f"{user_api_key_dict.team_id}:{requested_model}",
758835
rate_limit={
759836
"requests_per_unit": model_specific_rpm_limit,
760837
"tokens_per_unit": model_specific_tpm_limit,
@@ -1164,6 +1241,15 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
11641241
total_tokens=total_tokens,
11651242
)
11661243
)
1244+
if model_group and user_api_key_team_id:
1245+
pipeline_operations.extend(
1246+
self._create_pipeline_operations(
1247+
key="model_per_team",
1248+
value=f"{user_api_key_team_id}:{model_group}",
1249+
rate_limit_type="tokens",
1250+
total_tokens=total_tokens,
1251+
)
1252+
)
11671253

11681254
# Execute all increments in a single pipeline
11691255
if pipeline_operations:

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,8 @@ def check_team_key_model_specific_limits(
667667
if data.model_rpm_limit is not None:
668668
for model, rpm_limit in data.model_rpm_limit.items():
669669
if (
670-
model_specific_rpm_limit.get(model, 0) + rpm_limit
670+
team_table.rpm_limit is not None
671+
and model_specific_rpm_limit.get(model, 0) + rpm_limit
671672
> team_table.rpm_limit
672673
):
673674
raise HTTPException(
@@ -687,7 +688,7 @@ def check_team_key_model_specific_limits(
687688
):
688689
raise HTTPException(
689690
status_code=400,
690-
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit.get(model, 0)}",
691+
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit}",
691692
)
692693
if data.model_tpm_limit is not None:
693694
for model, tpm_limit in data.model_tpm_limit.items():

litellm/proxy/management_endpoints/team_endpoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ async def new_team( # noqa: PLR0915
300300
- members_with_roles: List[{"role": "admin" or "user", "user_id": "<user-id>"}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`.
301301
- team_member_permissions: Optional[List[str]] - A list of routes that non-admin team members can access. example: ["/key/generate", "/key/update", "/key/delete"]
302302
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"}
303+
- model_rpm_limit: Optional[Dict[str, int]] - The RPM (Requests Per Minute) limit for this team - applied across all keys for this team.
304+
- model_tpm_limit: Optional[Dict[str, int]] - The TPM (Tokens Per Minute) limit for this team - applied across all keys for this team.
303305
- tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit
304306
- rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit
305307
- max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget

0 commit comments

Comments
 (0)