diff --git a/changes/7177.enhance.md b/changes/7177.enhance.md new file mode 100644 index 00000000000..e86e0f8ba1b --- /dev/null +++ b/changes/7177.enhance.md @@ -0,0 +1 @@ +Explicitly raise error when operation fails in auth repository diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 723f49f077a..0a5fd39fa2f 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -128,7 +128,7 @@ type Query { """Added in 24.03.1""" id: String reference: String - architecture: String = "aarch64" + architecture: String = "x86_64" ): Image images( """ @@ -2341,7 +2341,7 @@ type Mutation { ): RescanImages preload_image(references: [String]!, target_agents: [String]!): PreloadImage unload_image(references: [String]!, target_agents: [String]!): UnloadImage - modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage + modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage """Added in 25.6.0""" clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @@ -2350,7 +2350,7 @@ type Mutation { forget_image_by_id(image_id: String!): ForgetImageById """Deprecated since 25.4.0. Use `forget_image_by_id` instead.""" - forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") + forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") """Added in 25.4.0""" purge_image_by_id( @@ -2362,7 +2362,7 @@ type Mutation { """Added in 24.03.1""" untag_image_from_registry(image_id: String!): UntagImageFromRegistry - alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage + alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage dealias_image(alias: String!): DealiasImage clear_images(registry: String): ClearImages @@ -2937,7 +2937,7 @@ type ClearImageCustomResourceLimitPayload { """Added in 25.6.0.""" input ClearImageCustomResourceLimitKey { image_canonical: String! - architecture: String! = "aarch64" + architecture: String! = "x86_64" } """Added in 24.03.0.""" diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 00214d82e92..c7bc588487a 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -1129,7 +1129,7 @@ input ClearImageCustomResourceLimitKey @join__type(graph: GRAPHENE) { image_canonical: String! - architecture: String! = "aarch64" + architecture: String! = "x86_64" } """Added in 25.6.0.""" @@ -4461,7 +4461,7 @@ type Mutation ): RescanImages @join__field(graph: GRAPHENE) preload_image(references: [String]!, target_agents: [String]!): PreloadImage @join__field(graph: GRAPHENE) unload_image(references: [String]!, target_agents: [String]!): UnloadImage @join__field(graph: GRAPHENE) - modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE) + modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE) """Added in 25.6.0""" clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @join__field(graph: GRAPHENE) @@ -4470,7 +4470,7 @@ type Mutation forget_image_by_id(image_id: String!): ForgetImageById @join__field(graph: GRAPHENE) """Deprecated since 25.4.0. Use `forget_image_by_id` instead.""" - forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") + forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") """Added in 25.4.0""" purge_image_by_id( @@ -4482,7 +4482,7 @@ type Mutation """Added in 24.03.1""" untag_image_from_registry(image_id: String!): UntagImageFromRegistry @join__field(graph: GRAPHENE) - alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage @join__field(graph: GRAPHENE) + alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage @join__field(graph: GRAPHENE) dealias_image(alias: String!): DealiasImage @join__field(graph: GRAPHENE) clear_images(registry: String): ClearImages @join__field(graph: GRAPHENE) @@ -5514,7 +5514,7 @@ type Query """Added in 24.03.1""" id: String reference: String - architecture: String = "aarch64" + architecture: String = "x86_64" ): Image @join__field(graph: GRAPHENE) images( """ diff --git a/src/ai/backend/manager/api/auth.py b/src/ai/backend/manager/api/auth.py index b18fea71099..1071e3bdb39 100644 --- a/src/ai/backend/manager/api/auth.py +++ b/src/ai/backend/manager/api/auth.py @@ -942,7 +942,7 @@ async def update_full_name(request: web.Request, params: Any) -> web.Response: domain_name = request["user"]["domain_name"] email = request["user"]["email"] log.info("AUTH.UPDATE_FULL_NAME(d:{}, email:{})", domain_name, email) - result = await root_ctx.processors.auth.update_full_name.wait_for_complete( + await root_ctx.processors.auth.update_full_name.wait_for_complete( UpdateFullNameAction( user_id=request["user"]["uuid"], full_name=params["full_name"], @@ -950,11 +950,6 @@ async def update_full_name(request: web.Request, params: Any) -> web.Response: email=email, ) ) - - if not result.success: - log.info("AUTH.UPDATE_FULL_NAME(d:{}, email:{}): Unknown user", domain_name, email) - return web.json_response({"error_msg": "Unknown user"}, status=HTTPStatus.BAD_REQUEST) - return web.json_response({}, status=HTTPStatus.OK) diff --git a/src/ai/backend/manager/models/user.py b/src/ai/backend/manager/models/user.py index 285b7be2f26..1e8a7502f4f 100644 --- a/src/ai/backend/manager/models/user.py +++ b/src/ai/backend/manager/models/user.py @@ -23,6 +23,7 @@ from ai.backend.manager.data.auth.hash import PasswordHashAlgorithm from ai.backend.manager.data.model_serving.types import UserData as ModelServingUserData from ai.backend.manager.data.user.types import UserCreator, UserData, UserRole, UserStatus +from ai.backend.manager.errors.auth import AuthorizationFailed from ai.backend.manager.models.hasher.types import HashInfo, PasswordInfo from .base import ( @@ -252,23 +253,6 @@ def load_main_keypair(cls) -> Callable: def load_resource_policy(cls) -> Callable: return joinedload(UserRow.resource_policy_row) - @classmethod - async def query_user_by_uuid( - cls, - user_uuid: UUID, - db_session: SASession, - ) -> Optional[Self]: - user_query = ( - sa.select(UserRow) - .where(UserRow.uuid == user_uuid) - .options( - joinedload(UserRow.main_keypair), - selectinload(UserRow.keypairs), - ) - ) - user_row = await db_session.scalar(user_query) - return user_row - @classmethod async def query_by_condition( cls, @@ -440,7 +424,7 @@ async def check_credential_with_migration( domain: str, email: str, target_password_info: PasswordInfo, -) -> Any: +) -> dict: """ Check user credentials and optionally migrate password hash if needed. @@ -451,8 +435,12 @@ async def check_credential_with_migration( target_password_info: Password configuration containing password and target hash settings Returns: - User row if credentials are valid, None otherwise + User row if credentials are valid + + Raises: + AuthorizationFailed: If user not found, password not set, or password mismatch """ + async with db.begin_readonly() as conn: result = await conn.execute( sa.select([users]) @@ -463,16 +451,15 @@ async def check_credential_with_migration( ) row = result.first() if row is None: - return None + raise AuthorizationFailed("User credential mismatch.") if row["password"] is None: - # user password is not set. - return None + raise AuthorizationFailed("User credential mismatch.") try: if not _verify_password(target_password_info.password, row["password"]): - return None + raise AuthorizationFailed("User credential mismatch.") except ValueError: - return None + raise AuthorizationFailed("User credential mismatch.") # Password is valid, check if we need to migrate the hash current_hash_info = HashInfo.from_hash_string(row["password"]) @@ -498,7 +485,7 @@ async def check_credential( domain: str, email: str, password: str, -) -> Any: +) -> dict[str, Any]: """ Check user credentials without migration (for signout, update password, etc.) @@ -509,8 +496,12 @@ async def check_credential( password: Plain text password to verify Returns: - User row if credentials are valid, None otherwise + User row if credentials are valid + + Raises: + AuthorizationFailed: If user not found, password not set, or password mismatch """ + async with db.begin_readonly() as conn: result = await conn.execute( sa.select([users]) @@ -521,15 +512,14 @@ async def check_credential( ) row = result.first() if row is None: - return None + raise AuthorizationFailed("User credential mismatch.") if row["password"] is None: - # user password is not set. - return None + raise AuthorizationFailed("User credential mismatch.") try: if not _verify_password(password, row["password"]): - return None + raise AuthorizationFailed("User credential mismatch.") except ValueError: - return None + raise AuthorizationFailed("User credential mismatch.") return row diff --git a/src/ai/backend/manager/repositories/auth/db_source/db_source.py b/src/ai/backend/manager/repositories/auth/db_source/db_source.py index 14551560570..0cdb30f5584 100644 --- a/src/ai/backend/manager/repositories/auth/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/auth/db_source/db_source.py @@ -3,12 +3,13 @@ from __future__ import annotations from datetime import datetime -from typing import Optional +from typing import Any, Optional from uuid import UUID import sqlalchemy as sa +from sqlalchemy.orm import joinedload, selectinload -from ai.backend.common.exception import BackendAIError +from ai.backend.common.exception import BackendAIError, UserNotFound from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy @@ -124,8 +125,8 @@ async def insert_user_with_keypair( return self._user_row_to_data(user_row) @auth_db_source_resilience.apply() - async def modify_user_full_name(self, email: str, domain_name: str, full_name: str) -> bool: - """Modify user's full name in database. Returns True if updated, False if user not found.""" + async def modify_user_full_name(self, email: str, domain_name: str, full_name: str) -> None: + """Modify user's full name in database.""" async with self._db.begin() as conn: query = ( sa.select(users) @@ -135,12 +136,11 @@ async def modify_user_full_name(self, email: str, domain_name: str, full_name: s result = await conn.execute(query) user_row = result.first() if not user_row: - return False + raise UserNotFound(extra_data={"email": email, "domain": domain_name}) data = {"full_name": full_name} update_query = users.update().values(data).where(users.c.email == email) await conn.execute(update_query) - return True @auth_db_source_resilience.apply() async def modify_user_password(self, email: str, password_info: PasswordInfo) -> None: @@ -239,7 +239,7 @@ async def verify_credential_with_migration( domain_name: str, email: str, target_password_info: PasswordInfo, - ) -> Optional[dict]: + ) -> dict[str, Any]: """Verify credentials with password migration support.""" return await check_credential_with_migration( db=self._db, @@ -254,7 +254,7 @@ async def verify_credential_without_migration( domain_name: str, email: str, password: str, - ) -> Optional[dict]: + ) -> dict[str, Any]: """Verify credentials without password migration (for signout, etc.)""" return await check_credential( db=self._db, @@ -264,10 +264,21 @@ async def verify_credential_without_migration( ) @auth_db_source_resilience.apply() - async def fetch_user_row_by_uuid(self, user_uuid: UUID) -> Optional[UserRow]: + async def fetch_user_row_by_uuid(self, user_uuid: UUID) -> UserRow: """Fetch user row by UUID from database.""" async with self._db.begin_session() as db_session: - return await UserRow.query_user_by_uuid(user_uuid, db_session) + user_query = ( + sa.select(UserRow) + .where(UserRow.uuid == user_uuid) + .options( + joinedload(UserRow.main_keypair), + selectinload(UserRow.keypairs), + ) + ) + user_row = await db_session.scalar(user_query) + if user_row is None: + raise UserNotFound(extra_data=user_uuid) + return user_row @auth_db_source_resilience.apply() async def fetch_current_time(self) -> datetime: diff --git a/src/ai/backend/manager/repositories/auth/repository.py b/src/ai/backend/manager/repositories/auth/repository.py index d8dfb5e7799..5dfec2b90cf 100644 --- a/src/ai/backend/manager/repositories/auth/repository.py +++ b/src/ai/backend/manager/repositories/auth/repository.py @@ -45,8 +45,8 @@ async def create_user_with_keypair( ) @auth_repository_resilience.apply() - async def update_user_full_name(self, email: str, domain_name: str, full_name: str) -> bool: - return await self._db_source.modify_user_full_name(email, domain_name, full_name) + async def update_user_full_name(self, email: str, domain_name: str, full_name: str) -> None: + await self._db_source.modify_user_full_name(email, domain_name, full_name) @auth_repository_resilience.apply() async def update_user_password(self, email: str, password_info: PasswordInfo) -> None: @@ -76,7 +76,7 @@ async def check_credential_with_migration( domain_name: str, email: str, target_password_info: PasswordInfo, - ) -> Optional[dict]: + ) -> dict: return await self._db_source.verify_credential_with_migration( domain_name, email, target_password_info ) @@ -87,14 +87,14 @@ async def check_credential_without_migration( domain_name: str, email: str, password: str, - ) -> Optional[dict]: + ) -> dict: """Check credentials without password migration (for signout, etc.)""" return await self._db_source.verify_credential_without_migration( domain_name, email, password ) @auth_repository_resilience.apply() - async def get_user_row_by_uuid(self, user_uuid: UUID) -> Optional[UserRow]: + async def get_user_row_by_uuid(self, user_uuid: UUID) -> UserRow: return await self._db_source.fetch_user_row_by_uuid(user_uuid) @auth_repository_resilience.apply() diff --git a/src/ai/backend/manager/services/auth/service.py b/src/ai/backend/manager/services/auth/service.py index 181e4402fe5..65d791450d7 100644 --- a/src/ai/backend/manager/services/auth/service.py +++ b/src/ai/backend/manager/services/auth/service.py @@ -18,7 +18,6 @@ GroupMembershipNotFoundError, PasswordExpired, UserCreationError, - UserNotFound, ) from ai.backend.manager.errors.common import ( GenericBadRequest, @@ -139,16 +138,12 @@ async def authorize(self, action: AuthorizeAction) -> AuthorizeActionResult: action.email, target_password_info=target_password_info, ) - if user is None: - raise AuthorizationFailed("User credential mismatch.") if user["status"] == UserStatus.BEFORE_VERIFICATION: raise AuthorizationFailed("This account needs email verification.") if user["status"] in INACTIVE_USER_STATUSES: raise AuthorizationFailed("User credential mismatch.") await self._check_password_age(user, auth_config) user_row = await self._auth_repository.get_user_row_by_uuid(user["uuid"]) - if user_row is None: - raise UserNotFound(extra_data=user["uuid"]) main_keypair_row = user_row.get_main_keypair_row() if main_keypair_row is None: raise AuthorizationFailed("No API keypairs found.") @@ -288,22 +283,20 @@ async def signout(self, action: SignoutAction) -> SignoutActionResult: if action.email != action.requester_email: raise GenericForbidden("Not the account owner") email = action.email - result = await self._auth_repository.check_credential_without_migration( + await self._auth_repository.check_credential_without_migration( action.domain_name, email, action.password, ) - if result is None: - raise GenericBadRequest("Invalid email and/or password") await self._auth_repository.deactivate_user_and_keypairs(email) return SignoutActionResult(success=True) async def update_full_name(self, action: UpdateFullNameAction) -> UpdateFullNameActionResult: - success = await self._auth_repository.update_user_full_name( + await self._auth_repository.update_user_full_name( action.email, action.domain_name, action.full_name ) - return UpdateFullNameActionResult(success=success) + return UpdateFullNameActionResult(success=True) async def update_password(self, action: UpdatePasswordAction) -> UpdatePasswordActionResult: domain_name = action.domain_name @@ -316,13 +309,14 @@ async def update_password(self, action: UpdatePasswordAction) -> UpdatePasswordA success=False, message="new password mismatch", ) - user = await self._auth_repository.check_credential_without_migration( - domain_name, - email, - action.old_password, - ) - if user is None: - log.info(log_fmt + ": old password mismtach", *log_args) + try: + await self._auth_repository.check_credential_without_migration( + domain_name, + email, + action.old_password, + ) + except AuthorizationFailed: + log.info(log_fmt + ": old password mismatch", *log_args) raise AuthorizationFailed("Old password mismatch") # [Hooking point for VERIFY_PASSWORD_FORMAT with the ALL_COMPLETED requirement] @@ -364,8 +358,6 @@ async def update_password_no_auth( action.email, password=action.current_password, ) - if checked_user is None: - raise AuthorizationFailed("User credential mismatch.") new_password = action.new_password if compare_to_hashed_password(new_password, checked_user["password"]): raise AuthorizationFailed("Cannot update to the same password as an existing password.") diff --git a/tests/manager/repositories/auth/test_auth_repository.py b/tests/manager/repositories/auth/test_auth_repository.py index 3d36dcf9df0..0b07b0c50f4 100644 --- a/tests/manager/repositories/auth/test_auth_repository.py +++ b/tests/manager/repositories/auth/test_auth_repository.py @@ -11,6 +11,7 @@ import pytest import sqlalchemy as sa +from ai.backend.common.exception import UserNotFound from ai.backend.manager.data.auth.hash import PasswordHashAlgorithm from ai.backend.manager.data.auth.types import UserData from ai.backend.manager.data.group.types import GroupData @@ -359,12 +360,10 @@ async def test_update_user_full_name( ) -> None: """Test updating user full name""" update_name = "Updated Full Name" - result = await auth_repository.update_user_full_name( + await auth_repository.update_user_full_name( sample_user_data.email, sample_user_data.domain_name, update_name ) - assert result is True - # Verify full name was updated async with database_engine.begin_session() as db_sess: user = await db_sess.scalar( @@ -373,6 +372,16 @@ async def test_update_user_full_name( assert user is not None assert user.full_name == update_name + @pytest.mark.asyncio + async def test_update_user_full_name_not_found( + self, auth_repository: AuthRepository, default_domain: DomainTestData + ) -> None: + """Test updating user full name when user doesn't exist""" + with pytest.raises(UserNotFound): + await auth_repository.update_user_full_name( + "nonexistent@example.com", default_domain.name, "Some Name" + ) + @pytest.mark.asyncio async def test_update_user_password( self, @@ -499,6 +508,14 @@ async def test_get_user_row_by_uuid( assert result.uuid == sample_user_data.uuid assert result.email == sample_user_data.email + @pytest.mark.asyncio + async def test_get_user_row_by_uuid_not_found(self, auth_repository: AuthRepository) -> None: + """Test getting user row by UUID when user doesn't exist""" + non_existent_uuid = UUID("99999999-9999-9999-9999-999999999999") + + with pytest.raises(UserNotFound): + await auth_repository.get_user_row_by_uuid(non_existent_uuid) + @pytest.mark.asyncio async def test_get_current_time(self, auth_repository: AuthRepository) -> None: """Test getting current time from database""" diff --git a/tests/manager/services/auth/test_authorize.py b/tests/manager/services/auth/test_authorize.py index 83dd3301fa9..aa7baedef55 100644 --- a/tests/manager/services/auth/test_authorize.py +++ b/tests/manager/services/auth/test_authorize.py @@ -116,7 +116,9 @@ async def test_authorize_invalid_credentials( mock_hook_plugin_ctx.dispatch.return_value = HookResult( status=HookResults.PASSED, result=None, reason=None ) - mock_auth_repository.check_credential_with_migration.return_value = None + mock_auth_repository.check_credential_with_migration.side_effect = AuthorizationFailed( + "User credential mismatch." + ) action = AuthorizeAction( type=AuthTokenType.KEYPAIR, diff --git a/tests/manager/services/auth/test_signout.py b/tests/manager/services/auth/test_signout.py index e75349bb703..0b401ee7cd6 100644 --- a/tests/manager/services/auth/test_signout.py +++ b/tests/manager/services/auth/test_signout.py @@ -3,7 +3,8 @@ import pytest -from ai.backend.manager.errors.common import GenericBadRequest, GenericForbidden +from ai.backend.manager.errors.auth import AuthorizationFailed +from ai.backend.manager.errors.common import GenericForbidden from ai.backend.manager.models.user import UserRole, UserStatus from ai.backend.manager.repositories.auth.repository import AuthRepository from ai.backend.manager.services.auth.actions.signout import SignoutAction @@ -85,10 +86,12 @@ async def test_signout_fails_with_invalid_credentials( requester_email="user@example.com", ) - # Setup invalid credential check - returns None for invalid credentials - mock_auth_repository.check_credential_without_migration.return_value = None + # Setup invalid credential check - raises AuthorizationFailed for invalid credentials + mock_auth_repository.check_credential_without_migration.side_effect = AuthorizationFailed( + "User credential mismatch." + ) - with pytest.raises(GenericBadRequest): + with pytest.raises(AuthorizationFailed): await auth_service.signout(action) mock_auth_repository.check_credential_without_migration.assert_called_once() diff --git a/tests/manager/services/auth/test_update_full_name.py b/tests/manager/services/auth/test_update_full_name.py index a3156d4a9c6..863d088e533 100644 --- a/tests/manager/services/auth/test_update_full_name.py +++ b/tests/manager/services/auth/test_update_full_name.py @@ -2,6 +2,7 @@ import pytest +from ai.backend.common.exception import UserNotFound from ai.backend.manager.repositories.auth.repository import AuthRepository from ai.backend.manager.services.auth.actions.update_full_name import ( UpdateFullNameAction, @@ -60,16 +61,18 @@ async def test_update_full_name_fails_for_nonexistent_user( full_name="Some Name", ) - mock_auth_repository.update_user_full_name.return_value = False + mock_auth_repository.update_user_full_name.side_effect = UserNotFound( + extra_data={"email": action.email, "domain": action.domain_name} + ) - result = await auth_service.update_full_name(action) + with pytest.raises(UserNotFound): + await auth_service.update_full_name(action) mock_auth_repository.update_user_full_name.assert_called_once_with( action.email, action.domain_name, action.full_name, ) - assert result.success is False @pytest.mark.asyncio diff --git a/tests/manager/services/auth/test_update_password.py b/tests/manager/services/auth/test_update_password.py index 8a7f2a10fc4..b5a3d7743f9 100644 --- a/tests/manager/services/auth/test_update_password.py +++ b/tests/manager/services/auth/test_update_password.py @@ -137,8 +137,10 @@ async def test_update_password_fails_with_incorrect_old_password( reason=None, ) - # Invalid old password - mock_auth_repository.check_credential_without_migration.return_value = None + # Invalid old password - raises AuthorizationFailed + mock_auth_repository.check_credential_without_migration.side_effect = AuthorizationFailed( + "User credential mismatch." + ) with pytest.raises(AuthorizationFailed): await auth_service.update_password(action) diff --git a/tests/manager/services/auth/test_update_password_no_auth.py b/tests/manager/services/auth/test_update_password_no_auth.py index a2188b540d1..cb6f4dc4162 100644 --- a/tests/manager/services/auth/test_update_password_no_auth.py +++ b/tests/manager/services/auth/test_update_password_no_auth.py @@ -140,8 +140,10 @@ async def test_update_password_no_auth_fails_with_incorrect_current_password( reason=None, ) - # Invalid current password - mock_auth_repository.check_credential_without_migration.return_value = None + # Invalid current password - raises AuthorizationFailed + mock_auth_repository.check_credential_without_migration.side_effect = AuthorizationFailed( + "User credential mismatch." + ) with pytest.raises(AuthorizationFailed): await auth_service.update_password_no_auth(action)