Skip to content

Commit ab1bc21

Browse files
committed
feat: Apply RBAC validation to VFolder
1 parent dbed927 commit ab1bc21

File tree

10 files changed

+153
-71
lines changed

10 files changed

+153
-71
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass
2+
3+
from .batch import BatchActionValidator
4+
from .scope import ScopeActionValidator
5+
from .single_entity import SingleEntityActionValidator
6+
7+
8+
@dataclass
9+
class ValidatorArgs:
10+
batch: list[BatchActionValidator]
11+
scope: list[ScopeActionValidator]
12+
single_entity: list[SingleEntityActionValidator]

src/ai/backend/manager/actions/validators/rbac/scope.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ai.backend.manager.data.permission.id import ScopeId
77
from ai.backend.manager.data.permission.role import ScopePermissionCheckInput
88
from ai.backend.manager.data.permission.types import EntityType, ScopeType
9+
from ai.backend.manager.errors.rbac import RBACForbidden
910
from ai.backend.manager.errors.user import UserNotFound
1011
from ai.backend.manager.repositories.permission_controller.repository import (
1112
PermissionControllerRepository,
@@ -30,7 +31,7 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -
3031
if user is None:
3132
raise UserNotFound("User not found in context")
3233

33-
await self._repository.check_permission_in_scope(
34+
is_valid = await self._repository.check_permission_in_scope(
3435
ScopePermissionCheckInput(
3536
user_id=user.user_id,
3637
operation=action.permission_operation_type(),
@@ -41,3 +42,8 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -
4142
),
4243
)
4344
)
45+
if not is_valid:
46+
raise RBACForbidden(
47+
"User does not have permission to perform this action in the specified scope "
48+
f"({scope_type.value}:{scope_id})"
49+
)

src/ai/backend/manager/actions/validators/rbac/single_entity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ai.backend.manager.data.permission.id import ObjectId
77
from ai.backend.manager.data.permission.role import SingleEntityPermissionCheckInput
88
from ai.backend.manager.data.permission.types import EntityType
9+
from ai.backend.manager.errors.rbac import RBACForbidden
910
from ai.backend.manager.errors.user import UserNotFound
1011
from ai.backend.manager.repositories.permission_controller.repository import (
1112
PermissionControllerRepository,
@@ -29,7 +30,7 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger
2930
if user is None:
3031
raise UserNotFound("User not found in context")
3132

32-
await self._repository.check_permission_of_entity(
33+
is_valid = await self._repository.check_permission_of_entity(
3334
SingleEntityPermissionCheckInput(
3435
user_id=user.user_id,
3536
operation=action.permission_operation_type(),
@@ -39,3 +40,8 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger
3940
),
4041
)
4142
)
43+
if not is_valid:
44+
raise RBACForbidden(
45+
"User does not have permission to perform this action on the specified entity "
46+
f"({entity_type.value}:{entity_id})"
47+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from aiohttp import web
2+
3+
from ai.backend.common.exception import (
4+
BackendAIError,
5+
ErrorCode,
6+
ErrorDetail,
7+
ErrorDomain,
8+
ErrorOperation,
9+
)
10+
11+
12+
class RBACForbidden(BackendAIError, web.HTTPForbidden):
13+
error_type = "https://api.backend.ai/probs/forbidden-operation"
14+
error_title = "The operation is forbidden due to insufficient RBAC permissions."
15+
16+
@classmethod
17+
def error_code(cls) -> ErrorCode:
18+
return ErrorCode(
19+
domain=ErrorDomain.PERMISSION,
20+
operation=ErrorOperation.ACCESS,
21+
error_detail=ErrorDetail.FORBIDDEN,
22+
)

src/ai/backend/manager/repositories/permission_controller/db_source.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,6 @@ async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]:
134134
result = await db_session.scalars(stmt)
135135
return result.all()
136136

137-
async def get_entity_mapped_scopes(
138-
self, target_object_id: ObjectId
139-
) -> list[AssociationScopesEntitiesRow]:
140-
async with self._db.begin_readonly_session() as db_session:
141-
stmt = sa.select(AssociationScopesEntitiesRow.scope_id).where(
142-
AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id,
143-
AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value,
144-
)
145-
result = await db_session.scalars(stmt)
146-
return result.all()
147-
148137
async def check_scope_permission_exist(
149138
self,
150139
user_id: uuid.UUID,

src/ai/backend/manager/repositories/permission_controller/repository.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from collections.abc import Mapping
3-
from typing import Optional
3+
from typing import Optional, Self
44

55
from ai.backend.common.exception import BackendAIError
66
from ai.backend.common.metrics.metric import DomainType, LayerType
@@ -20,6 +20,7 @@
2020
UserRoleAssignmentInput,
2121
)
2222
from ...models.utils import ExtendedAsyncSAEngine
23+
from ..types import RepositoryArgs
2324
from .db_source import PermissionDBSource
2425

2526
permission_controller_repository_resilience = Resilience(
@@ -47,6 +48,12 @@ class PermissionControllerRepository:
4748
def __init__(self, db: ExtendedAsyncSAEngine) -> None:
4849
self._db_source = PermissionDBSource(db)
4950

51+
@classmethod
52+
def create(cls, args: RepositoryArgs) -> Self:
53+
return cls(
54+
db=args.db,
55+
)
56+
5057
@permission_controller_repository_resilience.apply()
5158
async def create_role(self, data: RoleCreateInput) -> RoleData:
5259
"""
@@ -79,27 +86,20 @@ async def get_role(self, role_id: uuid.UUID) -> Optional[RoleData]:
7986

8087
@permission_controller_repository_resilience.apply()
8188
async def check_permission_of_entity(self, data: SingleEntityPermissionCheckInput) -> bool:
82-
target_object_id = data.target_object_id
83-
roles = await self._db_source.get_user_roles(data.user_id)
84-
associated_scopes = await self._db_source.get_entity_mapped_scopes(target_object_id)
85-
associated_scopes_set = set([row.parsed_scope_id() for row in associated_scopes])
86-
for role in roles:
87-
for object_perm in role.object_permission_rows:
88-
if object_perm.operation != data.operation:
89-
continue
90-
if object_perm.object_id() == target_object_id:
91-
return True
92-
93-
for permission_group in role.permission_group_rows:
94-
if permission_group.parsed_scope_id() not in associated_scopes_set:
95-
continue
96-
for permission in permission_group.permission_rows:
97-
if permission.operation == data.operation:
98-
return True
99-
return False
89+
"""
90+
Check if the user has the requested operation permission on the given entity.
91+
Returns True if the permission exists, False otherwise.
92+
"""
93+
return await self._db_source.check_object_permission_exist(
94+
data.user_id, data.target_object_id, data.operation
95+
)
10096

10197
@permission_controller_repository_resilience.apply()
10298
async def check_permission_in_scope(self, data: ScopePermissionCheckInput) -> bool:
99+
"""
100+
Check if the user has the requested operation permission in the given scope.
101+
Returns True if the permission exists, False otherwise.
102+
"""
103103
return await self._db_source.check_scope_permission_exist(
104104
data.user_id, data.target_scope_id, data.operation
105105
)

src/ai/backend/manager/repositories/repositories.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from ai.backend.manager.repositories.metric.repositories import MetricRepositories
2424
from ai.backend.manager.repositories.model_serving.repositories import ModelServingRepositories
2525
from ai.backend.manager.repositories.object_storage.repositories import ObjectStorageRepositories
26+
from ai.backend.manager.repositories.permission_controller.repository import (
27+
PermissionControllerRepository,
28+
)
2629
from ai.backend.manager.repositories.project_resource_policy.repositories import (
2730
ProjectResourcePolicyRepositories,
2831
)
@@ -72,6 +75,7 @@ class Repositories:
7275
artifact: ArtifactRepositories
7376
artifact_registry: ArtifactRegistryRepositories
7477
storage_namespace: StorageNamespaceRepositories
78+
permission_controller: PermissionControllerRepository
7579

7680
@classmethod
7781
def create(cls, args: RepositoryArgs) -> Self:
@@ -100,6 +104,7 @@ def create(cls, args: RepositoryArgs) -> Self:
100104
huggingface_registry_repositories = HuggingFaceRegistryRepositories.create(args)
101105
artifact_registries = ArtifactRegistryRepositories.create(args)
102106
storage_namespace_repositories = StorageNamespaceRepositories.create(args)
107+
permission_controller_repository = PermissionControllerRepository.create(args)
103108

104109
return cls(
105110
agent=agent_repositories,
@@ -127,4 +132,5 @@ def create(cls, args: RepositoryArgs) -> Self:
127132
artifact=artifact_repositories,
128133
artifact_registry=artifact_registries,
129134
storage_namespace=storage_namespace_repositories,
135+
permission_controller=permission_controller_repository,
130136
)

src/ai/backend/manager/server.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,10 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
654654
from .actions.monitors.audit_log import AuditLogMonitor
655655
from .actions.monitors.prometheus import PrometheusMonitor
656656
from .actions.monitors.reporter import ReporterMonitor
657+
from .actions.validator.args import ValidatorArgs
658+
from .actions.validators.rbac.batch import BatchActionRBACValidator
659+
from .actions.validators.rbac.scope import ScopeActionRBACValidator
660+
from .actions.validators.rbac.single_entity import SingleEntityActionRBACValidator
657661
from .reporters.hub import ReporterHub, ReporterHubArgs
658662
from .services.processors import ProcessorArgs, Processors, ServiceArgs
659663

@@ -667,6 +671,15 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
667671
reporter_monitor = ReporterMonitor(reporter_hub)
668672
prometheus_monitor = PrometheusMonitor()
669673
audit_log_monitor = AuditLogMonitor(root_ctx.db)
674+
batch_action_rbac_validator = BatchActionRBACValidator(
675+
root_ctx.repositories.permission_controller
676+
)
677+
single_entity_rbac_validator = SingleEntityActionRBACValidator(
678+
root_ctx.repositories.permission_controller
679+
)
680+
scope_action_rbac_validator = ScopeActionRBACValidator(
681+
root_ctx.repositories.permission_controller
682+
)
670683
root_ctx.processors = Processors.create(
671684
ProcessorArgs(
672685
service_args=ServiceArgs(
@@ -689,9 +702,14 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
689702
deployment_controller=root_ctx.deployment_controller,
690703
event_producer=root_ctx.event_producer,
691704
agent_cache=root_ctx.agent_cache,
692-
)
705+
),
706+
action_monitors=[reporter_monitor, prometheus_monitor, audit_log_monitor],
707+
action_validator_args=ValidatorArgs(
708+
batch=[batch_action_rbac_validator],
709+
single_entity=[single_entity_rbac_validator],
710+
scope=[scope_action_rbac_validator],
711+
),
693712
),
694-
[reporter_monitor, prometheus_monitor, audit_log_monitor],
695713
)
696714
yield
697715

src/ai/backend/manager/services/processors.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ai.backend.common.plugin.monitor import ErrorPluginContext
1313
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
1414
from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec
15+
from ai.backend.manager.actions.validator.args import ValidatorArgs
1516
from ai.backend.manager.agent_cache import AgentRPCCache
1617
from ai.backend.manager.config.provider import ManagerConfigProvider
1718
from ai.backend.manager.idle import IdleCheckerHost
@@ -318,6 +319,8 @@ def create(cls, args: ServiceArgs) -> Self:
318319
@dataclass
319320
class ProcessorArgs:
320321
service_args: ServiceArgs
322+
action_monitors: list[ActionMonitor]
323+
action_validator_args: ValidatorArgs
321324

322325

323326
@dataclass
@@ -349,61 +352,65 @@ class Processors(AbstractProcessorPackage):
349352
storage_namespace: StorageNamespaceProcessors
350353

351354
@classmethod
352-
def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Self:
355+
def create(cls, args: ProcessorArgs) -> Self:
353356
services = Services.create(args.service_args)
354-
agent_processors = AgentProcessors(services.agent, action_monitors)
355-
domain_processors = DomainProcessors(services.domain, action_monitors)
356-
group_processors = GroupProcessors(services.group, action_monitors)
357-
user_processors = UserProcessors(services.user, action_monitors)
358-
image_processors = ImageProcessors(services.image, action_monitors)
357+
agent_processors = AgentProcessors(services.agent, args.action_monitors)
358+
domain_processors = DomainProcessors(services.domain, args.action_monitors)
359+
group_processors = GroupProcessors(services.group, args.action_monitors)
360+
user_processors = UserProcessors(services.user, args.action_monitors)
361+
image_processors = ImageProcessors(services.image, args.action_monitors)
359362
container_registry_processors = ContainerRegistryProcessors(
360-
services.container_registry, action_monitors
363+
services.container_registry, args.action_monitors
361364
)
362-
vfolder_processors = VFolderProcessors(services.vfolder, action_monitors)
363-
vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, action_monitors)
365+
vfolder_processors = VFolderProcessors(
366+
services.vfolder, args.action_monitors, args.action_validator_args
367+
)
368+
vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, args.action_monitors)
364369
vfolder_invite_processors = VFolderInviteProcessors(
365-
services.vfolder_invite, action_monitors
370+
services.vfolder_invite, args.action_monitors
366371
)
367-
session_processors = SessionProcessors(services.session, action_monitors)
372+
session_processors = SessionProcessors(services.session, args.action_monitors)
368373
keypair_resource_policy_processors = KeypairResourcePolicyProcessors(
369-
services.keypair_resource_policy, action_monitors
374+
services.keypair_resource_policy, args.action_monitors
370375
)
371376
user_resource_policy_processors = UserResourcePolicyProcessors(
372-
services.user_resource_policy, action_monitors
377+
services.user_resource_policy, args.action_monitors
373378
)
374379
project_resource_policy_processors = ProjectResourcePolicyProcessors(
375-
services.project_resource_policy, action_monitors
380+
services.project_resource_policy, args.action_monitors
376381
)
377382
resource_preset_processors = ResourcePresetProcessors(
378-
services.resource_preset, action_monitors
383+
services.resource_preset, args.action_monitors
384+
)
385+
model_serving_processors = ModelServingProcessors(
386+
services.model_serving, args.action_monitors
379387
)
380-
model_serving_processors = ModelServingProcessors(services.model_serving, action_monitors)
381388
model_serving_auto_scaling_processors = ModelServingAutoScalingProcessors(
382-
services.model_serving_auto_scaling, action_monitors
389+
services.model_serving_auto_scaling, args.action_monitors
383390
)
384391
utilization_metric_processors = UtilizationMetricProcessors(
385-
services.utilization_metric, action_monitors
392+
services.utilization_metric, args.action_monitors
386393
)
387-
auth = AuthProcessors(services.auth, action_monitors)
394+
auth = AuthProcessors(services.auth, args.action_monitors)
388395
object_storage_processors = ObjectStorageProcessors(
389-
services.object_storage, action_monitors
396+
services.object_storage, args.action_monitors
390397
)
391-
vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, action_monitors)
392-
artifact_processors = ArtifactProcessors(services.artifact, action_monitors)
398+
vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, args.action_monitors)
399+
artifact_processors = ArtifactProcessors(services.artifact, args.action_monitors)
393400
artifact_registry_processors = ArtifactRegistryProcessors(
394-
services.artifact_registry, action_monitors
401+
services.artifact_registry, args.action_monitors
395402
)
396403
artifact_revision_processors = ArtifactRevisionProcessors(
397-
services.artifact_revision, action_monitors
404+
services.artifact_revision, args.action_monitors
398405
)
399406

400407
# Initialize deployment processors if service is available
401408
deployment_processors = None
402409
if services.deployment is not None:
403-
deployment_processors = DeploymentProcessors(services.deployment, action_monitors)
410+
deployment_processors = DeploymentProcessors(services.deployment, args.action_monitors)
404411

405412
storage_namespace_processors = StorageNamespaceProcessors(
406-
services.storage_namespace, action_monitors
413+
services.storage_namespace, args.action_monitors
407414
)
408415

409416
return cls(

0 commit comments

Comments
 (0)