Path: blob/main/extensions/copilot/src/platform/nesFetch/common/responseStream.ts
13401 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { ErrorUtils } from '../../../util/common/errors';6import { Result } from '../../../util/common/result';7import { DeferredPromise } from '../../../util/vs/base/common/async';8import { assertType } from '../../../util/vs/base/common/types';9import { RequestId } from '../../networking/common/fetch';10import { IHeaders, Response } from '../../networking/common/fetcherService';11import { Completion } from './completionsAPI';1213export class ResponseStream {14/**15* A promise that resolves to the array of completions that were emitted by the stream.16*17* (it's expected to not throw)18*/19public readonly aggregatedStream: Promise<Result<Completion[], Error>>;2021/**22* A completion that aggregates completions stream.23*24* (it's expected to not throw)25*/26public readonly response: Promise<Result<Completion, Error>>;2728/**29* The stream of completions that were emitted by the response.30*31* @remarks This stream is single-use — it can only be iterated once.32*33* @throws {Error} if the response stream throws an error.34*/35public readonly stream: AsyncIterable<Completion>;3637constructor(private readonly fetcherResponse: Response, stream: AsyncIterable<Completion>, public readonly requestId: RequestId, public readonly headers: IHeaders) {38const tokensDeferredPromise = new DeferredPromise<Result<Completion[], Error>>();39this.aggregatedStream = tokensDeferredPromise.p;40this.response = this.aggregatedStream.then((completions) => {41if (completions.isError()) {42return completions;43}44try {45return Result.ok(ResponseStream.aggregateCompletionsStream(completions.val));46} catch (err) {47return Result.error(err);48}49});5051this.stream = streamWithAggregation(stream, tokensDeferredPromise);52}5354/**55* @throws client of the method should handle the error56*/57public async destroy(): Promise<void> {58await this.fetcherResponse.body.destroy();59}6061private static aggregateCompletionsStream(stream: Completion[]): Completion {62let text = '';63let finishReason: Completion.FinishReason | null = null;64let aggregatedLogsProbs: Completion.LogProbs | null = null;65let aggregatedUsage: Completion.Usage | undefined = undefined;6667for (const completion of stream) {68const choice = completion.choices[0]; // TODO@ulugbekna: we only support choice.index=069text += choice.text ?? '';70if (choice.logprobs) {71if (aggregatedLogsProbs === null) {72aggregatedLogsProbs = {73tokens: [...choice.logprobs.tokens],74token_logprobs: [...choice.logprobs.token_logprobs],75text_offset: [...choice.logprobs.text_offset],76top_logprobs: [...choice.logprobs.top_logprobs],77};78} else {79aggregatedLogsProbs.tokens.push(...choice.logprobs.tokens);80aggregatedLogsProbs.token_logprobs.push(...choice.logprobs.token_logprobs);81aggregatedLogsProbs.text_offset.push(...choice.logprobs.text_offset);82aggregatedLogsProbs.top_logprobs.push(...choice.logprobs.top_logprobs);83}84}85if (completion.usage) {86if (aggregatedUsage === undefined) {87aggregatedUsage = {88completion_tokens: completion.usage.completion_tokens,89prompt_tokens: completion.usage.prompt_tokens,90total_tokens: completion.usage.total_tokens,91completion_tokens_details: {92audio_tokens: completion.usage.completion_tokens_details.audio_tokens,93reasoning_tokens: completion.usage.completion_tokens_details.reasoning_tokens,94},95prompt_tokens_details: {96audio_tokens: completion.usage.prompt_tokens_details.audio_tokens,97reasoning_tokens: completion.usage.prompt_tokens_details.reasoning_tokens,98}99};100} else {101aggregatedUsage.completion_tokens += completion.usage.completion_tokens;102aggregatedUsage.prompt_tokens += completion.usage.prompt_tokens;103aggregatedUsage.total_tokens += completion.usage.total_tokens;104aggregatedUsage.completion_tokens_details.audio_tokens += completion.usage.completion_tokens_details.audio_tokens;105aggregatedUsage.completion_tokens_details.reasoning_tokens += completion.usage.completion_tokens_details.reasoning_tokens;106aggregatedUsage.prompt_tokens_details.audio_tokens += completion.usage.prompt_tokens_details.audio_tokens;107aggregatedUsage.prompt_tokens_details.reasoning_tokens += completion.usage.prompt_tokens_details.reasoning_tokens;108}109}110if (choice.finish_reason) {111assertType(112finishReason === null,113'cannot already have finishReason if just seeing choice.finish_reason'114);115finishReason = choice.finish_reason;116}117}118119if (stream.length === 0) {120throw new Error(`Response is empty!`);121}122123const completion = stream[0];124125const choice: Completion.Choice = {126index: 0,127finish_reason: finishReason,128logprobs: aggregatedLogsProbs,129text,130};131132const aggregatedCompletion: Completion = {133choices: [choice],134system_fingerprint: completion.system_fingerprint,135object: completion.object,136usage: aggregatedUsage,137};138139return aggregatedCompletion;140}141}142143/**144* Wraps an async iterable stream into an async generator that also collects completions145* for aggregation and resolves the deferred promise when done.146*/147async function* streamWithAggregation(148stream: AsyncIterable<Completion>,149deferredPromise: DeferredPromise<Result<Completion[], Error>>150): AsyncGenerator<Completion> {151const completions: Completion[] = [];152let error: Error | undefined;153try {154for await (const completion of stream) {155completions.push(completion);156yield completion;157}158} catch (e: unknown) {159error = ErrorUtils.fromUnknown(e);160throw error;161} finally {162deferredPromise.complete(163error ? Result.error(error) : Result.ok(completions)164);165}166}167168169