Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5cebf53
fix: improve session tracker error logging
barapa Oct 15, 2025
e0ff3fb
Add session tracker delete and rollback tests
barapa Oct 16, 2025
a0153e5
Expand FileObjectSessionTracker tests: override semantics, state rete…
barapa Oct 16, 2025
af8898e
Add missing session tracker tests: rollback ignore FileNotFound, asyn…
barapa Oct 16, 2025
a4224d8
Parametrize session tracker tests for override semantics and async de…
barapa Oct 16, 2025
441def8
Rename parametrized tests to match repo naming conventions
barapa Oct 16, 2025
2b6af3b
Parametrize additional session tracker cases and fix log capture cleanup
barapa Oct 16, 2025
3249e99
Dedupe session tracker tests by removing cases covered by parametrize…
barapa Oct 16, 2025
5385e6d
Document session tracker commit/rollback semantics, error propagation…
barapa Oct 16, 2025
1c2dc65
Fix lint issues: ruff ERA001, add type annotations in tests; run make…
barapa Oct 16, 2025
e424c55
feat: use `ExceptionGroup` and add feature gate
cofin Oct 18, 2025
86693bf
fix: update ExceptionGroup import handling and adjust session config …
cofin Oct 18, 2025
7572157
chore: 1 more linting
cofin Oct 18, 2025
545cc6a
fix: update default behavior of file_object_raise_on_error to True
cofin Oct 26, 2025
5210a3f
fix: set default to true
cofin Oct 26, 2025
45d88b4
fix: revert uvlock
cofin Oct 26, 2025
d4d76d0
Merge branch 'main' into feature/session-tracker-error-handling
cofin Oct 26, 2025
76019be
Update advanced_alchemy/types/file_object/session_tracker.py
cofin Oct 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions advanced_alchemy/_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ def is_async_context() -> bool:
return _is_async_context.get()


def _get_session_tracker(create: bool = True) -> Optional["FileObjectSessionTracker"]:
def _get_session_tracker(
create: bool = True, session: Optional["Session"] = None
) -> Optional["FileObjectSessionTracker"]:
from advanced_alchemy.types.file_object import FileObjectSessionTracker

tracker = _current_session_tracker.get()
if tracker is None and create:
tracker = FileObjectSessionTracker()
raise_on_error = True
if session is not None:
raise_on_error = session.info.get("file_object_raise_on_error", True)
tracker = FileObjectSessionTracker(raise_on_error=raise_on_error)
_current_session_tracker.set(tracker)
return tracker

Expand Down Expand Up @@ -348,7 +353,7 @@ def before_flush(cls, session: "Session", flush_context: "UOWTransaction", insta
if not cls._is_listener_enabled(session):
return

tracker = _get_session_tracker(create=True)
tracker = _get_session_tracker(create=True, session=session)
if not tracker:
return

Expand Down
13 changes: 13 additions & 0 deletions advanced_alchemy/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
This is a listener that will automatically save and delete :class:`FileObject <advanced_alchemy.types.file_object.FileObject>` instances when they are saved or deleted.

Disable if you plan to bring your own save/delete mechanism for these columns"""
file_object_raise_on_error: bool = True
"""Control FileObject error handling behavior.

- ``False``: Log warnings on file operation failures, don't raise exceptions
- ``True`` (default): Raise exceptions on file operation failures
"""
_SESSION_SCOPE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of session scope keys in the class."""
_ENGINE_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
Expand Down Expand Up @@ -208,6 +214,13 @@ def __post_init__(self) -> None:

setup_file_object_listeners()

# Store file_object_raise_on_error in session_config.info
# Ensure session_config.info is a dict (convert from Empty if needed)
if self.session_config.info is Empty:
self.session_config.info = {}
if isinstance(self.session_config.info, dict):
self.session_config.info["file_object_raise_on_error"] = self.file_object_raise_on_error

def __hash__(self) -> int: # pragma: no cover
return hash(
(
Expand Down
175 changes: 129 additions & 46 deletions advanced_alchemy/types/file_object/session_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

import asyncio
import logging
import sys
from typing import TYPE_CHECKING, Any, Union

if sys.version_info >= (3, 11):
from builtins import ExceptionGroup
else:
from exceptiongroup import ExceptionGroup # type: ignore[import-not-found,unused-ignore]

if TYPE_CHECKING:
from collections.abc import Awaitable
from pathlib import Path

from advanced_alchemy.types.file_object import FileObject
Expand All @@ -17,8 +22,20 @@
class FileObjectSessionTracker:
"""Tracks FileObject changes within a single session transaction."""

def __init__(self) -> None:
"""Initialize the tracker."""
def __init__(self, raise_on_error: bool = False) -> None:
"""Initialize empty tracking state.

Args:
raise_on_error: If True, raise exceptions on file operation failures.
If False, log warnings and continue.

Internal structures:
- ``pending_saves``: ``FileObject -> data`` to be saved on commit
- ``pending_deletes``: ``FileObject`` instances to delete on commit
- ``_saved_in_transaction``: successfully saved objects used for
selective cleanup on rollback
"""
self.raise_on_error = raise_on_error
# Stores objects that have pending data to be saved on commit.
# Maps FileObject -> data source (bytes or Path)
self.pending_saves: "dict[FileObject, Union[bytes, Path]]" = {}
Expand Down Expand Up @@ -47,43 +64,94 @@ def commit(self) -> None:
for obj, data in self.pending_saves.items():
try:
obj.save(data)
except Exception as e: # noqa: BLE001
logger.warning("Error saving file for object %s: %s", obj, e.__cause__)
self._saved_in_transaction.add(obj)
except Exception:
if self.raise_on_error:
logger.exception("error saving file for object %s", obj)
raise
logger.warning("error saving file for object %s", obj, exc_info=True)

for obj in self.pending_deletes:
try:
obj.delete()
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file for object %s: %s", obj, e.__cause__)
except Exception:
if self.raise_on_error:
logger.exception("error deleting file for object %s", obj)
raise
logger.warning("error deleting file for object %s", obj, exc_info=True)

self.clear()

async def commit_async(self) -> None:
"""Process pending saves and deletes after a successful commit."""
save_tasks: list[Awaitable[Any]] = []
for obj, data in self.pending_saves.items():
save_tasks.append(obj.save_async(data))
self._saved_in_transaction.add(obj)

delete_tasks: list[Awaitable[Any]] = [obj.delete_async() for obj in self.pending_deletes]

# Run save and delete tasks concurrently
save_results = await asyncio.gather(*save_tasks, return_exceptions=True)
delete_results = await asyncio.gather(*delete_tasks, return_exceptions=True)

# Process save results (log errors)
for result, (obj, _data) in zip(save_results, self.pending_saves.items()):
if isinstance(result, Exception):
logger.warning("Error saving file for object %s: %s", obj, result.__cause__)
# Process delete results (log errors, ignore FileNotFoundError)
for result, obj_to_delete in zip(delete_results, self.pending_deletes):
save_items: "list[tuple[FileObject, Union[bytes, Path]]]" = list(self.pending_saves.items())
delete_items: "list[FileObject]" = list(self.pending_deletes)

save_results: "list[Any]" = await asyncio.gather(
*(obj.save_async(data) for obj, data in save_items),
return_exceptions=True,
)
delete_results: "list[Any]" = await asyncio.gather(
*(obj.delete_async() for obj in delete_items),
return_exceptions=True,
)

errors: list[Exception] = []

for (obj, _data), result in zip(save_items, save_results):
if isinstance(result, BaseException):
if isinstance(result, Exception):
if self.raise_on_error:
logger.error(
"error saving file for object %s",
obj,
exc_info=(type(result), result, result.__traceback__),
)
else:
# Legacy behavior: warning level
logger.warning(
"error saving file for object %s",
obj,
exc_info=(type(result), result, result.__traceback__),
)
errors.append(result)
else:
# BaseException (e.g., CancelledError) - always raise
raise result
else:
self._saved_in_transaction.add(obj)

for obj_to_delete, result in zip(delete_items, delete_results):
if isinstance(result, FileNotFoundError):
continue
if isinstance(result, Exception):
logger.warning("Error deleting file %s: %s", obj_to_delete.path, result.__cause__)

self.clear()
if isinstance(result, BaseException):
if isinstance(result, Exception):
if self.raise_on_error:
logger.error(
"error deleting file %s",
obj_to_delete.path or obj_to_delete,
exc_info=(type(result), result, result.__traceback__),
)
else:
logger.warning(
"error deleting file %s",
obj_to_delete.path or obj_to_delete,
exc_info=(type(result), result, result.__traceback__),
)
errors.append(result)
else:
raise result

if errors and self.raise_on_error:
if len(errors) == 1:
raise errors[0]
msg = "multiple FileObject operation failures"
raise ExceptionGroup(msg, errors)
if not errors:
self.clear()

def rollback(self) -> None:
"""Clean up files saved during a transaction that is being rolled back."""
Expand All @@ -94,30 +162,45 @@ def rollback(self) -> None:
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file during rollback %s: %s", obj.path, e.__cause__)
except Exception:
logger.exception("error deleting file during rollback %s", obj.path or obj)
raise
self.clear()

async def rollback_async(self) -> None:
"""Clean up files saved during a transaction that is being rolled back."""
rollback_delete_tasks: list[Awaitable[Any]] = []
objects_to_delete_on_rollback: list[FileObject] = []
# Only delete files that were actually saved *during this transaction*
for obj in self._saved_in_transaction:
if obj.path:
rollback_delete_tasks.append(obj.delete_async())
objects_to_delete_on_rollback.append(obj)

for task, obj_to_delete in zip(rollback_delete_tasks, objects_to_delete_on_rollback):
try:
await task
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file during rollback %s: %s", obj_to_delete.path, e.__cause__)
objects_to_delete = [obj for obj in self._saved_in_transaction if obj.path]
if not objects_to_delete:
self.clear()
return

delete_results = await asyncio.gather(
*(obj.delete_async() for obj in objects_to_delete),
return_exceptions=True,
)

errors: list[Exception] = []
for obj, result in zip(objects_to_delete, delete_results):
if isinstance(result, FileNotFoundError):
continue
if isinstance(result, BaseException):
if isinstance(result, Exception):
logger.error(
"error deleting file during rollback %s",
obj.path or obj,
exc_info=(type(result), result, result.__traceback__),
)
errors.append(result)
else:
# Propagate BaseExceptions like CancelledError
raise result

self.clear()
if errors:
if len(errors) == 1:
raise errors[0]
msg = "multiple FileObject rollback failures"
raise ExceptionGroup(msg, errors)

def clear(self) -> None:
"""Clear the tracker's state."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"typing-extensions>=4.0.0",
"greenlet",
"eval-type-backport ; python_full_version < '3.10'",
"exceptiongroup ; python_full_version < '3.11'",
]
description = "Ready-to-go SQLAlchemy concoctions."
keywords = ["sqlalchemy", "alembic", "litestar", "sanic", "fastapi", "flask"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def test_session_config_dict_with_no_provided_config(
) -> None:
"""Test session_config_dict with no provided config."""
config = config_cls()
assert config.session_config_dict == {}
# Config now includes file_object_raise_on_error in session info by default
assert config.session_config_dict == {"info": {"file_object_raise_on_error": True}}


def test_config_create_engine_if_engine_instance_provided(
Expand Down
Loading