diff --git a/backend/btrixcloud/crawlconfigs.py b/backend/btrixcloud/crawlconfigs.py index 1dcfabc615..37e6695743 100644 --- a/backend/btrixcloud/crawlconfigs.py +++ b/backend/btrixcloud/crawlconfigs.py @@ -27,6 +27,11 @@ import aiohttp from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorClientSession, + AsyncIOMotorDatabase, +) import pymongo from .pagination import DEFAULT_PAGE_SIZE, paginated_format @@ -126,8 +131,8 @@ class CrawlConfigOps: def __init__( self, - dbclient, - mdb, + dbclient: AsyncIOMotorClient, + mdb: AsyncIOMotorDatabase, user_manager, org_ops, crawl_manager, @@ -190,7 +195,7 @@ def __init__( raise TypeError("The channel list must include a 'default' channel") self._crawler_proxies_last_updated = None - self._crawler_proxies_map = None + self._crawler_proxies_map: dict[str, CrawlerProxy] | None = None if DEFAULT_PROXY_ID and DEFAULT_PROXY_ID not in self.get_crawler_proxies_map(): raise ValueError( @@ -648,7 +653,10 @@ async def update_crawl_config( orig_dict["id"] = uuid4() last_rev = ConfigRevision(**orig_dict) - last_rev = await self.config_revs.insert_one(last_rev.to_dict()) + # [TODO] 2025-12-04 emma: I don't think this needs to be assigned to + # a variable at all, but I'm not 100% sure so for now I'm ignoring + # the type mismatch here. + last_rev = await self.config_revs.insert_one(last_rev.to_dict()) # type: ignore # set update query query = update.dict(exclude_unset=True) @@ -914,10 +922,14 @@ async def mark_profiles_in_use(self, profiles: List[Profile], org: Organization) return profiles - async def get_running_crawl(self, cid: UUID) -> Optional[CrawlOut]: + async def get_running_crawl( + self, cid: UUID, session: AsyncIOMotorClientSession | None = None + ) -> Optional[CrawlOut]: """Return the id of currently running crawl for this config, if any""" # crawls = await self.crawl_manager.list_running_crawls(cid=crawlconfig.id) - crawls, _ = await self.crawl_ops.list_crawls(cid=cid, running_only=True) + crawls, _ = await self.crawl_ops.list_crawls( + cid=cid, running_only=True, session=session + ) if len(crawls) == 1: return crawls[0] @@ -1090,13 +1102,18 @@ async def get_crawl_config_revs( return revisions, total async def make_inactive_or_delete( - self, crawlconfig: CrawlConfig, org: Organization + self, + crawlconfig: CrawlConfig, + org: Organization, + session: AsyncIOMotorClientSession | None = None, ): """Make config inactive if crawls exist, otherwise move to inactive list""" query = {"inactive": True} - is_running = await self.get_running_crawl(crawlconfig.id) is not None + is_running = ( + await self.get_running_crawl(crawlconfig.id, session=session) is not None + ) if is_running: raise HTTPException(status_code=400, detail="crawl_running_cant_deactivate") @@ -1107,7 +1124,8 @@ async def make_inactive_or_delete( # if no crawls have been run, actually delete if not crawlconfig.crawlAttemptCount: result = await self.crawl_configs.delete_one( - {"_id": crawlconfig.id, "oid": crawlconfig.oid} + {"_id": crawlconfig.id, "oid": crawlconfig.oid}, + session=session, ) if result.deleted_count != 1: @@ -1116,7 +1134,7 @@ async def make_inactive_or_delete( if crawlconfig and crawlconfig.config.seedFileId: try: await self.file_ops.delete_seed_file( - crawlconfig.config.seedFileId, org + crawlconfig.config.seedFileId, org, session=session ) except HTTPException: pass @@ -1127,6 +1145,7 @@ async def make_inactive_or_delete( if not await self.crawl_configs.find_one_and_update( {"_id": crawlconfig.id, "inactive": {"$ne": True}}, {"$set": query}, + session=session, ): raise HTTPException(status_code=404, detail="failed_to_deactivate") @@ -1142,7 +1161,9 @@ async def do_make_inactive(self, crawlconfig: CrawlConfig, org: Organization): async with await self.dbclient.start_session() as sesh: async with sesh.start_transaction(): - status = await self.make_inactive_or_delete(crawlconfig, org) + status = await self.make_inactive_or_delete( + crawlconfig, org, session=sesh + ) return {"success": True, "status": status} @@ -1249,10 +1270,12 @@ async def run_now_internal( await self.check_if_too_many_waiting_crawls(org) if crawlconfig.profileid: - profile_filename, crawlconfig.proxyId, _ = ( - await self.profiles.get_profile_filename_proxy_channel( - crawlconfig.profileid, org - ) + ( + profile_filename, + crawlconfig.proxyId, + _, + ) = await self.profiles.get_profile_filename_proxy_channel( + crawlconfig.profileid, org ) if not profile_filename: raise HTTPException(status_code=400, detail="invalid_profile_id") @@ -1612,8 +1635,8 @@ async def stats_recompute_all(crawl_configs, crawls, cid: UUID): # pylint: disable=redefined-builtin,invalid-name,too-many-locals,too-many-arguments def init_crawl_config_api( app, - dbclient, - mdb, + dbclient: AsyncIOMotorClient, + mdb: AsyncIOMotorDatabase, user_dep, user_manager, org_ops, diff --git a/backend/btrixcloud/crawls.py b/backend/btrixcloud/crawls.py index eb273b5c7e..1ad6705f94 100644 --- a/backend/btrixcloud/crawls.py +++ b/backend/btrixcloud/crawls.py @@ -25,6 +25,7 @@ from fastapi import Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse +from motor.motor_asyncio import AsyncIOMotorClientSession from redis import asyncio as exceptions from redis.asyncio.client import Redis import pymongo @@ -183,6 +184,7 @@ async def list_crawls( sort_by: Optional[str] = None, sort_direction: int = -1, resources: bool = False, + session: AsyncIOMotorClientSession | None = None, ): """List all finished crawls from the db""" # pylint: disable=too-many-locals,too-many-branches,too-many-statements @@ -330,7 +332,7 @@ async def list_crawls( ) # Get total - cursor = self.crawls.aggregate(aggregate) + cursor = self.crawls.aggregate(aggregate, session=session) results = await cursor.to_list(length=1) result = results[0] items = result["items"] diff --git a/backend/btrixcloud/file_uploads.py b/backend/btrixcloud/file_uploads.py index 7d322f9f40..8cff9ff24a 100644 --- a/backend/btrixcloud/file_uploads.py +++ b/backend/btrixcloud/file_uploads.py @@ -9,6 +9,7 @@ import aiohttp from fastapi import APIRouter, Depends, HTTPException, Request +from motor.motor_asyncio import AsyncIOMotorClientSession, AsyncIOMotorDatabase import pymongo from .models import ( @@ -49,7 +50,7 @@ class FileUploadOps: # pylint: disable=too-many-locals, too-many-arguments, invalid-name - def __init__(self, mdb, org_ops, storage_ops): + def __init__(self, mdb: AsyncIOMotorDatabase, org_ops, storage_ops): self.files = mdb["file_uploads"] self.crawl_configs = mdb["crawl_configs"] self.crawls = mdb["crawls"] @@ -72,6 +73,7 @@ async def get_file_raw( file_id: UUID, org: Optional[Organization] = None, type_: Optional[str] = None, + session: AsyncIOMotorClientSession | None = None, ) -> Dict[str, Any]: """Get raw file from db""" query: dict[str, object] = {"_id": file_id} @@ -81,7 +83,7 @@ async def get_file_raw( if type_: query["type"] = type_ - res = await self.files.find_one(query) + res = await self.files.find_one(query, session=session) if not res: raise HTTPException(status_code=404, detail="file_not_found") @@ -93,9 +95,10 @@ async def get_seed_file( file_id: UUID, org: Optional[Organization] = None, type_: Optional[str] = None, + session: AsyncIOMotorClientSession | None = None, ) -> SeedFile: """Get file by UUID""" - file_raw = await self.get_file_raw(file_id, org, type_) + file_raw = await self.get_file_raw(file_id, org, type_, session=session) return SeedFile.from_dict(file_raw) async def get_seed_file_out( @@ -316,7 +319,10 @@ async def _parse_seed_info_from_file( return first_seed, seed_count async def delete_seed_file( - self, file_id: UUID, org: Organization + self, + file_id: UUID, + org: Organization, + session: AsyncIOMotorClientSession | None = None, ) -> Dict[str, bool]: """Delete user-uploaded file from storage and db""" file = await self.get_seed_file(file_id, org) @@ -337,7 +343,7 @@ async def delete_seed_file( await self.files.delete_one({"_id": file_id, "oid": org.id}) if file.type == "seedFile": await self.org_ops.inc_org_bytes_stored_field( - org.id, "bytesStoredSeedFiles", -file.size + org.id, "bytesStoredSeedFiles", -file.size, session=session ) return {"success": True} diff --git a/backend/btrixcloud/orgs.py b/backend/btrixcloud/orgs.py index 41dab0cec1..cfb8b9b03c 100644 --- a/backend/btrixcloud/orgs.py +++ b/backend/btrixcloud/orgs.py @@ -14,6 +14,7 @@ from typing import Optional, TYPE_CHECKING, Dict, Callable, List, AsyncGenerator, Any +from motor.motor_asyncio import AsyncIOMotorClientSession from pydantic import ValidationError from pymongo import ReturnDocument from pymongo.errors import AutoReconnect, DuplicateKeyError @@ -1439,10 +1440,12 @@ async def delete_org_and_data( async def recalculate_storage(self, org: Organization) -> dict[str, bool]: """Recalculate org storage use""" try: - total_crawl_size, crawl_size, upload_size = ( - await self.base_crawl_ops.calculate_org_crawl_file_storage( - org.id, - ) + ( + total_crawl_size, + crawl_size, + upload_size, + ) = await self.base_crawl_ops.calculate_org_crawl_file_storage( + org.id, ) profile_size = await self.profile_ops.calculate_org_profile_file_storage( org.id @@ -1485,11 +1488,19 @@ async def set_last_crawl_finished(self, oid: UUID): {"$set": {"lastCrawlFinished": last_crawl_finished}}, ) - async def inc_org_bytes_stored_field(self, oid: UUID, field: str, size: int): + async def inc_org_bytes_stored_field( + self, + oid: UUID, + field: str, + size: int, + session: AsyncIOMotorClientSession | None = None, + ): """Increment specific org bytesStored* field""" try: await self.orgs.find_one_and_update( - {"_id": oid}, {"$inc": {field: size, "bytesStored": size}} + {"_id": oid}, + {"$inc": {field: size, "bytesStored": size}}, + session=session, ) # pylint: disable=broad-exception-caught except Exception as err: