|
21 | 21 | from itertools import zip_longest |
22 | 22 | from logging import getLogger |
23 | 23 | from typing import ( |
| 24 | + TYPE_CHECKING, |
24 | 25 | Any, |
25 | 26 | ClassVar, |
26 | 27 | Literal, |
|
44 | 45 | import temporalio.exceptions |
45 | 46 | import temporalio.types |
46 | 47 |
|
| 48 | +if TYPE_CHECKING: |
| 49 | + from temporalio.extstore import StorageOptions, _ExternalStorageMiddleware |
| 50 | + |
47 | 51 | if sys.version_info < (3, 11): |
48 | 52 | # Python's datetime.fromisoformat doesn't support certain formats pre-3.11 |
49 | 53 | from dateutil import parser # type: ignore |
@@ -922,11 +926,42 @@ def to_failure( |
922 | 926 | failure: temporalio.api.failure.v1.Failure, |
923 | 927 | ) -> None: |
924 | 928 | """See base class.""" |
| 929 | + from temporalio.extstore import ( |
| 930 | + DriverError, |
| 931 | + PayloadNotFoundError, |
| 932 | + ) |
| 933 | + |
925 | 934 | # If already a failure error, use that |
926 | 935 | if isinstance(exception, temporalio.exceptions.FailureError): |
927 | 936 | self._error_to_failure(exception, payload_converter, failure) |
928 | 937 | elif isinstance(exception, nexusrpc.HandlerError): |
929 | 938 | self._nexus_handler_error_to_failure(exception, payload_converter, failure) |
| 939 | + elif isinstance(exception, PayloadNotFoundError): |
| 940 | + # Convert to failure error |
| 941 | + failure_error = temporalio.exceptions.ApplicationError( |
| 942 | + str(exception), |
| 943 | + { |
| 944 | + "driver_name": exception.driver_name, |
| 945 | + "driver_claim": exception.driver_claim, |
| 946 | + }, |
| 947 | + type=exception.__class__.__name__, |
| 948 | + non_retryable=True, |
| 949 | + ) |
| 950 | + failure_error.__traceback__ = exception.__traceback__ |
| 951 | + failure_error.__cause__ = exception.__cause__ |
| 952 | + self._error_to_failure(failure_error, payload_converter, failure) |
| 953 | + elif isinstance(exception, DriverError): |
| 954 | + # Convert to failure error |
| 955 | + failure_error = temporalio.exceptions.ApplicationError( |
| 956 | + str(exception), |
| 957 | + { |
| 958 | + "driver_name": exception.driver_name, |
| 959 | + }, |
| 960 | + type=exception.__class__.__name__, |
| 961 | + ) |
| 962 | + failure_error.__traceback__ = exception.__traceback__ |
| 963 | + failure_error.__cause__ = exception.__cause__ |
| 964 | + self._error_to_failure(failure_error, payload_converter, failure) |
930 | 965 | else: |
931 | 966 | # Convert to failure error |
932 | 967 | failure_error = temporalio.exceptions.ApplicationError( |
@@ -1289,15 +1324,27 @@ class DataConverter(WithSerializationContext): |
1289 | 1324 | payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() |
1290 | 1325 | """Settings for payload size limits.""" |
1291 | 1326 |
|
| 1327 | + external_storage: StorageOptions | None = None |
| 1328 | + """Options for external storage. If None, external storage is disabled. |
| 1329 | + |
| 1330 | + .. warning:: |
| 1331 | + This API is experimental. |
| 1332 | + """ |
| 1333 | + |
1292 | 1334 | default: ClassVar[DataConverter] |
1293 | 1335 | """Singleton default data converter.""" |
1294 | 1336 |
|
| 1337 | + _external_storage_middleware: "_ExternalStorageMiddleware" = dataclasses.field( |
| 1338 | + init=False |
| 1339 | + ) |
| 1340 | + |
1295 | 1341 | _payload_error_limits: _ServerPayloadErrorLimits | None = None |
1296 | 1342 | """Server-reported limits for payloads.""" |
1297 | 1343 |
|
1298 | 1344 | def __post_init__(self) -> None: # noqa: D105 |
1299 | 1345 | object.__setattr__(self, "payload_converter", self.payload_converter_class()) |
1300 | 1346 | object.__setattr__(self, "failure_converter", self.failure_converter_class()) |
| 1347 | + self._reset_external_storage_middleware() |
1301 | 1348 |
|
1302 | 1349 | async def encode( |
1303 | 1350 | self, values: Sequence[Any] |
@@ -1375,27 +1422,45 @@ def with_context(self, context: SerializationContext) -> Self: |
1375 | 1422 | payload_converter = self.payload_converter |
1376 | 1423 | payload_codec = self.payload_codec |
1377 | 1424 | failure_converter = self.failure_converter |
| 1425 | + external_storage = self.external_storage |
1378 | 1426 | if isinstance(payload_converter, WithSerializationContext): |
1379 | 1427 | payload_converter = payload_converter.with_context(context) |
1380 | 1428 | if isinstance(payload_codec, WithSerializationContext): |
1381 | 1429 | payload_codec = payload_codec.with_context(context) |
1382 | 1430 | if isinstance(failure_converter, WithSerializationContext): |
1383 | 1431 | failure_converter = failure_converter.with_context(context) |
| 1432 | + if isinstance(external_storage, WithSerializationContext): |
| 1433 | + external_storage = external_storage.with_context(context) |
1384 | 1434 | if all( |
1385 | 1435 | new is orig |
1386 | 1436 | for new, orig in [ |
1387 | 1437 | (payload_converter, self.payload_converter), |
1388 | 1438 | (payload_codec, self.payload_codec), |
1389 | 1439 | (failure_converter, self.failure_converter), |
| 1440 | + (external_storage, self.external_storage), |
1390 | 1441 | ] |
1391 | 1442 | ): |
1392 | 1443 | return self |
1393 | 1444 | cloned = dataclasses.replace(self) |
1394 | 1445 | object.__setattr__(cloned, "payload_converter", payload_converter) |
1395 | 1446 | object.__setattr__(cloned, "payload_codec", payload_codec) |
1396 | 1447 | object.__setattr__(cloned, "failure_converter", failure_converter) |
| 1448 | + object.__setattr__(cloned, "external_storage", external_storage) |
| 1449 | + cloned._reset_external_storage_middleware(context) |
1397 | 1450 | return cloned |
1398 | 1451 |
|
| 1452 | + def _reset_external_storage_middleware( |
| 1453 | + self, context: SerializationContext | None = None |
| 1454 | + ) -> None: |
| 1455 | + # Lazy import to avoid circular dependency |
| 1456 | + from temporalio.extstore import _ExternalStorageMiddleware |
| 1457 | + |
| 1458 | + object.__setattr__( |
| 1459 | + self, |
| 1460 | + "_external_storage_middleware", |
| 1461 | + _ExternalStorageMiddleware(self.external_storage, context), |
| 1462 | + ) |
| 1463 | + |
1399 | 1464 | def _with_payload_error_limits( |
1400 | 1465 | self, limits: _ServerPayloadErrorLimits | None |
1401 | 1466 | ) -> DataConverter: |
@@ -1453,48 +1518,47 @@ async def _encode_memo_existing( |
1453 | 1518 | async def _encode_payload( |
1454 | 1519 | self, payload: temporalio.api.common.v1.Payload |
1455 | 1520 | ) -> temporalio.api.common.v1.Payload: |
| 1521 | + payload = await self._external_storage_middleware.store_payload(payload) |
1456 | 1522 | if self.payload_codec: |
1457 | 1523 | payload = (await self.payload_codec.encode([payload]))[0] |
1458 | 1524 | self._validate_payload_limits([payload]) |
1459 | 1525 | return payload |
1460 | 1526 |
|
1461 | 1527 | async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): |
1462 | | - if self.payload_codec: |
1463 | | - await self.payload_codec.encode_wrapper(payloads) |
1464 | | - self._validate_payload_limits(payloads.payloads) |
| 1528 | + encoded_payloads = await self._encode_payload_sequence(payloads.payloads) |
| 1529 | + del payloads.payloads[:] |
| 1530 | + payloads.payloads.extend(encoded_payloads) |
1465 | 1531 |
|
1466 | 1532 | async def _encode_payload_sequence( |
1467 | 1533 | self, payloads: Sequence[temporalio.api.common.v1.Payload] |
1468 | 1534 | ) -> list[temporalio.api.common.v1.Payload]: |
1469 | | - encoded_payloads = list(payloads) |
| 1535 | + result = await self._external_storage_middleware.store_payloads(payloads) |
1470 | 1536 | if self.payload_codec: |
1471 | | - encoded_payloads = await self.payload_codec.encode(encoded_payloads) |
1472 | | - self._validate_payload_limits(encoded_payloads) |
1473 | | - return encoded_payloads |
| 1537 | + result = await self.payload_codec.encode(result) |
| 1538 | + self._validate_payload_limits(result) |
| 1539 | + return result |
1474 | 1540 |
|
1475 | 1541 | async def _decode_payload( |
1476 | 1542 | self, payload: temporalio.api.common.v1.Payload |
1477 | 1543 | ) -> temporalio.api.common.v1.Payload: |
1478 | 1544 | if self.payload_codec: |
1479 | 1545 | payload = (await self.payload_codec.decode([payload]))[0] |
| 1546 | + payload = await self._external_storage_middleware.retrieve_payload(payload) |
1480 | 1547 | return payload |
1481 | 1548 |
|
1482 | 1549 | async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): |
1483 | | - if self.payload_codec: |
1484 | | - await self.payload_codec.decode_wrapper(payloads) |
| 1550 | + decoded_payloads = await self._decode_payload_sequence(payloads.payloads) |
| 1551 | + del payloads.payloads[:] |
| 1552 | + payloads.payloads.extend(decoded_payloads) |
1485 | 1553 |
|
1486 | 1554 | async def _decode_payload_sequence( |
1487 | 1555 | self, payloads: Sequence[temporalio.api.common.v1.Payload] |
1488 | 1556 | ) -> list[temporalio.api.common.v1.Payload]: |
1489 | | - if not self.payload_codec: |
1490 | | - return list(payloads) |
1491 | | - return await self.payload_codec.decode(payloads) |
1492 | | - |
1493 | | - # Temporary shortcircuit detection while the _decode_* methods may no-op if |
1494 | | - # a payload codec is not configured. Remove once those paths have more to them. |
1495 | | - @property |
1496 | | - def _decode_payload_has_effect(self) -> bool: |
1497 | | - return self.payload_codec is not None |
| 1557 | + result = list(payloads) |
| 1558 | + if self.payload_codec: |
| 1559 | + result = await self.payload_codec.decode(result) |
| 1560 | + result = await self._external_storage_middleware.retrieve_payloads(result) |
| 1561 | + return result |
1498 | 1562 |
|
1499 | 1563 | @staticmethod |
1500 | 1564 | async def _apply_to_failure_payloads( |
|
0 commit comments