|
21 | 21 | from itertools import zip_longest |
22 | 22 | from logging import getLogger |
23 | 23 | from typing import ( |
| 24 | + TYPE_CHECKING, |
24 | 25 | Any, |
25 | 26 | ClassVar, |
26 | 27 | Literal, |
|
44 | 45 | import temporalio.exceptions |
45 | 46 | import temporalio.types |
46 | 47 |
|
| 48 | +if TYPE_CHECKING: |
| 49 | + from temporalio.extstore import StorageOptions, _ExternalStorageMiddleware |
| 50 | + |
47 | 51 | if sys.version_info < (3, 11): |
48 | 52 | # Python's datetime.fromisoformat doesn't support certain formats pre-3.11 |
49 | 53 | from dateutil import parser # type: ignore |
@@ -924,11 +928,42 @@ def to_failure( |
924 | 928 | failure: temporalio.api.failure.v1.Failure, |
925 | 929 | ) -> None: |
926 | 930 | """See base class.""" |
| 931 | + from temporalio.extstore import ( |
| 932 | + DriverError, |
| 933 | + PayloadNotFoundError, |
| 934 | + ) |
| 935 | + |
927 | 936 | # If already a failure error, use that |
928 | 937 | if isinstance(exception, temporalio.exceptions.FailureError): |
929 | 938 | self._error_to_failure(exception, payload_converter, failure) |
930 | 939 | elif isinstance(exception, nexusrpc.HandlerError): |
931 | 940 | self._nexus_handler_error_to_failure(exception, payload_converter, failure) |
| 941 | + elif isinstance(exception, PayloadNotFoundError): |
| 942 | + # Convert to failure error |
| 943 | + failure_error = temporalio.exceptions.ApplicationError( |
| 944 | + str(exception), |
| 945 | + { |
| 946 | + "driver_name": exception.driver_name, |
| 947 | + "driver_claim": exception.driver_claim, |
| 948 | + }, |
| 949 | + type=exception.__class__.__name__, |
| 950 | + non_retryable=True, |
| 951 | + ) |
| 952 | + failure_error.__traceback__ = exception.__traceback__ |
| 953 | + failure_error.__cause__ = exception.__cause__ |
| 954 | + self._error_to_failure(failure_error, payload_converter, failure) |
| 955 | + elif isinstance(exception, DriverError): |
| 956 | + # Convert to failure error |
| 957 | + failure_error = temporalio.exceptions.ApplicationError( |
| 958 | + str(exception), |
| 959 | + { |
| 960 | + "driver_name": exception.driver_name, |
| 961 | + }, |
| 962 | + type=exception.__class__.__name__, |
| 963 | + ) |
| 964 | + failure_error.__traceback__ = exception.__traceback__ |
| 965 | + failure_error.__cause__ = exception.__cause__ |
| 966 | + self._error_to_failure(failure_error, payload_converter, failure) |
932 | 967 | else: |
933 | 968 | # Convert to failure error |
934 | 969 | failure_error = temporalio.exceptions.ApplicationError( |
@@ -1359,15 +1394,27 @@ class DataConverter(WithSerializationContext): |
1359 | 1394 | payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() |
1360 | 1395 | """Settings for payload size limits.""" |
1361 | 1396 |
|
| 1397 | + external_storage: StorageOptions | None = None |
| 1398 | + """Options for external storage. If None, external storage is disabled. |
| 1399 | + |
| 1400 | + .. warning:: |
| 1401 | + This API is experimental. |
| 1402 | + """ |
| 1403 | + |
1362 | 1404 | default: ClassVar[DataConverter] |
1363 | 1405 | """Singleton default data converter.""" |
1364 | 1406 |
|
| 1407 | + _external_storage_middleware: "_ExternalStorageMiddleware" = dataclasses.field( |
| 1408 | + init=False |
| 1409 | + ) |
| 1410 | + |
1365 | 1411 | _payload_error_limits: _ServerPayloadErrorLimits | None = None |
1366 | 1412 | """Server-reported limits for payloads.""" |
1367 | 1413 |
|
1368 | 1414 | def __post_init__(self) -> None: # noqa: D105 |
1369 | 1415 | object.__setattr__(self, "payload_converter", self.payload_converter_class()) |
1370 | 1416 | object.__setattr__(self, "failure_converter", self.failure_converter_class()) |
| 1417 | + self._reset_external_storage_middleware() |
1371 | 1418 |
|
1372 | 1419 | async def encode( |
1373 | 1420 | self, values: Sequence[Any] |
@@ -1445,27 +1492,45 @@ def with_context(self, context: SerializationContext) -> Self: |
1445 | 1492 | payload_converter = self.payload_converter |
1446 | 1493 | payload_codec = self.payload_codec |
1447 | 1494 | failure_converter = self.failure_converter |
| 1495 | + external_storage = self.external_storage |
1448 | 1496 | if isinstance(payload_converter, WithSerializationContext): |
1449 | 1497 | payload_converter = payload_converter.with_context(context) |
1450 | 1498 | if isinstance(payload_codec, WithSerializationContext): |
1451 | 1499 | payload_codec = payload_codec.with_context(context) |
1452 | 1500 | if isinstance(failure_converter, WithSerializationContext): |
1453 | 1501 | failure_converter = failure_converter.with_context(context) |
| 1502 | + if isinstance(external_storage, WithSerializationContext): |
| 1503 | + external_storage = external_storage.with_context(context) |
1454 | 1504 | if all( |
1455 | 1505 | new is orig |
1456 | 1506 | for new, orig in [ |
1457 | 1507 | (payload_converter, self.payload_converter), |
1458 | 1508 | (payload_codec, self.payload_codec), |
1459 | 1509 | (failure_converter, self.failure_converter), |
| 1510 | + (external_storage, self.external_storage), |
1460 | 1511 | ] |
1461 | 1512 | ): |
1462 | 1513 | return self |
1463 | 1514 | cloned = dataclasses.replace(self) |
1464 | 1515 | object.__setattr__(cloned, "payload_converter", payload_converter) |
1465 | 1516 | object.__setattr__(cloned, "payload_codec", payload_codec) |
1466 | 1517 | object.__setattr__(cloned, "failure_converter", failure_converter) |
| 1518 | + object.__setattr__(cloned, "external_storage", external_storage) |
| 1519 | + cloned._reset_external_storage_middleware(context) |
1467 | 1520 | return cloned |
1468 | 1521 |
|
| 1522 | + def _reset_external_storage_middleware( |
| 1523 | + self, context: SerializationContext | None = None |
| 1524 | + ) -> None: |
| 1525 | + # Lazy import to avoid circular dependency |
| 1526 | + from temporalio.extstore import _ExternalStorageMiddleware |
| 1527 | + |
| 1528 | + object.__setattr__( |
| 1529 | + self, |
| 1530 | + "_external_storage_middleware", |
| 1531 | + _ExternalStorageMiddleware(self.external_storage, context), |
| 1532 | + ) |
| 1533 | + |
1469 | 1534 | def _with_payload_error_limits( |
1470 | 1535 | self, limits: _ServerPayloadErrorLimits | None |
1471 | 1536 | ) -> DataConverter: |
@@ -1523,48 +1588,47 @@ async def _encode_memo_existing( |
1523 | 1588 | async def _encode_payload( |
1524 | 1589 | self, payload: temporalio.api.common.v1.Payload |
1525 | 1590 | ) -> temporalio.api.common.v1.Payload: |
| 1591 | + payload = await self._external_storage_middleware.store_payload(payload) |
1526 | 1592 | if self.payload_codec: |
1527 | 1593 | payload = (await self.payload_codec.encode([payload]))[0] |
1528 | 1594 | self._validate_payload_limits([payload]) |
1529 | 1595 | return payload |
1530 | 1596 |
|
1531 | 1597 | async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): |
1532 | | - if self.payload_codec: |
1533 | | - await self.payload_codec.encode_wrapper(payloads) |
1534 | | - self._validate_payload_limits(payloads.payloads) |
| 1598 | + encoded_payloads = await self._encode_payload_sequence(payloads.payloads) |
| 1599 | + del payloads.payloads[:] |
| 1600 | + payloads.payloads.extend(encoded_payloads) |
1535 | 1601 |
|
1536 | 1602 | async def _encode_payload_sequence( |
1537 | 1603 | self, payloads: Sequence[temporalio.api.common.v1.Payload] |
1538 | 1604 | ) -> list[temporalio.api.common.v1.Payload]: |
1539 | | - encoded_payloads = list(payloads) |
| 1605 | + result = await self._external_storage_middleware.store_payloads(payloads) |
1540 | 1606 | if self.payload_codec: |
1541 | | - encoded_payloads = await self.payload_codec.encode(encoded_payloads) |
1542 | | - self._validate_payload_limits(encoded_payloads) |
1543 | | - return encoded_payloads |
| 1607 | + result = await self.payload_codec.encode(result) |
| 1608 | + self._validate_payload_limits(result) |
| 1609 | + return result |
1544 | 1610 |
|
1545 | 1611 | async def _decode_payload( |
1546 | 1612 | self, payload: temporalio.api.common.v1.Payload |
1547 | 1613 | ) -> temporalio.api.common.v1.Payload: |
1548 | 1614 | if self.payload_codec: |
1549 | 1615 | payload = (await self.payload_codec.decode([payload]))[0] |
| 1616 | + payload = await self._external_storage_middleware.retrieve_payload(payload) |
1550 | 1617 | return payload |
1551 | 1618 |
|
1552 | 1619 | async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): |
1553 | | - if self.payload_codec: |
1554 | | - await self.payload_codec.decode_wrapper(payloads) |
| 1620 | + decoded_payloads = await self._decode_payload_sequence(payloads.payloads) |
| 1621 | + del payloads.payloads[:] |
| 1622 | + payloads.payloads.extend(decoded_payloads) |
1555 | 1623 |
|
1556 | 1624 | async def _decode_payload_sequence( |
1557 | 1625 | self, payloads: Sequence[temporalio.api.common.v1.Payload] |
1558 | 1626 | ) -> list[temporalio.api.common.v1.Payload]: |
1559 | | - if not self.payload_codec: |
1560 | | - return list(payloads) |
1561 | | - return await self.payload_codec.decode(payloads) |
1562 | | - |
1563 | | - # Temporary shortcircuit detection while the _decode_* methods may no-op if |
1564 | | - # a payload codec is not configured. Remove once those paths have more to them. |
1565 | | - @property |
1566 | | - def _decode_payload_has_effect(self) -> bool: |
1567 | | - return self.payload_codec is not None |
| 1627 | + result = list(payloads) |
| 1628 | + if self.payload_codec: |
| 1629 | + result = await self.payload_codec.decode(result) |
| 1630 | + result = await self._external_storage_middleware.retrieve_payloads(result) |
| 1631 | + return result |
1568 | 1632 |
|
1569 | 1633 | @staticmethod |
1570 | 1634 | async def _apply_to_failure_payloads( |
|
0 commit comments