Skip to content

Commit dbed927

Browse files
authored
fix(BA-2788): Missing session type check (#6354)
1 parent 4ff2c0a commit dbed927

File tree

5 files changed

+194
-3
lines changed

5 files changed

+194
-3
lines changed

changes/6354.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Session type validation is now properly enforced when creating sessions within scaling groups

src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
MountNameValidationRule,
4646
ScalingGroupAccessRule,
4747
ServicePortRule,
48+
SessionTypeRule,
4849
SessionValidator,
4950
)
5051

@@ -101,6 +102,7 @@ def __init__(self, args: SchedulingControllerArgs) -> None:
101102
validator_rules = [
102103
ContainerLimitRule(),
103104
ScalingGroupAccessRule(),
105+
SessionTypeRule(),
104106
ServicePortRule(),
105107
ClusterValidationRule(),
106108
MountNameValidationRule(),

src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ResourceLimitRule,
99
ScalingGroupAccessRule,
1010
ServicePortRule,
11+
SessionTypeRule,
1112
)
1213
from .validator import SessionValidator
1314

@@ -16,6 +17,7 @@
1617
"SessionValidatorRule",
1718
"ContainerLimitRule",
1819
"ScalingGroupAccessRule",
20+
"SessionTypeRule",
1921
"ServicePortRule",
2022
"ResourceLimitRule",
2123
"ClusterValidationRule",

src/ai/backend/manager/sokovan/scheduling_controller/validators/rules.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Validator rules for session creation."""
22

3-
from typing import Mapping
3+
from typing import Mapping, override
44

55
from ai.backend.common.exception import BackendAIError
66
from ai.backend.common.service_ports import parse_service_ports
@@ -20,9 +20,11 @@
2020
class ContainerLimitRule(SessionValidatorRule):
2121
"""Validates cluster size against resource policy limits."""
2222

23+
@override
2324
def name(self) -> str:
2425
return "container_limit"
2526

27+
@override
2628
def validate(
2729
self,
2830
spec: SessionCreationSpec,
@@ -39,9 +41,11 @@ def validate(
3941
class ScalingGroupAccessRule(SessionValidatorRule):
4042
"""Validates that the scaling group is accessible."""
4143

44+
@override
4245
def name(self) -> str:
4346
return "scaling_group_access"
4447

48+
@override
4549
def validate(
4650
self,
4751
spec: SessionCreationSpec,
@@ -66,12 +70,44 @@ def validate(
6670
raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")
6771

6872

73+
class SessionTypeRule(SessionValidatorRule):
74+
"""Validates session type compatibility with scaling group."""
75+
76+
@override
77+
def name(self) -> str:
78+
return "session_type"
79+
80+
@override
81+
def validate(
82+
self,
83+
spec: SessionCreationSpec,
84+
context: SessionCreationContext,
85+
allowed_groups: list[AllowedScalingGroup],
86+
) -> None:
87+
if spec.scaling_group is None:
88+
# Should have been resolved already
89+
return
90+
91+
for sg in allowed_groups:
92+
if sg.name == spec.scaling_group:
93+
allowed_session_types = sg.scheduler_opts.allowed_session_types
94+
if spec.session_type not in allowed_session_types:
95+
raise InvalidAPIParameters(
96+
f"Session type {spec.session_type} is not allowed in scaling group {sg.name}"
97+
)
98+
return
99+
100+
raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")
101+
102+
69103
class ServicePortRule(SessionValidatorRule):
70104
"""Validates preopen ports against service ports."""
71105

106+
@override
72107
def name(self) -> str:
73108
return "service_port"
74109

110+
@override
75111
def validate(
76112
self,
77113
spec: SessionCreationSpec,
@@ -138,12 +174,14 @@ def validate(
138174
class ResourceLimitRule(SessionValidatorRule):
139175
"""Validates requested resources against image limits."""
140176

177+
@override
141178
def name(self) -> str:
142179
return "resource_limit"
143180

144181
def __init__(self, known_slot_types: Mapping[SlotName, SlotTypes] | None = None):
145182
self._known_slot_types = known_slot_types
146183

184+
@override
147185
def validate(
148186
self,
149187
spec: SessionCreationSpec,

tests/manager/sokovan/scheduling_controller/validators/test_rules.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
"""Tests for validation rules."""
22

3+
import uuid
4+
from collections.abc import Callable
5+
from datetime import datetime, timedelta
6+
from typing import Optional
37
from unittest.mock import MagicMock
48

59
import pytest
6-
7-
from ai.backend.common.types import SessionTypes
10+
import yarl
11+
12+
from ai.backend.common.types import (
13+
AccessKey,
14+
ClusterMode,
15+
KernelEnqueueingConfig,
16+
SessionId,
17+
SessionTypes,
18+
)
819
from ai.backend.manager.errors.api import InvalidAPIParameters
920
from ai.backend.manager.errors.kernel import QuotaExceeded
21+
from ai.backend.manager.models import NetworkRow
1022
from ai.backend.manager.models.scaling_group import ScalingGroupOpts
1123
from ai.backend.manager.repositories.scheduler.types.session_creation import (
1224
AllowedScalingGroup,
@@ -22,7 +34,9 @@
2234
ContainerLimitRule,
2335
ScalingGroupAccessRule,
2436
ServicePortRule,
37+
SessionTypeRule,
2538
)
39+
from ai.backend.manager.types import UserScope
2640

2741

2842
@pytest.fixture
@@ -55,6 +69,67 @@ def basic_context():
5569
)
5670

5771

72+
@pytest.fixture
73+
def session_spec_factory() -> Callable[..., SessionCreationSpec]:
74+
def create_spec(
75+
session_creation_id: str = "test-001",
76+
session_name: str = "test-session",
77+
access_key: AccessKey = AccessKey("test-key"),
78+
user_scope: UserScope = UserScope(
79+
domain_name="default",
80+
group_id=uuid.uuid4(),
81+
user_uuid=uuid.uuid4(),
82+
user_role="user",
83+
),
84+
session_type: SessionTypes = SessionTypes.INTERACTIVE,
85+
cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE,
86+
cluster_size: int = 1,
87+
priority: int = 10,
88+
resource_policy: dict | None = None,
89+
kernel_specs: list[KernelEnqueueingConfig] | None = None,
90+
creation_spec: dict | None = None,
91+
scaling_group: Optional[str] = None,
92+
session_tag: Optional[str] = None,
93+
starts_at: Optional[datetime] = None,
94+
batch_timeout: Optional[timedelta] = None,
95+
dependency_sessions: Optional[list[SessionId]] = None,
96+
callback_url: Optional[yarl.URL] = None,
97+
route_id: Optional[uuid.UUID] = None,
98+
sudo_session_enabled: bool = False,
99+
network: Optional[NetworkRow] = None,
100+
designated_agent_list: Optional[list[str]] = None,
101+
internal_data: Optional[dict] = None,
102+
public_sgroup_only: bool = True,
103+
) -> SessionCreationSpec:
104+
return SessionCreationSpec(
105+
session_creation_id=session_creation_id,
106+
session_name=session_name,
107+
access_key=access_key,
108+
user_scope=user_scope,
109+
session_type=session_type,
110+
cluster_mode=cluster_mode,
111+
cluster_size=cluster_size,
112+
priority=priority,
113+
resource_policy=resource_policy or {},
114+
kernel_specs=kernel_specs or [],
115+
creation_spec=creation_spec or {},
116+
scaling_group=scaling_group,
117+
session_tag=session_tag,
118+
starts_at=starts_at,
119+
batch_timeout=batch_timeout,
120+
dependency_sessions=dependency_sessions,
121+
callback_url=callback_url,
122+
route_id=route_id,
123+
sudo_session_enabled=sudo_session_enabled,
124+
network=network,
125+
designated_agent_list=designated_agent_list,
126+
internal_data=internal_data,
127+
public_sgroup_only=public_sgroup_only,
128+
)
129+
130+
return create_spec
131+
132+
58133
class TestContainerLimitRule:
59134
"""Test cases for ContainerLimitRule."""
60135

@@ -211,6 +286,79 @@ def test_inaccessible_sgroup(self, basic_context):
211286
assert "not accessible" in str(exc_info.value)
212287

213288

289+
class TestSessionTypeRule:
290+
"""Test cases for SessionTypeRule."""
291+
292+
def test_allowed_session_type(
293+
self,
294+
basic_context: SessionCreationContext,
295+
session_spec_factory: Callable[..., SessionCreationSpec],
296+
) -> None:
297+
"""Test session type that is allowed in scaling group."""
298+
rule = SessionTypeRule()
299+
300+
allowed_groups = [
301+
AllowedScalingGroup(
302+
name="test-sg",
303+
is_private=False,
304+
scheduler_opts=ScalingGroupOpts(
305+
allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.BATCH]
306+
),
307+
)
308+
]
309+
310+
spec = session_spec_factory(
311+
session_type=SessionTypes.INTERACTIVE,
312+
scaling_group="test-sg",
313+
)
314+
315+
# Should not raise
316+
rule.validate(spec, basic_context, allowed_groups)
317+
318+
def test_disallowed_session_type(
319+
self,
320+
basic_context: SessionCreationContext,
321+
session_spec_factory: Callable[..., SessionCreationSpec],
322+
) -> None:
323+
"""Test session type that is not allowed in scaling group."""
324+
rule = SessionTypeRule()
325+
326+
allowed_groups = [
327+
AllowedScalingGroup(
328+
name="batch-only-sg",
329+
is_private=False,
330+
scheduler_opts=ScalingGroupOpts(allowed_session_types=[SessionTypes.BATCH]),
331+
)
332+
]
333+
334+
spec = session_spec_factory(
335+
session_type=SessionTypes.INTERACTIVE,
336+
scaling_group="batch-only-sg",
337+
)
338+
339+
with pytest.raises(InvalidAPIParameters) as exc_info:
340+
rule.validate(spec, basic_context, allowed_groups)
341+
assert "not allowed in scaling group" in str(exc_info.value)
342+
343+
def test_empty_allowed_groups(
344+
self,
345+
basic_context: SessionCreationContext,
346+
session_spec_factory: Callable[..., SessionCreationSpec],
347+
) -> None:
348+
"""Test with empty allowed groups list."""
349+
rule = SessionTypeRule()
350+
351+
spec = session_spec_factory(
352+
session_type=SessionTypes.INTERACTIVE,
353+
scaling_group="any-sg",
354+
)
355+
356+
# Should raise - no allowed groups available
357+
with pytest.raises(InvalidAPIParameters) as exc_info:
358+
rule.validate(spec, basic_context, [])
359+
assert "not accessible" in str(exc_info.value)
360+
361+
214362
class TestServicePortRule:
215363
"""Test cases for ServicePortRule."""
216364

0 commit comments

Comments
 (0)