diff --git a/.env b/.env index 9e80c734cd6..71fe3b0c430 100644 --- a/.env +++ b/.env @@ -43,6 +43,9 @@ ASSET_SERVER_URL=http://host.docker.internal/web_asset_store.xml # Make sure to set the `ASSET_SERVER_KEY` to a unique value ASSET_SERVER_KEY=your_asset_server_access_key +# Information to connect to a Redis database +# Specify will use this database as a process broker and storage for temporary +# values REDIS_HOST=redis REDIS_PORT=6379 REDIS_DB_INDEX=0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a7c74a536f6..8565522320d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,6 +73,11 @@ jobs: options: --health-cmd="mariadb-admin ping" --health-interval=5s --health-timeout=2s --health-retries=3 + redis: + image: redis:latest + ports: + - 6379 + steps: - uses: actions/checkout@v4 @@ -128,6 +133,8 @@ jobs: echo "MIGRATOR_PASSWORD = 'MasterPassword'" >> specifyweb/settings/local_specify_settings.py echo "APP_USER_NAME = 'MasterUser'" >> specifyweb/settings/local_specify_settings.py echo "APP_USER_PASSWORD = 'MasterPassword'" >> specifyweb/settings/local_specify_settings.py + echo "REDIS_HOST = '127.0.0.1'" >> specifyweb/settings/local_specify_settings.py + echo "REDIS_PORT = ${{ job.services.redis.ports[6379] }}" >> specifyweb/settings/local_specify_settings.py - name: Need these files to be present run: diff --git a/specifyweb/backend/businessrules/rules/attachment_rules.py b/specifyweb/backend/businessrules/rules/attachment_rules.py index 991fc0ac704..fc52e109979 100644 --- a/specifyweb/backend/businessrules/rules/attachment_rules.py +++ b/specifyweb/backend/businessrules/rules/attachment_rules.py @@ -30,8 +30,8 @@ def attachment_jointable_save(sender, obj): attachee = get_attachee(obj) obj.attachment.tableid = attachee.specify_model.tableId - scopetype, scope = Scoping(attachee)() - obj.attachment.scopetype, obj.attachment.scopeid = scopetype, scope.id + scopetype, scope = Scoping.from_instance(attachee) + obj.attachment.scopetype, obj.attachment.scopeid = scopetype.value, scope.id obj.attachment.save() diff --git a/specifyweb/backend/redis_cache/connect.py b/specifyweb/backend/redis_cache/connect.py new file mode 100644 index 00000000000..571f47a2b3f --- /dev/null +++ b/specifyweb/backend/redis_cache/connect.py @@ -0,0 +1,116 @@ +from redis import Redis +from django.conf import settings + +class RedisConnection: + def __init__(self, + host=getattr(settings, "REDIS_HOST", None), + port=getattr(settings, "REDIS_PORT", None), + db_index=getattr(settings, "REDIS_DB_INDEX", 0), + decode_responses=True): + if None in (host, port, db_index): + raise ValueError( + "Redis is not correctly configured", host, port, db_index) + self.host = host + self.port = port + self.db_index = db_index + self.decode_responses = decode_responses + self.connection = Redis( + host=self.host, + port=self.port, + db=self.db_index, + decode_responses=self.decode_responses + ) + + def delete(self, key: str): + return self.connection.delete(key) + + +class RedisDataType: + def __init__(self, established: RedisConnection) -> None: + self._established = established + + @property + def connection(self): + return self._established.connection + + def delete(self, key: str): + return self._established.delete(key) + +class RedisList(RedisDataType): + """ + See https://redis.io/docs/latest/develop/data-types/lists/ + """ + + def left_push(self, key: str, value) -> int: + return self.connection.lpush(key, value) + + def right_push(self, key: str, value) -> int: + return self.connection.rpush(key, value) + + def right_pop(self, key: str) -> str | bytes | None: + return self.connection.rpop(key) + + def left_pop(self, key: str) -> str | bytes | None: + return self.connection.lpop(key) + + def length(self, key: str) -> int: + return self.connection.llen(key) + + def range(self, key: str, start_index: int, end_index: int) -> list[str] | list[bytes]: + return self.connection.lrange(key, start_index, end_index) + + def trim(self, key: str, start_index: int, end_index: int) -> list[str] | list[bytes]: + return self.connection.ltrim(key, start_index, end_index) + + def blocking_left_pop(self, key: str, timeout: int) -> str | bytes | None: + response = self.connection.blpop(key, timeout=timeout) + if response is None: + return None + _filled_list_key, item = response + return item + +class RedisSet(RedisDataType): + """ + See https://redis.io/docs/latest/develop/data-types/sets/ + """ + def add(self, key: str, *values: str) -> int: + return self.connection.sadd(key, *values) + + def is_member(self, key: str, value: str) -> bool: + is_member = int(self.connection.sismember(key, value)) + return is_member == 1 + + def remove(self, key: str, value: str): + return self.connection.srem(key, value) + + def size(self, key: str) -> int: + return self.connection.scard(key) + + def members(self, key: str) -> set[str]: + return self.connection.smembers(key) + + def union(self, *keys: str) -> set[str]: + return self.connection.sunion(*keys) + + def intersection(self, *keys: str) -> set[str]: + return self.connection.sinter(*keys) + + def difference(self, *keys: str) -> set[str]: + return self.connection.sdiff(*keys) + +class RedisString(RedisDataType): + """ + See https://redis.io/docs/latest/develop/data-types/strings/ + """ + + def set(self, key, value, time_to_live=None, override_existing=True): + flags = { + "ex": time_to_live, + "nx": not override_existing + } + self.connection.set(key, value, **flags) + + def get(self, key, delete_key=False) -> str | bytes | None: + if delete_key: + return self.connection.getdel(key) + return self.connection.get(key) diff --git a/specifyweb/backend/redis_cache/rqueue.py b/specifyweb/backend/redis_cache/rqueue.py new file mode 100644 index 00000000000..2e1b1d38206 --- /dev/null +++ b/specifyweb/backend/redis_cache/rqueue.py @@ -0,0 +1,61 @@ +import json + +from typing import Callable, Generator, Iterable, cast + +from specifyweb.backend.redis_cache.connect import RedisConnection, RedisList, RedisSet + +type Serialized = str | bytes | bytearray + +type Serializer[T] = Callable[[T], str] +type Deserializer[T] = Callable[[Serialized], T] + +def default_serializer(obj) -> str: + return str(obj) + +def default_deserializer(serialized: Serialized): + return serialized + +class RedisQueue[T]: + def __init__(self, connection: RedisConnection, key: str, + serializer: Serializer[T] | None = None, + deserializer: Deserializer[T] | None = None): + self.connection = RedisList(connection) + self.key = key + self.serializer = serializer or cast(Serializer[T], default_serializer) + self.deserializer = deserializer or cast(Deserializer[T], default_deserializer) + + def key_name(self, *name_parts: str | None): + key_name = "_".join([self.key, *(part for part in name_parts if part is not None)]) + return key_name + + def push(self, *objs: T, sub_key: str | None = None) -> int: + key_name = self.key_name(sub_key) + return self.connection.right_push(key_name, *self._serialize_objs(*objs)) + + def pop(self, sub_key: str | None = None) -> T | None: + key_name = self.key_name(sub_key) + popped = self.connection.left_pop(key_name) + if popped is None: + return None + return self.deserializer(popped) + + def wait_and_pop(self, timeout: int = 0, sub_key: str | None = None) -> T: + key_name = self.key_name(sub_key) + popped = self.connection.blocking_left_pop(key_name, timeout) + if popped is None: + raise TimeoutError("No items in queue after timeout") + return self.deserializer(popped) + + def peek(self, sub_key: str | None = None) -> T | None: + key_name = self.key_name(sub_key) + top_value = self._deserialize_objs(*self.connection.range(key_name, 0, 0)) + if len(top_value) == 0: + return None + return top_value[0] + + def _serialize_objs(self, *objs: T) -> Generator[str, None, None]: + return (self.serializer(obj) for obj in objs) + + def _deserialize_objs(self, *serialized: Serialized): + return tuple(self.deserializer(obj) for obj in serialized) + diff --git a/specifyweb/backend/stored_queries/execution.py b/specifyweb/backend/stored_queries/execution.py index 837cd1b2741..5dbb4228052 100644 --- a/specifyweb/backend/stored_queries/execution.py +++ b/specifyweb/backend/stored_queries/execution.py @@ -51,7 +51,6 @@ class QuerySort: def by_id(sort_id: QUREYFIELD_SORT_T): return QuerySort.SORT_TYPES[sort_id] - def DefaultQueryFormatterProps(): return ObjectFormatterProps( format_agent_type=False, diff --git a/specifyweb/backend/workbench/upload/parse.py b/specifyweb/backend/workbench/upload/parse.py index 42a2d2e1931..d187f81e87d 100644 --- a/specifyweb/backend/workbench/upload/parse.py +++ b/specifyweb/backend/workbench/upload/parse.py @@ -44,7 +44,12 @@ class ParseSucess(NamedTuple): ParseResult = ParseSucess | ParseFailure -def parse_field(table_name: str, field_name: str, raw_value: str, formatter: ScopedFormatter | None = None) -> ParseResult: +def parse_field( + table_name: str, + field_name: str, + raw_value: str, + formatter: ScopedFormatter | None = None + ) -> ParseResult: table = datamodel.get_table_strict(table_name) field = table.get_field_strict(field_name) @@ -170,7 +175,11 @@ def parse_date(table: Table, field_name: str, dateformat: str, value: str) -> Pa return ParseFailure('badDateFormat', {'value': value, 'format': dateformat}) -def parse_formatted(uiformatter: ScopedFormatter, table: Table, field: Field | Relationship, value: str) -> ParseResult: +def parse_formatted( + uiformatter: ScopedFormatter, + table: Table, + field: Field | Relationship, + value: str) -> ParseResult: try: canonicalized = uiformatter(table, value) except FormatMismatch as e: diff --git a/specifyweb/backend/workbench/upload/scoping.py b/specifyweb/backend/workbench/upload/scoping.py index cb91f468c84..24b7e887b11 100644 --- a/specifyweb/backend/workbench/upload/scoping.py +++ b/specifyweb/backend/workbench/upload/scoping.py @@ -1,11 +1,12 @@ from functools import reduce -from typing import Any, cast +from typing import Any, Callable, cast from collections.abc import Callable from specifyweb.specify.datamodel import datamodel, Table, is_tree_table from specifyweb.specify.utils.func import CustomRepr +from specifyweb.specify.utils.autonumbering import AutonumberingLockDispatcher from specifyweb.specify.models_utils.load_datamodel import DoesNotExistError from specifyweb.specify import models from specifyweb.backend.trees.utils import get_default_treedef @@ -113,7 +114,8 @@ def extend_columnoptions( fieldname: str, row: Row | None = None, toOne: dict[str, Uploadable] | None = None, - context: ScopeContext | None = None + context: ScopeContext | None = None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> ExtendedColumnOptions: context = context or ScopeContext() @@ -125,7 +127,7 @@ def extend_columnoptions( ui_formatter = get_or_defer_formatter(collection, tablename, fieldname, row, toOne, context) scoped_formatter = ( - None if ui_formatter is None else ui_formatter.apply_scope(collection) + None if ui_formatter is None else ui_formatter.apply_scope(collection, lock_dispatcher) ) if tablename.lower() == "collectionobjecttype" and fieldname.lower() == "name": @@ -256,7 +258,11 @@ def get_or_defer_formatter( def apply_scoping_to_uploadtable( - ut: UploadTable, collection, context: ScopeContext | None = None, row=None + ut: UploadTable, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> ScopedUploadTable: # IMPORTANT: # before this comment, collection is untrusted and unreliable @@ -276,7 +282,7 @@ def apply_scoping_to_uploadtable( apply_scoping = lambda key, value: get_deferred_scoping( key, table.django_name, value, row, ut, context - ).apply_scoping(collection, context, row) + ).apply_scoping(collection, context, row, lock_dispatcher=lock_dispatcher) to_ones = { key: adjuster(apply_scoping(key, value), key) @@ -299,7 +305,14 @@ def _backref(key): scoped_table = ScopedUploadTable( name=ut.name, wbcols={ - f: extend_columnoptions(colopts, collection, table.name, f, row, ut.toOne, context) + f: extend_columnoptions(colopts, + collection, + table.name, + f, + row, + ut.toOne, + context, + lock_dispatcher=lock_dispatcher) for f, colopts in ut.wbcols.items() }, static=ut.static, @@ -347,7 +360,10 @@ def set_order_number( return tmr._replace(strong_ignore=[*tmr.strong_ignore, *to_ignore]) -def apply_scoping_to_treerecord(tr: TreeRecord, collection, context: ScopeContext | None = None) -> ScopedTreeRecord: +def apply_scoping_to_treerecord(tr: TreeRecord, + collection, + context: ScopeContext | None = None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None) -> ScopedTreeRecord: table = datamodel.get_table_strict(tr.name) treedef = get_default_treedef(table, collection) @@ -376,7 +392,11 @@ def apply_scoping_to_treerecord(tr: TreeRecord, collection, context: ScopeContex else r._replace(treedef_id=treedef.id) if r.treedef_id is None # Adjust treeid for parsed JSON plans else r ): { - f: extend_columnoptions(colopts, collection, table.name, f) + f: extend_columnoptions(colopts, + collection, + table.name, + f, + lock_dispatcher=lock_dispatcher) for f, colopts in cols.items() } for r, cols in tr.ranks.items() diff --git a/specifyweb/backend/workbench/upload/treerecord.py b/specifyweb/backend/workbench/upload/treerecord.py index 81ea03542ce..ce82b1c3644 100644 --- a/specifyweb/backend/workbench/upload/treerecord.py +++ b/specifyweb/backend/workbench/upload/treerecord.py @@ -3,13 +3,14 @@ """ import logging -from typing import Any, NamedTuple, Union, Optional +from typing import Any, NamedTuple, Union, Optional, Callable from django.db import transaction, IntegrityError from typing_extensions import TypedDict from specifyweb.backend.businessrules.exceptions import BusinessRuleException from specifyweb.specify import models +from specifyweb.specify.utils.autonumbering import AutonumberingLockDispatcher from specifyweb.backend.workbench.upload.clone import clone_record from specifyweb.backend.workbench.upload.predicates import ( SPECIAL_TREE_FIELDS_TO_SKIP, @@ -149,11 +150,15 @@ class TreeRecord(NamedTuple): ranks: dict[str | TreeRankRecord, dict[str, ColumnOptions]] def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedTreeRecord": from .scoping import apply_scoping_to_treerecord as apply_scoping - return apply_scoping(self, collection, context) + return apply_scoping(self, collection, context, lock_dispatcher=lock_dispatcher) def get_cols(self) -> set[str]: return { @@ -491,9 +496,13 @@ def bind( class MustMatchTreeRecord(TreeRecord): def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedMustMatchTreeRecord": - s = super().apply_scoping(collection, context, row) + s = super().apply_scoping(collection, context, row, lock_dispatcher=lock_dispatcher) return ScopedMustMatchTreeRecord(*s) diff --git a/specifyweb/backend/workbench/upload/upload.py b/specifyweb/backend/workbench/upload/upload.py index a1998d196e1..6a502c1121f 100644 --- a/specifyweb/backend/workbench/upload/upload.py +++ b/specifyweb/backend/workbench/upload/upload.py @@ -5,13 +5,8 @@ from contextlib import contextmanager from datetime import datetime, timezone from typing import ( - List, - Dict, - Union, Callable, - Optional, Sized, - Tuple, ) from collections.abc import Callable from collections.abc import Sized @@ -34,6 +29,7 @@ BatchEditPrefs, ) from specifyweb.backend.trees.views import ALL_TREES +from specifyweb.specify.utils.autonumbering import AutonumberingLockDispatcher from . import disambiguation from .upload_plan_schema import schema, parse_plan_with_basetable @@ -352,7 +348,12 @@ def do_upload( _cache = cache.copy() if cache is not None and allow_partial else cache da = disambiguations[i] if disambiguations else None batch_edit_pack = batch_edit_packs[i] if batch_edit_packs else None - with savepoint("row upload") if allow_partial else no_savepoint(): + with ( + savepoint("row upload") if allow_partial else no_savepoint() as _, + AutonumberingLockDispatcher() as autonum_dispatcher + ): + get_lock_dispatcher = lambda: autonum_dispatcher + # the fact that upload plan is cachable, is invariant across rows. # so, we just apply scoping once. Honestly, see if it causes enough overhead to even warrant caching @@ -379,10 +380,9 @@ def do_upload( cache = _cache raise Rollback("failed row") row, row_upload_plan = add_attachments_to_plan(row, upload_plan) # type: ignore - scoped_table = row_upload_plan.apply_scoping(collection, scope_context, row) - + scoped_table = row_upload_plan.apply_scoping(collection, scope_context, row, lock_dispatcher=get_lock_dispatcher) elif cached_scope_table is None: - scoped_table = upload_plan.apply_scoping(collection, scope_context, row) + scoped_table = upload_plan.apply_scoping(collection, scope_context, row, lock_dispatcher=get_lock_dispatcher) if not scope_context.is_variable: # This forces every row to rescope when not variable cached_scope_table = scoped_table @@ -416,6 +416,8 @@ def do_upload( if result.contains_failure(): cache = _cache raise Rollback("failed row") + + autonum_dispatcher.commit_highest() toc = time.perf_counter() logger.info(f"finished upload of {len(results)} rows in {toc-tic}s") diff --git a/specifyweb/backend/workbench/upload/upload_table.py b/specifyweb/backend/workbench/upload/upload_table.py index 4493015486a..00709cb36ef 100644 --- a/specifyweb/backend/workbench/upload/upload_table.py +++ b/specifyweb/backend/workbench/upload/upload_table.py @@ -1,6 +1,6 @@ from decimal import Decimal import logging -from typing import Any, NamedTuple, Literal, Union +from typing import Any, NamedTuple, Literal, Union, Callable from django.db import transaction, IntegrityError @@ -8,6 +8,7 @@ from specifyweb.specify import models from specifyweb.specify.utils.func import Func from specifyweb.specify.utils.field_change_info import FieldChangeInfo +from specifyweb.specify.utils.autonumbering import AutonumberingLockDispatcher from specifyweb.backend.workbench.upload.clone import clone_record from specifyweb.backend.workbench.upload.predicates import ( ContetRef, @@ -17,6 +18,7 @@ resolve_reference_attributes, safe_fetch, ) +from specifyweb.specify.models_utils.lock_tables import LockDispatcher from specifyweb.backend.workbench.upload.scope_context import ScopeContext from .column_options import ColumnOptions, ExtendedColumnOptions from .parsing import parse_many, ParseResult, WorkBenchParseFailure @@ -68,11 +70,15 @@ class UploadTable(NamedTuple): overrideScope: dict[Literal["collection"], int | None] | None = None def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedUploadTable": from .scoping import apply_scoping_to_uploadtable - return apply_scoping_to_uploadtable(self, collection, context, row) + return apply_scoping_to_uploadtable(self, collection, context, row, lock_dispatcher=lock_dispatcher) def get_cols(self) -> set[str]: return ( @@ -270,9 +276,13 @@ def bind( class OneToOneTable(UploadTable): def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedOneToOneTable": - s = super().apply_scoping(collection, context, row) + s = super().apply_scoping(collection, context, row, lock_dispatcher=lock_dispatcher) return ScopedOneToOneTable(*s) def to_json(self) -> dict: @@ -293,9 +303,13 @@ def bind( class MustMatchTable(UploadTable): def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedMustMatchTable": - s = super().apply_scoping(collection, context, row) + s = super().apply_scoping(collection, context, row, lock_dispatcher=lock_dispatcher) return ScopedMustMatchTable(*s) def to_json(self) -> dict: diff --git a/specifyweb/backend/workbench/upload/uploadable.py b/specifyweb/backend/workbench/upload/uploadable.py index d1526713a4e..d3444dd08e2 100644 --- a/specifyweb/backend/workbench/upload/uploadable.py +++ b/specifyweb/backend/workbench/upload/uploadable.py @@ -1,8 +1,9 @@ -from typing import Any, TypedDict, Optional, Union +from typing import Any, TypedDict, Optional, Union, Literal from collections.abc import Callable from typing_extensions import Protocol from specifyweb.backend.workbench.upload.predicates import DjangoPredicates, ToRemove +from specifyweb.specify.utils.autonumbering import AutonumberingLockDispatcher from specifyweb.backend.workbench.upload.scope_context import ScopeContext @@ -44,7 +45,11 @@ class Uploadable(Protocol): # we cannot cache. well, we can make this more complicated by recursviely caching # static parts of even a non-entirely-cachable uploadable. def apply_scoping( - self, collection, context: ScopeContext | None = None, row=None + self, + collection, + context: ScopeContext | None = None, + row=None, + lock_dispatcher: Callable[[], AutonumberingLockDispatcher] | None = None ) -> "ScopedUploadable": ... def get_cols(self) -> set[str]: ... diff --git a/specifyweb/frontend/js_src/lib/components/InitialContext/systemInfo.ts b/specifyweb/frontend/js_src/lib/components/InitialContext/systemInfo.ts index e7729565eda..da17d142718 100644 --- a/specifyweb/frontend/js_src/lib/components/InitialContext/systemInfo.ts +++ b/specifyweb/frontend/js_src/lib/components/InitialContext/systemInfo.ts @@ -28,10 +28,10 @@ let systemInfo: SystemInfo; export const fetchContext = load( '/context/system_info.json', 'application/json' -).then((data) => { +).then(async (data) => { systemInfo = data; return systemInfo; }); -export const getSystemInfo = (): SystemInfo => systemInfo; +export const getSystemInfo = (): SystemInfo => systemInfo; \ No newline at end of file diff --git a/specifyweb/settings/specify_settings.py b/specifyweb/settings/specify_settings.py index bb110ffe892..afb5e548adc 100644 --- a/specifyweb/settings/specify_settings.py +++ b/specifyweb/settings/specify_settings.py @@ -93,7 +93,8 @@ REPORT_RUNNER_HOST = '' REPORT_RUNNER_PORT = '' -# Information to connect to a Redis database +# Specify will use this Redis as a process broker and storage for temporary +# values REDIS_HOST="redis" REDIS_PORT=6379 REDIS_DB_INDEX=0 diff --git a/specifyweb/specify/api/filter_by_col.py b/specifyweb/specify/api/filter_by_col.py index ef83f347414..ffef18caa9f 100644 --- a/specifyweb/specify/api/filter_by_col.py +++ b/specifyweb/specify/api/filter_by_col.py @@ -5,7 +5,7 @@ from django.core.exceptions import FieldError from django.db.models import Q -from ..utils.scoping import ScopeType +from ..utils.scoping import Scoping, ScopeType from specifyweb.specify.models import ( Geography, Geologictimeperiod, @@ -13,7 +13,8 @@ Taxon, Storage, Attachment, - Tectonicunit + Tectonicunit, + Accession ) CONCRETE_HIERARCHY = ["collection", "discipline", "division", "institution"] @@ -24,20 +25,28 @@ class HierarchyException(Exception): pass +# REFACTOR: Using Scoping here where possible def filter_by_collection(queryset, collection, strict=True): if queryset.model is Attachment: return queryset.filter( Q(scopetype=None) - | Q(scopetype=ScopeType.GLOBAL) - | Q(scopetype=ScopeType.COLLECTION, scopeid=collection.id) - | Q(scopetype=ScopeType.DISCIPLINE, scopeid=collection.discipline.id) - | Q(scopetype=ScopeType.DIVISION, scopeid=collection.discipline.division.id) + | Q(scopetype=ScopeType.GLOBAL.value) + | Q(scopetype=ScopeType.COLLECTION.value, scopeid=collection.id) + | Q(scopetype=ScopeType.DISCIPLINE.value, scopeid=collection.discipline.id) + | Q(scopetype=ScopeType.DIVISION.value, scopeid=collection.discipline.division.id) | Q( - scopetype=ScopeType.INSTITUTION, + scopetype=ScopeType.INSTITUTION.value, scopeid=collection.discipline.division.institution.id, ) ) + if queryset.model is Accession: + scope = Scoping.scope_type_from_class(Accession) + filters = ({"division": collection.discipline.division} + if scope == ScopeType.DIVISION + else {"division__institution": collection.discipline.division.institution}) + return queryset.filter(**filters) + if queryset.model in (Geography, Geologictimeperiod, Lithostrat, Tectonicunit): return queryset.filter(definition__disciplines=collection.discipline) diff --git a/specifyweb/specify/models_utils/lock_tables.py b/specifyweb/specify/models_utils/lock_tables.py index e6c2e75fbce..ef6762b0c17 100644 --- a/specifyweb/specify/models_utils/lock_tables.py +++ b/specifyweb/specify/models_utils/lock_tables.py @@ -1,13 +1,18 @@ from django.db import connection +from django.conf import settings from contextlib import contextmanager +from typing import Iterable import logging +import json logger = logging.getLogger(__name__) +LOCK_NAME_SEPARATOR = "_" @contextmanager def lock_tables(*tables): cursor = connection.cursor() + # REFACTOR: Extract this functionality to a decorator or contextmanager if cursor.db.vendor != 'mysql': logger.warning("unable to lock tables") yield @@ -18,3 +23,236 @@ def lock_tables(*tables): yield finally: cursor.execute('unlock tables') + + +@contextmanager +def named_lock(raw_lock_name: str, timeout: int = 5, retry_attempts: int = 0): + """ + Handles acquiring and finally releasing a named user advisory lock. + While the lock is held, no other connection can acquire the same named lock. + + Use this sparingly: these locks do not impose any behavior on the database + like normal locks--agents interacting with the database can opt to not + follow traditional application flow and circumnavigate application behavior + + Raises a TimeoutError if timeout seconds have elapsed without acquiring the + lock (another connection holds the lock), and a ConnectionError if the + database was otherwise unable to acquire the lock. + + Example: + ``` + try: + with named_lock('my_lock') as lock: + ... # do something + except TimeoutError: + ... # handle case when lock is held by other connection + ``` + + :param raw_lock_name: The name of lock to acquire + :type raw_lock_name: str + :param timeout: The time in seconds to wait for lock release if another + connection holds the lock + :type timeout: int + :return: yields True if the lock was obtained successfully and None + otherwise + :rtype: Generator[Literal[True] | None, Any, None] + """ + + # REFACTOR: Extract this functionality to a decorator or contextmanager + if connection.vendor != "mysql": + yield + return + + db_name = getattr(settings, "DATABASE_NAME") + lock_name = f"{db_name}_{raw_lock_name}" + + acquired = acquired_named_lock(lock_name, timeout) + + while retry_attempts > 0 and acquired != True: + acquired = acquired_named_lock(lock_name, timeout) + retry_attempts -= 1 + + if acquired == False: + raise TimeoutError( + f"Unable to acquire named lock: '{lock_name}'. Held by other connection") + if acquired is None: + raise ConnectionError( + f"Unable to acquire named lock: '{lock_name}'. The process might have run out of memory") + + try: + yield acquired + finally: + release_named_lock(lock_name) + + +def acquired_named_lock(lock_name: str, timeout: int) -> bool | None: + """ + Attempts to acquire a named lock in the database. Will wait for timeout + seconds for the lock to be released if held by another connection. + + See https://mariadb.com/docs/server/reference/sql-functions/secondary-functions/miscellaneous-functions/get_lock + + :param lock_name: The name of the lock to acquire + :type lock_name: str + :param timeout: The time in seconds to wait for lock release if another + connection holds the lock + :type timeout: int + :return: returns True if the lock was obtained successfully, False if timeout + seconds have elapsed without acquiring the lock, and None otherwise + :rtype: bool | None + """ + with connection.cursor() as cur: + cur.execute("SELECT GET_LOCK(%s, %s)", [lock_name, timeout]) + acquired_row = cur.fetchone() + + if acquired_row is None: + return None + + acquired = acquired_row[0] + + if acquired == 1: + return True + elif acquired == 0: + return False + + return None + + +def release_named_lock(lock_name: str) -> bool | None: + """ + Attempt to release one instance of a held named lock. Note that multiple + instances of the same lock can be held by a single connection, in which + case each instance of the lock needs to be released separately. + + See https://mariadb.com/docs/server/reference/sql-functions/secondary-functions/miscellaneous-functions/release_lock + + :param lock_name: The name of the lock to attempt to release + :type lock_name: str + :return: returns True if one instance of the lock was sucessfully released, False + if the lock is held by another connection, and None otherwise + :rtype: bool | None + """ + with connection.cursor() as cur: + cur.execute("SELECT RELEASE_LOCK(%s)", [lock_name]) + released_row = cur.fetchone() + + if released_row is None: + return None + + released = released_row[0] + if released == 1: + return True + elif released == 0: + return False + return None + + +class Lock: + def __init__(self, name: str, timeout: int): + self.name = name + self.timeout = timeout + + @classmethod + def from_json_str(cls, string: str | bytes | bytearray): + deserialized = json.loads(string) + return cls( + deserialized["name"], + deserialized["timeout"] + ) + + def acquire(self): + acquired = acquired_named_lock(self.name, self.timeout) + + if acquired == False: + raise TimeoutError( + f"Unable to acquire named lock: '{self.name}'. Held by other connection") + if acquired is None: + raise ConnectionError( + f"Unable to acquire named lock: '{self.name}'. The process might have run out of memory") + + return acquired + + def release(self): + released = release_named_lock(self.name) + return released + + @staticmethod + def serializer(lock: "Lock") -> str: + return lock.as_json_str() + + @staticmethod + def deserializer(lock_as_string: str | bytes | bytearray) -> "Lock": + return Lock.from_json_str(lock_as_string) + + def as_json(self): + return { + "name": self.name, + "timeout": self.timeout + } + + def as_json_str(self) -> str: + return json.dumps(self.as_json()) + + def __eq__(self, other): + if isinstance(other, Lock): + return self.name == other.name and self.timeout == other.timeout + return False + + +class LockDispatcher: + def __init__(self, lock_prefix: str | None = None, case_sensitive_names=False): + db_name = getattr(settings, "DATABASE_NAME") + self.lock_prefix_parts: list[str] = [db_name] + + if lock_prefix is not None: + self.lock_prefix_parts.append(lock_prefix) + + self.case_sensitive_names = case_sensitive_names + self.locks: dict[str, Lock] = dict() + self.in_context = False + + def close(self): + self.release_all() + + def __enter__(self): + self.in_context = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self.in_context = False + + def lock_name(self, *name_parts: str): + final_name = LOCK_NAME_SEPARATOR.join( + (*self.lock_prefix_parts, *name_parts)) + + return final_name.lower() if not self.case_sensitive_names else final_name + + @contextmanager + def lock_and_release(self, name: str, timeout: int = 5): + try: + yield self.acquire(name, timeout) + finally: + self.release(name) + + def create_lock(self, name: str, timeout: int = 5): + lock_name = self.lock_name(name) + return Lock(lock_name, timeout) + + def acquire(self, name: str, timeout: int = 5): + if self.locks.get(name) is not None: + return + lock = self.create_lock(name, timeout) + self.locks[name] = lock + return lock.acquire() + + def release_all(self): + for lock_name in list(self.locks.keys()): + self.release(lock_name) + self.locks = dict() + + def release(self, name: str): + lock = self.locks.pop(name, None) + if lock is None: + return + lock.release() diff --git a/specifyweb/specify/tests/test_api.py b/specifyweb/specify/tests/test_api.py index f6be54f4984..41ee93c34ea 100644 --- a/specifyweb/specify/tests/test_api.py +++ b/specifyweb/specify/tests/test_api.py @@ -1091,7 +1091,7 @@ def test_explicitly_defined_scope(self): accessionnumber="ACC_Test", division=self.division ) - accession_scope = scoping.Scoping(accession).get_scope_model() + accession_scope = scoping.Scoping.model_from_instance(accession) self.assertEqual(accession_scope.id, self.institution.id) loan = Loan.objects.create( @@ -1099,14 +1099,14 @@ def test_explicitly_defined_scope(self): discipline=self.other_discipline ) - loan_scope = scoping.Scoping(loan).get_scope_model() + loan_scope = scoping.Scoping.model_from_instance(loan) self.assertEqual(loan_scope.id, self.other_discipline.id) def test_infered_scope(self): disposal = Disposal.objects.create( disposalnumber = "DISPOSAL_TEST" ) - disposal_scope = scoping.Scoping(disposal).get_scope_model() + disposal_scope = scoping.Scoping.model_from_instance(disposal) self.assertEqual(disposal_scope.id, self.institution.id) loan = Loan.objects.create( @@ -1114,12 +1114,12 @@ def test_infered_scope(self): division=self.other_division, discipline=self.other_discipline, ) - inferred_loan_scope = scoping.Scoping(loan)._infer_scope()[1] - self.assertEqual(inferred_loan_scope.id, self.other_division.id) + inferred_loan_scope = scoping.Scoping.model_from_instance(loan) + self.assertEqual(inferred_loan_scope.id, self.other_discipline.id) - collection_object_scope = scoping.Scoping( + collection_object_scope = scoping.Scoping.model_from_instance( self.collectionobjects[0] - ).get_scope_model() + ) self.assertEqual(collection_object_scope.id, self.collection.id) def test_in_same_scope(self): diff --git a/specifyweb/specify/tests/test_autonumbering/test_do_autonumbering.py b/specifyweb/specify/tests/test_autonumbering/test_do_autonumbering.py index 50b3580a937..d9524caaf78 100644 --- a/specifyweb/specify/tests/test_autonumbering/test_do_autonumbering.py +++ b/specifyweb/specify/tests/test_autonumbering/test_do_autonumbering.py @@ -137,7 +137,7 @@ def test_increment_across_collections(self, group_filter: Mock): collection=second_collection, catalognumber="#########" ) - do_autonumbering(self.collection, third_co_different_collection, fields) + do_autonumbering(second_collection, third_co_different_collection, fields) third_co_different_collection.refresh_from_db() # This third CO must be created now. self.assertIsNotNone(third_co_different_collection.id) @@ -158,7 +158,7 @@ def test_increment_across_collections(self, group_filter: Mock): collection=third_collection, catalognumber="#########" ) - do_autonumbering(self.collection, fourth_co_irrelevant_collection, fields) + do_autonumbering(third_collection, fourth_co_irrelevant_collection, fields) fourth_co_irrelevant_collection.refresh_from_db() # This CO must be created now. diff --git a/specifyweb/specify/tests/test_filter_by_col/test_filter_by_collection.py b/specifyweb/specify/tests/test_filter_by_col/test_filter_by_collection.py index 7f62d2b8a85..282b534ada9 100644 --- a/specifyweb/specify/tests/test_filter_by_col/test_filter_by_collection.py +++ b/specifyweb/specify/tests/test_filter_by_col/test_filter_by_collection.py @@ -39,17 +39,17 @@ def test_attachment(self): attachment_1 = Attachment.objects.create(scopetype=None) attachment_2 = Attachment.objects.create( - scopetype=ScopeType.COLLECTION, scopeid=self.collection.id + scopetype=ScopeType.COLLECTION.value, scopeid=self.collection.id ) attachment_3 = Attachment.objects.create( - scopetype=ScopeType.COLLECTION, scopeid=collection_2.id + scopetype=ScopeType.COLLECTION.value, scopeid=collection_2.id ) attachment_4 = Attachment.objects.create( - scopetype=ScopeType.DISCIPLINE, scopeid=self.discipline.id + scopetype=ScopeType.DISCIPLINE.value, scopeid=self.discipline.id ) attachment_5 = Attachment.objects.create( - scopetype=ScopeType.DISCIPLINE, scopeid=discipline_2.id + scopetype=ScopeType.DISCIPLINE.value, scopeid=discipline_2.id ) queryset = filter_by_collection(Attachment.objects.all(), self.collection) diff --git a/specifyweb/specify/utils/autonumbering.py b/specifyweb/specify/utils/autonumbering.py index f573451b73b..bdc2696331e 100644 --- a/specifyweb/specify/utils/autonumbering.py +++ b/specifyweb/specify/utils/autonumbering.py @@ -2,15 +2,21 @@ Autonumbering logic """ +import re + +from collections import defaultdict + +from django.db.models import Value, Case, When from .uiformatters import UIFormatter, get_uiformatters -from ..models_utils.lock_tables import lock_tables +from ..models_utils.lock_tables import LockDispatcher import logging -from typing import List, Tuple, Set +from typing import MutableMapping, Callable from collections.abc import Sequence from specifyweb.specify.utils.scoping import Scoping from specifyweb.specify.datamodel import datamodel +from specifyweb.backend.redis_cache.connect import RedisConnection, RedisString logger = logging.getLogger(__name__) @@ -35,30 +41,169 @@ def autonumber_and_save(collection, user, obj) -> None: obj.save() -def do_autonumbering(collection, obj, fields: list[tuple[UIFormatter, Sequence[str]]]) -> None: - logger.debug("autonumbering %s fields: %s", obj, fields) +class AutonumberingLockDispatcher(LockDispatcher): + def __init__(self): + lock_prefix = "autonumbering" + super().__init__(lock_prefix=lock_prefix, case_sensitive_names=False) + + # We use Redis for IPC, to maintain the current "highest" autonumbering + # value for each table + field + self.redis = RedisConnection(decode_responses=True) + # Before the records are created within a transaction, and committed to + # Redis, they're stored locally within this dictonary + self.highest_in_flight: dict[str, str] = dict() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + + def highest_stored_value(self, + table_name: str, + field_name: str, + auto_num_regex: str, + scope_name: str, + scope_id: int) -> str | None: + key_name = self.autonumbering_redis_key( + table_name, field_name, auto_num_regex, scope_name, scope_id) + highest = RedisString(self.redis).get(key_name) + if isinstance(highest, bytes): + return highest.decode() + elif highest is None: + return None + return str(highest) + + def cache_highest(self, + table_name: str, + field_name: str, + auto_num_regex: str, + scope_name: str, + scope_id: int, + value: str): + key_name = self.autonumbering_redis_key(table_name, field_name, auto_num_regex, scope_name, scope_id) + self.highest_in_flight[key_name] = value + + def commit_highest(self): + for key_name, value in self.highest_in_flight.items(): + self.set_highest_value(key_name, value) + self.highest_in_flight.clear() + + def set_highest_value(self, + key_name: str, + value: str, + time_to_live: int = 5): + RedisString(self.redis).set(key_name, value, + time_to_live, override_existing=True) + + def autonumbering_redis_key(self, + table_name: str, + field_name: str, + auto_num_regex: str, + scope_name: str, + scope_id: int): + return self.lock_name(table_name, + field_name, + auto_num_regex, + "highest", + scope_name, + str(scope_id)) + + +def highest_autonumbering_value( + collection, + model, + formatter: UIFormatter, + values: Sequence[str], + get_lock_dispatcher: Callable[[], + AutonumberingLockDispatcher] | None = None, + wait_for_lock=10) -> str: + """ + Retrieves the next highest number in the autonumbering sequence for a given + autonumbering field format + """ + + if not formatter.needs_autonumber(values): + raise ValueError( + f"Formatter {formatter.format_name} does not need need autonumbered with {values}") + + if get_lock_dispatcher is None: + lock_dispatcher = None + else: + lock_dispatcher = get_lock_dispatcher() + lock_dispatcher.acquire(model._meta.db_table, timeout=wait_for_lock) + + field_name = formatter.field_name.lower() + + scope_type = Scoping.scope_type_from_class(model) + hierarchy_model = Scoping.get_hierarchy_model(collection, scope_type) + + with_year = formatter.fillin_year(values) + auto_number_regex = formatter.autonumber_regexp(with_year) + + stored_highest_value = (lock_dispatcher.highest_stored_value( + model._meta.db_table, + field_name, + auto_number_regex, + scope_type.name, + hierarchy_model.id) + if lock_dispatcher is not None else None) + + largest_in_database = formatter._autonumber_queryset(collection, model, field_name, with_year).annotate( + greater_than_stored=Case( + When(**{field_name + '__gt': stored_highest_value}, + then=Value(True)), + default=Value(False)) + if stored_highest_value is not None + else Value(False)) + + if not largest_in_database.exists(): + if stored_highest_value is not None: + filled_values = formatter.fill_vals_after(stored_highest_value) + else: + filled_values = formatter.fill_vals_no_prior(with_year) + else: + largest = largest_in_database[0] + database_larger = largest.greater_than_stored + value_to_inc = (getattr(largest, field_name) + if database_larger or stored_highest_value is None + else stored_highest_value) + filled_values = formatter.fill_vals_after(value_to_inc) + + highest = ''.join(filled_values) + + if lock_dispatcher is not None and lock_dispatcher.in_context: + lock_dispatcher.cache_highest( + model._meta.db_table, + field_name, + auto_number_regex, + scope_type.name, + hierarchy_model.id, + highest) - # The autonumber action is prepared and thunked outside the locked table - # context since it looks at other tables and that is not allowed by mysql - # if those tables are not also locked. - thunks = [ - formatter.prepare_autonumber_thunk(collection, obj.__class__, vals) - for formatter, vals in fields - ] + return highest - with lock_tables(*get_tables_to_lock(collection, obj, [formatter.field_name for formatter, _ in fields])): - for apply_autonumbering_to in thunks: - apply_autonumbering_to(obj) +def do_autonumbering(collection, obj, fields: list[tuple[UIFormatter, Sequence[str]]]) -> None: + logger.debug("autonumbering %s fields: %s", obj, fields) + + with AutonumberingLockDispatcher() as locks: + for formatter, vals in fields: + new_field_value = highest_autonumbering_value( + collection, + obj.__class__, + formatter, + vals, + get_lock_dispatcher=lambda: locks) + setattr(obj, formatter.field_name.lower(), new_field_value) obj.save() + locks.commit_highest() +# REFACTOR: Remove this funtion as it is no longer used def get_tables_to_lock(collection, obj, field_names) -> set[str]: # TODO: Include the fix for https://github.com/specify/specify7/issues/4148 from specifyweb.backend.businessrules.models import UniquenessRule obj_table = obj._meta.db_table - scope_table = Scoping(obj).get_scope_model() + scope_table = Scoping.model_from_instance(obj) tables = {obj._meta.db_table, 'django_migrations', UniquenessRule._meta.db_table, 'discipline', scope_table._meta.db_table} diff --git a/specifyweb/specify/utils/scoping.py b/specifyweb/specify/utils/scoping.py index 285cf6e6858..71eebc5c709 100644 --- a/specifyweb/specify/utils/scoping.py +++ b/specifyweb/specify/utils/scoping.py @@ -1,84 +1,219 @@ +from inspect import isclass -from collections import namedtuple -from typing import Tuple +from enum import Enum from django.db.models import Model from django.core.exceptions import ObjectDoesNotExist from .. import models -class ScopeType: +class ScopeType(Enum): COLLECTION = 0 DISCIPLINE = 1 DIVISION = 2 INSTITUTION = 3 GLOBAL = 10 - -class Scoping(namedtuple('Scoping', 'obj')): - def __call__(self) -> tuple[int, Model]: - """ - Returns the ScopeType and related Model instance of the - hierarchical position the `obj` occupies. - Tries and infers the scope based on the fields/relationships - on the model, and resolves the 'higher' scope before a more - specific scope if applicable for the object - """ - table = self.obj.__class__.__name__.lower() + @staticmethod + def from_model(obj) -> "ScopeType": + app_and_model_name = obj._meta.label_lower + + # We can't directly use `obj.__class__ is SomeScopeModel` here because + # that will break historical fake models during migrations + # Using the app and model name means this will work in both migration + # and normal runtimes + # See https://docs.djangoproject.com/en/6.0/topics/migrations/#historical-models + # for more information about Django's Histroical models + mapping = { + 'specify.institution': ScopeType.INSTITUTION, + 'specify.division': ScopeType.DIVISION, + 'specify.discipline': ScopeType.DISCIPLINE, + 'specify.collection': ScopeType.COLLECTION + } + + scope_type = mapping.get(app_and_model_name, None) + if scope_type is None: + raise TypeError(f"{app_and_model_name} is not a hierarchy table") + return scope_type + + def __gt__(self, other): + if not isinstance(other, ScopeType): + return NotImplemented + return self.value > other.value + + def __ge__(self, other): + if not isinstance(other, ScopeType): + return NotImplemented + return self.value >= other.value + + def __lt__(self, other): + if not isinstance(other, ScopeType): + return NotImplemented + return self.value < other.value + + def __le__(self, other): + if not isinstance(other, ScopeType): + return NotImplemented + return self.value <= other.value + + def __eq__(self, other): + if not isinstance(other, ScopeType): + return NotImplemented + return self.value == other.value + + +class ModelClassScope: + def __init__(self, model_class): + if not isclass(model_class): + raise TypeError(f"model_class: {model_class} is not a class!") + self.model_class = model_class + + @property + def scope_type(self) -> ScopeType: + table = self.model_class.__name__.lower() scope = getattr(self, table, lambda: None)() if scope is None: return self._infer_scope() - return scope - def get_scope_model(self) -> Model: - return self.__call__()[1] + def accession(self): + institution = models.Institution.objects.get() + if institution.isaccessionsglobal: + return ScopeType.INSTITUTION + else: + return ScopeType.DIVISION + def conservevent(self): return ModelClassScope( + models.Conservdescription).scope_type -################################################################################ + def fieldnotebookpage(self): return ModelClassScope( + models.Fieldnotebookpageset).scope_type + def fieldnotebookpageset(self): return ModelClassScope( + models.Fieldnotebook).scope_type - def accession(self): + def gift(self): return ScopeType.DISCIPLINE + + def loan(self): return ScopeType.DISCIPLINE + + def permit(self): + return ScopeType.INSTITUTION + + def referencework(self): + return ScopeType.INSTITUTION + + def taxon(self): + return ScopeType.DISCIPLINE + + def geography(self): + return ScopeType.DISCIPLINE + + def geologictimeperiod(self): + return ScopeType.DISCIPLINE + + def lithostrat(self): + return ScopeType.DISCIPLINE + + def tectonicunit(self): + return ScopeType.DISCIPLINE + + def storage(self): + return ScopeType.INSTITUTION + + +############################################################################# + + + def _infer_scope(self): + if is_related(self.model_class, "division"): + return ScopeType.DIVISION + if is_related(self.model_class, "discipline"): + return ScopeType.DISCIPLINE + if hasattr(self.model_class, "collectionmemberid") or is_related(self.model_class, "collection"): + return ScopeType.COLLECTION + + return ScopeType.INSTITUTION + + +class ModelInstanceScope: + def __init__(self, model_instance): + if isclass(model_instance): + raise ValueError(f"Expected object instead instead of class") + self.obj = model_instance + + @property + def scope_type(self) -> ScopeType: + return ScopeType.from_model(self.scope_model) + + @property + def scope_model(self) -> Model: + table = self.obj.__class__.__name__.lower() + scope = getattr(self, table, lambda: None)() + if scope is None: + return self._infer_scope_model() + + return scope + + def accession(self) -> Model: institution = models.Institution.objects.get() if institution.isaccessionsglobal: - return ScopeType.INSTITUTION, institution - else: - return self._simple_division_scope() + return institution + return self.obj.division - def conservevent(self): return Scoping(self.obj.conservdescription)() + def conservevent(self) -> Model: + return ModelInstanceScope(self.obj.conservdescription).scope_model - def fieldnotebookpage(self): return Scoping(self.obj.pageset)() + def fieldnotebookpage(self) -> Model: + return ModelInstanceScope(self.obj.pageset).scope_model - def fieldnotebookpageset(self): return Scoping(self.obj.fieldnotebook)() + def fieldnotebookpageset(self) -> Model: + return ModelInstanceScope(self.obj.fieldnotebook).scope_model - def gift(self): - if has_related(self.obj, 'discipline'): - return self._simple_discipline_scope() + def gift(self) -> Model: + if has_related(self.obj, "discipline"): + return self.obj.discipline - def loan(self): + def loan(self) -> Model: if has_related(self.obj, 'discipline'): - return self._simple_discipline_scope() + return self.obj.discipline - def permit(self): + def permit(self) -> Model: if has_related(self.obj, 'institution'): - return ScopeType.INSTITUTION, self.obj.institution + return self.obj.institution - def referencework(self): + def referencework(self) -> Model: if has_related(self.obj, 'institution'): - return ScopeType.INSTITUTION, self.obj.institution + return self.obj.institution - def taxon(self): - return ScopeType.DISCIPLINE, self.obj.definition.discipline + def taxon(self) -> Model: + return self.obj.definition.discipline -############################################################################# + def geography(self): + return self.obj.definition.discipline + + def geologictimeperiod(self): + return self.obj.definition.discipline - def _simple_discipline_scope(self) -> tuple[int, Model]: - return ScopeType.DISCIPLINE, self.obj.discipline + def lithostrat(self): + return self.obj.definition.discipline + + def tectonicunit(self): + return self.obj.definition.discipline + + def storage(self): + return self.obj.definition.institution + + def _infer_scope_model(self) -> Model: + if is_related(self.obj.__class__, "division") and has_related(self.obj, "division"): + return self.obj.division + if is_related(self.obj.__class__, "discipline") and has_related(self.obj, "discipline"): + return self.obj.discipline + if has_related(self.obj, "collectionmemberid") or (is_related(self.obj.__class__, "collection") and has_related(self.obj, "collection")): + return self._simple_collection_scope() - def _simple_division_scope(self) -> tuple[int, Model]: - return ScopeType.DIVISION, self.obj.division + return models.Institution.objects.get() - def _simple_collection_scope(self) -> tuple[int, Model]: + def _simple_collection_scope(self) -> Model: if hasattr(self.obj, "collectionmemberid"): try: """ @@ -95,41 +230,130 @@ def _simple_collection_scope(self) -> tuple[int, Model]: else: collection = self.obj.collection - return ScopeType.COLLECTION, collection + return collection - def _infer_scope(self): - if has_related(self.obj, "division"): - return self._simple_division_scope() - if has_related(self.obj, "discipline"): - return self._simple_discipline_scope() - if has_related(self.obj, "collectionmemberid") or has_related(self.obj, "collection"): - return self._simple_collection_scope() - return self._default_institution_scope() +class Scoping: - # If the table has no scope, and scope can not be inferred then scope to institution - def _default_institution_scope(self) -> tuple[int, Model]: - institution = models.Institution.objects.get() - return ScopeType.INSTITUTION, institution + @staticmethod + def scope_type_from_class(model_class) -> ScopeType: + """ + Returns the ScopeType that a particular class can be scoped to. + If you have an instantiated instance of the class, prefer using the + other methods like `from_instance` or `model_from_instance` + + Example: + ``` + loan_scope = Scoping.scope_type_from_class(models.Loan) + # ScopeType.Discipline + + accession_scope = Scoping.scope_type_from_class(models.Accession) + # ScopeType.Institution if accessions are global else ScopeType.Division + ``` + + :param model_class: + :return: + :rtype: ScopeType + """ + return ModelClassScope(model_class).scope_type + + @staticmethod + def from_instance(obj: Model) -> tuple[ScopeType, Model]: + instance = ModelInstanceScope(obj) + return instance.scope_type, instance.scope_model + @staticmethod + def model_from_instance(obj: Model) -> Model: + """ + Returns the Model that the provided Model instance can be scoped to. + Usually always one of: Collection, Discipline, Division, or Institution + + Example: + ``` + my_co = Collectionobject.objects.get(some_filters) + scoped = Scoping.model_from_instance(my_co) + isinstance(scoped, models.Collection) #-> True + ``` + + :param obj: + :type obj: Model + :return: + :rtype: Model + """ + instance = ModelInstanceScope(obj) + return instance.scope_model -def has_related(model_instance, field_name: str) -> bool: + @staticmethod + def get_hierarchy_model(collection, scope_type: ScopeType) -> Model: + """ + Given a collection and desired ScopeType, returns the model associated + with the ScopeType. + + Example: + ``` + my_collection = Collection.objects.get(some_filters) + my_div = Scoping.get_hierarchy_model(ny_collection, ScopeType.Division) + my_dis = Scoping.get_hierarchy_model(ny_collection, ScopeType.Discipline) + ``` + + :param collection: Description + :param scope_type: Description + :type scope_type: ScopeType + :return: + :rtype: Model + """ + steps = [ScopeType.COLLECTION, ScopeType.DISCIPLINE, + ScopeType.DIVISION, ScopeType.INSTITUTION] + num_steps = steps.index(scope_type) + model = collection + for _ in range(num_steps): + model = Scoping.model_from_instance(model) + return model + + +def has_related(model_instance: Model, field_name: str) -> bool: + """ + + :param model_instance: Description + :type model_instance: Model + :param field_name: Description + :type field_name: str + :return: Returns true if the model instance contains some non-None value in + the given field name + :rtype: bool + """ return hasattr(model_instance, field_name) and getattr(model_instance, field_name, None) is not None +def is_related(model_class: Model, field_name: str) -> bool: + """ + + :param model_class: Description + :type model_class: Model + :param field_name: Description + :type field_name: str + :return: Returns true if the field name for the model class is a + relationship + :rtype: bool + """ + if not hasattr(model_class, field_name): + return False + field_wrapper = getattr(model_class, field_name) + field = getattr(field_wrapper, "field") + return getattr(field, "is_relation", False) def in_same_scope(object1: Model, object2: Model) -> bool: """ Determines whether two Model Objects are in the same scope. Travels up the scoping heirarchy until a matching scope can be resolved """ - scope1_type, scope1 = Scoping(object1)() - scope2_type, scope2 = Scoping(object2)() + scope1_type, scope1 = Scoping.from_instance(object1) + scope2_type, scope2 = Scoping.from_instance(object2) if scope1_type > scope2_type: while scope2_type != scope1_type: - scope2_type, scope2 = Scoping(scope2)() + scope2_type, scope2 = Scoping.from_instance(scope2) elif scope1_type < scope2_type: while scope2_type != scope1_type: - scope1_type, scope1 = Scoping(scope1)() + scope1_type, scope1 = Scoping.from_instance(scope1) return scope1.id == scope2.id diff --git a/specifyweb/specify/utils/uiformatters.py b/specifyweb/specify/utils/uiformatters.py index 4684e1661fa..ab4ec3cb704 100644 --- a/specifyweb/specify/utils/uiformatters.py +++ b/specifyweb/specify/utils/uiformatters.py @@ -142,30 +142,10 @@ def _autonumber_queryset(self, collection, model, fieldname: str, with_year: lis objs = model.objects.filter(**{ fieldname + '__regex': self.autonumber_regexp(with_year) }) return group_filter(objs).order_by('-' + fieldname) - def prepare_autonumber_thunk(self, collection, model, vals: Sequence[str], year: int | None=None): - with_year = self.fillin_year(vals, year) - fieldname = self.field_name.lower() - - filtered_objs = self._autonumber_queryset(collection, model, fieldname, with_year) - # At this point the query for the autonumber is defined but not yet executed. - - # The actual lookup and setting of the autonumbering value - # is thunked so that it can be executed in context of locked tables. - def apply_autonumbering_to(obj): - try: - biggest = filtered_objs[0] # actual lookup occurs here - except IndexError: - filled_vals = self.fill_vals_no_prior(with_year) - else: - filled_vals = self.fill_vals_after(getattr(biggest, fieldname)) - - # And here the new value is assigned to the object. It is - # the callers responsibilty to save the object within the - # same locked tables context because there maybe multiple - # autonumber fields. - setattr(obj, self.field_name.lower(), ''.join(filled_vals)) - - return apply_autonumbering_to + def apply_autonumbering(self, collection, obj, vals: Sequence[str]): + field_name = self.field_name.lower() + field_value = self.autonumber_now(collection, obj.__class__, vals) + setattr(obj, field_name, field_value) def fill_vals_after(self, prior: str) -> list[str]: @@ -192,14 +172,17 @@ def fill_vals_no_prior(self, vals: Sequence[str]) -> list[str]: def canonicalize(self, values: Sequence[str]) -> str: return ''.join([field.canonicalize(value) for field, value in zip(self.fields, values)]) - def apply_scope(self, collection): + def apply_scope(self, collection, autonumbering_lock_dispatcher = None) -> ScopedFormatter: + from specifyweb.specify.utils.autonumbering import highest_autonumbering_value def parser(table: Table, value: str) -> str: parsed = self.parse(value) if self.needs_autonumber(parsed): - canonicalized = self.autonumber_now( + canonicalized = highest_autonumbering_value( collection, getattr(models, table.django_name), - parsed + self, + parsed, + get_lock_dispatcher=autonumbering_lock_dispatcher ) else: canonicalized = self.canonicalize(parsed)