diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index cfd04f9d9de8..6c5f9113486f 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -45,7 +45,7 @@ ShapeType, Task, ) -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import ImportRQMeta from ..engine.log import ServerLogManager from .annotation import AnnotationIR, AnnotationManager, TrackManager @@ -2452,9 +2452,10 @@ def load_dataset_data(project_annotation, dataset: dm.Dataset, project_data): raise CvatImportError(f'Target project does not have label with name "{label.name}"') for subset_id, subset in enumerate(dataset.subsets().values()): job = rq.get_current_job() - job.meta[RQJobMetaField.STATUS] = 'Task from dataset is being created...' - job.meta[RQJobMetaField.PROGRESS] = (subset_id + job.meta.get(RQJobMetaField.TASK_PROGRESS, 0.)) / len(dataset.subsets().keys()) - job.save_meta() + job_meta = ImportRQMeta.for_job(job) + job_meta.status = 'Task from dataset is being created...' + job_meta.progress = (subset_id + (job_meta.task_progress or 0.)) / len(dataset.subsets().keys()) + job_meta.save() task_fields = { 'project': project_annotation.db_project, diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index ae03e480aa25..014fdcc27239 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -18,7 +18,7 @@ from cvat.apps.engine import models from cvat.apps.engine.log import DatasetLogManager from cvat.apps.engine.model_utils import bulk_create -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import ImportRQMeta from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer from cvat.apps.engine.task import _create_thread as create_task @@ -198,9 +198,10 @@ def data(self) -> dict: @transaction.atomic def import_dataset_as_project(src_file, project_id, format_name, conv_mask_to_poly): rq_job = rq.get_current_job() - rq_job.meta[RQJobMetaField.STATUS] = 'Dataset import has been started...' - rq_job.meta[RQJobMetaField.PROGRESS] = 0. - rq_job.save_meta() + rq_job_meta = ImportRQMeta.for_job(rq_job) + rq_job_meta.status = 'Dataset import has been started...' + rq_job_meta.progress = 0. + rq_job_meta.save() project = ProjectAnnotationAndData(project_id) project.init_from_db() diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py index 6061e3054f41..f3a38596905f 100644 --- a/cvat/apps/dataset_manager/views.py +++ b/cvat/apps/dataset_manager/views.py @@ -20,7 +20,7 @@ import cvat.apps.dataset_manager.task as task from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.models import Job, Project, Task -from cvat.apps.engine.rq_job_handler import RQMeta +from cvat.apps.engine.rq import ExportRQMeta from cvat.apps.engine.utils import get_rq_lock_by_user from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS @@ -88,7 +88,8 @@ def _patched_retry(*_1, **_2): settings.CVAT_QUEUES.EXPORT_DATA.value ) - user_id = current_rq_job.meta.get('user', {}).get('id') or -1 + rq_job_meta = ExportRQMeta.for_job(current_rq_job) + user_id = rq_job_meta.user.id or -1 with get_rq_lock_by_user(settings.CVAT_QUEUES.EXPORT_DATA.value, user_id): scheduled_rq_job: rq.job.Job = scheduler.enqueue_in( @@ -97,7 +98,7 @@ def _patched_retry(*_1, **_2): *current_rq_job.args, **current_rq_job.kwargs, job_id=current_rq_job.id, - meta=RQMeta.reset_meta_on_retry(current_rq_job.meta), + meta=rq_job_meta.get_meta_on_retry(), job_ttl=current_rq_job.ttl, job_result_ttl=current_rq_job.result_ttl, job_description=current_rq_job.description, diff --git a/cvat/apps/engine/background.py b/cvat/apps/engine/background.py index 63e38e4b1301..994b6d6e0e20 100644 --- a/cvat/apps/engine/background.py +++ b/cvat/apps/engine/background.py @@ -37,14 +37,12 @@ Task, ) from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export -from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField +from cvat.apps.engine.rq import ExportRQMeta, RQId, define_dependent_job from cvat.apps.engine.serializers import RqIdSerializer from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import ( build_annotations_file_name, build_backup_file_name, - define_dependent_job, - get_rq_job_meta, get_rq_lock_by_user, get_rq_lock_for_job, sendfile, @@ -229,7 +227,7 @@ def _handle_rq_job_v1( ) -> Optional[Response]: def is_result_outdated() -> bool: - return rq_job.meta[RQJobMetaField.REQUEST]["timestamp"] < instance_update_time + return ExportRQMeta.for_job(rq_job).request.timestamp < instance_update_time def handle_local_download() -> Response: with dm.util.get_export_cache_lock( @@ -342,7 +340,7 @@ def handle_local_download() -> Response: f"Export to {self.export_args.location} location is not implemented yet" ) elif rq_job_status == RQJobStatus.FAILED: - exc_info = rq_job.meta.get(RQJobMetaField.FORMATTED_EXCEPTION, str(rq_job.exc_info)) + exc_info = ExportRQMeta.for_job(rq_job).formatted_exception or str(rq_job.exc_info) rq_job.delete() return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) elif ( @@ -478,6 +476,11 @@ def setup_background_job( result_url = self.make_result_url() with get_rq_lock_by_user(queue, user_id): + meta = ExportRQMeta.build_for( + request=self.request, + db_obj=self.db_instance, + result_url=result_url, + ) queue.enqueue_call( func=func, args=func_args, @@ -485,9 +488,7 @@ def setup_background_job( "server_url": server_address, }, job_id=rq_id, - meta=get_rq_job_meta( - request=self.request, db_obj=self.db_instance, result_url=result_url - ), + meta=meta, depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), result_ttl=cache_ttl.total_seconds(), failure_ttl=cache_ttl.total_seconds(), @@ -548,7 +549,7 @@ def _handle_rq_job_v1( ) -> Optional[Response]: def is_result_outdated() -> bool: - return rq_job.meta[RQJobMetaField.REQUEST]["timestamp"] < last_instance_update_time + return ExportRQMeta.for_job(rq_job).request.timestamp < last_instance_update_time last_instance_update_time = timezone.localtime(self.db_instance.updated_date) timestamp = self.get_timestamp(last_instance_update_time) @@ -644,7 +645,7 @@ def is_result_outdated() -> bool: f"Export to {self.export_args.location} location is not implemented yet" ) elif rq_job_status == RQJobStatus.FAILED: - exc_info = rq_job.meta.get(RQJobMetaField.FORMATTED_EXCEPTION, str(rq_job.exc_info)) + exc_info = ExportRQMeta.for_job(rq_job).formatted_exception or str(rq_job.exc_info) rq_job.delete() return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) elif ( @@ -756,13 +757,16 @@ def setup_background_job( user_id = self.request.user.id with get_rq_lock_by_user(queue, user_id): + meta = ExportRQMeta.build_for( + request=self.request, + db_obj=self.db_instance, + result_url=result_url, + ) queue.enqueue_call( func=func, args=func_args, job_id=rq_id, - meta=get_rq_job_meta( - request=self.request, db_obj=self.db_instance, result_url=result_url - ), + meta=meta, depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), result_ttl=cache_ttl.total_seconds(), failure_ttl=cache_ttl.total_seconds(), diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index f08249b99cc8..f03de29f118e 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -67,7 +67,7 @@ StorageMethodChoice, ) from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export -from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField +from cvat.apps.engine.rq import ImportRQMeta, RQId, define_dependent_job from cvat.apps.engine.serializers import ( AnnotationGuideWriteSerializer, AssetWriteSerializer, @@ -89,8 +89,6 @@ from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import ( av_scan_paths, - define_dependent_job, - get_rq_job_meta, get_rq_lock_by_user, import_resource_with_clean_up_after, process_failed_job, @@ -1180,6 +1178,7 @@ def create_backup( log_exception(logger) raise + def _import( importer: TaskImporter | ProjectImporter, request: ExtendedRequest, @@ -1192,9 +1191,6 @@ def _import( ): rq_job = queue.fetch_job(rq_id) - if (user_id_from_meta := getattr(rq_job, 'meta', {}).get(RQJobMetaField.USER, {}).get('id')) and user_id_from_meta != request.user.id: - return Response(status=status.HTTP_403_FORBIDDEN) - if not rq_job: org_id = getattr(request.iam_context['organization'], 'id', None) location = location_conf.get('location') @@ -1239,19 +1235,25 @@ def _import( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): + meta = ImportRQMeta.build_for( + request=request, + db_obj=None, + tmp_file=filename, + ) rq_job = queue.enqueue_call( func=func, args=func_args, job_id=rq_id, - meta={ - 'tmp_file': filename, - **get_rq_job_meta(request=request, db_obj=None) - }, + meta=meta, depends_on=define_dependent_job(queue, user_id), result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() ) else: + rq_job_meta = ImportRQMeta.for_job(rq_job) + if rq_job_meta.user.id != request.user.id: + return Response(status=status.HTTP_403_FORBIDDEN) + if rq_job.is_finished: project_id = rq_job.return_value() rq_job.delete() diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index c89b5647e0e9..462cc2857e7f 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -53,7 +53,7 @@ ZipCompressedChunkWriter, load_image, ) -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import RQMetaWithFailureInfo from cvat.apps.engine.utils import ( CvatChunkTimestampMismatchError, format_list, @@ -107,9 +107,10 @@ def wait_for_rq_job(rq_job: rq.job.Job): if job_status in ("finished",): return elif job_status in ("failed",): - job_meta = rq_job.get_meta() - exc_type = job_meta.get(RQJobMetaField.EXCEPTION_TYPE, Exception) - exc_args = job_meta.get(RQJobMetaField.EXCEPTION_ARGS, ("Cannot create chunk",)) + rq_job.get_meta() # refresh from Redis + job_meta = RQMetaWithFailureInfo.for_job(rq_job) + exc_type = job_meta.exc_type or Exception + exc_args = job_meta.exc_args or ("Cannot create chunk",) raise exc_type(*exc_args) time.sleep(settings.CVAT_CHUNK_CREATE_CHECK_INTERVAL) diff --git a/cvat/apps/engine/mixins.py b/cvat/apps/engine/mixins.py index 66b087ca05c9..e4481731f2b2 100644 --- a/cvat/apps/engine/mixins.py +++ b/cvat/apps/engine/mixins.py @@ -41,7 +41,7 @@ RequestTarget, Task, ) -from cvat.apps.engine.rq_job_handler import RQId +from cvat.apps.engine.rq import RQId from cvat.apps.engine.serializers import DataSerializer, RqIdSerializer from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import is_dataset_export diff --git a/cvat/apps/engine/permissions.py b/cvat/apps/engine/permissions.py index d7a648e22f90..8f0123497643 100644 --- a/cvat/apps/engine/permissions.py +++ b/cvat/apps/engine/permissions.py @@ -13,7 +13,7 @@ from rest_framework.viewsets import ViewSet from rq.job import Job as RQJob -from cvat.apps.engine.rq_job_handler import is_rq_job_owner +from cvat.apps.engine.rq import is_rq_job_owner from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import is_dataset_export from cvat.apps.iam.permissions import ( diff --git a/cvat/apps/engine/rq.py b/cvat/apps/engine/rq.py new file mode 100644 index 000000000000..c860fc370032 --- /dev/null +++ b/cvat/apps/engine/rq.py @@ -0,0 +1,527 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from datetime import datetime +from typing import Any, Callable, Optional, Protocol, Union +from uuid import UUID + +import attrs +from django.conf import settings +from django.db.models import Model +from django.utils import timezone +from django_rq.queues import DjangoRQ +from rq.job import Dependency as RQDependency +from rq.job import Job as RQJob +from rq.registry import BaseRegistry as RQBaseRegistry + +from cvat.apps.engine.types import ExtendedRequest + +from .models import RequestAction, RequestSubresource, RequestTarget + + +class RQJobMetaField: + class UserField: + ID = "id" + USERNAME = "username" + EMAIL = "email" + + class RequestField: + UUID = "uuid" + TIMESTAMP = "timestamp" + + # common fields + FORMATTED_EXCEPTION = "formatted_exception" + REQUEST = "request" + USER = "user" + PROJECT_ID = "project_id" + TASK_ID = "task_id" + JOB_ID = "job_id" + LAMBDA = "lambda" + ORG_ID = "org_id" + ORG_SLUG = "org_slug" + STATUS = "status" + PROGRESS = "progress" + TASK_PROGRESS = "task_progress" + # export specific fields + RESULT_URL = "result_url" + RESULT = "result" + FUNCTION_ID = "function_id" + EXCEPTION_TYPE = "exc_type" + EXCEPTION_ARGS = "exc_args" + TMP_FILE = "tmp_file" + + +class WithMeta(Protocol): + meta: dict[str, Any] + + +class _AbstractRQMetaAttribute: + def __init__( + self, key: str, *, optional: bool = False, validator: Callable | None = None + ) -> None: + assert validator is None or callable(validator), "validator must be callable" + self._key = key + self._validator = validator + self._optional = optional + + +class _GettableRQMetaAttribute(_AbstractRQMetaAttribute): + def __get__(self, instance: WithMeta, objtype: type[WithMeta] | None = None): + if self._optional: + return instance.meta.get(self._key) + + return instance.meta[self._key] + + +class _SettableRQMetaAttribute(_AbstractRQMetaAttribute): + def validate(self, value): + if value is None and not self._optional: + raise ValueError(f"{self._key} is required") + if value is not None and self._validator and not self._validator(value): + raise ValueError("Wrong type") + + def __set__(self, instance: WithMeta, value: Any): + self.validate(value) + instance.meta[self._key] = value + + +class ImmutableRQMetaAttribute(_GettableRQMetaAttribute): + pass + + +class MutableRQMetaAttribute(_GettableRQMetaAttribute, _SettableRQMetaAttribute): + pass + + +class UserRQMetaAttribute(ImmutableRQMetaAttribute): + def __init__(self, *, optional: bool = False, validator: Callable | None = None) -> None: + super().__init__(RQJobMetaField.USER, optional=optional, validator=validator) + + def __get__(self, instance: WithMeta, objtype: type[WithMeta] | None = None): + assert RQJobMetaField.USER == self._key + return UserMeta(instance.meta[self._key]) + + +class RequestRQMetaAttribute(ImmutableRQMetaAttribute): + def __init__(self, *, optional: bool = False, validator: Callable | None = None) -> None: + super().__init__(RQJobMetaField.REQUEST, optional=optional, validator=validator) + + def __get__(self, instance: WithMeta, objtype: type[WithMeta] | None = None): + assert RQJobMetaField.REQUEST == self._key + return RequestMeta(instance.meta[self._key]) + + +class UserMeta: + id: int = ImmutableRQMetaAttribute( + RQJobMetaField.UserField.ID, validator=lambda x: isinstance(x, int) + ) + username: str = ImmutableRQMetaAttribute( + RQJobMetaField.UserField.USERNAME, validator=lambda x: isinstance(x, str) + ) + email: str = ImmutableRQMetaAttribute( + RQJobMetaField.UserField.EMAIL, validator=lambda x: isinstance(x, str) + ) + + def __init__(self, meta: dict[RQJobMetaField.UserField, Any]) -> None: + self._meta = meta + + @property + def meta(self) -> dict[RQJobMetaField.UserField, Any]: + return self._meta + + def to_dict(self): + return self.meta + + +class RequestMeta: + uuid = ImmutableRQMetaAttribute( + RQJobMetaField.RequestField.UUID, validator=lambda x: isinstance(x, str) + ) + timestamp = ImmutableRQMetaAttribute( + RQJobMetaField.RequestField.TIMESTAMP, validator=lambda x: isinstance(x, datetime) + ) + + def __init__(self, meta: dict[RQJobMetaField.RequestField, Any]) -> None: + self._meta = meta + + @property + def meta(self) -> dict[RQJobMetaField.RequestField, Any]: + return self._meta + + def to_dict(self): + return self.meta + + +class AbstractRQMeta(metaclass=ABCMeta): + def __init__( + self, *, job: RQJob | None = None, meta: dict[RQJobMetaField, Any] | None = None + ) -> None: + assert (job and not meta) or (meta and not job), "Only job or meta can be passed" + self._job = job + self._meta = meta + + @property + def meta(self) -> dict[RQJobMetaField, Any]: + return self._job.meta if self._job else self._meta + + def to_dict(self): + return self.meta + + @classmethod + def for_job(cls, job: RQJob): + return cls(job=job) + + @classmethod + def for_meta(cls, meta: dict[RQJobMetaField, Any]): + return cls(meta=meta) + + def save(self) -> None: + assert isinstance(self._job, RQJob), "To save meta, rq job must be set" + self._job.save_meta() + + @staticmethod + @abstractmethod + def _get_resettable_fields() -> list[RQJobMetaField]: + """Return a list of fields that must be reset on retry""" + + def get_meta_on_retry(self) -> dict[RQJobMetaField, Any]: + resettable_fields = self._get_resettable_fields() + + return {k: v for k, v in self._job.meta.items() if k not in resettable_fields} + + +class RQMetaWithFailureInfo(AbstractRQMeta): + formatted_exception = MutableRQMetaAttribute( + RQJobMetaField.FORMATTED_EXCEPTION, + validator=lambda x: isinstance(x, str), + optional=True, + ) + exc_type = MutableRQMetaAttribute( + RQJobMetaField.EXCEPTION_TYPE, + validator=lambda x: issubclass(x, BaseException), + optional=True, + ) + exc_args = MutableRQMetaAttribute( + RQJobMetaField.EXCEPTION_ARGS, + validator=lambda x: isinstance(x, tuple), + optional=True, + ) + + @staticmethod + def _get_resettable_fields() -> list[RQJobMetaField]: + return [ + RQJobMetaField.FORMATTED_EXCEPTION, + RQJobMetaField.EXCEPTION_TYPE, + RQJobMetaField.EXCEPTION_ARGS, + ] + + +class BaseRQMeta(RQMetaWithFailureInfo): + # immutable && required fields + user: UserMeta = UserRQMetaAttribute() + request: RequestMeta = RequestRQMetaAttribute() + + # immutable && optional fields + org_id: int | None = ImmutableRQMetaAttribute( + RQJobMetaField.ORG_ID, validator=lambda x: isinstance(x, int), optional=True + ) + org_slug: int | None = ImmutableRQMetaAttribute( + RQJobMetaField.ORG_SLUG, validator=lambda x: isinstance(x, str), optional=True + ) + project_id: int | None = ImmutableRQMetaAttribute( + RQJobMetaField.PROJECT_ID, validator=lambda x: isinstance(x, int), optional=True + ) + task_id: int | None = ImmutableRQMetaAttribute( + RQJobMetaField.TASK_ID, validator=lambda x: isinstance(x, int), optional=True + ) + job_id: int | None = ImmutableRQMetaAttribute( + RQJobMetaField.JOB_ID, validator=lambda x: isinstance(x, int), optional=True + ) + + # mutable && optional fields + progress: float | None = MutableRQMetaAttribute( + RQJobMetaField.PROGRESS, validator=lambda x: isinstance(x, float), optional=True + ) + status: str | None = MutableRQMetaAttribute( + RQJobMetaField.STATUS, validator=lambda x: isinstance(x, str), optional=True + ) + + @staticmethod + def _get_resettable_fields() -> list[RQJobMetaField]: + return RQMetaWithFailureInfo._get_resettable_fields() + [ + RQJobMetaField.PROGRESS, + RQJobMetaField.STATUS, + ] + + @classmethod + def build( + cls, + *, + request: ExtendedRequest, + db_obj: Model | None, + ): + # to prevent circular import + from cvat.apps.events.handlers import job_id, organization_slug, task_id + from cvat.apps.webhooks.signals import organization_id, project_id + + oid = organization_id(db_obj) + oslug = organization_slug(db_obj) + pid = project_id(db_obj) + tid = task_id(db_obj) + jid = job_id(db_obj) + + user = request.user + + return cls.for_meta( + { + RQJobMetaField.USER: UserMeta( + { + RQJobMetaField.UserField.ID: user.id, + RQJobMetaField.UserField.USERNAME: user.username, + RQJobMetaField.UserField.EMAIL: getattr(user, "email", ""), + } + ).to_dict(), + RQJobMetaField.REQUEST: RequestMeta( + { + RQJobMetaField.RequestField.UUID: request.uuid, + RQJobMetaField.RequestField.TIMESTAMP: timezone.localtime(), + } + ).to_dict(), + RQJobMetaField.ORG_ID: oid, + RQJobMetaField.ORG_SLUG: oslug, + RQJobMetaField.PROJECT_ID: pid, + RQJobMetaField.TASK_ID: tid, + RQJobMetaField.JOB_ID: jid, + } + ).to_dict() + + +class ExportRQMeta(BaseRQMeta): + result_url: str | None = ImmutableRQMetaAttribute( + RQJobMetaField.RESULT_URL, validator=lambda x: isinstance(x, str), optional=True + ) # will be changed to ExportResultInfo in the next PR + + @staticmethod + def _get_resettable_fields() -> list[RQJobMetaField]: + base_fields = BaseRQMeta._get_resettable_fields() + return base_fields + [RQJobMetaField.RESULT] + + @classmethod + def build_for( + cls, + *, + request: ExtendedRequest, + db_obj: Model | None, + result_url: str | None, + ): + base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) + + return cls.for_meta({**base_meta, RQJobMetaField.RESULT_URL: result_url}).to_dict() + + +class ImportRQMeta(BaseRQMeta): + # immutable && optional fields + tmp_file: str | None = ImmutableRQMetaAttribute( + RQJobMetaField.TMP_FILE, validator=lambda x: isinstance(x, str), optional=True + ) # used only when importing annotations|datasets|backups + + # mutable fields + task_progress: float | None = MutableRQMetaAttribute( + RQJobMetaField.TASK_PROGRESS, validator=lambda x: isinstance(x, float), optional=True + ) # used when importing project dataset + + @staticmethod + def _get_resettable_fields() -> list[RQJobMetaField]: + base_fields = BaseRQMeta._get_resettable_fields() + + return base_fields + [RQJobMetaField.TASK_PROGRESS] + + @classmethod + def build_for( + cls, + *, + request: ExtendedRequest, + db_obj: Model | None, + tmp_file: str | None = None, + ): + base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) + + return cls.for_meta({**base_meta, RQJobMetaField.TMP_FILE: tmp_file}).to_dict() + + +def is_rq_job_owner(rq_job: RQJob, user_id: int) -> bool: + return BaseRQMeta.for_job(rq_job).user.id == user_id + + +@attrs.frozen() +class RQId: + action: RequestAction = attrs.field(validator=attrs.validators.instance_of(RequestAction)) + target: RequestTarget = attrs.field(validator=attrs.validators.instance_of(RequestTarget)) + identifier: Union[int, UUID] = attrs.field(validator=attrs.validators.instance_of((int, UUID))) + subresource: Optional[RequestSubresource] = attrs.field( + validator=attrs.validators.optional(attrs.validators.instance_of(RequestSubresource)), + kw_only=True, + default=None, + ) + user_id: Optional[int] = attrs.field( + validator=attrs.validators.optional(attrs.validators.instance_of(int)), + kw_only=True, + default=None, + ) + format: Optional[str] = attrs.field( + validator=attrs.validators.optional(attrs.validators.instance_of(str)), + kw_only=True, + default=None, + ) + + _OPTIONAL_FIELD_REQUIREMENTS = { + RequestAction.AUTOANNOTATE: {"subresource": False, "format": False, "user_id": False}, + RequestAction.CREATE: {"subresource": False, "format": False, "user_id": False}, + RequestAction.EXPORT: {"subresource": True, "user_id": True}, + RequestAction.IMPORT: {"subresource": True, "format": False, "user_id": False}, + } + + def __attrs_post_init__(self) -> None: + for field, req in self._OPTIONAL_FIELD_REQUIREMENTS[self.action].items(): + if req: + if getattr(self, field) is None: + raise ValueError(f"{field} is required for the {self.action} action") + else: + if getattr(self, field) is not None: + raise ValueError(f"{field} is not allowed for the {self.action} action") + + # RQ ID templates: + # autoannotate:task- + # import:-- + # create:task- + # export:---in--format-by- + # export:--backup-by- + + def render( + self, + ) -> str: + common_prefix = f"{self.action}:{self.target}-{self.identifier}" + + if RequestAction.IMPORT == self.action: + return f"{common_prefix}-{self.subresource}" + elif RequestAction.EXPORT == self.action: + if self.format is None: + return f"{common_prefix}-{self.subresource}-by-{self.user_id}" + + format_to_be_used_in_urls = self.format.replace(" ", "_").replace(".", "@") + return f"{common_prefix}-{self.subresource}-in-{format_to_be_used_in_urls}-format-by-{self.user_id}" + elif self.action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: + return common_prefix + else: + assert False, f"Unsupported action {self.action!r} was found" + + @staticmethod + def parse(rq_id: str) -> RQId: + identifier: Optional[Union[UUID, int]] = None + subresource: Optional[RequestSubresource] = None + user_id: Optional[int] = None + anno_format: Optional[str] = None + + try: + action_and_resource, unparsed = rq_id.split("-", maxsplit=1) + action_str, target_str = action_and_resource.split(":") + action = RequestAction(action_str) + target = RequestTarget(target_str) + + if action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: + identifier = unparsed + elif RequestAction.IMPORT == action: + identifier, subresource_str = unparsed.rsplit("-", maxsplit=1) + subresource = RequestSubresource(subresource_str) + else: # action == export + identifier, subresource_str, unparsed = unparsed.split("-", maxsplit=2) + subresource = RequestSubresource(subresource_str) + + if RequestSubresource.BACKUP == subresource: + _, user_id = unparsed.split("-") + else: + unparsed, _, user_id = unparsed.rsplit("-", maxsplit=2) + # remove prefix(in-), suffix(-format) and restore original format name + # by replacing special symbols: "_" -> " ", "@" -> "." + anno_format = unparsed[3:-7].replace("_", " ").replace("@", ".") + + if identifier is not None: + if identifier.isdigit(): + identifier = int(identifier) + else: + identifier = UUID(identifier) + + if user_id is not None: + user_id = int(user_id) + + return RQId( + action=action, + target=target, + identifier=identifier, + subresource=subresource, + user_id=user_id, + format=anno_format, + ) + + except Exception as ex: + raise ValueError(f"The {rq_id!r} RQ ID cannot be parsed: {str(ex)}") from ex + + +def define_dependent_job( + queue: DjangoRQ, + user_id: int, + should_be_dependent: bool = settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER, + *, + rq_id: Optional[str] = None, +) -> RQDependency: + if not should_be_dependent: + return None + + queues: list[RQBaseRegistry | DjangoRQ] = [ + queue.deferred_job_registry, + queue, + queue.started_job_registry, + ] + # Since there is no cleanup implementation in DeferredJobRegistry, + # this registry can contain "outdated" jobs that weren't deleted from it + # but were added to another registry. Probably such situations can occur + # if there are active or deferred jobs when restarting the worker container. + filters = [lambda job: job.is_deferred, lambda _: True, lambda _: True] + all_user_jobs: list[RQJob] = [] + for q, f in zip(queues, filters): + job_ids = q.get_job_ids() + jobs = q.job_class.fetch_many(job_ids, q.connection) + jobs = filter( + lambda job: job and BaseRQMeta.for_job(job).user.id == user_id and f(job), jobs + ) + all_user_jobs.extend(jobs) + + if rq_id: + # Prevent cases where an RQ job depends on itself. + # It isn't possible to have multiple RQ jobs with the same ID in Redis. + # However, a race condition in request processing can lead to self-dependencies + # when 2 parallel requests attempt to enqueue RQ jobs with the same ID. + # This happens if an rq_job is fetched without a lock, + # but a lock is used when defining the dependent job and enqueuing a new one. + if any(rq_id == job.id for job in all_user_jobs): + return None + + # prevent possible cyclic dependencies + all_job_dependency_ids = { + dep_id.decode() for job in all_user_jobs for dep_id in job.dependency_ids or () + } + + if RQJob.redis_job_namespace_prefix + rq_id in all_job_dependency_ids: + return None + + return ( + RQDependency( + jobs=[sorted(all_user_jobs, key=lambda job: job.created_at)[-1]], allow_failure=True + ) + if all_user_jobs + else None + ) diff --git a/cvat/apps/engine/rq_job_handler.py b/cvat/apps/engine/rq_job_handler.py deleted file mode 100644 index d3b9c36e6072..000000000000 --- a/cvat/apps/engine/rq_job_handler.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (C) CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -from __future__ import annotations - -from typing import Any, Optional, Union -from uuid import UUID - -import attrs -from rq.job import Job as RQJob - -from .models import RequestAction, RequestSubresource, RequestTarget - - -class RQMeta: - @staticmethod - def get_resettable_fields() -> list[RQJobMetaField]: - """Return a list of fields that must be reset on retry""" - return [ - RQJobMetaField.FORMATTED_EXCEPTION, - RQJobMetaField.PROGRESS, - RQJobMetaField.TASK_PROGRESS, - RQJobMetaField.STATUS - ] - - @classmethod - def reset_meta_on_retry(cls, meta_to_update: dict[RQJobMetaField, Any]) -> dict[RQJobMetaField, Any]: - resettable_fields = cls.get_resettable_fields() - - return { - k: v for k, v in meta_to_update.items() if k not in resettable_fields - } - -class RQJobMetaField: - # common fields - FORMATTED_EXCEPTION = "formatted_exception" - REQUEST = 'request' - USER = 'user' - PROJECT_ID = 'project_id' - TASK_ID = 'task_id' - JOB_ID = 'job_id' - ORG_ID = 'org_id' - ORG_SLUG = 'org_slug' - STATUS = 'status' - PROGRESS = 'progress' - TASK_PROGRESS = 'task_progress' - # export specific fields - RESULT_URL = 'result_url' - FUNCTION_ID = 'function_id' - EXCEPTION_TYPE = 'exc_type' - EXCEPTION_ARGS = 'exc_args' - -def is_rq_job_owner(rq_job: RQJob, user_id: int) -> bool: - return rq_job.meta.get(RQJobMetaField.USER, {}).get('id') == user_id - -@attrs.frozen() -class RQId: - action: RequestAction = attrs.field( - validator=attrs.validators.instance_of(RequestAction) - ) - target: RequestTarget = attrs.field( - validator=attrs.validators.instance_of(RequestTarget) - ) - identifier: Union[int, UUID] = attrs.field( - validator=attrs.validators.instance_of((int, UUID)) - ) - subresource: Optional[RequestSubresource] = attrs.field( - validator=attrs.validators.optional( - attrs.validators.instance_of(RequestSubresource) - ), - kw_only=True, default=None, - ) - user_id: Optional[int] = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(int)), - kw_only=True, default=None, - ) - format: Optional[str] = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(str)), - kw_only=True, default=None, - ) - - _OPTIONAL_FIELD_REQUIREMENTS = { - RequestAction.AUTOANNOTATE: {"subresource": False, "format": False, "user_id": False}, - RequestAction.CREATE: {"subresource": False, "format": False, "user_id": False}, - RequestAction.EXPORT: {"subresource": True, "user_id": True}, - RequestAction.IMPORT: {"subresource": True, "format": False, "user_id": False}, - } - - def __attrs_post_init__(self) -> None: - for field, req in self._OPTIONAL_FIELD_REQUIREMENTS[self.action].items(): - if req: - if getattr(self, field) is None: - raise ValueError(f"{field} is required for the {self.action} action") - else: - if getattr(self, field) is not None: - raise ValueError(f"{field} is not allowed for the {self.action} action") - - # RQ ID templates: - # autoannotate:task- - # import:-- - # create:task- - # export:---in--format-by- - # export:--backup-by- - - def render( - self, - ) -> str: - common_prefix = f"{self.action}:{self.target}-{self.identifier}" - - if RequestAction.IMPORT == self.action: - return f"{common_prefix}-{self.subresource}" - elif RequestAction.EXPORT == self.action: - if self.format is None: - return ( - f"{common_prefix}-{self.subresource}-by-{self.user_id}" - ) - - format_to_be_used_in_urls = self.format.replace(" ", "_").replace(".", "@") - return f"{common_prefix}-{self.subresource}-in-{format_to_be_used_in_urls}-format-by-{self.user_id}" - elif self.action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: - return common_prefix - else: - assert False, f"Unsupported action {self.action!r} was found" - - @staticmethod - def parse(rq_id: str) -> RQId: - identifier: Optional[Union[UUID, int]] = None - subresource: Optional[RequestSubresource] = None - user_id: Optional[int] = None - anno_format: Optional[str] = None - - try: - action_and_resource, unparsed = rq_id.split("-", maxsplit=1) - action_str, target_str = action_and_resource.split(":") - action = RequestAction(action_str) - target = RequestTarget(target_str) - - if action in {RequestAction.CREATE, RequestAction.AUTOANNOTATE}: - identifier = unparsed - elif RequestAction.IMPORT == action: - identifier, subresource_str = unparsed.rsplit("-", maxsplit=1) - subresource = RequestSubresource(subresource_str) - else: # action == export - identifier, subresource_str, unparsed = unparsed.split("-", maxsplit=2) - subresource = RequestSubresource(subresource_str) - - if RequestSubresource.BACKUP == subresource: - _, user_id = unparsed.split("-") - else: - unparsed, _, user_id = unparsed.rsplit("-", maxsplit=2) - # remove prefix(in-), suffix(-format) and restore original format name - # by replacing special symbols: "_" -> " ", "@" -> "." - anno_format = unparsed[3:-7].replace("_", " ").replace("@", ".") - - if identifier is not None: - if identifier.isdigit(): - identifier = int(identifier) - else: - identifier = UUID(identifier) - - if user_id is not None: - user_id = int(user_id) - - return RQId( - action=action, - target=target, - identifier=identifier, - subresource=subresource, - user_id=user_id, - format=anno_format, - ) - - except Exception as ex: - raise ValueError(f"The {rq_id!r} RQ ID cannot be parsed: {str(ex)}") from ex diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 381d05a7719e..3c9960a384be 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -41,7 +41,7 @@ from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.permissions import TaskPermission -from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField +from cvat.apps.engine.rq import BaseRQMeta, ExportRQMeta, ImportRQMeta, RequestAction, RQId from cvat.apps.engine.task_validation import HoneypotFrameSelector from cvat.apps.engine.utils import ( CvatChunkTimestampMismatchError, @@ -54,6 +54,7 @@ reverse, take_by, ) +from cvat.apps.lambda_manager.rq import LambdaRQMeta from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) @@ -3508,7 +3509,8 @@ class RequestDataOperationSerializer(serializers.Serializer): def to_representation(self, rq_job: RQJob) -> dict[str, Any]: parsed_rq_id: RQId = rq_job.parsed_rq_id - return { + base_rq_job_meta = BaseRQMeta.for_job(rq_job) + representation = { "type": ":".join( [ parsed_rq_id.action, @@ -3516,12 +3518,16 @@ def to_representation(self, rq_job: RQJob) -> dict[str, Any]: ] ), "target": parsed_rq_id.target, - "project_id": rq_job.meta[RQJobMetaField.PROJECT_ID], - "task_id": rq_job.meta[RQJobMetaField.TASK_ID], - "job_id": rq_job.meta[RQJobMetaField.JOB_ID], - "format": parsed_rq_id.format, - "function_id": rq_job.meta.get(RQJobMetaField.FUNCTION_ID), + "project_id": base_rq_job_meta.project_id, + "task_id": base_rq_job_meta.task_id, + "job_id": base_rq_job_meta.job_id, } + if parsed_rq_id.action == RequestAction.AUTOANNOTATE: + representation["function_id"] = LambdaRQMeta.for_job(rq_job).function_id + elif parsed_rq_id.action in (RequestAction.IMPORT, RequestAction.EXPORT): + representation["format"] = parsed_rq_id.format + + return representation class RequestSerializer(serializers.Serializer): # SerializerMethodField is not used here to mark "status" field as required and fix schema generation. @@ -3545,17 +3551,23 @@ class RequestSerializer(serializers.Serializer): result_url = serializers.URLField(required=False, allow_null=True) result_id = serializers.IntegerField(required=False, allow_null=True) + def __init__(self, *args, **kwargs): + self._base_rq_job_meta: BaseRQMeta | None = None + super().__init__(*args, **kwargs) + @extend_schema_field(UserIdentifiersSerializer()) def get_owner(self, rq_job: RQJob) -> dict[str, Any]: - return UserIdentifiersSerializer(rq_job.meta[RQJobMetaField.USER]).data + assert self._base_rq_job_meta + return UserIdentifiersSerializer(self._base_rq_job_meta.user).data @extend_schema_field( serializers.FloatField(min_value=0, max_value=1, required=False, allow_null=True) ) def get_progress(self, rq_job: RQJob) -> Decimal: + rq_job_meta = ImportRQMeta.for_job(rq_job) # progress of task creation is stored in "task_progress" field # progress of project import is stored in "progress" field - return Decimal(rq_job.meta.get(RQJobMetaField.PROGRESS) or rq_job.meta.get(RQJobMetaField.TASK_PROGRESS) or 0.) + return Decimal(rq_job_meta.progress or rq_job_meta.task_progress or 0.) @extend_schema_field(serializers.DateTimeField(required=False, allow_null=True)) def get_expiry_date(self, rq_job: RQJob) -> Optional[str]: @@ -3573,20 +3585,19 @@ def get_expiry_date(self, rq_job: RQJob) -> Optional[str]: @extend_schema_field(serializers.CharField(allow_blank=True)) def get_message(self, rq_job: RQJob) -> str: + assert self._base_rq_job_meta rq_job_status = rq_job.get_status() message = '' if RQJobStatus.STARTED == rq_job_status: - message = rq_job.meta.get(RQJobMetaField.STATUS, '') + message = self._base_rq_job_meta.status or message elif RQJobStatus.FAILED == rq_job_status: - message = rq_job.meta.get( - RQJobMetaField.FORMATTED_EXCEPTION, - parse_exception_message(str(rq_job.exc_info or "Unknown error")), - ) + message = self._base_rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error")) return message def to_representation(self, rq_job: RQJob) -> dict[str, Any]: + self._base_rq_job_meta = BaseRQMeta.for_job(rq_job) representation = super().to_representation(rq_job) # FUTURE-TODO: support such statuses on UI @@ -3594,8 +3605,8 @@ def to_representation(self, rq_job: RQJob) -> dict[str, Any]: representation["status"] = RQJobStatus.QUEUED if representation["status"] == RQJobStatus.FINISHED: - if result_url := rq_job.meta.get(RQJobMetaField.RESULT_URL): - representation["result_url"] = result_url + if rq_job.parsed_rq_id.action == models.RequestAction.EXPORT: + representation["result_url"] = ExportRQMeta.for_job(rq_job).result_url if ( rq_job.parsed_rq_id.action == models.RequestAction.IMPORT diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index a7766274397b..ca6819304626 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -9,12 +9,12 @@ import os import re import shutil -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Sequence from contextlib import closing from copy import deepcopy from datetime import datetime, timezone from pathlib import Path -from typing import Any, NamedTuple, Optional, Union +from typing import Any, Callable, NamedTuple, Optional, Union from urllib import parse as urlparse from urllib import request as urlrequest @@ -47,17 +47,10 @@ ) from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.models import RequestAction, RequestTarget -from cvat.apps.engine.rq_job_handler import RQId +from cvat.apps.engine.rq import ImportRQMeta, RQId, define_dependent_job from cvat.apps.engine.task_validation import HoneypotFrameSelector from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import ( - av_scan_paths, - define_dependent_job, - format_list, - get_rq_job_meta, - get_rq_lock_by_user, - take_by, -) +from cvat.apps.engine.utils import av_scan_paths, format_list, get_rq_lock_by_user, take_by from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session from utils.dataset_manifest import ImageManifestManager, VideoManifestManager, is_manifest from utils.dataset_manifest.core import VideoManifestValidator, is_dataset_manifest @@ -84,7 +77,7 @@ def create( func=_create_thread, args=(db_task.pk, data), job_id=rq_id, - meta=get_rq_job_meta(request=request, db_obj=db_task), + meta=ImportRQMeta.build_for(request=request, db_obj=db_task), depends_on=define_dependent_job(q, user_id), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds(), ) @@ -108,13 +101,13 @@ class SegmentsParams(NamedTuple): def _copy_data_from_share_point( server_files: list[str], + *, + update_status_callback: Callable[[str], None], upload_dir: str, - server_dir: Optional[str] = None, - server_files_exclude: Optional[list[str]] = None, + server_dir: str | None = None, + server_files_exclude: list[str] | None = None, ): - job = rq.get_current_job() - job.meta['status'] = 'Data are being copied from source..' - job.save_meta() + update_status_callback('Data are being copied from source..') filtered_server_files = server_files.copy() @@ -199,11 +192,10 @@ def _segments(): def _create_segments_and_jobs( db_task: models.Task, *, + update_status_callback: Callable[[str], None], job_file_mapping: Optional[JobFileMapping] = None, ): - rq_job = rq.get_current_job() - rq_job.meta['status'] = 'Task is being saved in database' - rq_job.save_meta() + update_status_callback('Task is being saved in database') segments, segment_size, overlap = _generate_segment_params( db_task=db_task, job_file_mapping=job_file_mapping, @@ -442,8 +434,12 @@ def _validate_scheme(url): if parsed_url.scheme not in ALLOWED_SCHEMES: raise ValueError('Unsupported URL scheme: {}. Only http and https are supported'.format(parsed_url.scheme)) -def _download_data(urls, upload_dir): - job = rq.get_current_job() +def _download_data( + urls: Iterable[str], + upload_dir: str, + *, + update_status_callback: Callable[[str], None], +): local_files = {} with make_requests_session() as session: @@ -453,8 +449,7 @@ def _download_data(urls, upload_dir): raise Exception("filename collision: {}".format(name)) _validate_scheme(url) slogger.glob.info("Downloading: {}".format(url)) - job.meta['status'] = '{} is being downloaded..'.format(url) - job.save_meta() + update_status_callback('{} is being downloaded..'.format(url)) response = session.get(url, stream=True, proxies=PROXIES_FOR_UNTRUSTED_URLS) if response.status_code == 200: @@ -592,10 +587,11 @@ def _create_thread( slogger.glob.info("create task #{}".format(db_task.id)) job = rq.get_current_job() + rq_job_meta = ImportRQMeta.for_job(job) - def _update_status(msg: str) -> None: - job.meta['status'] = msg - job.save_meta() + def update_status(msg: str) -> None: + rq_job_meta.status = msg + rq_job_meta.save() job_file_mapping = _validate_job_file_mapping(db_task, data) @@ -608,7 +604,7 @@ def _update_status(msg: str) -> None: is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE if data['remote_files'] and not is_dataset_import: - data['remote_files'] = _download_data(data['remote_files'], upload_dir) + data['remote_files'] = _download_data(data['remote_files'], upload_dir, update_status_callback=update_status) # find and validate manifest file manifest_files = _find_manifest_files(data) @@ -752,7 +748,7 @@ def _update_status(msg: str) -> None: # Packed media must be downloaded for task creation any(v for k, v in media.items() if k != 'image') ): - _update_status("Downloading input media") + update_status("Downloading input media") filtered_data = [] for files in (i for i in media.values() if i): @@ -786,7 +782,11 @@ def _update_status(msg: str) -> None: # this means that the data has not been downloaded from the storage to the host _copy_data_from_share_point( (data['server_files'] + [manifest_file]) if manifest_file else data['server_files'], - upload_dir, data.get('server_files_path'), data.get('server_files_exclude')) + upload_dir=upload_dir, + server_dir=data.get('server_files_path'), + server_files_exclude=data.get('server_files_exclude'), + update_status_callback=update_status, + ) manifest_root = upload_dir elif is_data_in_cloud: # we should sort media before sorting in the extractor because the manifest structure should match to the sorted media @@ -808,8 +808,7 @@ def _update_status(msg: str) -> None: av_scan_paths(upload_dir) - job.meta['status'] = 'Media files are being extracted...' - job.save_meta() + update_status('Media files are being extracted...') # If upload from server_files image and directories # need to update images list by all found images in directories @@ -1033,7 +1032,7 @@ def _update_status(msg: str) -> None: if task_mode == MEDIA_TYPES['video']['mode']: if manifest_file: try: - _update_status('Validating the input manifest file') + update_status('Validating the input manifest file') manifest = VideoManifestValidator( source_path=os.path.join(upload_dir, media_files[0]), @@ -1055,13 +1054,13 @@ def _update_status(msg: str) -> None: base_msg = "Failed to parse the uploaded manifest file" slogger.glob.warning(ex, exc_info=True) - _update_status(base_msg) + update_status(base_msg) else: manifest = None if not manifest: try: - _update_status('Preparing a manifest file') + update_status('Preparing a manifest file') # TODO: maybe generate manifest in a temp directory manifest = VideoManifestManager(db_data.get_manifest_path()) @@ -1073,7 +1072,7 @@ def _update_status(msg: str) -> None: ) manifest.create() - _update_status('A manifest has been created') + update_status('A manifest has been created') except Exception as ex: manifest.remove() @@ -1085,7 +1084,7 @@ def _update_status(msg: str) -> None: base_msg = "" slogger.glob.warning(ex, exc_info=True) - _update_status( + update_status( f"Failed to create manifest for the uploaded video{base_msg}. " "A manifest will not be used in this task" ) @@ -1401,7 +1400,7 @@ def _update_status(msg: str) -> None: slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id)) - _create_segments_and_jobs(db_task, job_file_mapping=job_file_mapping) + _create_segments_and_jobs(db_task, job_file_mapping=job_file_mapping, update_status_callback=update_status) if validation_params and validation_params['mode'] == models.ValidationMode.GT: # The RNG backend must not change to yield reproducible frame picks, @@ -1551,9 +1550,10 @@ def update_progress(self, progress: float): status_message, progress_animation[self._call_counter] ) - self._rq_job.meta['status'] = status_message - self._rq_job.meta['task_progress'] = progress or 0. - self._rq_job.save_meta() + rq_job_meta = ImportRQMeta.for_job(self._rq_job) + rq_job_meta.status = status_message + rq_job_meta.task_progress = progress or 0. + rq_job_meta.save() self._call_counter = (self._call_counter + 1) % len(progress_animation) diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index 27044bb37efe..66d691ba7731 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -28,23 +28,18 @@ from datumaro.util.os_util import walk from django.conf import settings from django.core.exceptions import ValidationError -from django.utils import timezone from django.utils.http import urlencode from django_rq.queues import DjangoRQ from django_sendfile import sendfile as _sendfile from PIL import Image from redis.lock import Lock from rest_framework.reverse import reverse as _reverse -from rq.job import Dependency as RQDependency from rq.job import Job as RQJob -from rq.registry import BaseRegistry as RQBaseRegistry from cvat.apps.engine.types import ExtendedRequest Import = namedtuple("Import", ["module", "name", "alias"]) -KEY_TO_EXCLUDE_FROM_DEPENDENCY = 'exclude_from_dependency' - def parse_imports(source_code: str): root = ast.parse(source_code) @@ -161,57 +156,6 @@ def process_failed_job(rq_job: RQJob): return msg -def define_dependent_job( - queue: DjangoRQ, - user_id: int, - should_be_dependent: bool = settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER, - *, - rq_id: str | None = None, -) -> RQDependency | None: - if not should_be_dependent: - return None - - queues: list[RQBaseRegistry | DjangoRQ] = [queue.deferred_job_registry, queue, queue.started_job_registry] - # Since there is no cleanup implementation in DeferredJobRegistry, - # this registry can contain "outdated" jobs that weren't deleted from it - # but were added to another registry. Probably such situations can occur - # if there are active or deferred jobs when restarting the worker container. - filters = [lambda job: job.is_deferred, lambda _: True, lambda _: True] - all_user_jobs: list[RQJob] = [] - for q, f in zip(queues, filters): - job_ids = q.get_job_ids() - jobs = q.job_class.fetch_many(job_ids, q.connection) - jobs = filter(lambda job: job and job.meta.get("user", {}).get("id") == user_id and f(job), jobs) - all_user_jobs.extend(jobs) - - if rq_id: - # Prevent cases where an RQ job depends on itself. - # It isn't possible to have multiple RQ jobs with the same ID in Redis. - # However, a race condition in request processing can lead to self-dependencies - # when 2 parallel requests attempt to enqueue RQ jobs with the same ID. - # This happens if an rq_job is fetched without a lock, - # but a lock is used when defining the dependent job and enqueuing a new one. - if any(rq_id == job.id for job in all_user_jobs): - return None - - # prevent possible cyclic dependencies - all_job_dependency_ids = { - dep_id.decode() - for job in all_user_jobs - for dep_id in job.dependency_ids or () - } - - if RQJob.redis_job_namespace_prefix + rq_id in all_job_dependency_ids: - return None - - user_jobs = [ - job for job in all_user_jobs - if not job.meta.get(KEY_TO_EXCLUDE_FROM_DEPENDENCY) - ] - - return RQDependency(jobs=[sorted(user_jobs, key=lambda job: job.created_at)[-1]], allow_failure=True) if user_jobs else None - - def get_rq_lock_by_user(queue: DjangoRQ, user_id: int, *, timeout: Optional[int] = 30, blocking_timeout: Optional[int] = None) -> Union[Lock, nullcontext]: if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER: return queue.connection.lock( @@ -232,45 +176,6 @@ def get_rq_lock_for_job(queue: DjangoRQ, rq_id: str, *, timeout: int = 60, block blocking_timeout=blocking_timeout, ) -def get_rq_job_meta( - request: ExtendedRequest, - db_obj: Any, - *, - result_url: Optional[str] = None, -): - # to prevent circular import - from cvat.apps.events.handlers import job_id, organization_slug, task_id - from cvat.apps.webhooks.signals import organization_id, project_id - - oid = organization_id(db_obj) - oslug = organization_slug(db_obj) - pid = project_id(db_obj) - tid = task_id(db_obj) - jid = job_id(db_obj) - - meta = { - 'user': { - 'id': getattr(request.user, "id", None), - 'username': getattr(request.user, "username", None), - 'email': getattr(request.user, "email", None), - }, - 'request': { - "uuid": request.uuid, - "timestamp": timezone.localtime(), - }, - 'org_id': oid, - 'org_slug': oslug, - 'project_id': pid, - 'task_id': tid, - 'job_id': jid, - } - - - if result_url: - meta['result_url'] = result_url - - return meta - def reverse(viewname, *, args=None, kwargs=None, query_params: Optional[dict[str, str]] = None, request: ExtendedRequest | None = None, diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 7461a4cab61c..10883a277cff 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -122,7 +122,13 @@ get_cloud_storage_for_import_or_export, get_iam_context, ) -from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField, is_rq_job_owner +from cvat.apps.engine.rq import ( + ImportRQMeta, + RQId, + RQMetaWithFailureInfo, + define_dependent_job, + is_rq_job_owner, +) from cvat.apps.engine.serializers import ( AboutSerializer, AnnotationFileSerializer, @@ -167,8 +173,6 @@ from cvat.apps.engine.types import ExtendedRequest from cvat.apps.engine.utils import ( av_scan_paths, - define_dependent_job, - get_rq_job_meta, get_rq_lock_by_user, get_rq_lock_for_job, import_resource_with_clean_up_after, @@ -731,6 +735,7 @@ def preview(self, request: ExtendedRequest, pk: int): def _get_rq_response(queue, job_id): queue = django_rq.get_queue(queue) job = queue.fetch_job(job_id) + rq_job_meta = ImportRQMeta.for_job(job) response = {} if job is None or job.is_finished: response = { "state": "Finished" } @@ -740,8 +745,8 @@ def _get_rq_response(queue, job_id): response = { "state": "Failed", "message": job.exc_info } else: response = { "state": "Started" } - response['message'] = job.meta.get('status', '') - response['progress'] = job.meta.get('progress', 0.) + response['message'] = rq_job_meta.status or "" + response['progress'] = rq_job_meta.progress or 0. return response @@ -1695,6 +1700,7 @@ def status(self, request, pk): def _get_rq_response(queue, job_id): queue = django_rq.get_queue(queue) job = queue.fetch_job(job_id) + rq_job_meta = ImportRQMeta.for_job(job) response = {} if job is None or job.is_finished: response = { "state": "Finished" } @@ -1708,9 +1714,9 @@ def _get_rq_response(queue, job_id): response = { "state": "Failed", "message": parse_exception_message(job.exc_info or "Unknown error") } else: response = { "state": "Started" } - if job.meta.get('status'): - response['message'] = job.meta['status'] - response['progress'] = job.meta.get('task_progress', 0.) + if rq_job_meta.status: + response['message'] = rq_job_meta.status + response['progress'] = rq_job_meta.progress or 0. return response @@ -3374,13 +3380,14 @@ def perform_destroy(self, instance): super().perform_destroy(instance) target.touch() -def rq_exception_handler(rq_job, exc_type, exc_value, tb): - rq_job.meta[RQJobMetaField.FORMATTED_EXCEPTION] = "".join( +def rq_exception_handler(rq_job: RQJob, exc_type: type[Exception], exc_value: Exception, tb): + rq_job_meta = RQMetaWithFailureInfo.for_job(rq_job) + rq_job_meta.formatted_exception = "".join( traceback.format_exception_only(exc_type, exc_value)) if rq_job.origin == settings.CVAT_QUEUES.CHUNKS.value: - rq_job.meta[RQJobMetaField.EXCEPTION_TYPE] = exc_type - rq_job.meta[RQJobMetaField.EXCEPTION_ARGS] = exc_value.args - rq_job.save_meta() + rq_job_meta.exc_type = exc_type + rq_job_meta.exc_args = exc_value.args + rq_job_meta.save() return True @@ -3475,15 +3482,13 @@ def _import_annotations( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): + meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) queue.enqueue_call( func=func, args=func_args, job_id=rq_id, depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), - meta={ - 'tmp_file': filename, - **get_rq_job_meta(request=request, db_obj=db_obj), - }, + meta=meta, result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() ) @@ -3602,14 +3607,12 @@ def _import_project_dataset( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): + meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) queue.enqueue_call( func=func, args=func_args, job_id=rq_id, - meta={ - 'tmp_file': filename, - **get_rq_job_meta(request=request, db_obj=db_obj), - }, + meta=meta, depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() diff --git a/cvat/apps/events/export.py b/cvat/apps/events/export.py index e3f66a7f5740..b7500571ceb9 100644 --- a/cvat/apps/events/export.py +++ b/cvat/apps/events/export.py @@ -17,7 +17,7 @@ from cvat.apps.dataset_manager.views import log_exception from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import RQMetaWithFailureInfo from cvat.apps.engine.utils import sendfile slogger = ServerLogManager(__name__) @@ -136,7 +136,7 @@ def export(request, filter_query, queue_name): "query_id": query_id, } - queue = django_rq.get_queue(queue_name) + queue: django_rq.queues.DjangoRQ = django_rq.get_queue(queue_name) rq_job = queue.fetch_job(rq_id) if rq_job: @@ -152,7 +152,8 @@ def export(request, filter_query, queue_name): if os.path.exists(file_path): return Response(status=status.HTTP_201_CREATED) elif rq_job.is_failed: - exc_info = rq_job.meta.get(RQJobMetaField.FORMATTED_EXCEPTION, str(rq_job.exc_info)) + rq_job_meta = RQMetaWithFailureInfo.for_job(rq_job) + exc_info = rq_job_meta.formatted_exception or str(rq_job.exc_info) rq_job.delete() return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) else: diff --git a/cvat/apps/events/handlers.py b/cvat/apps/events/handlers.py index 7205dfdb8185..b5c3a532ada7 100644 --- a/cvat/apps/events/handlers.py +++ b/cvat/apps/events/handlers.py @@ -22,7 +22,7 @@ Task, User, ) -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import BaseRQMeta from cvat.apps.engine.serializers import ( BasicUserSerializer, CloudStorageReadSerializer, @@ -97,7 +97,14 @@ def job_id(instance): return None -def get_user(instance=None): +def get_user(instance=None) -> User | dict | None: + def _get_user_from_rq_job(rq_job: rq.job.Job) -> dict | None: + # RQ jobs created in the chunks queue have no user info + try: + return BaseRQMeta.for_job(rq_job).user.to_dict() + except TypeError: + return None + # Try to get current user from request user = get_current_user() if user is not None: @@ -105,11 +112,11 @@ def get_user(instance=None): # Try to get user from rq_job if isinstance(instance, rq.job.Job): - return instance.meta.get(RQJobMetaField.USER, None) + return _get_user_from_rq_job(instance) else: rq_job = rq.get_current_job() if rq_job: - return rq_job.meta.get(RQJobMetaField.USER, None) + return _get_user_from_rq_job(rq_job) if isinstance(instance, User): return instance @@ -118,16 +125,23 @@ def get_user(instance=None): def get_request(instance=None): + def _get_request_from_rq_job(rq_job: rq.job.Job) -> dict | None: + # RQ jobs created in the chunks queue have no request info + try: + return BaseRQMeta.for_job(rq_job).request.to_dict() + except TypeError: + return None + request = get_current_request() if request is not None: return request if isinstance(instance, rq.job.Job): - return instance.meta.get(RQJobMetaField.REQUEST, None) + return _get_request_from_rq_job(instance) else: rq_job = rq.get_current_job() if rq_job: - return rq_job.meta.get(RQJobMetaField.REQUEST, None) + return _get_request_from_rq_job(rq_job) return None @@ -569,11 +583,12 @@ def handle_function_call( def handle_rq_exception(rq_job, exc_type, exc_value, tb): - oid = rq_job.meta.get(RQJobMetaField.ORG_ID, None) - oslug = rq_job.meta.get(RQJobMetaField.ORG_SLUG, None) - pid = rq_job.meta.get(RQJobMetaField.PROJECT_ID, None) - tid = rq_job.meta.get(RQJobMetaField.TASK_ID, None) - jid = rq_job.meta.get(RQJobMetaField.JOB_ID, None) + rq_job_meta = BaseRQMeta.for_job(rq_job) + oid = rq_job_meta.org_id + oslug = rq_job_meta.org_slug + pid = rq_job_meta.project_id + tid = rq_job_meta.task_id + jid = rq_job_meta.job_id uid = user_id(rq_job) uname = user_name(rq_job) uemail = user_email(rq_job) diff --git a/cvat/apps/lambda_manager/rq.py b/cvat/apps/lambda_manager/rq.py new file mode 100644 index 000000000000..2a3e19f06f1b --- /dev/null +++ b/cvat/apps/lambda_manager/rq.py @@ -0,0 +1,46 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from django.db.models import Model + +from cvat.apps.engine.rq import ( + BaseRQMeta, + ImmutableRQMetaAttribute, + MutableRQMetaAttribute, + RQJobMetaField, +) +from cvat.apps.engine.types import ExtendedRequest + + +class LambdaRQMeta(BaseRQMeta): + # immutable fields + function_id: str = ImmutableRQMetaAttribute( + RQJobMetaField.FUNCTION_ID, validator=lambda x: isinstance(x, str) + ) + lambda_: bool = ImmutableRQMetaAttribute( + RQJobMetaField.LAMBDA, validator=lambda x: isinstance(x, str) + ) + # FUTURE-FIXME: progress should be in [0, 1] range + progress: int | None = MutableRQMetaAttribute( + RQJobMetaField.PROGRESS, validator=lambda x: isinstance(x, int), optional=True + ) + + @classmethod + def build_for( + cls, + *, + request: ExtendedRequest, + db_obj: Model, + function_id: str, + ): + base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) + return cls.for_meta( + { + **base_meta, + RQJobMetaField.FUNCTION_ID: function_id, + RQJobMetaField.LAMBDA: True, + } + ).to_dict() diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index c682dc76af80..f7a4ed11a4ae 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -45,19 +45,15 @@ SourceType, Task, ) -from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField +from cvat.apps.engine.rq import RQId, define_dependent_job from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import ( - define_dependent_job, - get_rq_job_meta, - get_rq_lock_by_user, - get_rq_lock_for_job, -) +from cvat.apps.engine.utils import get_rq_lock_by_user, get_rq_lock_for_job from cvat.apps.events.handlers import handle_function_call from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.lambda_manager.models import FunctionKind from cvat.apps.lambda_manager.permissions import LambdaPermission +from cvat.apps.lambda_manager.rq import LambdaRQMeta from cvat.apps.lambda_manager.serializers import ( FunctionCallRequestSerializer, FunctionCallSerializer, @@ -595,7 +591,7 @@ def get_jobs(self): ) jobs = queue.job_class.fetch_many(job_ids, queue.connection) - return [LambdaJob(job) for job in jobs if job and job.meta.get("lambda")] + return [LambdaJob(job) for job in jobs if job and LambdaRQMeta.for_job(job).lambda_] def enqueue( self, @@ -635,17 +631,15 @@ def enqueue( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): + meta = LambdaRQMeta.build_for( + request=request, + db_obj=Job.objects.get(pk=job) if job else Task.objects.get(pk=task), + function_id=lambda_func.id, + ) rq_job = queue.create_job( LambdaJob(None), job_id=rq_id, - meta={ - **get_rq_job_meta( - request, - db_obj=(Job.objects.get(pk=job) if job else Task.objects.get(pk=task)), - ), - RQJobMetaField.FUNCTION_ID: lambda_func.id, - "lambda": True, - }, + meta=meta, kwargs={ "function": lambda_func, "threshold": threshold, @@ -668,7 +662,7 @@ def enqueue( def fetch_job(self, pk): queue = self._get_queue() rq_job = queue.fetch_job(pk) - if rq_job is None or not rq_job.meta.get("lambda"): + if rq_job is None or not LambdaRQMeta.for_job(rq_job).lambda_: raise ValidationError( "{} lambda job is not found".format(pk), code=status.HTTP_404_NOT_FOUND ) @@ -697,7 +691,7 @@ def to_dict(self): ), }, "status": self.job.get_status(), - "progress": self.job.meta.get("progress", 0), + "progress": LambdaRQMeta.for_job(self.job).progress, "enqueued": self.job.enqueued_at, "started": self.job.started_at, "ended": self.job.ended_at, @@ -911,10 +905,11 @@ def _map(sublabel_body): # progress is in [0, 1] range def _update_progress(progress): job = rq.get_current_job() + rq_job_meta = LambdaRQMeta.for_job(job) # If the job has been deleted, get_status will return None. Thus it will # exist the loop. - job.meta["progress"] = int(progress * 100) - job.save_meta() + rq_job_meta.progress = int(progress * 100) + rq_job_meta.save() return job.get_status() diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index 8be628a5bedb..126c5758425a 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -52,13 +52,9 @@ User, ValidationMode, ) +from cvat.apps.engine.rq import BaseRQMeta, define_dependent_job from cvat.apps.engine.types import ExtendedRequest -from cvat.apps.engine.utils import ( - define_dependent_job, - get_rq_job_meta, - get_rq_lock_by_user, - get_rq_lock_for_job, -) +from cvat.apps.engine.utils import get_rq_lock_by_user, get_rq_lock_for_job from cvat.apps.profiler import silk_profile from cvat.apps.quality_control import models from cvat.apps.quality_control.models import ( @@ -2303,7 +2299,7 @@ def schedule_custom_quality_check_job( self._check_task_quality, task_id=task.id, job_id=rq_id, - meta=get_rq_job_meta(request=request, db_obj=task), + meta=BaseRQMeta.build(request=request, db_obj=task), result_ttl=self._JOB_RESULT_TTL, failure_ttl=self._JOB_RESULT_TTL, depends_on=dependency, diff --git a/cvat/apps/quality_control/views.py b/cvat/apps/quality_control/views.py index 1dec746a0577..239ff6de8144 100644 --- a/cvat/apps/quality_control/views.py +++ b/cvat/apps/quality_control/views.py @@ -21,7 +21,7 @@ from cvat.apps.engine.mixins import PartialUpdateModelMixin from cvat.apps.engine.models import Task -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.rq import BaseRQMeta from cvat.apps.engine.serializers import RqIdSerializer from cvat.apps.engine.utils import get_server_url from cvat.apps.quality_control import quality_reports as qc @@ -294,7 +294,7 @@ def create(self, request, *args, **kwargs): if ( not rq_job or not QualityReportPermission.create_scope_check_status( - request, rq_job_owner_id=rq_job.meta[RQJobMetaField.USER]["id"] + request, rq_job_owner_id=BaseRQMeta.for_job(rq_job).user.id ) .check_access() .allow diff --git a/pyproject.toml b/pyproject.toml index 5d6f8b32ae05..75cb40c23f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ extend-exclude = """ |permissions.py |plugins.py |renderers.py - |rq_job_handler.py |schema.py |serializers.py |signals.py