diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index fa3e0bb4..1a0fdd7a 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -15,13 +15,20 @@ building durable orchestrations. The repo contains two packages: - Follow PEP 8 conventions. - Use `autopep8` for Python formatting. +## Python Type Checking + +Before linting, check for and fix any Pylance errors in the files you +changed. Use the editor's diagnostics (or the `get_errors` tool) to +identify type errors and resolve them first — type safety takes +priority over style. + ## Python Linting This repository uses [flake8](https://flake8.pycqa.org/) for Python linting. Run it after making changes to verify there are no issues: ```bash -flake8 path/to/changed/file.py +python -m flake8 path/to/changed/file.py ``` ## Markdown Style @@ -57,19 +64,19 @@ repository root. To lint a single file: ```bash -pymarkdown -c .pymarkdown.json scan path/to/file.md +python -m pymarkdown -c .pymarkdown.json scan path/to/file.md ``` To lint all Markdown files in the repository: ```bash -pymarkdown -c .pymarkdown.json scan **/*.md +python -m pymarkdown -c .pymarkdown.json scan **/*.md ``` Install the linter via the dev dependencies: ```bash -pip install -r dev-requirements.txt +python -m pip install -r dev-requirements.txt ``` ## Building and Testing @@ -77,18 +84,25 @@ pip install -r dev-requirements.txt Install the packages locally in editable mode: ```bash -pip install -e . -e ./durabletask-azuremanaged +python -m pip install -e . -e ./durabletask-azuremanaged ``` Run tests with pytest: ```bash -pytest +python -m pytest ``` ## Project Structure - `durabletask/` — core SDK source + - `payload/` — public payload externalization API (`PayloadStore` ABC, + `LargePayloadStorageOptions`, helper functions) + - `extensions/azure_blob_payloads/` — Azure Blob Storage payload store + (installed via `pip install durabletask[azure-blob-payloads]`) + - `entities/` — durable entity support + - `testing/` — in-memory backend for testing without a sidecar + - `internal/` — protobuf definitions, gRPC helpers, tracing (not public API) - `durabletask-azuremanaged/` — Azure managed provider source - `examples/` — example orchestrations (see `examples/README.md`) - `tests/` — test suite diff --git a/.github/workflows/durabletask.yml b/.github/workflows/durabletask.yml index 5acd9859..1fd072a0 100644 --- a/.github/workflows/durabletask.yml +++ b/.github/workflows/durabletask.yml @@ -47,16 +47,40 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + - name: Set up Node.js (needed for Azurite) + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Cache npm + uses: actions/cache@v3 + with: + path: ~/.npm + key: ${{ runner.os }}-npm-azurite + + - name: Install Azurite + run: npm install -g azurite + + - name: Start Azurite + shell: bash + run: | + azurite-blob --silent --blobPort 10000 & + sleep 2 + - name: Install durabletask dependencies and the library itself run: | python -m pip install --upgrade pip pip install flake8 pytest pip install -r requirements.txt - pip install . + pip install ".[azure-blob-payloads]" + pip install aiohttp + - name: Pytest unit tests working-directory: tests/durabletask run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 38c8d2a0..d715e5b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ADDED +- Added large payload externalization support for automatically + offloading oversized orchestration payloads to Azure Blob Storage. + Install with `pip install durabletask[azure-blob-payloads]`. + Pass a `BlobPayloadStore` to the worker and client via the + `payload_store` parameter. +- Added `durabletask.extensions.azure_blob_payloads` extension + package with `BlobPayloadStore` and `BlobPayloadStoreOptions` +- Added `PayloadStore` abstract base class in + `durabletask.payload` for custom storage backends - Added `durabletask.testing` module with `InMemoryOrchestrationBackend` for testing orchestrations without a sidecar process - Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio` - Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors diff --git a/README.md b/README.md index 49d6e0d0..48a2a467 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,20 @@ This repo contains a Python SDK for use with the [Azure Durable Task Scheduler]( - [Development Guide](./docs/development.md) - [Contributing Guide](./CONTRIBUTING.md) +## Optional Features + +### Large Payload Externalization + +Install the `azure-blob-payloads` extra to automatically offload +oversized orchestration payloads to Azure Blob Storage: + +```bash +pip install durabletask[azure-blob-payloads] +``` + +See the [feature documentation](./docs/features.md#large-payload-externalization) +and the [example](./examples/large_payload/) for usage details. + ## Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow diff --git a/docs/features.md b/docs/features.md index 0ccac74b..85db052d 100644 --- a/docs/features.md +++ b/docs/features.md @@ -150,6 +150,151 @@ Orchestrations can be suspended using the `suspend_orchestration` client API and Orchestrations can specify retry policies for activities and sub-orchestrations. These policies control how many times and how frequently an activity or sub-orchestration will be retried in the event of a transient error. +### Large payload externalization + +Orchestration inputs, outputs, and event data are transmitted through +gRPC messages. When these payloads become very large they can exceed +gRPC message size limits or degrade performance. Large payload +externalization solves this by transparently offloading oversized +payloads to an external store (such as Azure Blob Storage) and +replacing them with compact reference tokens in the gRPC messages. + +This feature is **opt-in** and requires installing an optional +dependency: + +```bash +pip install durabletask[azure-blob-payloads] +``` + +#### How it works + +1. When the worker or client sends a payload that exceeds the + configured threshold (default 900 KB), the payload is + compressed (GZip, enabled by default) and uploaded to the + external store. +2. The original payload in the gRPC message is replaced with a + compact token (e.g. `blob:v1::`). +3. When the worker or client receives a message containing a token, + it downloads and decompresses the original payload automatically. + +This process is fully transparent to orchestrator and activity code — +no changes are needed in your workflow logic. + +#### Configuring the blob payload store + +The built-in `BlobPayloadStore` uses Azure Blob Storage. Create a +store instance and pass it to both the worker and client: + +```python +from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions + +store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="DefaultEndpointsProtocol=https;...", + container_name="durabletask-payloads", # default + threshold_bytes=900_000, # default (900 KB) + max_stored_payload_bytes=10_485_760, # default (10 MB) + enable_compression=True, # default +)) +``` + +Then pass the store to the worker and client: + +```python +with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + payload_store=store, +) as w: + # ... register orchestrators and activities ... + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + payload_store=store, + ) +``` + +You can also authenticate using `account_url` and a +`TokenCredential` instead of a connection string: + +```python +from azure.identity import DefaultAzureCredential + +store = BlobPayloadStore(BlobPayloadStoreOptions( + account_url="https://.blob.core.windows.net", + credential=DefaultAzureCredential(), +)) +``` + +#### Configuration options + +| Option | Default | Description | +|---|---|---| +| `threshold_bytes` | 900,000 (900 KB) | Payloads larger than this are externalized | +| `max_stored_payload_bytes` | 10,485,760 (10 MB) | Maximum size for externalized payloads | +| `enable_compression` | `True` | GZip-compress payloads before uploading | +| `container_name` | `"durabletask-payloads"` | Azure Blob container name | +| `connection_string` | `None` | Azure Storage connection string | +| `account_url` | `None` | Azure Storage account URL (use with `credential`) | +| `credential` | `None` | `TokenCredential` for token-based auth | + +#### Cross-SDK compatibility + +The blob token format (`blob:v1::`) is +compatible with the .NET Durable Task SDK, enabling +interoperability between Python and .NET workers sharing the same +task hub and storage account. Note that message serialization strategies +may differ for complex objects and custom types. + +#### Custom payload stores + +You can implement a custom payload store by subclassing +`PayloadStore` from `durabletask.payload` and implementing +the `upload`, `upload_async`, `download`, `download_async`, and +`is_known_token` methods: + +```python +from typing import Optional + +from durabletask.payload import PayloadStore, LargePayloadStorageOptions + + +class MyPayloadStore(PayloadStore): + + def __init__(self, options: LargePayloadStorageOptions): + self._options = options + + @property + def options(self) -> LargePayloadStorageOptions: + return self._options + + def upload(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + # Store data and return a unique token string + ... + + async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + ... + + def download(self, token: str) -> bytes: + # Retrieve data by token + ... + + async def download_async(self, token: str) -> bytes: + ... + + def is_known_token(self, value: str) -> bool: + # Return True if the value looks like a token from this store + ... +``` + +See the [large payload example](../examples/large_payload/) for a +complete working sample. + ### Logging configuration Both the TaskHubGrpcWorker and TaskHubGrpcClient (as well as DurableTaskSchedulerWorker and DurableTaskSchedulerClient for durabletask-azuremanaged) accept a log_handler and log_formatter object from `logging`. These can be used to customize verbosity, output location, and format of logs emitted by these sources. @@ -164,5 +309,5 @@ with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_cha taskhub=taskhub_name, token_credential=credential, log_handler=log_handler) as w: ``` -**NOTE** -The worker and client output many logs at the `DEBUG` level that will be useful when understanding orchestration flow and diagnosing issues with Durable applications. Before submitting issues, please attempt a repro of the issue with debug logging enabled. +> [!NOTE] +> The worker and client output many logs at the `DEBUG` level that will be useful when understanding orchestration flow and diagnosing issues with Durable applications. Before submitting issues, please attempt a repro of the issue with debug logging enabled. diff --git a/docs/supported-patterns.md b/docs/supported-patterns.md index 612678a1..5f6f0738 100644 --- a/docs/supported-patterns.md +++ b/docs/supported-patterns.md @@ -118,4 +118,57 @@ def my_orchestrator(ctx: task.OrchestrationContext, order: Order): return "Success" ``` -See the full [version-aware orchestrator sample](../examples/version_aware_orchestrator.py) \ No newline at end of file +See the full [version-aware orchestrator sample](../examples/version_aware_orchestrator.py) + +### Large payload externalization + +When orchestrations work with very large inputs, outputs, or event +data, the payloads can exceed gRPC message size limits. The large +payload externalization pattern transparently offloads these payloads +to Azure Blob Storage and replaces them with compact reference tokens +in the gRPC messages. + +No changes are required in orchestrator or activity code. Simply +install the optional dependency and configure a payload store on the +worker and client: + +```python +from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Configure the blob payload store +store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="DefaultEndpointsProtocol=https;...", +)) + +# Pass the store to both worker and client +with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential, + payload_store=store, +) as w: + w.add_orchestrator(my_orchestrator) + w.add_activity(process_large_data) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential, + payload_store=store, + ) + + # This large input is automatically externalized to blob storage + large_input = "x" * 1_000_000 # 1 MB string + instance_id = c.schedule_new_orchestration(my_orchestrator, input=large_input) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) +``` + +In this example, any payload exceeding the threshold (default 900 KB) +is compressed and uploaded to the configured Azure Blob container. +When the worker or client reads the message, it downloads and +decompresses the payload automatically. + +See the full [large payload example](../examples/large_payload/) and +[feature documentation](./features.md#large-payload-externalization) +for configuration options and details. diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 84e6dbba..2d8774c6 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `AsyncDurableTaskSchedulerClient` for async/await usage with `grpc.aio` - Added `DTSAsyncDefaultClientInterceptorImpl` async gRPC interceptor for DTS authentication +- Added `payload_store` parameter to `DurableTaskSchedulerWorker`, + `DurableTaskSchedulerClient`, and `AsyncDurableTaskSchedulerClient` + for large payload externalization support +- Added `azure-blob-payloads` optional dependency that installs + `durabletask[azure-blob-payloads]` — install with + `pip install durabletask.azuremanaged[azure-blob-payloads]` ## v1.3.0 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 86878dd6..71d6581e 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -13,6 +13,7 @@ DTSDefaultClientInterceptorImpl, ) from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient +from durabletask.payload.store import PayloadStore # Client class used for Durable Task Scheduler (DTS) @@ -23,6 +24,7 @@ def __init__(self, *, token_credential: Optional[TokenCredential], secure_channel: bool = True, default_version: Optional[str] = None, + payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None): @@ -40,7 +42,8 @@ def __init__(self, *, log_handler=log_handler, log_formatter=log_formatter, interceptors=interceptors, - default_version=default_version) + default_version=default_version, + payload_store=payload_store) # Async client class used for Durable Task Scheduler (DTS) @@ -60,6 +63,8 @@ class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient): secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). Defaults to True. default_version (Optional[str], optional): Default version string for orchestrations. + payload_store (Optional[PayloadStore], optional): A payload store for + externalizing large payloads. If None, payloads are sent inline. log_handler (Optional[logging.Handler], optional): Custom logging handler for client logs. log_formatter (Optional[logging.Formatter], optional): Custom log formatter for client logs. @@ -85,6 +90,7 @@ def __init__(self, *, token_credential: Optional[AsyncTokenCredential], secure_channel: bool = True, default_version: Optional[str] = None, + payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None): @@ -102,4 +108,5 @@ def __init__(self, *, log_handler=log_handler, log_formatter=log_formatter, interceptors=interceptors, - default_version=default_version) + default_version=default_version, + payload_store=payload_store) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 48162234..ffe81d44 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -9,6 +9,7 @@ from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ DTSDefaultClientInterceptorImpl +from durabletask.payload.store import PayloadStore from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker @@ -30,6 +31,8 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default concurrency settings will be used. + payload_store (Optional[PayloadStore], optional): A payload store for + externalizing large payloads. If None, payloads are sent inline. log_handler (Optional[logging.Handler], optional): Custom logging handler for worker logs. log_formatter (Optional[logging.Formatter], optional): Custom log formatter for worker logs. @@ -63,6 +66,7 @@ def __init__(self, *, token_credential: Optional[TokenCredential], secure_channel: bool = True, concurrency_options: Optional[ConcurrencyOptions] = None, + payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None): @@ -80,4 +84,5 @@ def __init__(self, *, log_handler=log_handler, log_formatter=log_formatter, interceptors=interceptors, - concurrency_options=concurrency_options) + concurrency_options=concurrency_options, + payload_store=payload_store) diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index 7a084ebc..54407d95 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -30,6 +30,11 @@ dependencies = [ "azure-identity>=1.19.0" ] +[project.optional-dependencies] +azure-blob-payloads = [ + "durabletask[azure-blob-payloads]>=1.3.0" +] + [project.urls] repository = "https://github.com/microsoft/durabletask-python" changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" diff --git a/durabletask/__init__.py b/durabletask/__init__.py index e0e73d30..a8ed8b55 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -3,8 +3,14 @@ """Durable Task SDK for Python""" +from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore from durabletask.worker import ConcurrencyOptions, VersioningOptions -__all__ = ["ConcurrencyOptions", "VersioningOptions"] +__all__ = [ + "ConcurrencyOptions", + "LargePayloadStorageOptions", + "PayloadStore", + "VersioningOptions", +] PACKAGE_NAME = "durabletask" diff --git a/durabletask/client.py b/durabletask/client.py index aa8ab55e..a73fc343 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -32,6 +32,8 @@ prepare_async_interceptors, prepare_sync_interceptors, ) +from durabletask.payload import helpers as payload_helpers +from durabletask.payload.store import PayloadStore TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -152,7 +154,8 @@ def __init__(self, *, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, - default_version: Optional[str] = None): + default_version: Optional[str] = None, + payload_store: Optional[PayloadStore] = None): interceptors = prepare_sync_interceptors(metadata, interceptors) @@ -165,6 +168,7 @@ def __init__(self, *, self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version + self._payload_store = payload_store def close(self) -> None: """Closes the underlying gRPC channel.""" @@ -198,12 +202,20 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu req.parentTraceContext.CopyFrom(parent_trace_ctx) self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") + # Externalize any large payloads in the request + if self._payload_store is not None: + payload_helpers.externalize_payloads( + req, self._payload_store, instance_id=req.instanceId, + ) res: pb.CreateInstanceResponse = self._stub.StartInstance(req) return res.instanceId def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) res: pb.GetInstanceResponse = self._stub.GetInstance(req) + # De-externalize any large-payload tokens in the response + if self._payload_store is not None and res.exists: + payload_helpers.deexternalize_payloads(res, self._payload_store) return new_orchestration_state(req.instanceId, res) def get_all_orchestration_states(self, @@ -220,6 +232,8 @@ def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken @@ -235,6 +249,8 @@ def wait_for_orchestration_start(self, instance_id: str, *, try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout) + if self._payload_store is not None and res.exists: + payload_helpers.deexternalize_payloads(res, self._payload_store) return new_orchestration_state(req.instanceId, res) except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore @@ -250,6 +266,8 @@ def wait_for_orchestration_completion(self, instance_id: str, *, try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) + if self._payload_store is not None and res.exists: + payload_helpers.deexternalize_payloads(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) log_completion_state(self._logger, instance_id, state) return state @@ -264,6 +282,10 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, with tracing.start_raise_event_span(event_name, instance_id): req = build_raise_event_req(instance_id, event_name, data) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + if self._payload_store is not None: + payload_helpers.externalize_payloads( + req, self._payload_store, instance_id=instance_id, + ) self._stub.RaiseEvent(req) def terminate_orchestration(self, instance_id: str, *, @@ -272,6 +294,10 @@ def terminate_orchestration(self, instance_id: str, *, req = build_terminate_req(instance_id, output, recursive) self._logger.info(f"Terminating instance '{instance_id}'.") + if self._payload_store is not None: + payload_helpers.externalize_payloads( + req, self._payload_store, instance_id=instance_id, + ) self._stub.TerminateInstance(req) def suspend_orchestration(self, instance_id: str) -> None: @@ -330,6 +356,10 @@ def signal_entity(self, input: Optional[Any] = None) -> None: req = build_signal_entity_req(entity_instance_id, operation_name, input) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") + if self._payload_store is not None: + payload_helpers.externalize_payloads( + req, self._payload_store, instance_id=str(entity_instance_id), + ) self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? def get_entity(self, @@ -341,7 +371,8 @@ def get_entity(self, res: pb.GetEntityResponse = self._stub.GetEntity(req) if not res.exists: return None - + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(res, self._payload_store) return EntityMetadata.from_entity_metadata(res.entity, include_state) def get_all_entities(self, @@ -357,6 +388,8 @@ def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken @@ -402,7 +435,8 @@ def __init__(self, *, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, - default_version: Optional[str] = None): + default_version: Optional[str] = None, + payload_store: Optional[PayloadStore] = None): interceptors = prepare_async_interceptors(metadata, interceptors) @@ -415,6 +449,7 @@ def __init__(self, *, self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("async_client", log_handler, log_formatter) self.default_version = default_version + self._payload_store = payload_store async def close(self) -> None: """Closes the underlying gRPC channel.""" @@ -451,6 +486,11 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator req.parentTraceContext.CopyFrom(parent_trace_ctx) self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") + # Externalize any large payloads in the request + if self._payload_store is not None: + await payload_helpers.externalize_payloads_async( + req, self._payload_store, instance_id=req.instanceId, + ) res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) return res.instanceId @@ -458,6 +498,8 @@ async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + if self._payload_store is not None and res.exists: + await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) async def get_all_orchestration_states(self, @@ -474,6 +516,8 @@ async def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req) + if self._payload_store is not None: + await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken @@ -489,6 +533,8 @@ async def wait_for_orchestration_start(self, instance_id: str, *, try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout) + if self._payload_store is not None and res.exists: + await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) except grpc.aio.AioRpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: @@ -503,6 +549,8 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout) + if self._payload_store is not None and res.exists: + await payload_helpers.deexternalize_payloads_async(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) log_completion_state(self._logger, instance_id, state) return state @@ -517,6 +565,10 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *, with tracing.start_raise_event_span(event_name, instance_id): req = build_raise_event_req(instance_id, event_name, data) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + if self._payload_store is not None: + await payload_helpers.externalize_payloads_async( + req, self._payload_store, instance_id=instance_id, + ) await self._stub.RaiseEvent(req) async def terminate_orchestration(self, instance_id: str, *, @@ -525,6 +577,10 @@ async def terminate_orchestration(self, instance_id: str, *, req = build_terminate_req(instance_id, output, recursive) self._logger.info(f"Terminating instance '{instance_id}'.") + if self._payload_store is not None: + await payload_helpers.externalize_payloads_async( + req, self._payload_store, instance_id=instance_id, + ) await self._stub.TerminateInstance(req) async def suspend_orchestration(self, instance_id: str) -> None: @@ -583,6 +639,10 @@ async def signal_entity(self, input: Optional[Any] = None) -> None: req = build_signal_entity_req(entity_instance_id, operation_name, input) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") + if self._payload_store is not None: + await payload_helpers.externalize_payloads_async( + req, self._payload_store, instance_id=str(entity_instance_id), + ) await self._stub.SignalEntity(req, None) async def get_entity(self, @@ -594,7 +654,8 @@ async def get_entity(self, res: pb.GetEntityResponse = await self._stub.GetEntity(req) if not res.exists: return None - + if self._payload_store is not None: + await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return EntityMetadata.from_entity_metadata(res.entity, include_state) async def get_all_entities(self, @@ -610,6 +671,8 @@ async def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request) + if self._payload_store is not None: + await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken diff --git a/durabletask/extensions/__init__.py b/durabletask/extensions/__init__.py new file mode 100644 index 00000000..9c373d4c --- /dev/null +++ b/durabletask/extensions/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Durable Task SDK extension packages.""" diff --git a/durabletask/extensions/azure_blob_payloads/__init__.py b/durabletask/extensions/azure_blob_payloads/__init__.py new file mode 100644 index 00000000..83ff0ddf --- /dev/null +++ b/durabletask/extensions/azure_blob_payloads/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Azure Blob Storage payload externalization for Durable Task. + +This optional extension package provides a :class:`BlobPayloadStore` +that stores large orchestration / activity payloads in Azure Blob +Storage, keeping gRPC message sizes within safe limits. + +Install the required dependency with:: + + pip install durabletask[azure-blob-payloads] + +Usage:: + + from durabletask.extensions.azure_blob_payloads import ( + BlobPayloadStore, + BlobPayloadStoreOptions, + ) + + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="DefaultEndpointsProtocol=https;...", + )) + worker = TaskHubGrpcWorker(payload_store=store) +""" + +try: + from azure.storage.blob import BlobServiceClient # noqa: F401 +except ImportError as exc: + raise ImportError( + "The 'azure-storage-blob' package is required for blob payload " + "support. Install it with: pip install durabletask[azure-blob-payloads]" + ) from exc + +from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore +from durabletask.extensions.azure_blob_payloads.options import BlobPayloadStoreOptions + +__all__ = ["BlobPayloadStore", "BlobPayloadStoreOptions"] diff --git a/durabletask/extensions/azure_blob_payloads/blob_payload_store.py b/durabletask/extensions/azure_blob_payloads/blob_payload_store.py new file mode 100644 index 00000000..2673b155 --- /dev/null +++ b/durabletask/extensions/azure_blob_payloads/blob_payload_store.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Azure Blob Storage implementation of :class:`PayloadStore`.""" + +from __future__ import annotations + +import gzip +import logging +import uuid +from typing import Optional + +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobServiceClient +from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient + +from durabletask.extensions.azure_blob_payloads.options import BlobPayloadStoreOptions +from durabletask.payload.store import PayloadStore + +logger = logging.getLogger("durabletask-blobpayloads") + +# Token format matching the .NET SDK: blob:v1:: +_TOKEN_PREFIX = "blob:v1:" + + +class BlobPayloadStore(PayloadStore): + """Stores and retrieves large payloads in Azure Blob Storage. + + This implementation is compatible with the .NET SDK's + ``AzureBlobPayloadsSideCarInterceptor`` – both SDKs use the same + token format (``blob:v1::``) and the same + storage layout, allowing cross-language interoperability. + + Example:: + + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="...", + )) + + Args: + options: A :class:`BlobPayloadStoreOptions` with all settings. + """ + + def __init__(self, options: BlobPayloadStoreOptions): + if not options.connection_string and not options.account_url: + raise ValueError( + "Either 'connection_string' or 'account_url' (with 'credential') must be provided." + ) + + self._options = options + self._container_name = options.container_name + + # Optional kwargs shared by both sync and async clients. + extra_kwargs: dict = {} + if options.api_version: + extra_kwargs["api_version"] = options.api_version + + # Build sync client + if options.connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + options.connection_string, **extra_kwargs, + ) + else: + assert options.account_url is not None # guaranteed by validation above + self._blob_service_client = BlobServiceClient( + account_url=options.account_url, + credential=options.credential, + **extra_kwargs, + ) + + # Build async client + if options.connection_string: + self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string( + options.connection_string, **extra_kwargs, + ) + else: + assert options.account_url is not None # guaranteed by validation above + self._async_blob_service_client = AsyncBlobServiceClient( + account_url=options.account_url, + credential=options.credential, + **extra_kwargs, + ) + + self._ensure_container_created = False + + # ------------------------------------------------------------------ + # Lifecycle / resource management + # ------------------------------------------------------------------ + + def close(self) -> None: + """Close the underlying sync blob service client.""" + self._blob_service_client.close() + + async def close_async(self) -> None: + """Close the underlying async blob service client.""" + await self._async_blob_service_client.close() + + def __enter__(self) -> BlobPayloadStore: + return self + + def __exit__(self, *args: object) -> None: + self.close() + + async def __aenter__(self) -> BlobPayloadStore: + return self + + async def __aexit__(self, *args: object) -> None: + await self.close_async() + + @property + def options(self) -> BlobPayloadStoreOptions: + return self._options + + # ------------------------------------------------------------------ + # Sync operations + # ------------------------------------------------------------------ + + def upload(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + self._ensure_container_sync() + + if self._options.enable_compression: + data = gzip.compress(data) + + blob_name = self._make_blob_name(instance_id) + container_client = self._blob_service_client.get_container_client(self._container_name) + container_client.upload_blob(name=blob_name, data=data, overwrite=True) + + token = f"{_TOKEN_PREFIX}{self._container_name}:{blob_name}" + logger.debug("Uploaded %d bytes -> %s", len(data), token) + return token + + def download(self, token: str) -> bytes: + container, blob_name = self._parse_token(token) + container_client = self._blob_service_client.get_container_client(container) + blob_data = container_client.download_blob(blob_name).readall() + + if self._options.enable_compression: + blob_data = gzip.decompress(blob_data) + + logger.debug("Downloaded %d bytes <- %s", len(blob_data), token) + return blob_data + + # ------------------------------------------------------------------ + # Async operations + # ------------------------------------------------------------------ + + async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + await self._ensure_container_async() + + if self._options.enable_compression: + data = gzip.compress(data) + + blob_name = self._make_blob_name(instance_id) + container_client = self._async_blob_service_client.get_container_client(self._container_name) + await container_client.upload_blob(name=blob_name, data=data, overwrite=True) + + token = f"{_TOKEN_PREFIX}{self._container_name}:{blob_name}" + logger.debug("Uploaded %d bytes -> %s", len(data), token) + return token + + async def download_async(self, token: str) -> bytes: + container, blob_name = self._parse_token(token) + container_client = self._async_blob_service_client.get_container_client(container) + stream = await container_client.download_blob(blob_name) + blob_data = await stream.readall() + + if self._options.enable_compression: + blob_data = gzip.decompress(blob_data) + + logger.debug("Downloaded %d bytes <- %s", len(blob_data), token) + return blob_data + + # ------------------------------------------------------------------ + # Token helpers + # ------------------------------------------------------------------ + + def is_known_token(self, value: str) -> bool: + try: + self._parse_token(value) + return True + except ValueError: + return False + + @staticmethod + def _parse_token(token: str) -> tuple[str, str]: + """Parse ``blob:v1::`` into (container, blobName).""" + if not token.startswith(_TOKEN_PREFIX): + raise ValueError(f"Invalid blob payload token: {token!r}") + rest = token[len(_TOKEN_PREFIX):] + parts = rest.split(":", 1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError(f"Invalid blob payload token: {token!r}") + return parts[0], parts[1] + + @staticmethod + def _make_blob_name(instance_id: Optional[str] = None) -> str: + """Generate a blob name, optionally scoped under an instance ID folder.""" + unique = uuid.uuid4().hex + if instance_id: + return f"{instance_id}/{unique}" + return unique + + # ------------------------------------------------------------------ + # Container lifecycle + # ------------------------------------------------------------------ + + def _ensure_container_sync(self) -> None: + if self._ensure_container_created: + return + container_client = self._blob_service_client.get_container_client(self._container_name) + try: + container_client.create_container() + except ResourceExistsError: + pass + self._ensure_container_created = True + + async def _ensure_container_async(self) -> None: + if self._ensure_container_created: + return + container_client = self._async_blob_service_client.get_container_client(self._container_name) + try: + await container_client.create_container() + except ResourceExistsError: + pass + self._ensure_container_created = True diff --git a/durabletask/extensions/azure_blob_payloads/options.py b/durabletask/extensions/azure_blob_payloads/options.py new file mode 100644 index 00000000..1c551790 --- /dev/null +++ b/durabletask/extensions/azure_blob_payloads/options.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Configuration options for the Azure Blob payload store.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from durabletask.payload.store import LargePayloadStorageOptions + + +@dataclass +class BlobPayloadStoreOptions(LargePayloadStorageOptions): + """Configuration specific to the Azure Blob payload store. + + Inherits general threshold / compression settings from + :class:`~durabletask.payload.store.LargePayloadStorageOptions` + and adds Azure Blob-specific fields. + + Attributes: + container_name: Azure Blob container used to store externalized + payloads. Defaults to ``"durabletask-payloads"``. + connection_string: Azure Storage connection string. Mutually + exclusive with *account_url*. + account_url: Azure Storage account URL (e.g. + ``"https://.blob.core.windows.net"``). Use + together with *credential* for token-based auth. + credential: A ``TokenCredential`` instance (e.g. + ``DefaultAzureCredential``) for authenticating to the + storage account when using *account_url*. + api_version: Azure Storage API version override (useful for + Azurite compatibility). + """ + container_name: str = "durabletask-payloads" + connection_string: Optional[str] = None + account_url: Optional[str] = None + credential: Any = field(default=None, repr=False) + api_version: Optional[str] = None diff --git a/durabletask/payload/__init__.py b/durabletask/payload/__init__.py new file mode 100644 index 00000000..95bb32e1 --- /dev/null +++ b/durabletask/payload/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Public payload externalization API for the Durable Task SDK. + +This package exposes the abstract :class:`PayloadStore` interface, +configuration options, and helper functions for externalizing and +de-externalizing large payloads in protobuf messages. +""" + +from durabletask.payload.helpers import ( + deexternalize_payloads, + deexternalize_payloads_async, + externalize_payloads, + externalize_payloads_async, +) +from durabletask.payload.store import ( + LargePayloadStorageOptions, + PayloadStore, +) + +__all__ = [ + "LargePayloadStorageOptions", + "PayloadStore", + "deexternalize_payloads", + "deexternalize_payloads_async", + "externalize_payloads", + "externalize_payloads_async", +] diff --git a/durabletask/payload/helpers.py b/durabletask/payload/helpers.py new file mode 100644 index 00000000..1e8fd040 --- /dev/null +++ b/durabletask/payload/helpers.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Helpers for externalizing and de-externalizing large payloads in protobuf messages. + +These functions walk protobuf messages recursively, finding ``StringValue`` +fields whose content exceeds a configured threshold (externalize) or +matches a known payload-store token (de-externalize). The actual upload +/ download is delegated to a :class:`PayloadStore` instance. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +from google.protobuf import message as proto_message +from google.protobuf import wrappers_pb2 + +if TYPE_CHECKING: + from durabletask.payload.store import PayloadStore + +logger = logging.getLogger("durabletask-payloads") + + +# ------------------------------------------------------------------ +# Synchronous helpers +# ------------------------------------------------------------------ + + +def externalize_payloads( + msg: proto_message.Message, + store: PayloadStore, + *, + instance_id: Optional[str] = None, +) -> None: + """Walk *msg* in-place, uploading large ``StringValue`` fields to *store*. + + Any ``StringValue`` whose UTF-8 byte length exceeds + ``store.options.threshold_bytes`` is uploaded and its value replaced + with the token returned by the store. + """ + threshold = store.options.threshold_bytes + max_bytes = store.options.max_stored_payload_bytes + _walk_and_externalize(msg, store, threshold, max_bytes, instance_id) + + +def deexternalize_payloads( + msg: proto_message.Message, + store: PayloadStore, +) -> None: + """Walk *msg* in-place, downloading any ``StringValue`` fields that + contain a known payload-store token and replacing them with the + original content.""" + _walk_and_deexternalize(msg, store) + + +# ------------------------------------------------------------------ +# Async helpers +# ------------------------------------------------------------------ + + +async def externalize_payloads_async( + msg: proto_message.Message, + store: PayloadStore, + *, + instance_id: Optional[str] = None, +) -> None: + """Async version of :func:`externalize_payloads`.""" + threshold = store.options.threshold_bytes + max_bytes = store.options.max_stored_payload_bytes + await _walk_and_externalize_async(msg, store, threshold, max_bytes, instance_id) + + +async def deexternalize_payloads_async( + msg: proto_message.Message, + store: PayloadStore, +) -> None: + """Async version of :func:`deexternalize_payloads`.""" + await _walk_and_deexternalize_async(msg, store) + + +# ------------------------------------------------------------------ +# Internal recursive walkers – sync +# ------------------------------------------------------------------ + +def _is_map_field(fd) -> bool: + """Return True if the field descriptor represents a protobuf map field.""" + mt = fd.message_type + return mt is not None and fd.is_repeated and mt.GetOptions().map_entry + + +def _walk_and_externalize( + msg: proto_message.Message, + store: PayloadStore, + threshold: int, + max_bytes: int, + instance_id: Optional[str], +) -> None: + for fd in msg.DESCRIPTOR.fields: + if fd.message_type is None: + continue + + if _is_map_field(fd): + # Map fields: iterate values. ScalarMap values are not + # messages and will be skipped by the isinstance check. + for map_value in getattr(msg, fd.name).values(): + if isinstance(map_value, proto_message.Message): + if isinstance(map_value, wrappers_pb2.StringValue): + _try_externalize_field( + fd.name, map_value, store, + threshold, max_bytes, instance_id, + ) + else: + _walk_and_externalize( + map_value, store, threshold, max_bytes, instance_id + ) + elif fd.is_repeated: + value = getattr(msg, fd.name) + for item in value: + if isinstance(item, proto_message.Message): + if isinstance(item, wrappers_pb2.StringValue): + _try_externalize_field( + fd.name, item, store, + threshold, max_bytes, instance_id, + ) + else: + _walk_and_externalize( + item, store, threshold, max_bytes, instance_id + ) + else: + # Singular message field — only recurse if actually set + if not msg.HasField(fd.name): + continue + value = getattr(msg, fd.name) + if isinstance(value, wrappers_pb2.StringValue): + _try_externalize_field( + fd.name, value, store, + threshold, max_bytes, instance_id, + ) + else: + _walk_and_externalize( + value, store, threshold, max_bytes, instance_id + ) + + +def _try_externalize_field( + field_name: str, + sv: wrappers_pb2.StringValue, + store: PayloadStore, + threshold: int, + max_bytes: int, + instance_id: Optional[str], +) -> None: + val = sv.value + if not val: + return + # Already a token – skip + if store.is_known_token(val): + return + payload_bytes = val.encode("utf-8") + if len(payload_bytes) <= threshold: + return + if len(payload_bytes) > max_bytes: + raise ValueError( + f"Payload size {len(payload_bytes)} bytes exceeds the maximum " + f"allowed size of {max_bytes} bytes." + ) + token = store.upload(payload_bytes, instance_id=instance_id) + sv.value = token + logger.debug( + "Externalized %d-byte payload in field '%s' -> %s", + len(payload_bytes), field_name, token, + ) + + +def _walk_and_deexternalize( + msg: proto_message.Message, + store: PayloadStore, +) -> None: + for fd in msg.DESCRIPTOR.fields: + if fd.message_type is None: + continue + + if _is_map_field(fd): + for map_value in getattr(msg, fd.name).values(): + if isinstance(map_value, proto_message.Message): + if isinstance(map_value, wrappers_pb2.StringValue): + _try_deexternalize_field(map_value, store) + else: + _walk_and_deexternalize(map_value, store) + elif fd.is_repeated: + value = getattr(msg, fd.name) + for item in value: + if isinstance(item, proto_message.Message): + if isinstance(item, wrappers_pb2.StringValue): + _try_deexternalize_field(item, store) + else: + _walk_and_deexternalize(item, store) + else: + if not msg.HasField(fd.name): + continue + value = getattr(msg, fd.name) + if isinstance(value, wrappers_pb2.StringValue): + _try_deexternalize_field(value, store) + else: + _walk_and_deexternalize(value, store) + + +def _try_deexternalize_field( + sv: wrappers_pb2.StringValue, + store: PayloadStore, +) -> None: + val = sv.value + if not val or not store.is_known_token(val): + return + payload_bytes = store.download(val) + sv.value = payload_bytes.decode("utf-8") + logger.debug("De-externalized token %s -> %d bytes", val, len(payload_bytes)) + + +# ------------------------------------------------------------------ +# Internal recursive walkers – async +# ------------------------------------------------------------------ + +async def _walk_and_externalize_async( + msg: proto_message.Message, + store: PayloadStore, + threshold: int, + max_bytes: int, + instance_id: Optional[str], +) -> None: + for fd in msg.DESCRIPTOR.fields: + if fd.message_type is None: + continue + + if _is_map_field(fd): + for map_value in getattr(msg, fd.name).values(): + if isinstance(map_value, proto_message.Message): + if isinstance(map_value, wrappers_pb2.StringValue): + await _try_externalize_field_async( + fd.name, map_value, store, + threshold, max_bytes, instance_id, + ) + else: + await _walk_and_externalize_async( + map_value, store, threshold, max_bytes, instance_id, + ) + elif fd.is_repeated: + value = getattr(msg, fd.name) + for item in value: + if isinstance(item, proto_message.Message): + if isinstance(item, wrappers_pb2.StringValue): + await _try_externalize_field_async( + fd.name, item, store, + threshold, max_bytes, instance_id, + ) + else: + await _walk_and_externalize_async( + item, store, threshold, max_bytes, instance_id, + ) + else: + if not msg.HasField(fd.name): + continue + value = getattr(msg, fd.name) + if isinstance(value, wrappers_pb2.StringValue): + await _try_externalize_field_async( + fd.name, value, store, + threshold, max_bytes, instance_id, + ) + else: + await _walk_and_externalize_async( + value, store, threshold, max_bytes, instance_id, + ) + + +async def _try_externalize_field_async( + field_name: str, + sv: wrappers_pb2.StringValue, + store: PayloadStore, + threshold: int, + max_bytes: int, + instance_id: Optional[str], +) -> None: + val = sv.value + if not val: + return + # Already a token – skip + if store.is_known_token(val): + return + payload_bytes = val.encode("utf-8") + if len(payload_bytes) <= threshold: + return + if len(payload_bytes) > max_bytes: + raise ValueError( + f"Payload size {len(payload_bytes)} bytes exceeds the maximum " + f"allowed size of {max_bytes} bytes." + ) + token = await store.upload_async(payload_bytes, instance_id=instance_id) + sv.value = token + logger.debug( + "Externalized %d-byte payload in field '%s' -> %s", + len(payload_bytes), field_name, token, + ) + + +async def _walk_and_deexternalize_async( + msg: proto_message.Message, + store: PayloadStore, +) -> None: + for fd in msg.DESCRIPTOR.fields: + if fd.message_type is None: + continue + + if _is_map_field(fd): + for map_value in getattr(msg, fd.name).values(): + if isinstance(map_value, proto_message.Message): + if isinstance(map_value, wrappers_pb2.StringValue): + await _try_deexternalize_field_async(map_value, store) + else: + await _walk_and_deexternalize_async(map_value, store) + elif fd.is_repeated: + value = getattr(msg, fd.name) + for item in value: + if isinstance(item, proto_message.Message): + if isinstance(item, wrappers_pb2.StringValue): + await _try_deexternalize_field_async(item, store) + else: + await _walk_and_deexternalize_async(item, store) + else: + if not msg.HasField(fd.name): + continue + value = getattr(msg, fd.name) + if isinstance(value, wrappers_pb2.StringValue): + await _try_deexternalize_field_async(value, store) + else: + await _walk_and_deexternalize_async(value, store) + + +async def _try_deexternalize_field_async( + sv: wrappers_pb2.StringValue, + store: PayloadStore, +) -> None: + val = sv.value + if not val or not store.is_known_token(val): + return + payload_bytes = await store.download_async(val) + sv.value = payload_bytes.decode("utf-8") + logger.debug("De-externalized token %s -> %d bytes", val, len(payload_bytes)) diff --git a/durabletask/payload/store.py b/durabletask/payload/store.py new file mode 100644 index 00000000..a6e81d56 --- /dev/null +++ b/durabletask/payload/store.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Abstract base class for external payload storage providers. + +This module defines the interface that payload storage backends must +implement to support externalizing large orchestration payloads. The +default (and currently only) implementation stores payloads in Azure +Blob Storage. +""" + +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class LargePayloadStorageOptions: + """Configuration options for large-payload externalization. + + Attributes: + threshold_bytes: Payloads larger than this value (in bytes) will + be externalized to the payload store. Defaults to 900,000 + (900 KB), matching the .NET SDK default. + max_stored_payload_bytes: Maximum payload size (in bytes) that + can be stored externally. Payloads exceeding this limit + will cause an error. Defaults to 10,485,760 (10 MB). + enable_compression: When ``True`` (the default), payloads are + GZip-compressed before uploading. + """ + threshold_bytes: int = 900_000 + max_stored_payload_bytes: int = 10 * 1024 * 1024 # 10 MB + enable_compression: bool = True + + +class PayloadStore(abc.ABC): + """Abstract base class for external payload storage backends.""" + + @property + @abc.abstractmethod + def options(self) -> LargePayloadStorageOptions: + """Return the storage options for this payload store.""" + ... + + @abc.abstractmethod + def upload(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + """Upload a payload and return a reference token string. + + The returned token is a compact string that can be embedded in + gRPC messages in place of the original payload. The format + must be recognisable by :meth:`is_known_token`. + + Args: + data: The raw payload bytes to store. + instance_id: Optional orchestration instance ID for + organizing stored blobs. + + Returns: + A token string that can be used to retrieve the payload. + """ + ... + + @abc.abstractmethod + async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + """Async version of :meth:`upload`.""" + ... + + @abc.abstractmethod + def download(self, token: str) -> bytes: + """Download an externalized payload identified by *token*. + + Args: + token: The reference token returned by a previous + :meth:`upload` call. + + Returns: + The original payload bytes. + """ + ... + + @abc.abstractmethod + async def download_async(self, token: str) -> bytes: + """Async version of :meth:`download`.""" + ... + + @abc.abstractmethod + def is_known_token(self, value: str) -> bool: + """Return ``True`` if *value* looks like a token produced by this store.""" + ... diff --git a/durabletask/worker.py b/durabletask/worker.py index 9c7f2d46..f820803c 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -34,8 +34,10 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing +from durabletask.payload import helpers as payload_helpers from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from durabletask.payload.store import PayloadStore TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") @@ -322,6 +324,7 @@ def __init__( secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, concurrency_options: Optional[ConcurrencyOptions] = None, + payload_store: Optional[PayloadStore] = None, ): self._registry = _Registry() self._host_address = ( @@ -331,6 +334,7 @@ def __init__( self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + self._payload_store = payload_store # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -498,9 +502,13 @@ def should_invalidate_connection(rpc_error): try: assert current_stub is not None stub = current_stub + capabilities = [] + if self._payload_store is not None: + capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) get_work_items_request = pb.GetWorkItemsRequest( maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, + capabilities=capabilities, ) self._response_stream = stub.GetWorkItems(get_work_items_request) self._logger.info( @@ -637,6 +645,10 @@ def _execute_orchestrator( ): instance_id = req.instanceId + # De-externalize any large-payload tokens in the incoming request + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(req, self._payload_store) + # Extract parent trace context from executionStarted event parent_trace_ctx = None orchestration_name = "" @@ -744,6 +756,11 @@ def _execute_orchestrator( ) try: + # Externalize any large payloads in the response + if self._payload_store is not None: + payload_helpers.externalize_payloads( + res, self._payload_store, instance_id=instance_id, + ) stub.CompleteOrchestratorTask(res) except Exception as ex: self._logger.exception( @@ -770,6 +787,10 @@ def _execute_activity( completionToken, ): instance_id = req.orchestrationInstance.instanceId + + # De-externalize any large-payload tokens in the incoming request + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(req, self._payload_store) try: executor = _ActivityExecutor(self._registry, self._logger) with tracing.start_span( @@ -805,6 +826,11 @@ def _execute_activity( ) try: + # Externalize any large payloads in the response + if self._payload_store is not None: + payload_helpers.externalize_payloads( + res, self._payload_store, instance_id=instance_id, + ) stub.CompleteActivityTask(res) except Exception as ex: self._logger.exception( @@ -834,6 +860,10 @@ def _execute_entity_batch( if isinstance(req, pb.EntityRequest): req, operation_infos = helpers.convert_to_entity_batch_request(req) + # De-externalize any large-payload tokens in the incoming request + if self._payload_store is not None: + payload_helpers.deexternalize_payloads(req, self._payload_store) + entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None) instance_id = req.instanceId @@ -899,6 +929,11 @@ def _execute_entity_batch( ) try: + # Externalize any large payloads in the response + if self._payload_store is not None: + payload_helpers.externalize_payloads( + batch_result, self._payload_store, instance_id=instance_id, + ) stub.CompleteEntityTask(batch_result) except Exception as ex: self._logger.exception( diff --git a/examples/README.md b/examples/README.md index 59fa7fd2..3a4ce41f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -164,6 +164,11 @@ You can now execute any of the examples in this directory using Python: python activity_sequence.py ``` +> [!NOTE] +> The `large_payload/` example requires Azurite or an Azure Storage +> account and an additional install step. See the +> [large payload README](./large_payload/README.md) for details. + ### Review Orchestration History and Status To access the Durable Task Scheduler Dashboard, follow these steps: diff --git a/examples/large_payload/README.md b/examples/large_payload/README.md new file mode 100644 index 00000000..63309efe --- /dev/null +++ b/examples/large_payload/README.md @@ -0,0 +1,118 @@ +# Large Payload Externalization Example + +This example demonstrates how to use the large payload externalization +feature to automatically offload oversized orchestration payloads to +Azure Blob Storage. + +## Overview + +When orchestration inputs, activity outputs, or event data exceed a +configurable size threshold, the SDK can automatically: + +1. Compress the payload with GZip +1. Upload it to Azure Blob Storage +1. Replace the payload in the gRPC message with a compact reference + token + +On the receiving side, the SDK detects these tokens and transparently +downloads and decompresses the original data. No changes are needed in +your orchestrator or activity code. + +## Prerequisites + +- Python 3.10+ +- [Docker](https://www.docker.com/) (for the DTS emulator) +- [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) + or an Azure Storage account + +## Getting Started + +1. Start the DTS emulator: + + ```bash + docker run --name dtsemulator -d -p 8080:8080 mcr.microsoft.com/dts/dts-emulator:latest + ``` + +1. Start Azurite (blob service only): + + ```bash + azurite-blob --location /tmp/azurite --blobPort 10000 + ``` + + Or use the Azurite Docker image: + + ```bash + docker run -d -p 10000:10000 \ + mcr.microsoft.com/azure-storage/azurite \ + azurite-blob --blobHost 0.0.0.0 + ``` + +1. Create and activate a virtual environment: + + Bash: + + ```bash + python -m venv .venv + source .venv/bin/activate + ``` + + PowerShell: + + ```powershell + python -m venv .venv + .\.venv\Scripts\Activate.ps1 + ``` + +1. Install dependencies (from the repository root): + + ```bash + pip install -e ".[azure-blob-payloads]" -e ./durabletask-azuremanaged + ``` + +## Running the Example + +```bash +python app.py +``` + +The example schedules two orchestrations: + +- **Small payload** — The input and output stay inline in the gRPC + messages (below the 1 KB threshold configured in the example). +- **Large payload** — The activity output (~70 KB) exceeds the + threshold and is automatically externalized to blob storage and + retrieved transparently. + +## Using Azure Storage Instead of Azurite + +Set the `STORAGE_CONNECTION_STRING` environment variable to your Azure +Storage connection string: + +Bash: + +```bash +export STORAGE_CONNECTION_STRING="DefaultEndpointsProtocol=https;..." +``` + +PowerShell: + +```powershell +$env:STORAGE_CONNECTION_STRING = "DefaultEndpointsProtocol=https;..." +``` + +## Configuration Options + +The `BlobPayloadStoreOptions` class supports the following settings: + +| Option | Default | Description | +|---|---|---| +| `threshold_bytes` | 900,000 (900 KB) | Payloads larger than this are externalized | +| `max_stored_payload_bytes` | 10,485,760 (10 MB) | Maximum externalized payload size | +| `enable_compression` | `True` | GZip-compress before uploading | +| `container_name` | `"durabletask-payloads"` | Blob container name | +| `connection_string` | `None` | Storage connection string | +| `account_url` | `None` | Storage account URL (with `credential`) | +| `credential` | `None` | `TokenCredential` for token-based auth | + +For more details, see the +[feature documentation](../../docs/features.md#large-payload-externalization). diff --git a/examples/large_payload/app.py b/examples/large_payload/app.py new file mode 100644 index 00000000..afdda044 --- /dev/null +++ b/examples/large_payload/app.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end sample demonstrating large payload externalization. + +This example shows how to configure a BlobPayloadStore so that large +orchestration inputs, activity outputs, and event data are automatically +offloaded to Azure Blob Storage and replaced with compact reference +tokens in gRPC messages. + +Prerequisites: + pip install durabletask[azure-blob-payloads] durabletask-azuremanaged azure-identity + +Usage (emulator + Azurite — no Azure resources needed): + # Start the DTS emulator (port 8080) and Azurite (port 10000) + python app.py + +Usage (Azure): + export ENDPOINT=https://.durabletask.io + export TASKHUB= + export STORAGE_CONNECTION_STRING="DefaultEndpointsProtocol=https;..." + python app.py +""" + +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions + + +# --------------- Activities --------------- + +def generate_report(ctx: task.ActivityContext, num_records: int) -> str: + """Activity that returns a large payload (simulating a report).""" + return "RECORD|" * num_records + + +def summarize(ctx: task.ActivityContext, report: str) -> str: + """Activity that summarizes a report.""" + record_count = report.count("RECORD|") + return f"Report contains {record_count} records ({len(report)} bytes)" + + +# --------------- Orchestrator --------------- + +def large_payload_orchestrator(ctx: task.OrchestrationContext, num_records: int): + """Orchestrator that generates a large report and then summarizes it. + + Both the report (activity output) and the orchestration input are + transparently externalized to blob storage when they exceed the + configured threshold. + """ + report = yield ctx.call_activity(generate_report, input=num_records) + summary = yield ctx.call_activity(summarize, input=report) + return summary + + +# --------------- Main --------------- + +def main(): + # DTS endpoint and taskhub (defaults to the emulator) + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + # Azure Storage connection string (defaults to Azurite) + storage_conn_str = os.getenv( + "STORAGE_CONNECTION_STRING", + "UseDevelopmentStorage=true", + ) + + print(f"Using taskhub: {taskhub_name}") + print(f"Using endpoint: {endpoint}") + + # Configure the blob payload store + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string=storage_conn_str, + # Use a low threshold so that we can see externalization in action + threshold_bytes=1_024, + )) + + secure_channel = endpoint.startswith("https://") + credential = DefaultAzureCredential() if secure_channel else None + + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + payload_store=store, + ) as w: + w.add_orchestrator(large_payload_orchestrator) + w.add_activity(generate_report) + w.add_activity(summarize) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + payload_store=store, + ) + + # Schedule an orchestration with a small input (stays inline) + print("\n--- Small payload (stays inline) ---") + instance_id = c.schedule_new_orchestration( + large_payload_orchestrator, input=10) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f"Result: {state.serialized_output}") + + # Schedule an orchestration that produces a large activity output + # (the report will be externalized to blob storage automatically) + print("\n--- Large payload (externalized to blob storage) ---") + instance_id = c.schedule_new_orchestration( + large_payload_orchestrator, input=10_000) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f"Result: {state.serialized_output}") + elif state: + print(f"Orchestration failed: {state.failure_details}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 04cfcc6a..b5a923da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ opentelemetry = [ "opentelemetry-api>=1.0.0", "opentelemetry-sdk>=1.0.0" ] +azure-blob-payloads = [ + "azure-storage-blob[aio]>=12.0.0" +] [project.urls] repository = "https://github.com/microsoft/durabletask-python" @@ -48,3 +51,6 @@ include = ["durabletask", "durabletask.*"] minversion = "6.0" testpaths = ["tests"] asyncio_mode = "auto" +markers = [ + "azurite: tests that require Azurite (local Azure Storage emulator)", +] diff --git a/tests/durabletask-azuremanaged/test_dts_large_payload_e2e.py b/tests/durabletask-azuremanaged/test_dts_large_payload_e2e.py new file mode 100644 index 00000000..9eb7e891 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_large_payload_e2e.py @@ -0,0 +1,322 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for large-payload externalization using Azure Durable Task +Scheduler (DTS) and Azurite. + +These tests use the DTS emulator for orchestration and Azurite for blob +storage. They verify that payloads too large for inline gRPC messages are +transparently offloaded to blob storage and recovered on the other side +when using the ``DurableTaskSchedulerWorker`` and +``DurableTaskSchedulerClient`` classes. + +Prerequisites: + - DTS emulator must be running locally. + Start it with: + ``docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest`` + - Azurite must be running locally on the default blob port (10000). + Start it with: ``azurite --silent --blobPort 10000`` + - ``azure-storage-blob`` must be installed. +""" + +import json +import os +import uuid + +import pytest + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# Skip the entire module if azure-storage-blob is not installed. +azure_blob = pytest.importorskip("azure.storage.blob") + +from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions # noqa: E402 + +# DTS emulator settings +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +# Azurite well-known connection string +AZURITE_CONN_STR = "UseDevelopmentStorage=true" + +# Use a unique container per test run to avoid collisions. +TEST_CONTAINER = f"dts-payloads-{uuid.uuid4().hex[:8]}" + +# A low threshold so we can trigger externalization without massive strings. +# In production the default is 900 KB; here we use 1 KB for fast tests. +THRESHOLD_BYTES = 1_024 + +# Pin API version to one that Azurite supports. +AZURITE_API_VERSION = "2024-08-04" + + +# ------------------------------------------------------------------ +# Skip checks +# ------------------------------------------------------------------ + + +def _azurite_is_running() -> bool: + """Return True if Azurite blob service is reachable.""" + try: + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + next(iter(svc.list_containers(results_per_page=1)), None) + return True + except Exception: + return False + + +pytestmark = [ + pytest.mark.dts, + pytest.mark.azurite, + pytest.mark.skipif( + not _azurite_is_running(), + reason="Azurite blob service is not running on 127.0.0.1:10000", + ), +] + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture(scope="module") +def payload_store(): + """Create a BlobPayloadStore pointing at Azurite with a low threshold.""" + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string=AZURITE_CONN_STR, + container_name=TEST_CONTAINER, + threshold_bytes=THRESHOLD_BYTES, + enable_compression=True, + api_version=AZURITE_API_VERSION, + )) + yield store + + # Clean up: delete the test container. + try: + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + svc.delete_container(TEST_CONTAINER) + except Exception: + pass + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_large_string(size_kb: int = 2) -> str: + """Return a JSON-serializable string larger than THRESHOLD_BYTES.""" + return "X" * (size_kb * 1024) + + +# ------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------ + + +class TestDTSLargeInputOutput: + """Orchestrations whose input/output exceeds the threshold.""" + + def test_large_input_round_trips(self, payload_store): + """A large orchestration input is externalized then recovered.""" + large_input = _make_large_string(5) # 5 KB > 1 KB threshold + + def echo(ctx: task.OrchestrationContext, inp: str): + return inp + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(echo) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(echo, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + assert json.loads(state.serialized_output) == large_input + + def test_large_activity_result(self, payload_store): + """A large activity return value is externalized then recovered.""" + def produce_large(_: task.ActivityContext, size_kb: int) -> str: + return "Z" * (size_kb * 1024) + + def orchestrator(ctx: task.OrchestrationContext, size_kb: int): + result = yield ctx.call_activity(produce_large, input=size_kb) + return result + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(orchestrator) + w.add_activity(produce_large) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(orchestrator, input=10) # 10 KB + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert len(output) == 10 * 1024 + + def test_large_input_and_output(self, payload_store): + """Both input and output are large — both directions externalize.""" + large_input = {"data": "Y" * (5 * 1024)} + + def transform(ctx: task.OrchestrationContext, inp: dict): + return {"echo": inp["data"], "extra": "A" * 3000} + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(transform) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(transform, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output["echo"] == large_input["data"] + + +class TestDTSLargeEvents: + """External events carrying large payloads.""" + + def test_large_event_data(self, payload_store): + """A large external event payload is externalized and resolved.""" + large_event = _make_large_string(5) + + def wait_for_event(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("big_event") + return result + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(wait_for_event) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(wait_for_event) + c.wait_for_orchestration_start(inst_id, timeout=10) + + c.raise_orchestration_event(inst_id, "big_event", data=large_event) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + assert json.loads(state.serialized_output) == large_event + + +class TestDTSLargeFanOut: + """Fan-out/fan-in with multiple large activity results.""" + + def test_fan_out_fan_in_large_results(self, payload_store): + """Multiple activities each return large payloads.""" + def make_large(_: task.ActivityContext, idx: int) -> str: + return f"result-{idx}-" + ("D" * 2000) + + def fan_out(ctx: task.OrchestrationContext, count: int): + tasks = [ctx.call_activity(make_large, input=i) for i in range(count)] + results = yield task.when_all(tasks) + return results + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(fan_out) + w.add_activity(make_large) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(fan_out, input=5) + state = c.wait_for_orchestration_completion(inst_id, timeout=60) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + results = json.loads(state.serialized_output) + assert len(results) == 5 + for i, r in enumerate(results): + assert r.startswith(f"result-{i}-") + + +class TestDTSLargeTerminate: + """Terminate with a large output payload.""" + + def test_terminate_with_large_output(self, payload_store): + """Terminating with a large output externalizes it.""" + large_output = _make_large_string(3) + + def long_running(ctx: task.OrchestrationContext, _): + yield ctx.wait_for_external_event("never_arrives") + + with DurableTaskSchedulerWorker( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) as w: + w.add_orchestrator(long_running) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None, + payload_store=payload_store, + ) + inst_id = c.schedule_new_orchestration(long_running) + c.wait_for_orchestration_start(inst_id, timeout=10) + + c.terminate_orchestration(inst_id, output=large_output) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED diff --git a/tests/durabletask/test_large_payload.py b/tests/durabletask/test_large_payload.py new file mode 100644 index 00000000..c32e4d44 --- /dev/null +++ b/tests/durabletask/test_large_payload.py @@ -0,0 +1,553 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for large-payload externalization and de-externalization.""" + +from typing import Optional +from unittest.mock import MagicMock + +import pytest +from google.protobuf import wrappers_pb2 + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.payload.helpers import ( + deexternalize_payloads, + deexternalize_payloads_async, + externalize_payloads, + externalize_payloads_async, +) + +from durabletask.payload.store import ( + LargePayloadStorageOptions, + PayloadStore, +) + + +# ------------------------------------------------------------------ +# Fake in-memory PayloadStore for tests +# ------------------------------------------------------------------ + + +class FakePayloadStore(PayloadStore): + """In-memory payload store for testing.""" + + TOKEN_PREFIX = "blob:v1:test-container:" + + def __init__( + self, + threshold_bytes: int = 100, + max_stored_payload_bytes: int = 10 * 1024 * 1024, + enable_compression: bool = False, + ): + self._options = LargePayloadStorageOptions( + threshold_bytes=threshold_bytes, + max_stored_payload_bytes=max_stored_payload_bytes, + enable_compression=enable_compression, + ) + self._blobs: dict[str, bytes] = {} + self._counter = 0 + + @property + def options(self) -> LargePayloadStorageOptions: + return self._options + + def upload(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + self._counter += 1 + blob_name = f"blob-{self._counter}" + token = f"{self.TOKEN_PREFIX}{blob_name}" + self._blobs[token] = data + return token + + async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None) -> str: + return self.upload(data, instance_id=instance_id) + + def download(self, token: str) -> bytes: + return self._blobs[token] + + async def download_async(self, token: str) -> bytes: + return self.download(token) + + def is_known_token(self, value: str) -> bool: + return value.startswith(self.TOKEN_PREFIX) + + +# ------------------------------------------------------------------ +# Helper to create StringValue +# ------------------------------------------------------------------ + + +def sv(text: str) -> wrappers_pb2.StringValue: + return wrappers_pb2.StringValue(value=text) + + +# ------------------------------------------------------------------ +# Tests: externalize_payloads +# ------------------------------------------------------------------ + + +class TestExternalizePayloads: + def test_small_payload_not_externalized(self): + """Payloads smaller than the threshold should be left intact.""" + store = FakePayloadStore(threshold_bytes=1000) + req = pb.CreateInstanceRequest( + instanceId="test-1", + name="MyOrch", + input=sv("short"), + ) + externalize_payloads(req, store, instance_id="test-1") + assert req.input.value == "short" + assert len(store._blobs) == 0 + + def test_large_payload_externalized(self): + """Payloads larger than the threshold should be uploaded and replaced with a token.""" + store = FakePayloadStore(threshold_bytes=10) + large_data = "x" * 200 + req = pb.CreateInstanceRequest( + instanceId="test-1", + name="MyOrch", + input=sv(large_data), + ) + externalize_payloads(req, store, instance_id="test-1") + + # The input should now be a token + assert req.input.value.startswith(FakePayloadStore.TOKEN_PREFIX) + # The store should have stored the original data + stored = store._blobs[req.input.value] + assert stored == large_data.encode("utf-8") + + def test_already_token_not_re_uploaded(self): + """If a field already contains a token, it should not be re-uploaded.""" + store = FakePayloadStore(threshold_bytes=10) + token = f"{FakePayloadStore.TOKEN_PREFIX}existing-blob" + store._blobs[token] = b"some data" + + req = pb.CreateInstanceRequest( + instanceId="test-1", + name="MyOrch", + input=sv(token), + ) + externalize_payloads(req, store, instance_id="test-1") + assert req.input.value == token + assert len(store._blobs) == 1 # No new upload + + def test_exceeds_max_raises_error(self): + """Payloads exceeding max_stored_payload_bytes should raise ValueError.""" + store = FakePayloadStore(threshold_bytes=10, max_stored_payload_bytes=50) + large_data = "x" * 100 # 100 bytes > max 50 + + req = pb.CreateInstanceRequest( + instanceId="test-1", + name="MyOrch", + input=sv(large_data), + ) + with pytest.raises(ValueError, match="exceeds the maximum"): + externalize_payloads(req, store, instance_id="test-1") + + def test_empty_value_not_externalized(self): + """Empty StringValue should not be externalized.""" + store = FakePayloadStore(threshold_bytes=10) + req = pb.CreateInstanceRequest( + instanceId="test-1", + name="MyOrch", + input=sv(""), + ) + externalize_payloads(req, store, instance_id="test-1") + assert req.input.value == "" + assert len(store._blobs) == 0 + + def test_nested_history_events_externalized(self): + """StringValue fields inside nested history events should be externalized.""" + store = FakePayloadStore(threshold_bytes=10) + large_input = "y" * 200 + + event = pb.HistoryEvent( + eventId=1, + executionStarted=pb.ExecutionStartedEvent( + name="MyOrch", + input=sv(large_input), + ), + ) + req = pb.OrchestratorRequest( + instanceId="test-1", + newEvents=[event], + ) + externalize_payloads(req, store, instance_id="test-1") + + # Verify the nested field was externalized + actual = req.newEvents[0].executionStarted.input.value + assert actual.startswith(FakePayloadStore.TOKEN_PREFIX) + + def test_orchestrator_response_actions_externalized(self): + """Action fields in OrchestratorResponse should be externalized.""" + store = FakePayloadStore(threshold_bytes=10) + large_result = "z" * 200 + + action = pb.OrchestratorAction( + id=1, + completeOrchestration=pb.CompleteOrchestrationAction( + orchestrationStatus=pb.ORCHESTRATION_STATUS_COMPLETED, + result=sv(large_result), + ), + ) + res = pb.OrchestratorResponse( + instanceId="test-1", + actions=[action], + ) + externalize_payloads(res, store, instance_id="test-1") + + actual = res.actions[0].completeOrchestration.result.value + assert actual.startswith(FakePayloadStore.TOKEN_PREFIX) + + def test_activity_response_externalized(self): + """ActivityResponse.result should be externalized if large.""" + store = FakePayloadStore(threshold_bytes=10) + large_result = "a" * 200 + + res = pb.ActivityResponse( + instanceId="test-1", + taskId=42, + result=sv(large_result), + ) + externalize_payloads(res, store, instance_id="test-1") + + assert res.result.value.startswith(FakePayloadStore.TOKEN_PREFIX) + + +# ------------------------------------------------------------------ +# Tests: deexternalize_payloads +# ------------------------------------------------------------------ + + +class TestDeexternalizePayloads: + def test_token_replaced_with_original(self): + """Token in a StringValue should be replaced with the original payload.""" + store = FakePayloadStore(threshold_bytes=10) + original = "original data here" + token = store.upload(original.encode("utf-8")) + + req = pb.ActivityRequest( + name="MyActivity", + input=sv(token), + ) + deexternalize_payloads(req, store) + assert req.input.value == original + + def test_non_token_not_modified(self): + """Regular string values should not be modified.""" + store = FakePayloadStore(threshold_bytes=10) + req = pb.ActivityRequest( + name="MyActivity", + input=sv("just a normal string"), + ) + deexternalize_payloads(req, store) + assert req.input.value == "just a normal string" + + def test_nested_history_events_deexternalized(self): + """Tokens in nested history events should be replaced.""" + store = FakePayloadStore(threshold_bytes=10) + original = "nested payload content" + token = store.upload(original.encode("utf-8")) + + event = pb.HistoryEvent( + eventId=1, + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=5, + result=sv(token), + ), + ) + req = pb.OrchestratorRequest( + instanceId="test-1", + pastEvents=[event], + ) + deexternalize_payloads(req, store) + + actual = req.pastEvents[0].taskCompleted.result.value + assert actual == original + + def test_get_instance_response_deexternalized(self): + """GetInstanceResponse state fields should be de-externalized.""" + store = FakePayloadStore(threshold_bytes=10) + original_input = "large input data" + original_output = "large output data" + token_input = store.upload(original_input.encode("utf-8")) + token_output = store.upload(original_output.encode("utf-8")) + + res = pb.GetInstanceResponse( + exists=True, + orchestrationState=pb.OrchestrationState( + instanceId="test-1", + name="MyOrch", + input=sv(token_input), + output=sv(token_output), + ), + ) + deexternalize_payloads(res, store) + + assert res.orchestrationState.input.value == original_input + assert res.orchestrationState.output.value == original_output + + +# ------------------------------------------------------------------ +# Tests: round-trip +# ------------------------------------------------------------------ + + +class TestRoundTrip: + def test_externalize_then_deexternalize(self): + """A payload that is externalized should round-trip correctly.""" + store = FakePayloadStore(threshold_bytes=10) + original = "round trip payload " * 20 + + req = pb.CreateInstanceRequest( + instanceId="rt-1", + name="MyOrch", + input=sv(original), + ) + externalize_payloads(req, store, instance_id="rt-1") + assert req.input.value != original # Should be a token + + deexternalize_payloads(req, store) + assert req.input.value == original + + +# ------------------------------------------------------------------ +# Tests: async +# ------------------------------------------------------------------ + + +class TestAsyncPayloadHelpers: + @pytest.mark.asyncio + async def test_async_externalize_and_deexternalize(self): + """Async versions should work identically to sync.""" + store = FakePayloadStore(threshold_bytes=10) + original = "async round trip " * 20 + + req = pb.CreateInstanceRequest( + instanceId="async-1", + name="MyOrch", + input=sv(original), + ) + + await externalize_payloads_async(req, store, instance_id="async-1") + assert req.input.value.startswith(FakePayloadStore.TOKEN_PREFIX) + + await deexternalize_payloads_async(req, store) + assert req.input.value == original + + +# ------------------------------------------------------------------ +# Tests: worker capability flag +# ------------------------------------------------------------------ + + +class TestWorkerCapabilityFlag: + def test_capability_set_when_payload_store_provided(self): + """WORKER_CAPABILITY_LARGE_PAYLOADS should be present in GetWorkItemsRequest.""" + capabilities = [pb.WORKER_CAPABILITY_LARGE_PAYLOADS] + req = pb.GetWorkItemsRequest( + maxConcurrentOrchestrationWorkItems=10, + maxConcurrentActivityWorkItems=10, + capabilities=capabilities, + ) + assert pb.WORKER_CAPABILITY_LARGE_PAYLOADS in req.capabilities + + def test_capability_not_set_when_no_payload_store(self): + """Without a payload store, capabilities should be empty.""" + req = pb.GetWorkItemsRequest( + maxConcurrentOrchestrationWorkItems=10, + maxConcurrentActivityWorkItems=10, + ) + assert pb.WORKER_CAPABILITY_LARGE_PAYLOADS not in req.capabilities + + +# ------------------------------------------------------------------ +# Tests: BlobPayloadStore token parsing +# ------------------------------------------------------------------ + + +class TestBlobPayloadStoreTokenParsing: + def test_valid_token_is_known(self): + """Valid blob:v1:... tokens should be recognized.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore + + assert BlobPayloadStore._parse_token( + "blob:v1:my-container:some/blob/name" + ) == ("my-container", "some/blob/name") + + def test_invalid_token_raises(self): + """Invalid tokens should raise ValueError.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore + + with pytest.raises(ValueError, match="Invalid blob payload token"): + BlobPayloadStore._parse_token("not-a-token") + + def test_token_missing_blob_name_raises(self): + """Token with missing blob name should raise ValueError.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore + + with pytest.raises(ValueError, match="Invalid blob payload token"): + BlobPayloadStore._parse_token("blob:v1:container:") + + def test_is_known_token(self): + """is_known_token correctly identifies blob tokens.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore + + store = MagicMock(spec=BlobPayloadStore) + store.is_known_token = BlobPayloadStore.is_known_token.__get__(store) + store._parse_token = BlobPayloadStore._parse_token + + assert store.is_known_token("blob:v1:c:b") is True + assert store.is_known_token("not-a-token") is False + assert store.is_known_token("") is False + assert store.is_known_token("blob:v1:") is False + assert store.is_known_token("blob:v1:container:") is False + + +# ------------------------------------------------------------------ +# Tests: BlobPayloadStore construction and defaults +# ------------------------------------------------------------------ + + +class TestBlobPayloadStoreDefaults: + def test_default_options(self): + """Constructing with connection_string should use .NET SDK defaults.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions + + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="UseDevelopmentStorage=true", + )) + opts = store.options + assert opts.threshold_bytes == 900_000 + assert opts.max_stored_payload_bytes == 10 * 1024 * 1024 + assert opts.enable_compression is True + assert opts.container_name == "durabletask-payloads" + assert opts.connection_string == "UseDevelopmentStorage=true" + + def test_custom_options(self): + """Custom constructor params should be reflected in options.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions + + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string="UseDevelopmentStorage=true", + threshold_bytes=500_000, + container_name="my-container", + )) + assert store.options.threshold_bytes == 500_000 + assert store.options.container_name == "my-container" + + +# ------------------------------------------------------------------ +# Tests: client method coverage +# ------------------------------------------------------------------ + + +class TestTerminateRequestExternalized: + def test_terminate_output_externalized(self): + """TerminateRequest.output should be externalized when large.""" + store = FakePayloadStore(threshold_bytes=10) + large_output = "t" * 200 + req = pb.TerminateRequest( + instanceId="term-1", + output=sv(large_output), + recursive=True, + ) + externalize_payloads(req, store, instance_id="term-1") + assert req.output.value.startswith(FakePayloadStore.TOKEN_PREFIX) + assert store._blobs[req.output.value] == large_output.encode("utf-8") + + +class TestSignalEntityRequestExternalized: + def test_signal_entity_input_externalized(self): + """SignalEntityRequest.input should be externalized when large.""" + store = FakePayloadStore(threshold_bytes=10) + large_input = "e" * 200 + req = pb.SignalEntityRequest( + instanceId="entity-1", + name="MyOp", + input=sv(large_input), + ) + externalize_payloads(req, store, instance_id="entity-1") + assert req.input.value.startswith(FakePayloadStore.TOKEN_PREFIX) + assert store._blobs[req.input.value] == large_input.encode("utf-8") + + +class TestQueryInstancesResponseDeexternalized: + def test_query_instances_response_deexternalized(self): + """OrchestrationState fields inside QueryInstancesResponse should be de-externalized.""" + store = FakePayloadStore(threshold_bytes=10) + original_input = "query result input payload" + original_output = "query result output payload" + token_input = store.upload(original_input.encode("utf-8")) + token_output = store.upload(original_output.encode("utf-8")) + + resp = pb.QueryInstancesResponse( + orchestrationState=[ + pb.OrchestrationState( + instanceId="q-1", + name="Orch", + input=sv(token_input), + output=sv(token_output), + ), + ], + ) + deexternalize_payloads(resp, store) + assert resp.orchestrationState[0].input.value == original_input + assert resp.orchestrationState[0].output.value == original_output + + +class TestGetEntityResponseDeexternalized: + def test_get_entity_response_deexternalized(self): + """GetEntityResponse.entity.serializedState should be de-externalized.""" + store = FakePayloadStore(threshold_bytes=10) + original_state = "large entity state data" + token = store.upload(original_state.encode("utf-8")) + + resp = pb.GetEntityResponse( + exists=True, + entity=pb.EntityMetadata( + instanceId="ent-1", + serializedState=sv(token), + ), + ) + deexternalize_payloads(resp, store) + assert resp.entity.serializedState.value == original_state + + +class TestQueryEntitiesResponseDeexternalized: + def test_query_entities_response_deexternalized(self): + """EntityMetadata inside QueryEntitiesResponse should be de-externalized.""" + store = FakePayloadStore(threshold_bytes=10) + original_state = "queried entity state" + token = store.upload(original_state.encode("utf-8")) + + resp = pb.QueryEntitiesResponse( + entities=[ + pb.EntityMetadata( + instanceId="ent-q-1", + serializedState=sv(token), + ), + ], + ) + deexternalize_payloads(resp, store) + assert resp.entities[0].serializedState.value == original_state + + +# ------------------------------------------------------------------ +# Tests: BlobPayloadStore construction validation +# ------------------------------------------------------------------ + + +class TestBlobPayloadStoreConstruction: + def test_no_credentials_raises(self): + """Constructing without connection_string or account_url should raise.""" + pytest.importorskip("azure.storage.blob") + from durabletask.extensions.azure_blob_payloads.blob_payload_store import BlobPayloadStore + + with pytest.raises(TypeError): + BlobPayloadStore() # type: ignore[call-arg] diff --git a/tests/durabletask/test_large_payload_e2e.py b/tests/durabletask/test_large_payload_e2e.py new file mode 100644 index 00000000..3ff7af85 --- /dev/null +++ b/tests/durabletask/test_large_payload_e2e.py @@ -0,0 +1,415 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for large-payload externalization using Azurite. + +These tests spin up a real in-memory Durable Task backend *and* a real +``BlobPayloadStore`` backed by Azurite (local Azure Storage emulator). +They verify that payloads too large for inline gRPC messages are +transparently offloaded to blob storage and recovered on the other side. + +Prerequisites: + - Azurite must be running locally on the default blob port (10000). + Start it with: ``azurite --silent --blobPort 10000`` + - ``azure-storage-blob`` must be installed. +""" + +import json +import uuid + +import pytest + +from durabletask import client, task, worker +from durabletask.testing import create_test_backend + +# Skip the entire module if azure-storage-blob is not installed. +azure_blob = pytest.importorskip("azure.storage.blob") + +from durabletask.extensions.azure_blob_payloads import BlobPayloadStore, BlobPayloadStoreOptions # noqa: E402 + +# Azurite well-known connection string +AZURITE_CONN_STR = "UseDevelopmentStorage=true" + +HOST = "localhost:50070" +BACKEND_PORT = 50070 + +# Use a unique container per test run to avoid collisions. +TEST_CONTAINER = f"e2e-payloads-{uuid.uuid4().hex[:8]}" + +# A low threshold so we can trigger externalization without massive strings. +# In production the default is 900 KB; here we use 1 KB for fast tests. +THRESHOLD_BYTES = 1_024 + +# Pin API version to one that Azurite supports. +AZURITE_API_VERSION = "2024-08-04" + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +def _azurite_is_running() -> bool: + """Return True if Azurite blob service is reachable.""" + try: + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + # list_containers is a lightweight call that works with Azurite's + # well-known credentials without any special permissions. + next(iter(svc.list_containers(results_per_page=1)), None) + return True + except Exception: + return False + + +# Skip all tests if Azurite is not reachable. +pytestmark = [ + pytest.mark.azurite, + pytest.mark.skipif( + not _azurite_is_running(), + reason="Azurite blob service is not running on 127.0.0.1:10000", + ), +] + + +@pytest.fixture(scope="module") +def payload_store(): + """Create a BlobPayloadStore pointing at Azurite with a low threshold.""" + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string=AZURITE_CONN_STR, + container_name=TEST_CONTAINER, + threshold_bytes=THRESHOLD_BYTES, + enable_compression=True, + api_version=AZURITE_API_VERSION, + )) + yield store + + # Clean up: delete the test container. + try: + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + svc.delete_container(TEST_CONTAINER) + except Exception: + pass + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for each test.""" + b = create_test_backend(port=BACKEND_PORT) + yield b + b.stop() + b.reset() + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_large_string(size_kb: int = 2) -> str: + """Return a JSON-serializable string larger than THRESHOLD_BYTES.""" + return "X" * (size_kb * 1024) + + +def _make_large_payload(size_kb: int = 100) -> dict: + """Return a dict whose JSON serialization is about *size_kb* KB.""" + return {"data": "Y" * (size_kb * 1024)} + + +# ------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------ + + +class TestLargeInputOutput: + """Orchestrations whose input/output exceeds the threshold.""" + + def test_large_input_round_trips(self, payload_store): + """A large orchestration input is externalized then recovered.""" + large_input = _make_large_string(5) # 5 KB > 1 KB threshold + + def echo(ctx: task.OrchestrationContext, inp: str): + return inp + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(echo) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(echo, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + assert json.loads(state.serialized_output) == large_input + + def test_large_activity_result(self, payload_store): + """A large activity return value is externalized then recovered.""" + def produce_large(_: task.ActivityContext, size_kb: int) -> str: + return "Z" * (size_kb * 1024) + + def orchestrator(ctx: task.OrchestrationContext, size_kb: int): + result = yield ctx.call_activity(produce_large, input=size_kb) + return result + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(orchestrator) + w.add_activity(produce_large) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(orchestrator, input=10) # 10 KB + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert len(output) == 10 * 1024 + + def test_large_input_and_output(self, payload_store): + """Both input and output are large — both directions externalize.""" + large_input = _make_large_payload(5) # ~5 KB dict + + def transform(ctx: task.OrchestrationContext, inp: dict): + return {"echo": inp["data"], "extra": "A" * 3000} + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(transform) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(transform, input=large_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output["echo"] == large_input["data"] + + +class TestLargeEvents: + """External events carrying large payloads.""" + + def test_large_event_data(self, payload_store): + """A large external event payload is externalized and resolved.""" + large_event = _make_large_string(5) + + def wait_for_event(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("big_event") + return result + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(wait_for_event) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(wait_for_event) + c.wait_for_orchestration_start(inst_id, timeout=10) + + c.raise_orchestration_event(inst_id, "big_event", data=large_event) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + assert json.loads(state.serialized_output) == large_event + + +class TestLargeTerminate: + """Terminate with a large output payload.""" + + def test_terminate_with_large_output(self, payload_store): + """Terminating with a large output externalizes it.""" + large_output = _make_large_string(3) + + def long_running(ctx: task.OrchestrationContext, _): + yield ctx.wait_for_external_event("never_arrives") + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(long_running) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(long_running) + c.wait_for_orchestration_start(inst_id, timeout=10) + + c.terminate_orchestration(inst_id, output=large_output) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + + +class TestMultipleActivitiesLargePayloads: + """Fan-out/fan-in with multiple large activity results.""" + + def test_fan_out_fan_in_large_results(self, payload_store): + """Multiple activities each return large payloads.""" + def make_large(_: task.ActivityContext, idx: int) -> str: + return f"result-{idx}-" + ("D" * 2000) + + def fan_out(ctx: task.OrchestrationContext, count: int): + tasks = [ctx.call_activity(make_large, input=i) for i in range(count)] + results = yield task.when_all(tasks) + return results + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(fan_out) + w.add_activity(make_large) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(fan_out, input=5) + state = c.wait_for_orchestration_completion(inst_id, timeout=60) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + results = json.loads(state.serialized_output) + assert len(results) == 5 + for i, r in enumerate(results): + assert r.startswith(f"result-{i}-") + + +class TestBlobStorageVerification: + """Verify blobs actually land in Azurite storage.""" + + def test_blobs_created_in_container(self, payload_store): + """After an orchestration with large payloads, blobs exist in the container.""" + large_input = _make_large_string(5) + + def echo(ctx: task.OrchestrationContext, inp: str): + return inp + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=payload_store) as w: + w.add_orchestrator(echo) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=payload_store) + inst_id = c.schedule_new_orchestration(echo, input=large_input) + c.wait_for_orchestration_completion(inst_id, timeout=30) + + # Verify blobs were actually created in the Azurite container + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + container_client = svc.get_container_client(TEST_CONTAINER) + blobs = list(container_client.list_blobs()) + + assert len(blobs) > 0, "Expected at least one blob in the container" + # All blobs should contain compressed data (GZip header: 1f 8b) + for blob in blobs: + data = container_client.download_blob(blob.name).readall() + assert data[:2] == b"\x1f\x8b", f"Blob {blob.name} is not GZip-compressed" + + +class TestSmallPayloadNotExternalized: + """Payloads below the threshold should NOT hit blob storage.""" + + def test_small_payload_stays_inline(self, payload_store): + """A small payload should not create any blobs.""" + small_input = "hello" + + # Use a fresh container to isolate blob count + fresh_container = f"small-test-{uuid.uuid4().hex[:8]}" + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string=AZURITE_CONN_STR, + container_name=fresh_container, + threshold_bytes=THRESHOLD_BYTES, + enable_compression=True, + api_version=AZURITE_API_VERSION, + )) + + def echo(ctx: task.OrchestrationContext, inp: str): + return inp + + with worker.TaskHubGrpcWorker(host_address=HOST, payload_store=store) as w: + w.add_orchestrator(echo) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST, payload_store=store) + inst_id = c.schedule_new_orchestration(echo, input=small_input) + state = c.wait_for_orchestration_completion(inst_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output is not None + assert json.loads(state.serialized_output) == small_input + + # Verify no blobs were created + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + container_client = svc.get_container_client(fresh_container) + try: + blobs = list(container_client.list_blobs()) + assert len(blobs) == 0, f"Expected 0 blobs but found {len(blobs)}" + except Exception: + pass # Container may not even exist — that's fine + finally: + try: + svc.delete_container(fresh_container) + except Exception: + pass + + +class TestAsyncBlobClient: + """Test upload_async / download_async against Azurite.""" + + @pytest.fixture() + def async_store(self): + """Per-test BlobPayloadStore with a unique container.""" + container = f"async-test-{uuid.uuid4().hex[:8]}" + store = BlobPayloadStore(BlobPayloadStoreOptions( + connection_string=AZURITE_CONN_STR, + container_name=container, + threshold_bytes=THRESHOLD_BYTES, + enable_compression=True, + api_version=AZURITE_API_VERSION, + )) + yield store + try: + svc = azure_blob.BlobServiceClient.from_connection_string( + AZURITE_CONN_STR, api_version=AZURITE_API_VERSION, + ) + svc.delete_container(container) + except Exception: + pass + + @pytest.mark.asyncio + async def test_async_upload_and_download_round_trip(self, async_store): + """upload_async stores data that download_async can retrieve.""" + payload = b"async round-trip payload " * 200 + token = await async_store.upload_async(payload, instance_id="async-1") + + assert async_store.is_known_token(token) + result = await async_store.download_async(token) + assert result == payload + + @pytest.mark.asyncio + async def test_async_upload_with_compression(self, async_store): + """Compressed upload should still decompress on download.""" + payload = b"Z" * 5000 + token = await async_store.upload_async(payload) + + downloaded = await async_store.download_async(token) + assert downloaded == payload + + @pytest.mark.asyncio + async def test_async_upload_instance_id_scopes_blob(self, async_store): + """Blobs uploaded with instance_id are scoped under that prefix.""" + payload = b"scoped payload" + token = await async_store.upload_async(payload, instance_id="inst-42") + + # Token format: blob:v1::/ + _, blob_name = BlobPayloadStore._parse_token(token) + assert blob_name.startswith("inst-42/") + + downloaded = await async_store.download_async(token) + assert downloaded == payload