Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6413.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement RBAC Creator Pattern to ensure that RBAC records are created whenever any RBAC-related entities are created
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar

import sqlalchemy as sa
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession as SASession

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.data.permission.association_scopes_entities import (
AssociationScopesEntitiesCreateInput,
)
from ai.backend.manager.data.permission.id import ObjectId, ScopeId
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
AssociationScopesEntitiesRow,
)

log = BraceStyleAdapter(logging.getLogger(__name__))


@dataclass
class RBACEntityCreateInput:
scope_id: ScopeId
object_id: ObjectId


TEntityCreateInput = TypeVar("TEntityCreateInput")
TCreatedEntity = TypeVar("TCreatedEntity")


class RBACEntityCreator(Generic[TEntityCreateInput, TCreatedEntity], ABC):
async def create_entity(
self,
db_session: SASession,
input: TEntityCreateInput,
rbac_input: RBACEntityCreateInput,
) -> TCreatedEntity:
result = await self._create_entity(db_session, input)
await self._create_rbac_entity(db_session, rbac_input)
return result

@abstractmethod
async def _create_entity(
self,
db_session: SASession,
input: TEntityCreateInput,
) -> TCreatedEntity:
raise NotImplementedError

async def _create_rbac_entity(
self,
db_session: SASession,
rbac_input: RBACEntityCreateInput,
) -> None:
scope_id = rbac_input.scope_id
entity_id = rbac_input.object_id
creator = AssociationScopesEntitiesCreateInput(
scope_id=scope_id,
object_id=entity_id,
)
try:
await db_session.execute(
sa.insert(AssociationScopesEntitiesRow).values(creator.fields_to_store())
)
except IntegrityError:
log.exception(
"entity and scope mapping already exists: {}, {}. Skipping.",
entity_id.to_str(),
scope_id.to_str(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, final

import sqlalchemy as sa
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession as SASession

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.data.permission.id import ObjectId, ScopeId
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
AssociationScopesEntitiesRow,
)

log = BraceStyleAdapter(logging.getLogger(__name__))


class RBACEntityDeletor(ABC):
async def delete_entity(self, db_session: SASession) -> None:
scope_id = self.scope_id()
entity_id = self.object_id()
try:
await db_session.execute(
sa.delete(AssociationScopesEntitiesRow).where(
sa.and_(
AssociationScopesEntitiesRow.scope_id == scope_id.scope_id,
AssociationScopesEntitiesRow.scope_type == scope_id.scope_type,
AssociationScopesEntitiesRow.entity_id == entity_id,
AssociationScopesEntitiesRow.entity_type == entity_id.entity_type,
)
)
)
except IntegrityError:
log.exception(
"failed to delete entity and scope mapping: {}, {}.",
entity_id.to_str(),
scope_id.to_str(),
)

@abstractmethod
def scope_id(self) -> ScopeId:
raise NotImplementedError

@abstractmethod
def object_id(self) -> ObjectId:
raise NotImplementedError


TDeletedEntity = TypeVar("TDeletedEntity")


class RBACDeletor(Generic[TDeletedEntity], ABC):
def __init__(self, rbac_entity_deletor: RBACEntityDeletor) -> None:
self._rbac_entity_deletor = rbac_entity_deletor

@final
async def delete(self, db_session: SASession) -> TDeletedEntity:
entity = await self._delete(db_session)
await self._rbac_entity_deletor.delete_entity(db_session)
return entity

@abstractmethod
async def _delete(self, db_session: SASession) -> TDeletedEntity:
raise NotImplementedError
6 changes: 6 additions & 0 deletions tests/manager/repositories/permission_controller/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python_tests(
name="tests",
dependencies=[
"src/ai/backend/manager:src",
],
)
Empty file.
23 changes: 23 additions & 0 deletions tests/manager/repositories/permission_controller/test_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Tests for RBACEntityCreator and RBACCreator functionality.
Tests the creator classes with real database operations.
"""

# from __future__ import annotations

# import uuid

# import pytest
# import sqlalchemy as sa
# from sqlalchemy.ext.asyncio import AsyncSession as SASession

# from ai.backend.manager.data.permission.id import ObjectId, ScopeId
# from ai.backend.manager.data.permission.types import EntityType, ScopeType
# from ai.backend.manager.models.rbac_models.association_scopes_entities import (
# AssociationScopesEntitiesRow,
# )
# from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
# from ai.backend.manager.repositories.permission_controller.creator import (
# RBACEntityCreateInput,
# RBACEntityCreator,
# )
Loading