Path: blob/main/extensions/copilot/src/extension/prompts/node/base/promptRenderer.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { BasePromptElementProps, PromptRenderer as BasePromptRenderer, HTMLTracer, ITokenizer, JSONTree, MetadataMap, OutputMode, QueueItem, Raw, RenderPromptResult } from '@vscode/prompt-tsx';6import type { ChatResponsePart, ChatResponseProgressPart, LanguageModelToolTokenizationOptions, Progress } from 'vscode';7import { ChatLocation } from '../../../../platform/chat/common/commonTypes';8import { toTextPart } from '../../../../platform/chat/common/globalStringUtils';9import { ConfigKey, IConfigurationService } from '../../../../platform/configuration/common/configurationService';10import { IEndpointProvider } from '../../../../platform/endpoint/common/endpointProvider';11import { ILogService } from '../../../../platform/log/common/logService';12import { IChatEndpoint } from '../../../../platform/networking/common/networking';13import { IRequestLogger } from '../../../../platform/requestLogger/common/requestLogger';14import { ITokenizerProvider } from '../../../../platform/tokenizer/node/tokenizer';15import { createServiceIdentifier } from '../../../../util/common/services';16import { isLocation } from '../../../../util/common/types';17import { CancellationToken } from '../../../../util/vs/base/common/cancellation';18import { URI } from '../../../../util/vs/base/common/uri';19import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';20import { ServiceCollection } from '../../../../util/vs/platform/instantiation/common/serviceCollection';21import { ChatResponseReferencePart, Location, Uri } from '../../../../vscodeTypes';22import { RendererVisualizations } from '../../../inlineChat/node/rendererVisualization';23import { getUniqueReferences, PromptReference } from '../../../prompt/common/conversation';24import { IBuildPromptContext } from '../../../prompt/common/intents';25import { IIntent } from '../../../prompt/node/intents';26import { PromptElementCtor } from './promptElement';2728/**29* Allows us to use dependency injection to pass the fully fledged IChatEndpoint to the prompt element being rendered.30*/31export type IPromptEndpoint = IChatEndpoint & {32_serviceBrand: undefined;33};34export const IPromptEndpoint = createServiceIdentifier<IPromptEndpoint>('IPromptEndpoint');3536/**37* Convenience intent invocation that uses a renderer for prompt crafting.38*/39export abstract class RendererIntentInvocation {4041constructor(42readonly intent: IIntent,43readonly location: ChatLocation,44readonly endpoint: IChatEndpoint,45) { }4647async buildPrompt(promptParams: IBuildPromptContext, progress: Progress<ChatResponseReferencePart | ChatResponseProgressPart>, token: CancellationToken): Promise<RenderPromptResult<OutputMode.Raw> & { references: PromptReference[] }> {48const renderer = await this.createRenderer(promptParams, this.endpoint, progress, token);49return await renderer.render(progress, token);50}5152abstract createRenderer(promptParams: IBuildPromptContext, endpoint: IChatEndpoint, progress: Progress<ChatResponseReferencePart | ChatResponseProgressPart>, token: CancellationToken): BasePromptRenderer<any, OutputMode.Raw> | Promise<BasePromptRenderer<any, OutputMode.Raw>>;53}5455export class PromptRenderer<P extends BasePromptElementProps> extends BasePromptRenderer<P, OutputMode.Raw> {56private ctorName?: string; // when and iff tracing is enabled5758public static create<P extends BasePromptElementProps>(59instantiationService: IInstantiationService,60endpoint: IChatEndpoint,61ctor: PromptElementCtor<P, any>,62props: P,63) {64// TODO@Alex, TODO@Joh: instantiationService.createInstance doesn't work here65const hydratedInstaService = instantiationService.createChild(new ServiceCollection([IPromptEndpoint, endpoint]));66return hydratedInstaService.invokeFunction((accessor) => {67const tokenizerProvider = accessor.get(ITokenizerProvider);68let renderer = new PromptRenderer(hydratedInstaService, endpoint, ctor, props, tokenizerProvider, accessor.get(IRequestLogger), accessor.get(ILogService), accessor.get(IConfigurationService));6970const visualizations = RendererVisualizations.getIfVisualizationTestIsRunning();71if (visualizations) {72renderer = visualizations.decorateAndRegister(renderer, ctor.name);73}7475return renderer;76});77}7879constructor(80private readonly _instantiationService: IInstantiationService,81protected readonly endpoint: IChatEndpoint,82ctor: PromptElementCtor<P, any>,83props: P,84@ITokenizerProvider tokenizerProvider: ITokenizerProvider,85@IRequestLogger private readonly _requestLogger: IRequestLogger,86@ILogService private readonly _logService: ILogService,87@IConfigurationService configurationService: IConfigurationService,88) {89const tokenizer = tokenizerProvider.acquireTokenizer(endpoint);90super(endpoint, ctor, props, tokenizer);9192if (configurationService.getConfig(ConfigKey.TeamInternal.EnablePromptRendererTracing)) {93this.ctorName = ctor.name || '<anonymous>';94this.tracer = new HTMLTracer();95}96}9798override createElement(element: QueueItem<PromptElementCtor<P, any>, P>, ...args: any[]) {99return this._instantiationService.createInstance(element.ctor, element.props, ...args);100}101102override async render(progress?: Progress<ChatResponsePart> | undefined, token?: CancellationToken | undefined, opts?: Partial<{ trace: boolean }>): Promise<RenderPromptResult> {103const result = await super.render(progress, token);104const defaultOptions = { trace: true };105opts = { ...defaultOptions, ...opts };106if (this.tracer && !!opts.trace) {107this._requestLogger.addPromptTrace(this.ctorName!, this.endpoint, result, this.tracer as HTMLTracer);108}109110// Collapse consecutive system messages because CAPI currently expects a single111// system message per prompt. Note: this may slightly reduce the actual112// token usage under the `RenderPromptResult.tokenCount`.113for (let i = 1; i < result.messages.length; i++) {114const current = result.messages[i];115const prev = result.messages[i - 1];116if (current.role === Raw.ChatRole.System && prev.role === Raw.ChatRole.System) {117const lastContent = prev.content.at(-1);118const nextContent = current.content.at(0);119if (lastContent && nextContent && lastContent.type === Raw.ChatCompletionContentPartKind.Text && nextContent.type === Raw.ChatCompletionContentPartKind.Text) {120lastContent.text = lastContent.text.trimEnd() + '\n' + nextContent.text;121prev.content = prev.content.concat(current.content.slice(1));122} else {123prev.content.push(toTextPart('\n'));124prev.content = prev.content.concat(current.content);125}126result.messages.splice(i, 1);127i--;128}129}130131const references = result.references.filter(ref => this.validateReference(ref));132this._instantiationService.dispose(); // Dispose the hydrated instantiation service133return { ...result, references: getUniqueReferences(references) };134}135136private validateReference(reference: PromptReference) {137const validateLocation = (value: Uri | Location) => {138const uri = isLocation(value) ? value.uri : value;139if (!URI.isUri(uri)) {140this._logService.warn(`Invalid PromptReference, uri not an instance of URI: ${uri}. Try to find the code that is creating this reference and fix it.`);141return false;142}143return true;144};145const refAnchor = reference.anchor;146if ('variableName' in refAnchor) {147return refAnchor.value === undefined || validateLocation(refAnchor.value);148}149return validateLocation(refAnchor);150}151152async countTokens(token?: CancellationToken): Promise<number> {153const result = await super.render(undefined, token);154return result.tokenCount;155}156}157158export async function renderPromptElement<P extends BasePromptElementProps>(159instantiationService: IInstantiationService,160endpoint: IChatEndpoint,161ctor: PromptElementCtor<P, any>,162props: P,163progress?: Progress<ChatResponseProgressPart>,164token?: CancellationToken,165): Promise<{ messages: Raw.ChatMessage[]; tokenCount: number; metadatas: MetadataMap; references: PromptReference[] }> {166const renderer = PromptRenderer.create(instantiationService, endpoint, ctor, props);167const { messages, tokenCount, references, metadata } = await renderer.render(progress, token);168return { messages, tokenCount, metadatas: metadata, references: getUniqueReferences(references) };169}170171// The below all exists to wrap `renderElementJSON` to call our instantiation service172173class PromptRendererForJSON<P extends BasePromptElementProps> extends BasePromptRenderer<P, OutputMode.Raw> {174constructor(175ctor: PromptElementCtor<P, any>,176props: P,177tokenOptions: LanguageModelToolTokenizationOptions | undefined,178chatEndpoint: IChatEndpoint,179private readonly instantiationService: IInstantiationService,180) {181// Copied from prompt-tsx to map the vscode tool tokenOptions to ITokenizer182const tokenizer: ITokenizer<OutputMode.Raw> = {183mode: OutputMode.Raw,184countMessageTokens(message) {185throw new Error('Tools may only return text, not messages.');186},187tokenLength(text, token) {188if (text.type === Raw.ChatCompletionContentPartKind.Text) {189return Promise.resolve(tokenOptions?.countTokens(text.text, token) ?? Promise.resolve(1));190} else {191return Promise.resolve(1);192}193},194};195196super({ modelMaxPromptTokens: tokenOptions?.tokenBudget ?? chatEndpoint.modelMaxPromptTokens }, ctor, props, tokenizer);197}198199override createElement(element: QueueItem<PromptElementCtor<P, any>, P>, ...args: any[]) {200return this.instantiationService.createInstance(element.ctor, element.props, ...args);201}202}203204export async function renderPromptElementJSON<P extends BasePromptElementProps>(205instantiationService: IInstantiationService,206ctor: PromptElementCtor<P, any>,207props: P,208tokenOptions?: LanguageModelToolTokenizationOptions,209token?: CancellationToken210): Promise<JSONTree.PromptElementJSON> {211// todo@connor4312: we don't know what model the tool call will use, just assume copilot base212// todo@lramos15: We should pass in endpoint provider rather than doing invoke function, but this was easier213const endpoint = await instantiationService.invokeFunction(async (accessor) => {214const endpointProvider = accessor.get(IEndpointProvider);215return await endpointProvider.getChatEndpoint('copilot-base');216});217const hydratedInstaService = instantiationService.createChild(new ServiceCollection([IPromptEndpoint, endpoint]));218const renderer = new PromptRendererForJSON(ctor as any, props, tokenOptions, endpoint, hydratedInstaService);219return await renderer.renderElementJSON(token);220}221222223