|
1 | 1 | """Tests for validation rules.""" |
2 | 2 |
|
| 3 | +import uuid |
| 4 | +from collections.abc import Callable |
| 5 | +from datetime import datetime, timedelta |
| 6 | +from typing import Optional |
3 | 7 | from unittest.mock import MagicMock |
4 | 8 |
|
5 | 9 | import pytest |
6 | | - |
7 | | -from ai.backend.common.types import SessionTypes |
| 10 | +import yarl |
| 11 | + |
| 12 | +from ai.backend.common.types import ( |
| 13 | + AccessKey, |
| 14 | + ClusterMode, |
| 15 | + KernelEnqueueingConfig, |
| 16 | + SessionId, |
| 17 | + SessionTypes, |
| 18 | +) |
8 | 19 | from ai.backend.manager.errors.api import InvalidAPIParameters |
9 | 20 | from ai.backend.manager.errors.kernel import QuotaExceeded |
| 21 | +from ai.backend.manager.models import NetworkRow |
10 | 22 | from ai.backend.manager.models.scaling_group import ScalingGroupOpts |
11 | 23 | from ai.backend.manager.repositories.scheduler.types.session_creation import ( |
12 | 24 | AllowedScalingGroup, |
|
22 | 34 | ContainerLimitRule, |
23 | 35 | ScalingGroupAccessRule, |
24 | 36 | ServicePortRule, |
| 37 | + SessionTypeRule, |
25 | 38 | ) |
| 39 | +from ai.backend.manager.types import UserScope |
26 | 40 |
|
27 | 41 |
|
28 | 42 | @pytest.fixture |
@@ -55,6 +69,67 @@ def basic_context(): |
55 | 69 | ) |
56 | 70 |
|
57 | 71 |
|
| 72 | +@pytest.fixture |
| 73 | +def session_spec_factory() -> Callable[..., SessionCreationSpec]: |
| 74 | + def create_spec( |
| 75 | + session_creation_id: str = "test-001", |
| 76 | + session_name: str = "test-session", |
| 77 | + access_key: AccessKey = AccessKey("test-key"), |
| 78 | + user_scope: UserScope = UserScope( |
| 79 | + domain_name="default", |
| 80 | + group_id=uuid.uuid4(), |
| 81 | + user_uuid=uuid.uuid4(), |
| 82 | + user_role="user", |
| 83 | + ), |
| 84 | + session_type: SessionTypes = SessionTypes.INTERACTIVE, |
| 85 | + cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE, |
| 86 | + cluster_size: int = 1, |
| 87 | + priority: int = 10, |
| 88 | + resource_policy: dict | None = None, |
| 89 | + kernel_specs: list[KernelEnqueueingConfig] | None = None, |
| 90 | + creation_spec: dict | None = None, |
| 91 | + scaling_group: Optional[str] = None, |
| 92 | + session_tag: Optional[str] = None, |
| 93 | + starts_at: Optional[datetime] = None, |
| 94 | + batch_timeout: Optional[timedelta] = None, |
| 95 | + dependency_sessions: Optional[list[SessionId]] = None, |
| 96 | + callback_url: Optional[yarl.URL] = None, |
| 97 | + route_id: Optional[uuid.UUID] = None, |
| 98 | + sudo_session_enabled: bool = False, |
| 99 | + network: Optional[NetworkRow] = None, |
| 100 | + designated_agent_list: Optional[list[str]] = None, |
| 101 | + internal_data: Optional[dict] = None, |
| 102 | + public_sgroup_only: bool = True, |
| 103 | + ) -> SessionCreationSpec: |
| 104 | + return SessionCreationSpec( |
| 105 | + session_creation_id=session_creation_id, |
| 106 | + session_name=session_name, |
| 107 | + access_key=access_key, |
| 108 | + user_scope=user_scope, |
| 109 | + session_type=session_type, |
| 110 | + cluster_mode=cluster_mode, |
| 111 | + cluster_size=cluster_size, |
| 112 | + priority=priority, |
| 113 | + resource_policy=resource_policy or {}, |
| 114 | + kernel_specs=kernel_specs or [], |
| 115 | + creation_spec=creation_spec or {}, |
| 116 | + scaling_group=scaling_group, |
| 117 | + session_tag=session_tag, |
| 118 | + starts_at=starts_at, |
| 119 | + batch_timeout=batch_timeout, |
| 120 | + dependency_sessions=dependency_sessions, |
| 121 | + callback_url=callback_url, |
| 122 | + route_id=route_id, |
| 123 | + sudo_session_enabled=sudo_session_enabled, |
| 124 | + network=network, |
| 125 | + designated_agent_list=designated_agent_list, |
| 126 | + internal_data=internal_data, |
| 127 | + public_sgroup_only=public_sgroup_only, |
| 128 | + ) |
| 129 | + |
| 130 | + return create_spec |
| 131 | + |
| 132 | + |
58 | 133 | class TestContainerLimitRule: |
59 | 134 | """Test cases for ContainerLimitRule.""" |
60 | 135 |
|
@@ -211,6 +286,79 @@ def test_inaccessible_sgroup(self, basic_context): |
211 | 286 | assert "not accessible" in str(exc_info.value) |
212 | 287 |
|
213 | 288 |
|
| 289 | +class TestSessionTypeRule: |
| 290 | + """Test cases for SessionTypeRule.""" |
| 291 | + |
| 292 | + def test_allowed_session_type( |
| 293 | + self, |
| 294 | + basic_context: SessionCreationContext, |
| 295 | + session_spec_factory: Callable[..., SessionCreationSpec], |
| 296 | + ) -> None: |
| 297 | + """Test session type that is allowed in scaling group.""" |
| 298 | + rule = SessionTypeRule() |
| 299 | + |
| 300 | + allowed_groups = [ |
| 301 | + AllowedScalingGroup( |
| 302 | + name="test-sg", |
| 303 | + is_private=False, |
| 304 | + scheduler_opts=ScalingGroupOpts( |
| 305 | + allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.BATCH] |
| 306 | + ), |
| 307 | + ) |
| 308 | + ] |
| 309 | + |
| 310 | + spec = session_spec_factory( |
| 311 | + session_type=SessionTypes.INTERACTIVE, |
| 312 | + scaling_group="test-sg", |
| 313 | + ) |
| 314 | + |
| 315 | + # Should not raise |
| 316 | + rule.validate(spec, basic_context, allowed_groups) |
| 317 | + |
| 318 | + def test_disallowed_session_type( |
| 319 | + self, |
| 320 | + basic_context: SessionCreationContext, |
| 321 | + session_spec_factory: Callable[..., SessionCreationSpec], |
| 322 | + ) -> None: |
| 323 | + """Test session type that is not allowed in scaling group.""" |
| 324 | + rule = SessionTypeRule() |
| 325 | + |
| 326 | + allowed_groups = [ |
| 327 | + AllowedScalingGroup( |
| 328 | + name="batch-only-sg", |
| 329 | + is_private=False, |
| 330 | + scheduler_opts=ScalingGroupOpts(allowed_session_types=[SessionTypes.BATCH]), |
| 331 | + ) |
| 332 | + ] |
| 333 | + |
| 334 | + spec = session_spec_factory( |
| 335 | + session_type=SessionTypes.INTERACTIVE, |
| 336 | + scaling_group="batch-only-sg", |
| 337 | + ) |
| 338 | + |
| 339 | + with pytest.raises(InvalidAPIParameters) as exc_info: |
| 340 | + rule.validate(spec, basic_context, allowed_groups) |
| 341 | + assert "not allowed in scaling group" in str(exc_info.value) |
| 342 | + |
| 343 | + def test_empty_allowed_groups( |
| 344 | + self, |
| 345 | + basic_context: SessionCreationContext, |
| 346 | + session_spec_factory: Callable[..., SessionCreationSpec], |
| 347 | + ) -> None: |
| 348 | + """Test with empty allowed groups list.""" |
| 349 | + rule = SessionTypeRule() |
| 350 | + |
| 351 | + spec = session_spec_factory( |
| 352 | + session_type=SessionTypes.INTERACTIVE, |
| 353 | + scaling_group="any-sg", |
| 354 | + ) |
| 355 | + |
| 356 | + # Should raise - no allowed groups available |
| 357 | + with pytest.raises(InvalidAPIParameters) as exc_info: |
| 358 | + rule.validate(spec, basic_context, []) |
| 359 | + assert "not accessible" in str(exc_info.value) |
| 360 | + |
| 361 | + |
214 | 362 | class TestServicePortRule: |
215 | 363 | """Test cases for ServicePortRule.""" |
216 | 364 |
|
|
0 commit comments