diff --git a/durabletask/client.py b/durabletask/client.py index aa8ab55e..8ba9ed81 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -284,6 +284,26 @@ def resume_orchestration(self, instance_id: str) -> None: self._logger.info(f"Resuming instance '{instance_id}'.") self._stub.ResumeInstance(req) + def rewind_orchestration(self, instance_id: str, *, + reason: Optional[str] = None): + """Rewinds a failed orchestration instance to its last known good state. + + Rewind removes failed task and sub-orchestration results from the + orchestration history and replays the orchestration from the last + successful checkpoint. Activities that previously succeeded are + not re-executed; only failed work is retried. + + Args: + instance_id: The ID of the orchestration instance to rewind. + reason: An optional reason string describing why the orchestration is being rewound. + """ + req = pb.RewindInstanceRequest( + instanceId=instance_id, + reason=helpers.get_string_value(reason)) + + self._logger.info(f"Rewinding instance '{instance_id}'.") + self._stub.RewindInstance(req) + def restart_orchestration(self, instance_id: str, *, restart_with_new_instance_id: bool = False) -> str: """Restarts an existing orchestration instance. diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 03da314c..a91d02cd 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -183,6 +183,21 @@ def new_terminated_event(*, encoded_output: Optional[str] = None) -> pb.HistoryE ) +def new_execution_completed_event( + status: 'pb.OrchestrationStatus', + encoded_result: Optional[str] = None, + failure_details: Optional['pb.TaskFailureDetails'] = None) -> pb.HistoryEvent: + return pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + executionCompleted=pb.ExecutionCompletedEvent( + orchestrationStatus=status, + result=get_string_value(encoded_result), + failureDetails=failure_details, + ) + ) + + def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: if val is None: return None diff --git a/durabletask/testing/README.md b/durabletask/testing/README.md index 0a6e9855..06ed30ac 100644 --- a/durabletask/testing/README.md +++ b/durabletask/testing/README.md @@ -255,8 +255,7 @@ The in-memory backend is designed for testing and has some limitations compared 1. **No persistence**: All state is lost when the backend is stopped 2. **No distributed execution**: Runs in a single process 3. **No history streaming**: StreamInstanceHistory is not implemented -4. **No rewind**: RewindInstance is not implemented -5. **No recursive termination**: Recursive termination is not supported +4. **No recursive termination**: Recursive termination is not supported ### Best Practices diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index 590688ad..4e9e4518 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -241,14 +241,16 @@ def StartInstance(self, request: pb.CreateInstanceRequest, context): ) self._next_completion_token += 1 - # Add initial events to start the orchestration - orchestrator_started = helpers.new_orchestrator_started_event(start_time) + # Add initial events to start the orchestration. + # orchestratorStarted bookends each replay batch and is + # always the very first event, followed by executionStarted. execution_started = helpers.new_execution_started_event( request.name, instance_id, request.input.value if request.input else None, dict(request.tags) if request.tags else None, version=version, ) + orchestrator_started = helpers.new_orchestrator_started_event(start_time) instance.pending_events.append(orchestrator_started) instance.pending_events.append(execution_started) @@ -612,6 +614,22 @@ def CompleteOrchestratorTask(self, request: pb.OrchestratorResponse, context): instance.completion_token = self._next_completion_token self._next_completion_token += 1 + # Bookend the replay with orchestratorCompleted. + # Skip for continue-as-new (status is PENDING after reset). + if instance.status != pb.ORCHESTRATION_STATUS_PENDING: + instance.history.append(helpers.new_orchestrator_completed_event()) + + # executionCompleted is the very last event when the + # orchestration reaches a terminal state. + if self._is_terminal_status(instance.status): + instance.history.append( + helpers.new_execution_completed_event( + instance.status, + instance.output, + instance.failure_details, + ) + ) + # Remove from in-flight before notifying or re-enqueuing self._orchestration_in_flight.discard(request.instanceId) @@ -981,8 +999,33 @@ def DeleteTaskHub(self, request: pb.DeleteTaskHubRequest, context): return pb.DeleteTaskHubResponse() def RewindInstance(self, request: pb.RewindInstanceRequest, context): - """Rewinds an orchestration instance (not implemented).""" - context.abort(grpc.StatusCode.UNIMPLEMENTED, "RewindInstance not implemented") + """Rewinds a failed orchestration instance. + + The backend validates the instance is in a failed state, appends + an ``ExecutionRewoundEvent`` to the pending events, resets the + instance status to RUNNING, and re-enqueues the orchestration + so the worker can replay it and produce a + ``RewindOrchestrationAction`` with the corrected history. + """ + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + context.abort( + grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return pb.RewindInstanceResponse() + + if instance.status != pb.ORCHESTRATION_STATUS_FAILED: + context.abort( + grpc.StatusCode.FAILED_PRECONDITION, + f"Orchestration instance '{request.instanceId}' is not in a failed state") + return pb.RewindInstanceResponse() + + reason = request.reason.value if request.HasField("reason") else None + self._prepare_rewind(instance, reason) + + self._logger.info(f"Rewound instance '{request.instanceId}'") + return pb.RewindInstanceResponse() def AbandonTaskActivityWorkItem(self, request: pb.AbandonActivityTaskRequest, context): """Abandons an activity work item.""" @@ -1067,9 +1110,9 @@ def _create_instance_internal(self, instance_id: str, name: str, ) self._next_completion_token += 1 - orchestrator_started = helpers.new_orchestrator_started_event(now) execution_started = helpers.new_execution_started_event( name, instance_id, encoded_input, version=version) + orchestrator_started = helpers.new_orchestrator_started_event(now) instance.pending_events.append(orchestrator_started) instance.pending_events.append(execution_started) @@ -1196,6 +1239,8 @@ def _process_action(self, instance: OrchestrationInstance, action: pb.Orchestrat self._process_send_event_action(action.sendEvent) elif action.HasField("sendEntityMessage"): self._process_send_entity_message_action(instance, action) + elif action.HasField("rewindOrchestration"): + self._process_rewind_orchestration_action(instance, action.rewindOrchestration) def _process_complete_orchestration_action(self, instance: OrchestrationInstance, complete_action: pb.CompleteOrchestrationAction): @@ -1230,11 +1275,11 @@ def _process_complete_orchestration_action(self, instance: OrchestrationInstance # Build the new pending events in the correct order: # OrchestratorStarted, ExecutionStarted, carryover, new arrivals now = datetime.now(timezone.utc) - orchestrator_started = helpers.new_orchestrator_started_event(now) execution_started = helpers.new_execution_started_event( instance.name, instance.instance_id, new_input, version=new_version, ) + orchestrator_started = helpers.new_orchestrator_started_event(now) instance.pending_events.append(orchestrator_started) instance.pending_events.append(execution_started) instance.pending_events.extend(carryover_events) @@ -1558,6 +1603,128 @@ def _signal_entity_internal(self, entity_id: str, operation: str, ) self._queue_entity_operation(entity_id, event) + def _prepare_rewind(self, instance: OrchestrationInstance, + reason: Optional[str] = None): + """Prepares an orchestration instance for rewind. + + Appends an ``ExecutionRewoundEvent`` to the pending events, resets + the instance status to RUNNING, and re-enqueues it so the worker + can replay it. The actual history rewriting is done by the SDK + worker when it processes the rewind event. + + Args: + instance: The orchestration instance to rewind. + reason: Optional reason string for the rewind. + + Note: + Must be called while holding ``self._lock``. + """ + # Reset instance state so it can be re-processed. + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + instance.output = None + instance.failure_details = None + instance.last_updated_at = datetime.now(timezone.utc) + + # Clear any stale dispatched events. + instance.dispatched_events.clear() + + # Add the ExecutionRewound event as a new pending event. + rewind_event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + executionRewound=pb.ExecutionRewoundEvent( + reason=wrappers_pb2.StringValue(value=reason) if reason else None, + ), + ) + instance.pending_events.append(rewind_event) + + # Refresh the completion token and enqueue. + instance.completion_token = self._next_completion_token + self._next_completion_token += 1 + self._orchestration_in_flight.discard(instance.instance_id) + self._enqueue_orchestration(instance.instance_id) + + def _process_rewind_orchestration_action( + self, instance: OrchestrationInstance, + rewind_action: pb.RewindOrchestrationAction): + """Processes a RewindOrchestrationAction returned by the SDK. + + The action contains a ``newHistory`` field with the rewritten + history computed by the SDK (failed tasks and sub-orchestration + failures removed). The backend replaces the instance's history + with this new history, recursively rewinds any failed + sub-orchestrations, and re-enqueues the orchestration. + """ + new_history = list(rewind_action.newHistory) + + # Replace history with the rewritten version. + instance.history = new_history + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + instance.output = None + instance.failure_details = None + instance.last_updated_at = datetime.now(timezone.utc) + + # Identify sub-orchestrations that were created but did not + # complete successfully — they need to be recursively rewound. + completed_sub_orch_task_ids: set[int] = set() + created_sub_orch_events: dict[int, pb.HistoryEvent] = {} + for event in new_history: + if event.HasField("subOrchestrationInstanceCreated"): + created_sub_orch_events[event.eventId] = event + elif event.HasField("subOrchestrationInstanceCompleted"): + completed_sub_orch_task_ids.add( + event.subOrchestrationInstanceCompleted.taskScheduledId) + + # Extract the rewind reason from the last ExecutionRewound event. + reason: Optional[str] = None + for event in reversed(new_history): + if event.HasField("executionRewound"): + if event.executionRewound.HasField("reason"): + reason = event.executionRewound.reason.value + break + + # Recursively rewind failed sub-orchestrations. If the sub was + # purged (no longer in _instances), re-create it from the + # subOrchestrationInstanceCreated event so it runs fresh. + for task_id, event in created_sub_orch_events.items(): + if task_id not in completed_sub_orch_task_ids: + sub_info = event.subOrchestrationInstanceCreated + sub_instance_id = sub_info.instanceId + sub_instance = self._instances.get(sub_instance_id) + if sub_instance is None: + # Sub-orchestration was purged — re-create it. + sub_name = sub_info.name + sub_input = sub_info.input.value if sub_info.HasField("input") else None + sub_version = sub_info.version.value if sub_info.HasField("version") else None + self._create_instance_internal( + sub_instance_id, sub_name, sub_input, version=sub_version) + elif sub_instance.status == pb.ORCHESTRATION_STATUS_FAILED: + self._prepare_rewind(sub_instance, reason) + self._watch_sub_orchestration( + instance.instance_id, sub_instance_id, task_id) + + # Re-enqueue so the orchestration replays with the clean history. + # The executionRewound event is added to pending_events so the + # worker can see it in new_events; the worker uses the presence + # of executionRewound in old_events (history) to distinguish + # this normal post-rewind replay from the initial rewind + # short-circuit. Note: we do NOT add orchestratorStarted here + # because the work-item dispatch loop already inserts one when + # the instance has non-empty history. + rewind_event = None + for event in new_history: + if event.HasField("executionRewound"): + rewind_event = event + break + instance.pending_events.clear() + instance.dispatched_events.clear() + if rewind_event is not None: + instance.pending_events.append(rewind_event) + instance.completion_token = self._next_completion_token + self._next_completion_token += 1 + self._orchestration_in_flight.discard(instance.instance_id) + self._enqueue_orchestration(instance.instance_id) + def _enqueue_entity(self, entity_id: str): """Enqueues an entity for processing.""" if entity_id not in self._entity_queue_set: diff --git a/durabletask/worker.py b/durabletask/worker.py index 9c7f2d46..8c55ffa3 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1462,6 +1462,26 @@ def execute( "The new history event list must have at least one event in it." ) + # Check for rewind BEFORE replay. A rewind is indicated by an + # executionRewound event in new_events. We look for an + # executionCompleted event anywhere in the history (old or new + # events) to decide whether to rewind or replay: + # 1. executionCompleted IS present → the orchestration reached a + # terminal state (e.g. failed). This is a *new* rewind that + # the worker must short-circuit by building clean history. + # 2. executionCompleted is NOT present → the backend already + # processed the RewindOrchestrationAction which removed + # executionCompleted from history. This is a normal + # post-rewind replay. + has_rewind_in_new = any( + e.HasField("executionRewound") for e in new_events + ) + if has_rewind_in_new and any(e.HasField("executionCompleted") for e in old_events): + # The orchestration completed (with failure) and needs + # rewinding — short-circuit to build clean history. + return self._build_rewind_result( + instance_id, orchestration_name, old_events, new_events) + ctx = _RuntimeOrchestrationContext(instance_id, self._registry) try: # Rebuild local state by replaying old history into the orchestrator function @@ -1523,6 +1543,97 @@ def execute( orchestration_trace_context=ctx._orchestration_trace_context, ) + def _build_rewind_result( + self, + instance_id: str, + orchestration_name: str, + old_events: Sequence[pb.HistoryEvent], + new_events: Sequence[pb.HistoryEvent], + ) -> ExecutionResults: + """Build an ``ExecutionResults`` containing a ``RewindOrchestrationAction``. + + When the worker detects an ``executionRewound`` event in the new + events (that does not yet appear in the committed history) it + rewrites the history by removing failed task results + (``taskFailed``) and failed sub-orchestration results + (``subOrchestrationInstanceFailed``). The ``executionRewound`` + event is kept so the backend knows why the rewind happened and + so it remains in the history for audit purposes. + + For failed activities, the corresponding ``taskScheduled`` event + is also removed so that the SDK will re-generate a + ``scheduleTask`` action during the next replay, causing the + backend to re-dispatch the activity. + + For failed sub-orchestrations, the ``subOrchestrationInstanceCreated`` + event is kept so the backend can identify which sub-orchestration + instances to recursively rewind. + + WARNING!!!: + If any changes are made to how this method modifies the orchestration's history, then corresponding changes *must* + be made in the backend implementations that rely on this method for executing a rewind. + """ + self._logger.info( + f"{instance_id}: Orchestration {orchestration_name} is being rewound" + ) + + if len(new_events) != 2 or not new_events[1].HasField("executionRewound"): + raise ValueError( + "When rewinding an orchestration, the new events list must contain exactly two events: orchestratorStarted and the executionRewound event." + ) + + rewind_event: pb.ExecutionRewoundEvent = new_events[1].executionRewound + + all_events = list(old_events) + list(new_events) + # Generate a new execution ID for the rewound execution. + new_execution_id = uuid.uuid4().hex + + # First pass: collect the task-scheduled IDs that correspond to + # failed activities / sub-orchestrations so we can remove the + # matching taskScheduled events in the second pass. + failed_task_ids: set[int] = set() + for event in all_events: + if event.HasField("taskFailed"): + failed_task_ids.add(event.taskFailed.taskScheduledId) + elif event.HasField("subOrchestrationInstanceFailed"): + failed_task_ids.add(event.subOrchestrationInstanceFailed.taskScheduledId) + + # Second pass: build the clean history. + clean_history: list[pb.HistoryEvent] = [] + for event in all_events: + if event.HasField("taskFailed"): + continue + if event.HasField("taskScheduled") and event.eventId in failed_task_ids: + continue + if event.HasField("subOrchestrationInstanceFailed"): + continue + if event.HasField("executionCompleted"): + continue + + # Modify the executionStarted event: assign a fresh + # execution ID and, for sub-orchestrations, update the + # parent's execution ID so it matches the parent's new run. + if event.HasField("executionStarted"): + event_copy = pb.HistoryEvent() + event_copy.CopyFrom(event) + event_copy.executionStarted.orchestrationInstance.executionId.CopyFrom( + ph.get_string_value_or_empty(new_execution_id)) + if rewind_event.HasField("parentExecutionId"): + if rewind_event.parentExecutionId.value: + event_copy.executionStarted.parentInstance.orchestrationInstance.executionId.CopyFrom( + rewind_event.parentExecutionId) + clean_history.append(event_copy) + continue + + clean_history.append(event) + + rewind_action = pb.RewindOrchestrationAction(newHistory=clean_history) + action = pb.OrchestratorAction( + id=-1, + rewindOrchestration=rewind_action, + ) + return ExecutionResults(actions=[action], encoded_custom_status=None) + def process_event( self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent ) -> None: @@ -2049,7 +2160,13 @@ def process_event( entity_task.fail(str(failure), failure) ctx.resume() elif event.HasField("orchestratorCompleted"): - # Added in Functions only (for some reason) and does not affect orchestrator flow + # Bookend event for each replay batch — no action needed. + pass + elif event.HasField("executionCompleted"): + # Terminal marker event — in practice, this never appears during replay. + pass + elif event.HasField("executionRewound"): + # Informational event added when an orchestration is rewound. No action needed. pass elif event.HasField("eventSent"): # Check if this eventSent corresponds to an entity operation call after being translated to the old diff --git a/tests/durabletask-azuremanaged/test_dts_rewind_e2e.py b/tests/durabletask-azuremanaged/test_dts_rewind_e2e.py new file mode 100644 index 00000000..ee989a78 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_rewind_e2e.py @@ -0,0 +1,358 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import os + +import pytest + +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = [ + pytest.mark.dts, + pytest.mark.skip(reason="Rewind support is not yet available in the public DTS emulator"), +] + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def _get_credential(): + """Returns DefaultAzureCredential if endpoint is https, otherwise None (for emulator).""" + if endpoint.startswith("https://"): + from azure.identity import DefaultAzureCredential + return DefaultAzureCredential() + return None + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_rewind_failed_activity(): + """Rewind a failed orchestration whose single activity failed. + + After rewind the activity succeeds and the orchestration completes. + """ + activity_call_count = 0 + should_fail = True + + def failing_activity(_: task.ActivityContext, input: str) -> str: + nonlocal activity_call_count + activity_call_count += 1 + if should_fail: + raise RuntimeError("Simulated failure") + return f"Hello, {input}!" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(failing_activity, input=input) + return result + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.add_activity(failing_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration(orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # The orchestration should have failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Fix the activity so it now succeeds, then rewind. + should_fail = False + c.rewind_orchestration(instance_id, reason="retry after fix") + + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + assert state.failure_details is None + # Activity was called twice (once failed, once succeeded after rewind). + assert activity_call_count == 2 + + +def test_rewind_preserves_successful_results(): + """When an orchestration has a mix of successful and failed activities, + rewind should re-execute only the failed activity while the successful + result is replayed from history.""" + call_tracker: dict[str, int] = {"first": 0, "second": 0} + should_fail_second = True + + def first_activity(_: task.ActivityContext, input: str) -> str: + call_tracker["first"] += 1 + return f"first:{input}" + + def second_activity(_: task.ActivityContext, input: str) -> str: + call_tracker["second"] += 1 + if should_fail_second: + raise RuntimeError("Temporary failure") + return f"second:{input}" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + r1 = yield ctx.call_activity(first_activity, input=input) + r2 = yield ctx.call_activity(second_activity, input=input) + return [r1, r2] + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.add_activity(first_activity) + w.add_activity(second_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration(orchestrator, input="test") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # The orchestration should have failed (second_activity fails). + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Fix second_activity so it now succeeds, then rewind. + should_fail_second = False + c.rewind_orchestration(instance_id, reason="retry") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(["first:test", "second:test"]) + assert state.failure_details is None + # first_activity should NOT be re-executed – its result is replayed. + assert call_tracker["first"] == 1 + # second_activity was called at least twice (once failed, once succeeded). + assert call_tracker["second"] >= 2 + + +def test_rewind_not_found(): + """Rewinding a non-existent instance should raise an RPC error.""" + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + with pytest.raises(Exception): + c.rewind_orchestration("nonexistent-id") + + +def test_rewind_non_failed_instance(): + """Rewinding a completed (non-failed) instance should raise an error.""" + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + with pytest.raises(Exception): + c.rewind_orchestration(instance_id) + + +def test_rewind_with_sub_orchestration(): + """Rewind should recursively rewind failed sub-orchestrations.""" + sub_call_count = 0 + + def child_activity(_: task.ActivityContext, input: str) -> str: + nonlocal sub_call_count + sub_call_count += 1 + if sub_call_count == 1: + raise RuntimeError("Child failure") + return f"child:{input}" + + def child_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(child_activity, input=input) + return result + + def parent_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_sub_orchestrator( + child_orchestrator, input=input) + return f"parent:{result}" + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(parent_orchestrator) + w.add_orchestrator(child_orchestrator) + w.add_activity(child_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration( + parent_orchestrator, input="data") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Parent should fail because child failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Rewind – child_activity will succeed on retry. + c.rewind_orchestration(instance_id, reason="sub-orch fix") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("parent:child:data") + assert sub_call_count == 2 + + +def test_rewind_purged_sub_orchestration(): + """A purged sub-orchestration is re-run when the parent is rewound. + + Flow: parent orchestrator -> calls sub-orchestrator -> sub-orchestrator + fails -> parent fails -> client purges the sub-orchestration -> client + rewinds the parent -> parent re-schedules the sub-orchestration which + now succeeds -> parent completes. + """ + child_call_count = 0 + + def child_activity(_: task.ActivityContext, input: str) -> str: + nonlocal child_call_count + child_call_count += 1 + if child_call_count == 1: + raise RuntimeError("Child failure") + return f"child:{input}" + + def child_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(child_activity, input=input) + return result + + def parent_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_sub_orchestrator( + child_orchestrator, input=input, instance_id="sub-orch-to-purge") + return f"parent:{result}" + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(parent_orchestrator) + w.add_orchestrator(child_orchestrator) + w.add_activity(child_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration( + parent_orchestrator, input="data") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Parent should fail because child failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Purge the sub-orchestration so it must be completely re-run. + c.purge_orchestration("sub-orch-to-purge") + + # Rewind the parent – child will be re-scheduled and succeed. + c.rewind_orchestration(instance_id, reason="purge and retry") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("parent:child:data") + assert child_call_count == 2 + + +def test_rewind_without_reason(): + """Rewind should work when no reason is provided.""" + call_count = 0 + + def flaky_activity(_: task.ActivityContext, _1) -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Boom") + return "ok" + + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(flaky_activity) + return result + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.add_activity(flaky_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Rewind without a reason + c.rewind_orchestration(instance_id) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("ok") + + +def test_rewind_twice(): + """Rewind the same orchestration twice after it fails a second time. + + The first rewind cleans up the initial failure. The activity then + fails again. A second rewind should clean up the new failure and + the orchestration should eventually complete. + """ + call_count = 0 + + def flaky_activity(_: task.ActivityContext, input: str) -> str: + nonlocal call_count + call_count += 1 + # Fail on the 1st and 2nd calls; succeed on the 3rd. + if call_count <= 2: + raise RuntimeError(f"Failure #{call_count}") + return f"Hello, {input}!" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(flaky_activity, input=input) + return result + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.add_activity(flaky_activity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + instance_id = c.schedule_new_orchestration(orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # First failure. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # First rewind — activity will fail again (call_count == 2). + c.rewind_orchestration(instance_id, reason="first rewind") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Second rewind — activity will succeed (call_count == 3). + c.rewind_orchestration(instance_id, reason="second rewind") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + assert call_count == 3 diff --git a/tests/durabletask/test_build_rewind_result.py b/tests/durabletask/test_build_rewind_result.py new file mode 100644 index 00000000..68c74285 --- /dev/null +++ b/tests/durabletask/test_build_rewind_result.py @@ -0,0 +1,700 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for _OrchestrationExecutor._build_rewind_result. + +These tests directly invoke the rewind history-rewriting logic and +verify that the clean history produced by the worker matches the +expected semantics: + +* The ``executionStarted`` event gets a new execution ID. +* When a ``parentExecutionId`` is present on the rewind event, + the ``parentInstance.orchestrationInstance.executionId`` on the + ``executionStarted`` copy is updated accordingly. +* ``taskFailed`` events and their corresponding ``taskScheduled`` + events are removed. +* ``subOrchestrationInstanceFailed`` events and their corresponding + ``taskScheduled`` events are removed. +* ``subOrchestrationInstanceCreated`` events for failed sub-orchestrations + are kept so the backend can recursively rewind them. +* ``executionCompleted`` events are removed. +* Successful activity/sub-orchestration results are preserved. +""" + +import json +import logging + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask import task, worker +from google.protobuf import wrappers_pb2 + +logging.basicConfig( + format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG) +TEST_LOGGER = logging.getLogger("tests") + +TEST_INSTANCE_ID = "rewind-test-instance" +ORIGINAL_EXECUTION_ID = "original-exec-id" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_execution_started( + name: str, + instance_id: str = TEST_INSTANCE_ID, + execution_id: str = ORIGINAL_EXECUTION_ID, + parent_instance: pb.ParentInstanceInfo | None = None, +) -> pb.HistoryEvent: + """Create an executionStarted event with an explicit execution ID.""" + event = pb.HistoryEvent( + eventId=-1, + executionStarted=pb.ExecutionStartedEvent( + name=name, + orchestrationInstance=pb.OrchestrationInstance( + instanceId=instance_id, + executionId=wrappers_pb2.StringValue(value=execution_id), + ), + ), + ) + if parent_instance is not None: + event.executionStarted.parentInstance.CopyFrom(parent_instance) + return event + + +def _make_execution_rewound( + reason: str = "test rewind", + parent_execution_id: str | None = None, +) -> pb.HistoryEvent: + """Create an executionRewound history event.""" + rewound = pb.ExecutionRewoundEvent( + reason=wrappers_pb2.StringValue(value=reason), + ) + if parent_execution_id is not None: + rewound.parentExecutionId.CopyFrom( + wrappers_pb2.StringValue(value=parent_execution_id)) + return pb.HistoryEvent( + eventId=-1, + executionRewound=rewound, + ) + + +def _dummy_orchestrator(ctx: task.OrchestrationContext, _): + return None + + +def _make_executor() -> worker._OrchestrationExecutor: + """Create a minimal _OrchestrationExecutor (no registered functions needed).""" + registry = worker._Registry() + registry.add_orchestrator(_dummy_orchestrator) + return worker._OrchestrationExecutor(registry, TEST_LOGGER) + + +def _get_clean_history(result: worker.ExecutionResults) -> list[pb.HistoryEvent]: + """Extract the clean history from a RewindOrchestrationAction result.""" + assert len(result.actions) == 1 + action = result.actions[0] + assert action.HasField("rewindOrchestration") + return list(action.rewindOrchestration.newHistory) + + +# --------------------------------------------------------------------------- +# Tests: execution ID changes +# --------------------------------------------------------------------------- + + +def test_rewind_assigns_new_execution_id(): + """The executionStarted event in the clean history must have a new, + different execution ID.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("my_orch"), + helpers.new_task_scheduled_event(1, "my_activity"), + helpers.new_task_failed_event(1, RuntimeError("boom")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("boom"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "my_orch", old_events, new_events) + clean = _get_clean_history(result) + + # Find the executionStarted event + started_events = [e for e in clean if e.HasField("executionStarted")] + assert len(started_events) == 1 + + new_exec_id = started_events[0].executionStarted.orchestrationInstance.executionId.value + assert new_exec_id != ORIGINAL_EXECUTION_ID + assert len(new_exec_id) > 0 # must be a non-empty string + + +def test_rewind_preserves_execution_started_fields(): + """The executionStarted copy should preserve the original fields + (name, instance ID) while changing only the execution ID.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("preserve_me"), + helpers.new_task_scheduled_event(1, "act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "preserve_me", old_events, new_events) + clean = _get_clean_history(result) + + started = [e for e in clean if e.HasField("executionStarted")][0] + assert started.executionStarted.name == "preserve_me" + assert started.executionStarted.orchestrationInstance.instanceId == TEST_INSTANCE_ID + + +def test_rewind_updates_parent_execution_id(): + """When the rewind event carries a parentExecutionId, the + executionStarted copy must update + parentInstance.orchestrationInstance.executionId to match.""" + executor = _make_executor() + + parent_exec_id = "parent-old-exec-id" + parent_new_exec_id = "parent-new-exec-id" + + parent_info = pb.ParentInstanceInfo( + taskScheduledId=5, + name=wrappers_pb2.StringValue(value="parent_orch"), + orchestrationInstance=pb.OrchestrationInstance( + instanceId="parent-instance", + executionId=wrappers_pb2.StringValue(value=parent_exec_id), + ), + ) + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("child_orch", parent_instance=parent_info), + helpers.new_task_scheduled_event(1, "child_act"), + helpers.new_task_failed_event(1, RuntimeError("child fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("child fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("parent rewind", parent_execution_id=parent_new_exec_id), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "child_orch", old_events, new_events) + clean = _get_clean_history(result) + + started = [e for e in clean if e.HasField("executionStarted")][0] + # Parent execution ID should be updated to the new one. + actual_parent_exec_id = ( + started.executionStarted.parentInstance + .orchestrationInstance.executionId.value + ) + assert actual_parent_exec_id == parent_new_exec_id + + +def test_rewind_no_parent_execution_id_leaves_parent_unchanged(): + """When the rewind event has no parentExecutionId, the executionStarted + copy should leave the parentInstance untouched (if any).""" + executor = _make_executor() + + parent_exec_id = "parent-exec-id-unchanged" + parent_info = pb.ParentInstanceInfo( + taskScheduledId=5, + name=wrappers_pb2.StringValue(value="parent_orch"), + orchestrationInstance=pb.OrchestrationInstance( + instanceId="parent-instance", + executionId=wrappers_pb2.StringValue(value=parent_exec_id), + ), + ) + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("child_orch", parent_instance=parent_info), + helpers.new_task_scheduled_event(1, "child_act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + # No parentExecutionId on the rewind event. + _make_execution_rewound("top-level rewind"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "child_orch", old_events, new_events) + clean = _get_clean_history(result) + + started = [e for e in clean if e.HasField("executionStarted")][0] + actual_parent_exec_id = ( + started.executionStarted.parentInstance + .orchestrationInstance.executionId.value + ) + # Should remain unchanged. + assert actual_parent_exec_id == parent_exec_id + + +# --------------------------------------------------------------------------- +# Tests: failed activity cleanup +# --------------------------------------------------------------------------- + + +def test_rewind_removes_failed_activity_events(): + """taskFailed and its corresponding taskScheduled should be removed.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + helpers.new_task_scheduled_event(1, "my_act"), + helpers.new_task_failed_event(1, RuntimeError("boom")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("boom"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + clean = _get_clean_history(result) + + # No taskFailed or taskScheduled for the failed activity. + assert not any(e.HasField("taskFailed") for e in clean) + assert not any( + e.HasField("taskScheduled") and e.eventId == 1 for e in clean + ) + + +def test_rewind_preserves_successful_activity(): + """Successful taskScheduled + taskCompleted should remain in clean history.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + # Activity 1 succeeds + helpers.new_task_scheduled_event(1, "good_act"), + helpers.new_task_completed_event(1, json.dumps("ok")), + helpers.new_orchestrator_completed_event(), + # Activity 2 fails + helpers.new_orchestrator_started_event(), + helpers.new_task_scheduled_event(2, "bad_act"), + helpers.new_task_failed_event(2, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + clean = _get_clean_history(result) + + # Activity 1's taskScheduled and taskCompleted should still be present. + assert any( + e.HasField("taskScheduled") and e.eventId == 1 for e in clean + ) + assert any( + e.HasField("taskCompleted") and e.taskCompleted.taskScheduledId == 1 + for e in clean + ) + # Activity 2's taskScheduled and taskFailed should be removed. + assert not any( + e.HasField("taskScheduled") and e.eventId == 2 for e in clean + ) + assert not any(e.HasField("taskFailed") for e in clean) + + +# --------------------------------------------------------------------------- +# Tests: failed sub-orchestration cleanup +# --------------------------------------------------------------------------- + + +def test_rewind_removes_failed_sub_orch_events(): + """subOrchestrationInstanceFailed and its corresponding taskScheduled + should be removed, but subOrchestrationInstanceCreated is kept.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("parent_orch"), + helpers.new_sub_orchestration_created_event(1, "child_orch", "child-id"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_failed_event(1, RuntimeError("child exploded")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("child exploded"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("rewind child"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "parent_orch", old_events, new_events) + clean = _get_clean_history(result) + + # subOrchestrationInstanceFailed is removed. + assert not any( + e.HasField("subOrchestrationInstanceFailed") for e in clean + ) + # The corresponding taskScheduled (eventId == 1) used by + # subOrchestrationInstanceFailed should also be removed, since + # _build_rewind_result collects the taskScheduledId from + # subOrchestrationInstanceFailed too. + # Note: subOrchestrationInstanceCreated uses a separate event with + # its own eventId and is NOT removed. + assert not any( + e.HasField("taskScheduled") and e.eventId == 1 for e in clean + ) + # The subOrchestrationInstanceCreated event should be preserved + # so the backend can identify which sub-orchestration to rewind. + assert any( + e.HasField("subOrchestrationInstanceCreated") for e in clean + ) + + +def test_rewind_preserves_successful_sub_orchestration(): + """Successful sub-orchestration events should be preserved.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("parent_orch"), + # Sub-orch 1 succeeds + helpers.new_sub_orchestration_created_event(1, "child_ok", "child-ok-id"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_completed_event(1, json.dumps("child result")), + # Sub-orch 2 fails + helpers.new_sub_orchestration_created_event(2, "child_fail", "child-fail-id"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_failed_event(2, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "parent_orch", old_events, new_events) + clean = _get_clean_history(result) + + # Sub-orch 1's created + completed should be present. + created_ids = [ + e.subOrchestrationInstanceCreated.instanceId + for e in clean if e.HasField("subOrchestrationInstanceCreated") + ] + assert "child-ok-id" in created_ids + completed_sub_ids = [ + e.subOrchestrationInstanceCompleted.taskScheduledId + for e in clean if e.HasField("subOrchestrationInstanceCompleted") + ] + assert 1 in completed_sub_ids + # Sub-orch 2's failed event should be removed. + assert not any( + e.HasField("subOrchestrationInstanceFailed") for e in clean + ) + # Sub-orch 2's created event should be kept (for backend recursive rewind). + assert "child-fail-id" in created_ids + + +# --------------------------------------------------------------------------- +# Tests: executionCompleted removal +# --------------------------------------------------------------------------- + + +def test_rewind_removes_execution_completed(): + """executionCompleted events should be stripped from the clean history.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + helpers.new_task_scheduled_event(1, "act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + clean = _get_clean_history(result) + + assert not any(e.HasField("executionCompleted") for e in clean) + + +# --------------------------------------------------------------------------- +# Tests: orchestratorStarted/Completed preservation +# --------------------------------------------------------------------------- + + +def test_rewind_keeps_orchestrator_started_and_completed(): + """orchestratorStarted and orchestratorCompleted bookend events + should be preserved in clean history.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + helpers.new_task_scheduled_event(1, "act"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + clean = _get_clean_history(result) + + orch_started_count = sum( + 1 for e in clean if e.HasField("orchestratorStarted") + ) + orch_completed_count = sum( + 1 for e in clean if e.HasField("orchestratorCompleted") + ) + # old_events has 2 orchestratorStarted + 2 orchestratorCompleted, + # new_events adds 1 orchestratorStarted. All should be kept. + assert orch_started_count >= 2 + assert orch_completed_count >= 2 + + +# --------------------------------------------------------------------------- +# Tests: executionRewound event preserved +# --------------------------------------------------------------------------- + + +def test_rewind_keeps_execution_rewound_event(): + """The executionRewound event itself should remain in the clean + history so it is visible in the audit trail.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + helpers.new_task_scheduled_event(1, "act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("rewind reason"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + clean = _get_clean_history(result) + + rewound_events = [e for e in clean if e.HasField("executionRewound")] + assert len(rewound_events) == 1 + assert rewound_events[0].executionRewound.reason.value == "rewind reason" + + +# --------------------------------------------------------------------------- +# Tests: mixed scenario +# --------------------------------------------------------------------------- + + +def test_rewind_mixed_activities_and_sub_orchestrations(): + """A complex scenario with successful activities, failed activities, + successful sub-orchestrations, and failed sub-orchestrations. + Verifies that only failed items are cleaned and execution ID is updated.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("complex_orch"), + # Activity 1 succeeds (eventId=1) + helpers.new_task_scheduled_event(1, "good_activity"), + helpers.new_task_completed_event(1, json.dumps("good")), + # Sub-orch A succeeds (eventId=2) + helpers.new_sub_orchestration_created_event(2, "child_ok", "child-ok-id"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_completed_event(2, json.dumps("child ok result")), + # Activity 3 fails (eventId=3) + helpers.new_task_scheduled_event(3, "bad_activity"), + helpers.new_task_failed_event(3, RuntimeError("act fail")), + helpers.new_orchestrator_completed_event(), + # Sub-orch B fails (eventId=4) + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_created_event(4, "child_fail", "child-fail-id"), + helpers.new_orchestrator_completed_event(), + helpers.new_orchestrator_started_event(), + helpers.new_sub_orchestration_failed_event(4, RuntimeError("sub orch fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("overall fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("fix everything"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "complex_orch", old_events, new_events) + clean = _get_clean_history(result) + + # --- Execution ID changed --- + started = [e for e in clean if e.HasField("executionStarted")] + assert len(started) == 1 + assert (started[0].executionStarted.orchestrationInstance + .executionId.value != ORIGINAL_EXECUTION_ID) + + # --- Successful activity 1 preserved --- + assert any( + e.HasField("taskScheduled") and e.eventId == 1 for e in clean + ) + assert any( + e.HasField("taskCompleted") and e.taskCompleted.taskScheduledId == 1 + for e in clean + ) + + # --- Failed activity 3 removed --- + assert not any( + e.HasField("taskScheduled") and e.eventId == 3 for e in clean + ) + assert not any(e.HasField("taskFailed") for e in clean) + + # --- Successful sub-orch A preserved --- + created_ids = [ + e.subOrchestrationInstanceCreated.instanceId + for e in clean if e.HasField("subOrchestrationInstanceCreated") + ] + assert "child-ok-id" in created_ids + completed_sub_ids = [ + e.subOrchestrationInstanceCompleted.taskScheduledId + for e in clean if e.HasField("subOrchestrationInstanceCompleted") + ] + assert 2 in completed_sub_ids + + # --- Failed sub-orch B: failed event removed, created kept --- + assert not any( + e.HasField("subOrchestrationInstanceFailed") for e in clean + ) + assert "child-fail-id" in created_ids + + # --- executionCompleted removed --- + assert not any(e.HasField("executionCompleted") for e in clean) + + # --- executionRewound preserved --- + assert any(e.HasField("executionRewound") for e in clean) + + +def test_rewind_does_not_mutate_original_events(): + """Verify that the original history events are not modified in place.""" + executor = _make_executor() + + es_event = _make_execution_started("orch") + original_exec_id = ( + es_event.executionStarted.orchestrationInstance.executionId.value + ) + + old_events = [ + helpers.new_orchestrator_started_event(), + es_event, + helpers.new_task_scheduled_event(1, "act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + + # The original executionStarted event should NOT be mutated. + actual = es_event.executionStarted.orchestrationInstance.executionId.value + assert actual == original_exec_id + + +def test_rewind_result_action_structure(): + """The result should contain exactly one OrchestratorAction with id=-1 + and a rewindOrchestration field.""" + executor = _make_executor() + + old_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_started("orch"), + helpers.new_task_scheduled_event(1, "act"), + helpers.new_task_failed_event(1, RuntimeError("fail")), + helpers.new_orchestrator_completed_event(), + helpers.new_execution_completed_event( + pb.ORCHESTRATION_STATUS_FAILED, + failure_details=helpers.new_failure_details(RuntimeError("fail"))), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + _make_execution_rewound("retry"), + ] + + result = executor._build_rewind_result( + TEST_INSTANCE_ID, "orch", old_events, new_events) + + assert len(result.actions) == 1 + action = result.actions[0] + assert action.id == -1 + assert action.HasField("rewindOrchestration") + assert result.encoded_custom_status is None diff --git a/tests/durabletask/test_rewind_e2e.py b/tests/durabletask/test_rewind_e2e.py new file mode 100644 index 00000000..f10f7690 --- /dev/null +++ b/tests/durabletask/test_rewind_e2e.py @@ -0,0 +1,354 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json + +import pytest + +from durabletask import client, task, worker +from durabletask.testing import create_test_backend + +HOST = "localhost:50055" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for testing.""" + b = create_test_backend(port=50055) + yield b + b.stop() + b.reset() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# These counters live at module level so that orchestrators and +# activities can mutate them via ``nonlocal``. +_activity_call_count = 0 +_should_fail = True + + +def _reset_counters(): + global _activity_call_count, _should_fail + _activity_call_count = 0 + _should_fail = True + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_rewind_failed_activity(): + """Rewind a failed orchestration whose single activity failed. + + After rewind the activity succeeds and the orchestration completes. + """ + _reset_counters() + + def failing_activity(_: task.ActivityContext, input: str) -> str: + global _activity_call_count + _activity_call_count += 1 + if _should_fail: + raise RuntimeError("Simulated failure") + return f"Hello, {input}!" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(failing_activity, input=input) + return result + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.add_activity(failing_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration(orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # The orchestration should have failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Fix the activity so it now succeeds, then rewind. + global _should_fail + _should_fail = False + c.rewind_orchestration(instance_id, reason="retry after fix") + + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + assert state.failure_details is None + # Activity was called twice (once failed, once succeeded after rewind). + assert _activity_call_count == 2 + + +def test_rewind_preserves_successful_results(): + """When an orchestration has a mix of successful and failed activities, + rewind should re-execute only the failed activity while the successful + result is replayed from history.""" + _reset_counters() + + call_tracker: dict[str, int] = {"first": 0, "second": 0} + should_fail_second = True + + def first_activity(_: task.ActivityContext, input: str) -> str: + call_tracker["first"] += 1 + return f"first:{input}" + + def second_activity(_: task.ActivityContext, input: str) -> str: + call_tracker["second"] += 1 + if should_fail_second: + raise RuntimeError("Temporary failure") + return f"second:{input}" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + r1 = yield ctx.call_activity(first_activity, input=input) + r2 = yield ctx.call_activity(second_activity, input=input) + return [r1, r2] + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.add_activity(first_activity) + w.add_activity(second_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration(orchestrator, input="test") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # The orchestration should have failed (second_activity fails). + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Fix second_activity so it now succeeds, then rewind. + should_fail_second = False + c.rewind_orchestration(instance_id, reason="retry") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(["first:test", "second:test"]) + assert state.failure_details is None + # first_activity should NOT be re-executed – its result is replayed. + assert call_tracker["first"] == 1 + # second_activity was called at least twice (once failed, once succeeded). + assert call_tracker["second"] >= 2 + + +def test_rewind_not_found(): + """Rewinding a non-existent instance should raise an RPC error.""" + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.start() + c = client.TaskHubGrpcClient(host_address=HOST) + with pytest.raises(Exception): + c.rewind_orchestration("nonexistent-id") + + +def test_rewind_non_failed_instance(): + """Rewinding a completed (non-failed) instance should raise an error.""" + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + with pytest.raises(Exception): + c.rewind_orchestration(instance_id) + + +def test_rewind_with_sub_orchestration(): + """Rewind should recursively rewind failed sub-orchestrations.""" + sub_call_count = 0 + + def child_activity(_: task.ActivityContext, input: str) -> str: + nonlocal sub_call_count + sub_call_count += 1 + if sub_call_count == 1: + raise RuntimeError("Child failure") + return f"child:{input}" + + def child_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(child_activity, input=input) + return result + + def parent_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_sub_orchestrator( + child_orchestrator, input=input) + return f"parent:{result}" + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(parent_orchestrator) + w.add_orchestrator(child_orchestrator) + w.add_activity(child_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration( + parent_orchestrator, input="data") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Parent should fail because child failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Rewind – child_activity will succeed on retry. + c.rewind_orchestration(instance_id, reason="sub-orch fix") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("parent:child:data") + assert sub_call_count == 2 + + +def test_rewind_without_reason(): + """Rewind should work when no reason is provided.""" + _reset_counters() + call_count = 0 + + def flaky_activity(_: task.ActivityContext, _1) -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Boom") + return "ok" + + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(flaky_activity) + return result + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.add_activity(flaky_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Rewind without a reason + c.rewind_orchestration(instance_id) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("ok") + + +def test_rewind_purged_sub_orchestration(): + """A purged sub-orchestration is re-run when the parent is rewound. + + Flow: parent orchestrator -> calls sub-orchestrator -> sub-orchestrator + fails -> parent fails -> client purges the sub-orchestration -> client + rewinds the parent -> parent re-schedules the sub-orchestration which + now succeeds -> parent completes. + """ + child_call_count = 0 + + def child_activity(_: task.ActivityContext, input: str) -> str: + nonlocal child_call_count + child_call_count += 1 + if child_call_count == 1: + raise RuntimeError("Child failure") + return f"child:{input}" + + def child_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(child_activity, input=input) + return result + + def parent_orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_sub_orchestrator( + child_orchestrator, input=input, instance_id="sub-orch-to-purge") + return f"parent:{result}" + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(parent_orchestrator) + w.add_orchestrator(child_orchestrator) + w.add_activity(child_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration( + parent_orchestrator, input="data") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Parent should fail because child failed. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Purge the sub-orchestration so it must be completely re-run. + c.purge_orchestration("sub-orch-to-purge") + + # Rewind the parent – child will be re-scheduled and succeed. + c.rewind_orchestration(instance_id, reason="purge and retry") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("parent:child:data") + assert child_call_count == 2 + + +def test_rewind_twice(): + """Rewind the same orchestration twice after it fails a second time. + + The first rewind cleans up the initial failure. The activity then + fails again. A second rewind should clean up the new failure and + the orchestration should eventually complete. + """ + call_count = 0 + + def flaky_activity(_: task.ActivityContext, input: str) -> str: + nonlocal call_count + call_count += 1 + # Fail on the 1st and 2nd calls; succeed on the 3rd. + if call_count <= 2: + raise RuntimeError(f"Failure #{call_count}") + return f"Hello, {input}!" + + def orchestrator(ctx: task.OrchestrationContext, input: str): + result = yield ctx.call_activity(flaky_activity, input=input) + return result + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.add_activity(flaky_activity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + instance_id = c.schedule_new_orchestration(orchestrator, input="World") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + # First failure. + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # First rewind — activity will fail again (call_count == 2). + c.rewind_orchestration(instance_id, reason="first rewind") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + + # Second rewind — activity will succeed (call_count == 3). + c.rewind_orchestration(instance_id, reason="second rewind") + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + assert call_count == 3