Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 100 additions & 43 deletions apps/sim/app/workspace/[workspaceId]/home/hooks/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ const DEPLOY_TOOL_NAMES = new Set(['deploy_api', 'deploy_chat', 'deploy_mcp', 'r
const RECONNECT_TAIL_ERROR =
'Live reconnect failed before the stream finished. The latest response may be incomplete.'
const TERMINAL_STREAM_STATUSES = new Set(['complete', 'error', 'cancelled'])
const MAX_RECONNECT_ATTEMPTS = 10
const RECONNECT_BASE_DELAY_MS = 1000
const RECONNECT_MAX_DELAY_MS = 30_000

interface StreamEventEnvelope {
eventId: number
Expand Down Expand Up @@ -1565,6 +1568,7 @@ export function useChat(
(options?: { error?: boolean }) => {
sendingRef.current = false
setIsSending(false)
setIsReconnecting(false)
abortControllerRef.current = null
invalidateChatQueries()

Expand Down Expand Up @@ -1609,6 +1613,47 @@ export function useChat(
)
finalizeRef.current = finalize

const resumeOrFinalize = useCallback(
async (opts: {
streamId: string
assistantId: string
gen: number
fromEventId: number
snapshot?: StreamSnapshot | null
signal?: AbortSignal
}): Promise<void> => {
const { streamId, assistantId, gen, fromEventId, snapshot, signal } = opts

const batch =
snapshot ??
(await (async () => {
const b = await fetchStreamBatch(streamId, fromEventId, signal)
if (streamGenRef.current !== gen) return null
return { events: b.events, status: b.status } as StreamSnapshot
})())

if (!batch || streamGenRef.current !== gen) return

if (isTerminalStreamStatus(batch.status)) {
finalize(batch.status === 'error' ? { error: true } : undefined)
return
}

const reconnectResult = await attachToExistingStream({
streamId,
assistantId,
expectedGen: gen,
snapshot: batch,
initialLastEventId: batch.events[batch.events.length - 1]?.eventId ?? fromEventId,
})

if (streamGenRef.current === gen && !reconnectResult.aborted) {
finalize(reconnectResult.error ? { error: true } : undefined)
}
},
[fetchStreamBatch, attachToExistingStream, finalize]
)

const sendMessage = useCallback(
async (message: string, fileAttachments?: FileAttachmentForApi[], contexts?: ChatContext[]) => {
if (!message.trim() || !workspaceId) return
Expand Down Expand Up @@ -1745,44 +1790,13 @@ export function useChat(
return
}

const batch = await fetchStreamBatch(
userMessageId,
termination.lastEventId,
abortController.signal
)
if (streamGenRef.current !== gen) {
return
}
if (isTerminalStreamStatus(batch.status)) {
finalize(batch.status === 'error' ? { error: true } : undefined)
return
}

logger.warn(
'Primary stream ended without terminal event, attempting in-place reconnect',
{
streamId: userMessageId,
lastEventId: termination.lastEventId,
streamStatus: batch.status,
sawDoneEvent: termination.sawDoneEvent,
}
)

const reconnectResult = await attachToExistingStream({
await resumeOrFinalize({
streamId: userMessageId,
assistantId,
expectedGen: gen,
snapshot: {
events: batch.events,
status: batch.status,
},
initialLastEventId:
batch.events[batch.events.length - 1]?.eventId ?? termination.lastEventId,
gen,
fromEventId: termination.lastEventId,
signal: abortController.signal,
})

if (streamGenRef.current === gen && !reconnectResult.aborted) {
finalize(reconnectResult.error ? { error: true } : undefined)
}
}
} catch (err) {
if (err instanceof Error && err.name === 'AbortError') return
Expand Down Expand Up @@ -1827,17 +1841,13 @@ export function useChat(
.find((m) => m.role === 'assistant')
const recoveryAssistantId = lastAssistantMsg?.id ?? assistantId

const reconnectResult = await attachToExistingStream({
await resumeOrFinalize({
streamId: pendingRecovery.streamId,
assistantId: recoveryAssistantId,
expectedGen: gen,
gen,
fromEventId: lastEventIdRef.current,
snapshot: pendingRecovery.snapshot,
initialLastEventId: lastEventIdRef.current,
})

if (streamGenRef.current === gen && !reconnectResult.aborted) {
finalize(reconnectResult.error ? { error: true } : undefined)
}
return
} catch (recoveryError) {
logger.warn('Failed to recover active stream after conflict', {
Expand All @@ -1847,6 +1857,53 @@ export function useChat(
}
}

const activeStreamId = streamIdRef.current
if (activeStreamId && streamGenRef.current === gen) {
for (let attempt = 0; attempt < MAX_RECONNECT_ATTEMPTS; attempt++) {
if (streamGenRef.current !== gen) return
if (abortControllerRef.current?.signal.aborted) return

const delayMs = Math.min(RECONNECT_BASE_DELAY_MS * 2 ** attempt, RECONNECT_MAX_DELAY_MS)
logger.info('Reconnect attempt after network error', {
streamId: activeStreamId,
attempt: attempt + 1,
maxAttempts: MAX_RECONNECT_ATTEMPTS,
delayMs,
error: errorMessage,
})

setIsReconnecting(true)
await new Promise((resolve) => setTimeout(resolve, delayMs))

if (streamGenRef.current !== gen) return
if (abortControllerRef.current?.signal.aborted) return

try {
await resumeOrFinalize({
streamId: activeStreamId,
assistantId,
gen,
fromEventId: lastEventIdRef.current,
signal: abortController.signal,
})
return
} catch (reconnectErr) {
if (reconnectErr instanceof Error && reconnectErr.name === 'AbortError') return
logger.warn('Reconnect attempt failed', {
streamId: activeStreamId,
attempt: attempt + 1,
error: reconnectErr instanceof Error ? reconnectErr.message : String(reconnectErr),
})
}
}

logger.error('All reconnect attempts exhausted', {
streamId: activeStreamId,
maxAttempts: MAX_RECONNECT_ATTEMPTS,
})
setIsReconnecting(false)
}

setError(errorMessage)
if (streamGenRef.current === gen) {
finalize({ error: true })
Expand All @@ -1859,7 +1916,7 @@ export function useChat(
queryClient,
processSSEStream,
finalize,
attachToExistingStream,
resumeOrFinalize,
preparePendingStreamRecovery,
]
)
Expand Down
21 changes: 14 additions & 7 deletions apps/sim/lib/copilot/chat-streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ const CHAT_STREAM_LOCK_TTL_SECONDS = 2 * 60 * 60
const STREAM_ABORT_TTL_SECONDS = 10 * 60
const STREAM_ABORT_POLL_MS = 1000

// Registry of in-flight Sim→Go streams so the explicit abort endpoint can
// reach them. Keyed by streamId, cleaned up when the stream completes.
const activeStreams = new Map<string, AbortController>()
interface ActiveStreamEntry {
abortController: AbortController
userStopController: AbortController
}

const activeStreams = new Map<string, ActiveStreamEntry>()

// Tracks in-flight streams by chatId so that a subsequent request for the
// same chat can force-abort the previous stream and wait for it to settle
Expand Down Expand Up @@ -184,9 +187,10 @@ export async function abortActiveStream(streamId: string): Promise<boolean> {
})
}
}
const controller = activeStreams.get(streamId)
if (!controller) return published
controller.abort()
const entry = activeStreams.get(streamId)
if (!entry) return published
entry.userStopController.abort()
entry.abortController.abort()
activeStreams.delete(streamId)
return true
}
Expand Down Expand Up @@ -285,7 +289,8 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
let eventWriter: ReturnType<typeof createStreamEventWriter> | null = null
let clientDisconnected = false
const abortController = new AbortController()
activeStreams.set(streamId, abortController)
const userStopController = new AbortController()
activeStreams.set(streamId, { abortController, userStopController })

if (chatId && !pendingChatStreamAlreadyRegistered) {
registerPendingChatStream(chatId, streamId)
Expand Down Expand Up @@ -348,6 +353,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
try {
const shouldAbort = await redis.get(getStreamAbortKey(streamId))
if (shouldAbort && !abortController.signal.aborted) {
userStopController.abort()
abortController.abort()
await redis.del(getStreamAbortKey(streamId))
}
Expand Down Expand Up @@ -449,6 +455,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
executionId,
runId,
abortSignal: abortController.signal,
userStopSignal: userStopController.signal,
onEvent: async (event) => {
await pushEvent(event)
},
Expand Down
1 change: 1 addition & 0 deletions apps/sim/lib/copilot/orchestrator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export async function orchestrateCopilotStream(
execContext.executionId = executionId
execContext.runId = runId
execContext.abortSignal = options.abortSignal
execContext.userStopSignal = options.userStopSignal

const payloadMsgId = requestPayload?.messageId
const messageId = typeof payloadMsgId === 'string' ? payloadMsgId : crypto.randomUUID()
Expand Down
10 changes: 9 additions & 1 deletion apps/sim/lib/copilot/orchestrator/sse/handlers/handlers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ describe('sse-handlers tool lifecycle', () => {

it('marks an in-flight tool as cancelled when aborted mid-execution', async () => {
const abortController = new AbortController()
const userStopController = new AbortController()
execContext.abortSignal = abortController.signal
execContext.userStopSignal = userStopController.signal

executeToolServerSide.mockImplementationOnce(
() =>
Expand All @@ -137,9 +139,15 @@ describe('sse-handlers tool lifecycle', () => {
} as any,
context,
execContext,
{ interactive: false, timeout: 1000, abortSignal: abortController.signal }
{
interactive: false,
timeout: 1000,
abortSignal: abortController.signal,
userStopSignal: userStopController.signal,
}
)

userStopController.abort()
abortController.abort()
await new Promise((resolve) => setTimeout(resolve, 10))

Expand Down
10 changes: 7 additions & 3 deletions apps/sim/lib/copilot/orchestrator/sse/handlers/tool-execution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,13 @@ function abortRequested(
execContext: ExecutionContext,
options?: OrchestratorOptions
): boolean {
return Boolean(
options?.abortSignal?.aborted || execContext.abortSignal?.aborted || context.wasAborted
)
if (options?.userStopSignal?.aborted || execContext.userStopSignal?.aborted) {
return true
}
if (context.wasAborted) {
return true
}
return false
}

function cancelledCompletion(message: string): AsyncToolCompletion {
Expand Down
1 change: 1 addition & 0 deletions apps/sim/lib/copilot/orchestrator/tool-executor/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ async function executeServerToolDirect(
chatId: context.chatId,
messageId: context.messageId,
abortSignal: context.abortSignal,
userStopSignal: context.userStopSignal,
})

const resultRecord =
Expand Down
4 changes: 4 additions & 0 deletions apps/sim/lib/copilot/orchestrator/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ export interface OrchestratorOptions {
onComplete?: (result: OrchestratorResult) => void | Promise<void>
onError?: (error: Error) => void | Promise<void>
abortSignal?: AbortSignal
/** Fires only on explicit user stop, never on passive transport disconnect. */
userStopSignal?: AbortSignal
interactive?: boolean
}

Expand Down Expand Up @@ -199,6 +201,8 @@ export interface ExecutionContext {
executionId?: string
runId?: string
abortSignal?: AbortSignal
/** Fires only on explicit user stop, never on passive transport disconnect. */
userStopSignal?: AbortSignal
userTimezone?: string
userPermission?: string
decryptedEnvVars?: Record<string, string>
Expand Down
4 changes: 3 additions & 1 deletion apps/sim/lib/copilot/tools/server/base-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ export interface ServerToolContext {
chatId?: string
messageId?: string
abortSignal?: AbortSignal
/** Fires only on explicit user stop, never on passive transport disconnect. */
userStopSignal?: AbortSignal
}

export function assertServerToolNotAborted(
context?: ServerToolContext,
message = 'Request aborted before tool mutation could be applied.'
): void {
if (context?.abortSignal?.aborted) {
if (context?.userStopSignal?.aborted) {
throw new Error(message)
}
}
Expand Down
Loading