Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpTaskManager.ts
4780 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { disposableTimeout } from '../../../../base/common/async.js';6import { CancellationToken, CancellationTokenSource } from '../../../../base/common/cancellation.js';7import { CancellationError } from '../../../../base/common/errors.js';8import { Emitter } from '../../../../base/common/event.js';9import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from '../../../../base/common/lifecycle.js';10import { generateUuid } from '../../../../base/common/uuid.js';11import type { McpServerRequestHandler } from './mcpServerRequestHandler.js';12import { McpError } from './mcpTypes.js';13import { MCP } from './modelContextProtocol.js';1415export interface IMcpTaskInternal extends IDisposable {16readonly id: string;17onDidUpdateState(task: MCP.Task): void;18setHandler(handler: McpServerRequestHandler | undefined): void;19}2021interface TaskEntry extends IDisposable {22task: MCP.Task;23result?: MCP.Result;24error?: MCP.Error;25cts: CancellationTokenSource;26/** Time when the task was created (client time), used to calculate TTL expiration */27createdAtTime: number;28/** Promise that resolves when the task execution completes */29executionPromise: Promise<void>;30}3132/**33* Manages in-memory task state for server-side MCP tasks (sampling and elicitation).34* Also tracks client-side tasks to survive handler reconnections.35* Lifecycle is tied to the McpServer instance.36*/37export class McpTaskManager extends Disposable {38private readonly _serverTasks = this._register(new DisposableMap<string, TaskEntry>());39private readonly _clientTasks = this._register(new DisposableMap<string, IMcpTaskInternal>());40private readonly _onDidUpdateTask = this._register(new Emitter<MCP.Task>());41public readonly onDidUpdateTask = this._onDidUpdateTask.event;4243/**44* Attach a new handler to this task manager.45* Updates all client tasks to use the new handler.46*/47setHandler(handler: McpServerRequestHandler | undefined): void {48for (const task of this._clientTasks.values()) {49task.setHandler(handler);50}51}5253/**54* Get a client task by ID for status notification handling.55*/56getClientTask(taskId: string): IMcpTaskInternal | undefined {57return this._clientTasks.get(taskId);58}5960/**61* Track a new client task.62*/63adoptClientTask(task: IMcpTaskInternal): void {64this._clientTasks.set(task.id, task);65}6667/**68* Untracks a client task.69*/70abandonClientTask(taskId: string): void {71this._clientTasks.deleteAndDispose(taskId);72}7374/**75* Create a new task and execute it asynchronously.76* Returns the task immediately while execution continues in the background.77*/78public createTask<TResult extends MCP.Result>(79ttl: number | null,80executor: (token: CancellationToken) => Promise<TResult>81): MCP.CreateTaskResult {82const taskId = generateUuid();83const createdAt = new Date().toISOString();84const createdAtTime = Date.now();8586const task: MCP.Task = {87taskId,88status: 'working',89createdAt,90ttl,91pollInterval: 1000, // Suggest 1 second polling interval92};9394const store = new DisposableStore();95const cts = new CancellationTokenSource();96store.add(toDisposable(() => cts.dispose(true)));9798const executionPromise = this._executeTask(taskId, executor, cts.token);99100// Delete the task after its TTL. Or, if no TTL is given, delete it shortly after the task completes.101if (ttl) {102store.add(disposableTimeout(() => this._serverTasks.deleteAndDispose(taskId), ttl));103} else {104executionPromise.finally(() => {105const timeout = this._register(disposableTimeout(() => {106this._serverTasks.deleteAndDispose(taskId);107this._store.delete(timeout);108}, 60_000));109});110}111112this._serverTasks.set(taskId, {113task,114cts,115dispose: () => store.dispose(),116createdAtTime,117executionPromise,118});119120return { task };121}122123/**124* Execute a task asynchronously and update its state.125*/126private async _executeTask<TResult extends MCP.Result>(127taskId: string,128executor: (token: CancellationToken) => Promise<TResult>,129token: CancellationToken130): Promise<void> {131try {132const result = await executor(token);133this._updateTaskStatus(taskId, 'completed', undefined, result);134} catch (error) {135if (error instanceof CancellationError) {136this._updateTaskStatus(taskId, 'cancelled', 'Task was cancelled by the client');137} else if (error instanceof McpError) {138this._updateTaskStatus(taskId, 'failed', error.message, undefined, {139code: error.code,140message: error.message,141data: error.data,142});143} else if (error instanceof Error) {144this._updateTaskStatus(taskId, 'failed', error.message, undefined, {145code: MCP.INTERNAL_ERROR,146message: error.message,147});148} else {149this._updateTaskStatus(taskId, 'failed', 'Unknown error', undefined, {150code: MCP.INTERNAL_ERROR,151message: 'Unknown error',152});153}154}155}156157/**158* Update task status and optionally store result or error.159*/160private _updateTaskStatus(161taskId: string,162status: MCP.TaskStatus,163statusMessage?: string,164result?: MCP.Result,165error?: MCP.Error166): void {167const entry = this._serverTasks.get(taskId);168if (!entry) {169return;170}171172entry.task.status = status;173if (statusMessage !== undefined) {174entry.task.statusMessage = statusMessage;175}176if (result !== undefined) {177entry.result = result;178}179if (error !== undefined) {180entry.error = error;181}182183this._onDidUpdateTask.fire({ ...entry.task });184}185186/**187* Get the current state of a task.188* Returns an error if the task doesn't exist or has expired.189*/190public getTask(taskId: string): MCP.GetTaskResult {191const entry = this._serverTasks.get(taskId);192if (!entry) {193throw new McpError(MCP.INVALID_PARAMS, `Task not found: ${taskId}`);194}195196return { ...entry.task };197}198199/**200* Get the result of a completed task.201* Blocks until the task completes if it's still in progress.202*/203public async getTaskResult(taskId: string): Promise<MCP.GetTaskPayloadResult> {204const entry = this._serverTasks.get(taskId);205if (!entry) {206throw new McpError(MCP.INVALID_PARAMS, `Task not found: ${taskId}`);207}208209if (entry.task.status === 'working' || entry.task.status === 'input_required') {210await entry.executionPromise;211}212213// Refresh entry after waiting214const updatedEntry = this._serverTasks.get(taskId);215if (!updatedEntry) {216throw new McpError(MCP.INVALID_PARAMS, `Task not found: ${taskId}`);217}218219if (updatedEntry.error) {220throw new McpError(updatedEntry.error.code, updatedEntry.error.message, updatedEntry.error.data);221}222223if (!updatedEntry.result) {224throw new McpError(MCP.INTERNAL_ERROR, 'Task completed but no result available');225}226227return updatedEntry.result;228}229230/**231* Cancel a task.232*/233public cancelTask(taskId: string): MCP.CancelTaskResult {234const entry = this._serverTasks.get(taskId);235if (!entry) {236throw new McpError(MCP.INVALID_PARAMS, `Task not found: ${taskId}`);237}238239// Check if already in terminal status240if (entry.task.status === 'completed' || entry.task.status === 'failed' || entry.task.status === 'cancelled') {241throw new McpError(MCP.INVALID_PARAMS, `Cannot cancel task in ${entry.task.status} status`);242}243244entry.task.status = 'cancelled';245entry.task.statusMessage = 'Task was cancelled by the client';246entry.cts.cancel();247248return { ...entry.task };249}250251/**252* List all tasks.253*/254public listTasks(): MCP.ListTasksResult {255const tasks: MCP.Task[] = [];256257for (const entry of this._serverTasks.values()) {258tasks.push({ ...entry.task });259}260261return { tasks };262}263}264265266