mirror of
				https://github.com/zadam/trilium.git
				synced 2025-10-31 10:26:08 +01:00 
			
		
		
		
	dynamically adjust context window sizes based on conversation context
This commit is contained in:
		| @@ -69,7 +69,7 @@ export interface SemanticContextService { | ||||
|     /** | ||||
|      * Retrieve semantic context based on relevance to user query | ||||
|      */ | ||||
|     getSemanticContext(noteId: string, userQuery: string, maxResults?: number): Promise<string>; | ||||
|     getSemanticContext(noteId: string, userQuery: string, maxResults?: number, messages?: Message[]): Promise<string>; | ||||
|  | ||||
|     /** | ||||
|      * Get progressive context based on depth | ||||
|   | ||||
| @@ -3,6 +3,9 @@ import log from '../../../log.js'; | ||||
| import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js'; | ||||
| import { LLM_CONSTANTS } from '../../constants/provider_constants.js'; | ||||
| import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js'; | ||||
| import modelCapabilitiesService from '../../model_capabilities_service.js'; | ||||
| import { calculateAvailableContextSize } from '../../interfaces/model_capabilities.js'; | ||||
| import type { Message } from '../../ai_interface.js'; | ||||
|  | ||||
| // Use constants from the centralized file | ||||
| // const CONTEXT_WINDOW = { | ||||
| @@ -20,26 +23,46 @@ import type { IContextFormatter, NoteSearchResult } from '../../interfaces/conte | ||||
|  */ | ||||
| export class ContextFormatter implements IContextFormatter { | ||||
|     /** | ||||
|      * Build a structured context string from note sources | ||||
|      * Build formatted context from a list of note search results | ||||
|      * | ||||
|      * @param sources Array of note data with content and metadata | ||||
|      * @param query The user's query for context | ||||
|      * @param providerId Optional provider ID to customize formatting | ||||
|      * @param messages Optional conversation messages to adjust context size | ||||
|      * @returns Formatted context string | ||||
|      */ | ||||
|     async buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId: string = 'default'): Promise<string> { | ||||
|     async buildContextFromNotes( | ||||
|         sources: NoteSearchResult[], | ||||
|         query: string, | ||||
|         providerId: string = 'default', | ||||
|         messages: Message[] = [] | ||||
|     ): Promise<string> { | ||||
|         if (!sources || sources.length === 0) { | ||||
|             log.info('No sources provided to context formatter'); | ||||
|             return CONTEXT_PROMPTS.NO_NOTES_CONTEXT; | ||||
|         } | ||||
|  | ||||
|         try { | ||||
|             // Get appropriate context size based on provider | ||||
|             const maxTotalLength = | ||||
|             // Get model name from provider | ||||
|             let modelName = providerId; | ||||
|  | ||||
|             // Look up model capabilities | ||||
|             const modelCapabilities = await modelCapabilitiesService.getModelCapabilities(modelName); | ||||
|  | ||||
|             // Calculate available context size for this conversation | ||||
|             const availableContextSize = calculateAvailableContextSize( | ||||
|                 modelCapabilities, | ||||
|                 messages, | ||||
|                 3 // Expected additional turns | ||||
|             ); | ||||
|  | ||||
|             // Use the calculated size or fall back to constants | ||||
|             const maxTotalLength = availableContextSize || ( | ||||
|                 providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI : | ||||
|                 providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC : | ||||
|                 providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA : | ||||
|                 LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT; | ||||
|                 LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT | ||||
|             ); | ||||
|  | ||||
|             // DEBUG: Log context window size | ||||
|             log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`); | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js'; | ||||
| import becca from '../../../../becca/becca.js'; | ||||
| import type { NoteSearchResult } from '../../interfaces/context_interfaces.js'; | ||||
| import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; | ||||
| import type { Message } from '../../ai_interface.js'; | ||||
|  | ||||
| /** | ||||
|  * Main context service that integrates all context-related functionality | ||||
| @@ -635,14 +636,20 @@ export class ContextService { | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Get semantic context for a note and query | ||||
|      * Get semantic context based on query | ||||
|      * | ||||
|      * @param noteId - The base note ID | ||||
|      * @param userQuery - The user's query | ||||
|      * @param maxResults - Maximum number of results to include | ||||
|      * @returns Formatted context string | ||||
|      * @param noteId - Note ID to start from | ||||
|      * @param userQuery - User query for context | ||||
|      * @param maxResults - Maximum number of results | ||||
|      * @param messages - Optional conversation messages to adjust context size | ||||
|      * @returns Formatted context | ||||
|      */ | ||||
|     async getSemanticContext(noteId: string, userQuery: string, maxResults: number = 5): Promise<string> { | ||||
|     async getSemanticContext( | ||||
|         noteId: string, | ||||
|         userQuery: string, | ||||
|         maxResults: number = 5, | ||||
|         messages: Message[] = [] | ||||
|     ): Promise<string> { | ||||
|         if (!this.initialized) { | ||||
|             await this.initialize(); | ||||
|         } | ||||
| @@ -712,24 +719,39 @@ export class ContextService { | ||||
|  | ||||
|             // Get content for the top N most relevant notes | ||||
|             const mostRelevantNotes = rankedNotes.slice(0, maxResults); | ||||
|             const relevantContent = await Promise.all( | ||||
|  | ||||
|             // Get relevant search results to pass to context formatter | ||||
|             const searchResults = await Promise.all( | ||||
|                 mostRelevantNotes.map(async note => { | ||||
|                     const content = await this.contextExtractor.getNoteContent(note.noteId); | ||||
|                     if (!content) return null; | ||||
|  | ||||
|                     // Format with relevance score and title | ||||
|                     return `### ${note.title} (Relevance: ${Math.round(note.relevance * 100)}%)\n\n${content}`; | ||||
|                     // Create a properly typed NoteSearchResult object | ||||
|                     return { | ||||
|                         noteId: note.noteId, | ||||
|                         title: note.title, | ||||
|                         content, | ||||
|                         similarity: note.relevance | ||||
|                     }; | ||||
|                 }) | ||||
|             ); | ||||
|  | ||||
|             // Filter out nulls and empty content | ||||
|             const validResults: NoteSearchResult[] = searchResults | ||||
|                 .filter(result => result !== null && result.content && result.content.trim().length > 0) | ||||
|                 .map(result => result as NoteSearchResult); | ||||
|  | ||||
|             // If no content retrieved, return empty string | ||||
|             if (!relevantContent.filter(Boolean).length) { | ||||
|             if (validResults.length === 0) { | ||||
|                 return ''; | ||||
|             } | ||||
|  | ||||
|             return `# Relevant Context\n\nThe following notes are most relevant to your query:\n\n${ | ||||
|                 relevantContent.filter(Boolean).join('\n\n---\n\n') | ||||
|             }`; | ||||
|             // Get the provider information for formatting | ||||
|             const provider = await providerManager.getPreferredEmbeddingProvider(); | ||||
|             const providerId = provider?.name || 'default'; | ||||
|  | ||||
|             // Format the context with the context formatter (which handles adjusting for conversation size) | ||||
|             return contextFormatter.buildContextFromNotes(validResults, userQuery, providerId, messages); | ||||
|         } catch (error) { | ||||
|             log.error(`Error getting semantic context: ${error}`); | ||||
|             return ''; | ||||
|   | ||||
| @@ -154,10 +154,11 @@ class TriliumContextService { | ||||
|      * @param noteId - The note ID | ||||
|      * @param userQuery - The user's query | ||||
|      * @param maxResults - Maximum results to include | ||||
|      * @param messages - Optional conversation messages to adjust context size | ||||
|      * @returns Formatted context string | ||||
|      */ | ||||
|     async getSemanticContext(noteId: string, userQuery: string, maxResults = 5): Promise<string> { | ||||
|         return contextService.getSemanticContext(noteId, userQuery, maxResults); | ||||
|     async getSemanticContext(noteId: string, userQuery: string, maxResults = 5, messages: Message[] = []): Promise<string> { | ||||
|         return contextService.getSemanticContext(noteId, userQuery, maxResults, messages); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|   | ||||
| @@ -46,7 +46,12 @@ export interface NoteSearchResult { | ||||
|  * Interface for context formatter | ||||
|  */ | ||||
| export interface IContextFormatter { | ||||
|   buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId?: string): Promise<string>; | ||||
|   buildContextFromNotes( | ||||
|     sources: NoteSearchResult[], | ||||
|     query: string, | ||||
|     providerId?: string, | ||||
|     messages?: Array<{role: string, content: string}> | ||||
|   ): Promise<string>; | ||||
| } | ||||
|  | ||||
| /** | ||||
|   | ||||
							
								
								
									
										138
									
								
								src/services/llm/interfaces/model_capabilities.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								src/services/llm/interfaces/model_capabilities.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| import type { Message } from '../ai_interface.js'; | ||||
|  | ||||
| /** | ||||
|  * Interface for model capabilities information | ||||
|  */ | ||||
| export interface ModelCapabilities { | ||||
|     contextWindowTokens: number;  // Context window size in tokens | ||||
|     contextWindowChars: number;   // Estimated context window size in characters (for planning) | ||||
|     maxCompletionTokens: number;  // Maximum completion length | ||||
|     hasFunctionCalling: boolean;  // Whether the model supports function calling | ||||
|     hasVision: boolean;           // Whether the model supports image input | ||||
|     costPerInputToken: number;    // Cost per input token (if applicable) | ||||
|     costPerOutputToken: number;   // Cost per output token (if applicable) | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Default model capabilities for unknown models | ||||
|  */ | ||||
| export const DEFAULT_MODEL_CAPABILITIES: ModelCapabilities = { | ||||
|     contextWindowTokens: 4096, | ||||
|     contextWindowChars: 16000,  // ~4 chars per token estimate | ||||
|     maxCompletionTokens: 1024, | ||||
|     hasFunctionCalling: false, | ||||
|     hasVision: false, | ||||
|     costPerInputToken: 0, | ||||
|     costPerOutputToken: 0 | ||||
| }; | ||||
|  | ||||
| /** | ||||
|  * Model capabilities for common models | ||||
|  */ | ||||
| export const MODEL_CAPABILITIES: Record<string, Partial<ModelCapabilities>> = { | ||||
|     // OpenAI models | ||||
|     'gpt-3.5-turbo': { | ||||
|         contextWindowTokens: 4096, | ||||
|         contextWindowChars: 16000, | ||||
|         hasFunctionCalling: true | ||||
|     }, | ||||
|     'gpt-3.5-turbo-16k': { | ||||
|         contextWindowTokens: 16384, | ||||
|         contextWindowChars: 65000, | ||||
|         hasFunctionCalling: true | ||||
|     }, | ||||
|     'gpt-4': { | ||||
|         contextWindowTokens: 8192, | ||||
|         contextWindowChars: 32000, | ||||
|         hasFunctionCalling: true | ||||
|     }, | ||||
|     'gpt-4-32k': { | ||||
|         contextWindowTokens: 32768, | ||||
|         contextWindowChars: 130000, | ||||
|         hasFunctionCalling: true | ||||
|     }, | ||||
|     'gpt-4-turbo': { | ||||
|         contextWindowTokens: 128000, | ||||
|         contextWindowChars: 512000, | ||||
|         hasFunctionCalling: true, | ||||
|         hasVision: true | ||||
|     }, | ||||
|     'gpt-4o': { | ||||
|         contextWindowTokens: 128000, | ||||
|         contextWindowChars: 512000, | ||||
|         hasFunctionCalling: true, | ||||
|         hasVision: true | ||||
|     }, | ||||
|  | ||||
|     // Anthropic models | ||||
|     'claude-3-haiku': { | ||||
|         contextWindowTokens: 200000, | ||||
|         contextWindowChars: 800000, | ||||
|         hasVision: true | ||||
|     }, | ||||
|     'claude-3-sonnet': { | ||||
|         contextWindowTokens: 200000, | ||||
|         contextWindowChars: 800000, | ||||
|         hasVision: true | ||||
|     }, | ||||
|     'claude-3-opus': { | ||||
|         contextWindowTokens: 200000, | ||||
|         contextWindowChars: 800000, | ||||
|         hasVision: true | ||||
|     }, | ||||
|     'claude-2': { | ||||
|         contextWindowTokens: 100000, | ||||
|         contextWindowChars: 400000 | ||||
|     }, | ||||
|  | ||||
|     // Ollama models (defaults, will be updated dynamically) | ||||
|     'llama3': { | ||||
|         contextWindowTokens: 8192, | ||||
|         contextWindowChars: 32000 | ||||
|     }, | ||||
|     'mistral': { | ||||
|         contextWindowTokens: 8192, | ||||
|         contextWindowChars: 32000 | ||||
|     }, | ||||
|     'llama2': { | ||||
|         contextWindowTokens: 4096, | ||||
|         contextWindowChars: 16000 | ||||
|     } | ||||
| }; | ||||
|  | ||||
| /** | ||||
|  * Calculate available context window size for context generation | ||||
|  * This takes into account expected message sizes and other overhead | ||||
|  * | ||||
|  * @param model Model name | ||||
|  * @param messages Current conversation messages | ||||
|  * @param expectedTurns Number of expected additional conversation turns | ||||
|  * @returns Available context size in characters | ||||
|  */ | ||||
| export function calculateAvailableContextSize( | ||||
|     modelCapabilities: ModelCapabilities, | ||||
|     messages: Message[], | ||||
|     expectedTurns: number = 3 | ||||
| ): number { | ||||
|     // Calculate current message token usage (rough estimate) | ||||
|     let currentMessageChars = 0; | ||||
|     for (const message of messages) { | ||||
|         currentMessageChars += message.content.length; | ||||
|     } | ||||
|  | ||||
|     // Reserve space for system prompt and overhead | ||||
|     const systemPromptReserve = 1000; | ||||
|  | ||||
|     // Reserve space for expected conversation turns | ||||
|     const turnReserve = expectedTurns * 2000; // Average 2000 chars per turn (including both user and assistant) | ||||
|  | ||||
|     // Calculate available space | ||||
|     const totalReserved = currentMessageChars + systemPromptReserve + turnReserve; | ||||
|     const availableContextSize = Math.max(0, modelCapabilities.contextWindowChars - totalReserved); | ||||
|  | ||||
|     // Use at most 70% of total context window size to be safe | ||||
|     const maxSafeContextSize = Math.floor(modelCapabilities.contextWindowChars * 0.7); | ||||
|  | ||||
|     // Return the smaller of available size or max safe size | ||||
|     return Math.min(availableContextSize, maxSafeContextSize); | ||||
| } | ||||
							
								
								
									
										159
									
								
								src/services/llm/model_capabilities_service.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								src/services/llm/model_capabilities_service.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| import log from '../log.js'; | ||||
| import type { ModelCapabilities } from './interfaces/model_capabilities.js'; | ||||
| import { MODEL_CAPABILITIES, DEFAULT_MODEL_CAPABILITIES } from './interfaces/model_capabilities.js'; | ||||
| import aiServiceManager from './ai_service_manager.js'; | ||||
| import { getEmbeddingProvider } from './providers/providers.js'; | ||||
| import type { BaseEmbeddingProvider } from './embeddings/base_embeddings.js'; | ||||
| import type { EmbeddingModelInfo } from './interfaces/embedding_interfaces.js'; | ||||
|  | ||||
| // Define a type for embedding providers that might have the getModelInfo method | ||||
| interface EmbeddingProviderWithModelInfo { | ||||
|     getModelInfo?: (modelName: string) => Promise<EmbeddingModelInfo>; | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Service for fetching and caching model capabilities | ||||
|  */ | ||||
| export class ModelCapabilitiesService { | ||||
|     // Cache model capabilities | ||||
|     private capabilitiesCache: Map<string, ModelCapabilities> = new Map(); | ||||
|  | ||||
|     constructor() { | ||||
|         // Initialize cache with known models | ||||
|         this.initializeCache(); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Initialize the cache with known model capabilities | ||||
|      */ | ||||
|     private initializeCache() { | ||||
|         // Add all predefined model capabilities to cache | ||||
|         for (const [model, capabilities] of Object.entries(MODEL_CAPABILITIES)) { | ||||
|             this.capabilitiesCache.set(model, { | ||||
|                 ...DEFAULT_MODEL_CAPABILITIES, | ||||
|                 ...capabilities | ||||
|             }); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Get model capabilities, fetching from provider if needed | ||||
|      * | ||||
|      * @param modelName Full model name (with or without provider prefix) | ||||
|      * @returns Model capabilities | ||||
|      */ | ||||
|     async getModelCapabilities(modelName: string): Promise<ModelCapabilities> { | ||||
|         // Handle provider-prefixed model names (e.g., "openai:gpt-4") | ||||
|         let provider = 'default'; | ||||
|         let baseModelName = modelName; | ||||
|  | ||||
|         if (modelName.includes(':')) { | ||||
|             const parts = modelName.split(':'); | ||||
|             provider = parts[0]; | ||||
|             baseModelName = parts[1]; | ||||
|         } | ||||
|  | ||||
|         // Check cache first | ||||
|         const cacheKey = baseModelName; | ||||
|         if (this.capabilitiesCache.has(cacheKey)) { | ||||
|             return this.capabilitiesCache.get(cacheKey)!; | ||||
|         } | ||||
|  | ||||
|         // Fetch from provider if possible | ||||
|         try { | ||||
|             // Get provider service | ||||
|             const providerService = aiServiceManager.getService(provider); | ||||
|  | ||||
|             if (providerService && typeof (providerService as any).getModelCapabilities === 'function') { | ||||
|                 // If provider supports direct capability fetching, use it | ||||
|                 const capabilities = await (providerService as any).getModelCapabilities(baseModelName); | ||||
|  | ||||
|                 if (capabilities) { | ||||
|                     // Merge with defaults and cache | ||||
|                     const fullCapabilities = { | ||||
|                         ...DEFAULT_MODEL_CAPABILITIES, | ||||
|                         ...capabilities | ||||
|                     }; | ||||
|  | ||||
|                     this.capabilitiesCache.set(cacheKey, fullCapabilities); | ||||
|                     log.info(`Fetched capabilities for ${modelName}: context window ${fullCapabilities.contextWindowTokens} tokens`); | ||||
|  | ||||
|                     return fullCapabilities; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Try to fetch from embedding provider if available | ||||
|             const embeddingProvider = getEmbeddingProvider(provider); | ||||
|  | ||||
|             if (embeddingProvider) { | ||||
|                 try { | ||||
|                     // Cast to a type that might have getModelInfo method | ||||
|                     const providerWithModelInfo = embeddingProvider as unknown as EmbeddingProviderWithModelInfo; | ||||
|  | ||||
|                     if (providerWithModelInfo.getModelInfo) { | ||||
|                         const modelInfo = await providerWithModelInfo.getModelInfo(baseModelName); | ||||
|  | ||||
|                         if (modelInfo && modelInfo.contextWidth) { | ||||
|                             // Convert to our capabilities format | ||||
|                             const capabilities: ModelCapabilities = { | ||||
|                                 ...DEFAULT_MODEL_CAPABILITIES, | ||||
|                                 contextWindowTokens: modelInfo.contextWidth, | ||||
|                                 contextWindowChars: modelInfo.contextWidth * 4 // Rough estimate: 4 chars per token | ||||
|                             }; | ||||
|  | ||||
|                             this.capabilitiesCache.set(cacheKey, capabilities); | ||||
|                             log.info(`Derived capabilities for ${modelName} from embedding provider: context window ${capabilities.contextWindowTokens} tokens`); | ||||
|  | ||||
|                             return capabilities; | ||||
|                         } | ||||
|                     } | ||||
|                 } catch (error) { | ||||
|                     log.info(`Could not get model info from embedding provider for ${modelName}: ${error}`); | ||||
|                 } | ||||
|             } | ||||
|         } catch (error) { | ||||
|             log.error(`Error fetching model capabilities for ${modelName}: ${error}`); | ||||
|         } | ||||
|  | ||||
|         // If we get here, try to find a similar model in our predefined list | ||||
|         for (const knownModel of Object.keys(MODEL_CAPABILITIES)) { | ||||
|             // Check if the model name contains this known model (e.g., "gpt-4-1106-preview" contains "gpt-4") | ||||
|             if (baseModelName.includes(knownModel)) { | ||||
|                 const capabilities = { | ||||
|                     ...DEFAULT_MODEL_CAPABILITIES, | ||||
|                     ...MODEL_CAPABILITIES[knownModel] | ||||
|                 }; | ||||
|  | ||||
|                 this.capabilitiesCache.set(cacheKey, capabilities); | ||||
|                 log.info(`Using similar model (${knownModel}) capabilities for ${modelName}`); | ||||
|  | ||||
|                 return capabilities; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // Fall back to defaults if nothing else works | ||||
|         log.info(`Using default capabilities for unknown model ${modelName}`); | ||||
|         this.capabilitiesCache.set(cacheKey, DEFAULT_MODEL_CAPABILITIES); | ||||
|  | ||||
|         return DEFAULT_MODEL_CAPABILITIES; | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Update model capabilities in the cache | ||||
|      * | ||||
|      * @param modelName Model name | ||||
|      * @param capabilities Capabilities to update | ||||
|      */ | ||||
|     updateModelCapabilities(modelName: string, capabilities: Partial<ModelCapabilities>) { | ||||
|         const currentCapabilities = this.capabilitiesCache.get(modelName) || DEFAULT_MODEL_CAPABILITIES; | ||||
|  | ||||
|         this.capabilitiesCache.set(modelName, { | ||||
|             ...currentCapabilities, | ||||
|             ...capabilities | ||||
|         }); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Create and export singleton instance | ||||
| const modelCapabilitiesService = new ModelCapabilitiesService(); | ||||
| export default modelCapabilitiesService; | ||||
| @@ -106,7 +106,8 @@ export class ChatPipeline { | ||||
|                     // Get semantic context for regular queries | ||||
|                     const semanticContext = await this.stages.semanticContextExtraction.execute({ | ||||
|                         noteId: input.noteId, | ||||
|                         query: input.query | ||||
|                         query: input.query, | ||||
|                         messages: input.messages | ||||
|                     }); | ||||
|                     context = semanticContext.context; | ||||
|                     this.updateStageMetrics('semanticContextExtraction', contextStartTime); | ||||
| @@ -136,10 +137,10 @@ export class ChatPipeline { | ||||
|             const llmStartTime = Date.now(); | ||||
|  | ||||
|             // Setup streaming handler if streaming is enabled and callback provided | ||||
|             const enableStreaming = this.config.enableStreaming &&  | ||||
|             const enableStreaming = this.config.enableStreaming && | ||||
|                                   modelSelection.options.stream !== false && | ||||
|                                   typeof streamCallback === 'function'; | ||||
|              | ||||
|  | ||||
|             if (enableStreaming) { | ||||
|                 // Make sure stream is enabled in options | ||||
|                 modelSelection.options.stream = true; | ||||
| @@ -157,10 +158,10 @@ export class ChatPipeline { | ||||
|                 await completion.response.stream(async (chunk: StreamChunk) => { | ||||
|                     // Process the chunk text | ||||
|                     const processedChunk = await this.processStreamChunk(chunk, input.options); | ||||
|                      | ||||
|  | ||||
|                     // Accumulate text for final response | ||||
|                     accumulatedText += processedChunk.text; | ||||
|                      | ||||
|  | ||||
|                     // Forward to callback | ||||
|                     await streamCallback!(processedChunk.text, processedChunk.done); | ||||
|                 }); | ||||
| @@ -182,12 +183,12 @@ export class ChatPipeline { | ||||
|  | ||||
|             const endTime = Date.now(); | ||||
|             const executionTime = endTime - startTime; | ||||
|              | ||||
|  | ||||
|             // Update overall average execution time | ||||
|             this.metrics.averageExecutionTime =  | ||||
|             this.metrics.averageExecutionTime = | ||||
|                 (this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) / | ||||
|                 this.metrics.totalExecutions; | ||||
|                  | ||||
|  | ||||
|             log.info(`Chat pipeline completed in ${executionTime}ms`); | ||||
|  | ||||
|             return finalResponse; | ||||
| @@ -235,12 +236,12 @@ export class ChatPipeline { | ||||
|      */ | ||||
|     private updateStageMetrics(stageName: string, startTime: number) { | ||||
|         if (!this.config.enableMetrics) return; | ||||
|          | ||||
|  | ||||
|         const executionTime = Date.now() - startTime; | ||||
|         const metrics = this.metrics.stageMetrics[stageName]; | ||||
|          | ||||
|  | ||||
|         metrics.totalExecutions++; | ||||
|         metrics.averageExecutionTime =  | ||||
|         metrics.averageExecutionTime = | ||||
|             (metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) / | ||||
|             metrics.totalExecutions; | ||||
|     } | ||||
| @@ -258,7 +259,7 @@ export class ChatPipeline { | ||||
|     resetMetrics(): void { | ||||
|         this.metrics.totalExecutions = 0; | ||||
|         this.metrics.averageExecutionTime = 0; | ||||
|          | ||||
|  | ||||
|         Object.keys(this.metrics.stageMetrics).forEach(stageName => { | ||||
|             this.metrics.stageMetrics[stageName] = { | ||||
|                 totalExecutions: 0, | ||||
|   | ||||
| @@ -15,12 +15,12 @@ export interface ChatPipelineConfig { | ||||
|      * Whether to enable streaming support | ||||
|      */ | ||||
|     enableStreaming: boolean; | ||||
|      | ||||
|  | ||||
|     /** | ||||
|      * Whether to enable performance metrics | ||||
|      */ | ||||
|     enableMetrics: boolean; | ||||
|      | ||||
|  | ||||
|     /** | ||||
|      * Maximum number of tool call iterations | ||||
|      */ | ||||
| @@ -84,6 +84,7 @@ export interface SemanticContextExtractionInput extends PipelineInput { | ||||
|     noteId: string; | ||||
|     query: string; | ||||
|     maxResults?: number; | ||||
|     messages?: Message[]; | ||||
| } | ||||
|  | ||||
| /** | ||||
|   | ||||
| @@ -15,11 +15,11 @@ export class SemanticContextExtractionStage extends BasePipelineStage<SemanticCo | ||||
|      * Extract semantic context based on a query | ||||
|      */ | ||||
|     protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> { | ||||
|         const { noteId, query, maxResults = 5 } = input; | ||||
|         const { noteId, query, maxResults = 5, messages = [] } = input; | ||||
|         log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`); | ||||
|  | ||||
|         const contextService = aiServiceManager.getContextService(); | ||||
|         const context = await contextService.getSemanticContext(noteId, query, maxResults); | ||||
|         const context = await contextService.getSemanticContext(noteId, query, maxResults, messages); | ||||
|  | ||||
|         return { context }; | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user