mirror of
				https://github.com/zadam/trilium.git
				synced 2025-10-31 10:26:08 +01:00 
			
		
		
		
	dynamically adjust context window sizes based on conversation context
This commit is contained in:
		| @@ -69,7 +69,7 @@ export interface SemanticContextService { | |||||||
|     /** |     /** | ||||||
|      * Retrieve semantic context based on relevance to user query |      * Retrieve semantic context based on relevance to user query | ||||||
|      */ |      */ | ||||||
|     getSemanticContext(noteId: string, userQuery: string, maxResults?: number): Promise<string>; |     getSemanticContext(noteId: string, userQuery: string, maxResults?: number, messages?: Message[]): Promise<string>; | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * Get progressive context based on depth |      * Get progressive context based on depth | ||||||
|   | |||||||
| @@ -3,6 +3,9 @@ import log from '../../../log.js'; | |||||||
| import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js'; | import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js'; | ||||||
| import { LLM_CONSTANTS } from '../../constants/provider_constants.js'; | import { LLM_CONSTANTS } from '../../constants/provider_constants.js'; | ||||||
| import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js'; | import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js'; | ||||||
|  | import modelCapabilitiesService from '../../model_capabilities_service.js'; | ||||||
|  | import { calculateAvailableContextSize } from '../../interfaces/model_capabilities.js'; | ||||||
|  | import type { Message } from '../../ai_interface.js'; | ||||||
|  |  | ||||||
| // Use constants from the centralized file | // Use constants from the centralized file | ||||||
| // const CONTEXT_WINDOW = { | // const CONTEXT_WINDOW = { | ||||||
| @@ -20,26 +23,46 @@ import type { IContextFormatter, NoteSearchResult } from '../../interfaces/conte | |||||||
|  */ |  */ | ||||||
| export class ContextFormatter implements IContextFormatter { | export class ContextFormatter implements IContextFormatter { | ||||||
|     /** |     /** | ||||||
|      * Build a structured context string from note sources |      * Build formatted context from a list of note search results | ||||||
|      * |      * | ||||||
|      * @param sources Array of note data with content and metadata |      * @param sources Array of note data with content and metadata | ||||||
|      * @param query The user's query for context |      * @param query The user's query for context | ||||||
|      * @param providerId Optional provider ID to customize formatting |      * @param providerId Optional provider ID to customize formatting | ||||||
|  |      * @param messages Optional conversation messages to adjust context size | ||||||
|      * @returns Formatted context string |      * @returns Formatted context string | ||||||
|      */ |      */ | ||||||
|     async buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId: string = 'default'): Promise<string> { |     async buildContextFromNotes( | ||||||
|  |         sources: NoteSearchResult[], | ||||||
|  |         query: string, | ||||||
|  |         providerId: string = 'default', | ||||||
|  |         messages: Message[] = [] | ||||||
|  |     ): Promise<string> { | ||||||
|         if (!sources || sources.length === 0) { |         if (!sources || sources.length === 0) { | ||||||
|             log.info('No sources provided to context formatter'); |             log.info('No sources provided to context formatter'); | ||||||
|             return CONTEXT_PROMPTS.NO_NOTES_CONTEXT; |             return CONTEXT_PROMPTS.NO_NOTES_CONTEXT; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         try { |         try { | ||||||
|             // Get appropriate context size based on provider |             // Get model name from provider | ||||||
|             const maxTotalLength = |             let modelName = providerId; | ||||||
|  |  | ||||||
|  |             // Look up model capabilities | ||||||
|  |             const modelCapabilities = await modelCapabilitiesService.getModelCapabilities(modelName); | ||||||
|  |  | ||||||
|  |             // Calculate available context size for this conversation | ||||||
|  |             const availableContextSize = calculateAvailableContextSize( | ||||||
|  |                 modelCapabilities, | ||||||
|  |                 messages, | ||||||
|  |                 3 // Expected additional turns | ||||||
|  |             ); | ||||||
|  |  | ||||||
|  |             // Use the calculated size or fall back to constants | ||||||
|  |             const maxTotalLength = availableContextSize || ( | ||||||
|                 providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI : |                 providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI : | ||||||
|                 providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC : |                 providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC : | ||||||
|                 providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA : |                 providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA : | ||||||
|                 LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT; |                 LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT | ||||||
|  |             ); | ||||||
|  |  | ||||||
|             // DEBUG: Log context window size |             // DEBUG: Log context window size | ||||||
|             log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`); |             log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`); | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js'; | |||||||
| import becca from '../../../../becca/becca.js'; | import becca from '../../../../becca/becca.js'; | ||||||
| import type { NoteSearchResult } from '../../interfaces/context_interfaces.js'; | import type { NoteSearchResult } from '../../interfaces/context_interfaces.js'; | ||||||
| import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; | import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; | ||||||
|  | import type { Message } from '../../ai_interface.js'; | ||||||
|  |  | ||||||
| /** | /** | ||||||
|  * Main context service that integrates all context-related functionality |  * Main context service that integrates all context-related functionality | ||||||
| @@ -635,14 +636,20 @@ export class ContextService { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * Get semantic context for a note and query |      * Get semantic context based on query | ||||||
|      * |      * | ||||||
|      * @param noteId - The base note ID |      * @param noteId - Note ID to start from | ||||||
|      * @param userQuery - The user's query |      * @param userQuery - User query for context | ||||||
|      * @param maxResults - Maximum number of results to include |      * @param maxResults - Maximum number of results | ||||||
|      * @returns Formatted context string |      * @param messages - Optional conversation messages to adjust context size | ||||||
|  |      * @returns Formatted context | ||||||
|      */ |      */ | ||||||
|     async getSemanticContext(noteId: string, userQuery: string, maxResults: number = 5): Promise<string> { |     async getSemanticContext( | ||||||
|  |         noteId: string, | ||||||
|  |         userQuery: string, | ||||||
|  |         maxResults: number = 5, | ||||||
|  |         messages: Message[] = [] | ||||||
|  |     ): Promise<string> { | ||||||
|         if (!this.initialized) { |         if (!this.initialized) { | ||||||
|             await this.initialize(); |             await this.initialize(); | ||||||
|         } |         } | ||||||
| @@ -712,24 +719,39 @@ export class ContextService { | |||||||
|  |  | ||||||
|             // Get content for the top N most relevant notes |             // Get content for the top N most relevant notes | ||||||
|             const mostRelevantNotes = rankedNotes.slice(0, maxResults); |             const mostRelevantNotes = rankedNotes.slice(0, maxResults); | ||||||
|             const relevantContent = await Promise.all( |  | ||||||
|  |             // Get relevant search results to pass to context formatter | ||||||
|  |             const searchResults = await Promise.all( | ||||||
|                 mostRelevantNotes.map(async note => { |                 mostRelevantNotes.map(async note => { | ||||||
|                     const content = await this.contextExtractor.getNoteContent(note.noteId); |                     const content = await this.contextExtractor.getNoteContent(note.noteId); | ||||||
|                     if (!content) return null; |                     if (!content) return null; | ||||||
|  |  | ||||||
|                     // Format with relevance score and title |                     // Create a properly typed NoteSearchResult object | ||||||
|                     return `### ${note.title} (Relevance: ${Math.round(note.relevance * 100)}%)\n\n${content}`; |                     return { | ||||||
|  |                         noteId: note.noteId, | ||||||
|  |                         title: note.title, | ||||||
|  |                         content, | ||||||
|  |                         similarity: note.relevance | ||||||
|  |                     }; | ||||||
|                 }) |                 }) | ||||||
|             ); |             ); | ||||||
|  |  | ||||||
|  |             // Filter out nulls and empty content | ||||||
|  |             const validResults: NoteSearchResult[] = searchResults | ||||||
|  |                 .filter(result => result !== null && result.content && result.content.trim().length > 0) | ||||||
|  |                 .map(result => result as NoteSearchResult); | ||||||
|  |  | ||||||
|             // If no content retrieved, return empty string |             // If no content retrieved, return empty string | ||||||
|             if (!relevantContent.filter(Boolean).length) { |             if (validResults.length === 0) { | ||||||
|                 return ''; |                 return ''; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             return `# Relevant Context\n\nThe following notes are most relevant to your query:\n\n${ |             // Get the provider information for formatting | ||||||
|                 relevantContent.filter(Boolean).join('\n\n---\n\n') |             const provider = await providerManager.getPreferredEmbeddingProvider(); | ||||||
|             }`; |             const providerId = provider?.name || 'default'; | ||||||
|  |  | ||||||
|  |             // Format the context with the context formatter (which handles adjusting for conversation size) | ||||||
|  |             return contextFormatter.buildContextFromNotes(validResults, userQuery, providerId, messages); | ||||||
|         } catch (error) { |         } catch (error) { | ||||||
|             log.error(`Error getting semantic context: ${error}`); |             log.error(`Error getting semantic context: ${error}`); | ||||||
|             return ''; |             return ''; | ||||||
|   | |||||||
| @@ -154,10 +154,11 @@ class TriliumContextService { | |||||||
|      * @param noteId - The note ID |      * @param noteId - The note ID | ||||||
|      * @param userQuery - The user's query |      * @param userQuery - The user's query | ||||||
|      * @param maxResults - Maximum results to include |      * @param maxResults - Maximum results to include | ||||||
|  |      * @param messages - Optional conversation messages to adjust context size | ||||||
|      * @returns Formatted context string |      * @returns Formatted context string | ||||||
|      */ |      */ | ||||||
|     async getSemanticContext(noteId: string, userQuery: string, maxResults = 5): Promise<string> { |     async getSemanticContext(noteId: string, userQuery: string, maxResults = 5, messages: Message[] = []): Promise<string> { | ||||||
|         return contextService.getSemanticContext(noteId, userQuery, maxResults); |         return contextService.getSemanticContext(noteId, userQuery, maxResults, messages); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|   | |||||||
| @@ -46,7 +46,12 @@ export interface NoteSearchResult { | |||||||
|  * Interface for context formatter |  * Interface for context formatter | ||||||
|  */ |  */ | ||||||
| export interface IContextFormatter { | export interface IContextFormatter { | ||||||
|   buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId?: string): Promise<string>; |   buildContextFromNotes( | ||||||
|  |     sources: NoteSearchResult[], | ||||||
|  |     query: string, | ||||||
|  |     providerId?: string, | ||||||
|  |     messages?: Array<{role: string, content: string}> | ||||||
|  |   ): Promise<string>; | ||||||
| } | } | ||||||
|  |  | ||||||
| /** | /** | ||||||
|   | |||||||
							
								
								
									
										138
									
								
								src/services/llm/interfaces/model_capabilities.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								src/services/llm/interfaces/model_capabilities.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | |||||||
|  | import type { Message } from '../ai_interface.js'; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Interface for model capabilities information | ||||||
|  |  */ | ||||||
|  | export interface ModelCapabilities { | ||||||
|  |     contextWindowTokens: number;  // Context window size in tokens | ||||||
|  |     contextWindowChars: number;   // Estimated context window size in characters (for planning) | ||||||
|  |     maxCompletionTokens: number;  // Maximum completion length | ||||||
|  |     hasFunctionCalling: boolean;  // Whether the model supports function calling | ||||||
|  |     hasVision: boolean;           // Whether the model supports image input | ||||||
|  |     costPerInputToken: number;    // Cost per input token (if applicable) | ||||||
|  |     costPerOutputToken: number;   // Cost per output token (if applicable) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Default model capabilities for unknown models | ||||||
|  |  */ | ||||||
|  | export const DEFAULT_MODEL_CAPABILITIES: ModelCapabilities = { | ||||||
|  |     contextWindowTokens: 4096, | ||||||
|  |     contextWindowChars: 16000,  // ~4 chars per token estimate | ||||||
|  |     maxCompletionTokens: 1024, | ||||||
|  |     hasFunctionCalling: false, | ||||||
|  |     hasVision: false, | ||||||
|  |     costPerInputToken: 0, | ||||||
|  |     costPerOutputToken: 0 | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Model capabilities for common models | ||||||
|  |  */ | ||||||
|  | export const MODEL_CAPABILITIES: Record<string, Partial<ModelCapabilities>> = { | ||||||
|  |     // OpenAI models | ||||||
|  |     'gpt-3.5-turbo': { | ||||||
|  |         contextWindowTokens: 4096, | ||||||
|  |         contextWindowChars: 16000, | ||||||
|  |         hasFunctionCalling: true | ||||||
|  |     }, | ||||||
|  |     'gpt-3.5-turbo-16k': { | ||||||
|  |         contextWindowTokens: 16384, | ||||||
|  |         contextWindowChars: 65000, | ||||||
|  |         hasFunctionCalling: true | ||||||
|  |     }, | ||||||
|  |     'gpt-4': { | ||||||
|  |         contextWindowTokens: 8192, | ||||||
|  |         contextWindowChars: 32000, | ||||||
|  |         hasFunctionCalling: true | ||||||
|  |     }, | ||||||
|  |     'gpt-4-32k': { | ||||||
|  |         contextWindowTokens: 32768, | ||||||
|  |         contextWindowChars: 130000, | ||||||
|  |         hasFunctionCalling: true | ||||||
|  |     }, | ||||||
|  |     'gpt-4-turbo': { | ||||||
|  |         contextWindowTokens: 128000, | ||||||
|  |         contextWindowChars: 512000, | ||||||
|  |         hasFunctionCalling: true, | ||||||
|  |         hasVision: true | ||||||
|  |     }, | ||||||
|  |     'gpt-4o': { | ||||||
|  |         contextWindowTokens: 128000, | ||||||
|  |         contextWindowChars: 512000, | ||||||
|  |         hasFunctionCalling: true, | ||||||
|  |         hasVision: true | ||||||
|  |     }, | ||||||
|  |  | ||||||
|  |     // Anthropic models | ||||||
|  |     'claude-3-haiku': { | ||||||
|  |         contextWindowTokens: 200000, | ||||||
|  |         contextWindowChars: 800000, | ||||||
|  |         hasVision: true | ||||||
|  |     }, | ||||||
|  |     'claude-3-sonnet': { | ||||||
|  |         contextWindowTokens: 200000, | ||||||
|  |         contextWindowChars: 800000, | ||||||
|  |         hasVision: true | ||||||
|  |     }, | ||||||
|  |     'claude-3-opus': { | ||||||
|  |         contextWindowTokens: 200000, | ||||||
|  |         contextWindowChars: 800000, | ||||||
|  |         hasVision: true | ||||||
|  |     }, | ||||||
|  |     'claude-2': { | ||||||
|  |         contextWindowTokens: 100000, | ||||||
|  |         contextWindowChars: 400000 | ||||||
|  |     }, | ||||||
|  |  | ||||||
|  |     // Ollama models (defaults, will be updated dynamically) | ||||||
|  |     'llama3': { | ||||||
|  |         contextWindowTokens: 8192, | ||||||
|  |         contextWindowChars: 32000 | ||||||
|  |     }, | ||||||
|  |     'mistral': { | ||||||
|  |         contextWindowTokens: 8192, | ||||||
|  |         contextWindowChars: 32000 | ||||||
|  |     }, | ||||||
|  |     'llama2': { | ||||||
|  |         contextWindowTokens: 4096, | ||||||
|  |         contextWindowChars: 16000 | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Calculate available context window size for context generation | ||||||
|  |  * This takes into account expected message sizes and other overhead | ||||||
|  |  * | ||||||
|  |  * @param model Model name | ||||||
|  |  * @param messages Current conversation messages | ||||||
|  |  * @param expectedTurns Number of expected additional conversation turns | ||||||
|  |  * @returns Available context size in characters | ||||||
|  |  */ | ||||||
|  | export function calculateAvailableContextSize( | ||||||
|  |     modelCapabilities: ModelCapabilities, | ||||||
|  |     messages: Message[], | ||||||
|  |     expectedTurns: number = 3 | ||||||
|  | ): number { | ||||||
|  |     // Calculate current message token usage (rough estimate) | ||||||
|  |     let currentMessageChars = 0; | ||||||
|  |     for (const message of messages) { | ||||||
|  |         currentMessageChars += message.content.length; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Reserve space for system prompt and overhead | ||||||
|  |     const systemPromptReserve = 1000; | ||||||
|  |  | ||||||
|  |     // Reserve space for expected conversation turns | ||||||
|  |     const turnReserve = expectedTurns * 2000; // Average 2000 chars per turn (including both user and assistant) | ||||||
|  |  | ||||||
|  |     // Calculate available space | ||||||
|  |     const totalReserved = currentMessageChars + systemPromptReserve + turnReserve; | ||||||
|  |     const availableContextSize = Math.max(0, modelCapabilities.contextWindowChars - totalReserved); | ||||||
|  |  | ||||||
|  |     // Use at most 70% of total context window size to be safe | ||||||
|  |     const maxSafeContextSize = Math.floor(modelCapabilities.contextWindowChars * 0.7); | ||||||
|  |  | ||||||
|  |     // Return the smaller of available size or max safe size | ||||||
|  |     return Math.min(availableContextSize, maxSafeContextSize); | ||||||
|  | } | ||||||
							
								
								
									
										159
									
								
								src/services/llm/model_capabilities_service.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								src/services/llm/model_capabilities_service.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | |||||||
|  | import log from '../log.js'; | ||||||
|  | import type { ModelCapabilities } from './interfaces/model_capabilities.js'; | ||||||
|  | import { MODEL_CAPABILITIES, DEFAULT_MODEL_CAPABILITIES } from './interfaces/model_capabilities.js'; | ||||||
|  | import aiServiceManager from './ai_service_manager.js'; | ||||||
|  | import { getEmbeddingProvider } from './providers/providers.js'; | ||||||
|  | import type { BaseEmbeddingProvider } from './embeddings/base_embeddings.js'; | ||||||
|  | import type { EmbeddingModelInfo } from './interfaces/embedding_interfaces.js'; | ||||||
|  |  | ||||||
|  | // Define a type for embedding providers that might have the getModelInfo method | ||||||
|  | interface EmbeddingProviderWithModelInfo { | ||||||
|  |     getModelInfo?: (modelName: string) => Promise<EmbeddingModelInfo>; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Service for fetching and caching model capabilities | ||||||
|  |  */ | ||||||
|  | export class ModelCapabilitiesService { | ||||||
|  |     // Cache model capabilities | ||||||
|  |     private capabilitiesCache: Map<string, ModelCapabilities> = new Map(); | ||||||
|  |  | ||||||
|  |     constructor() { | ||||||
|  |         // Initialize cache with known models | ||||||
|  |         this.initializeCache(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Initialize the cache with known model capabilities | ||||||
|  |      */ | ||||||
|  |     private initializeCache() { | ||||||
|  |         // Add all predefined model capabilities to cache | ||||||
|  |         for (const [model, capabilities] of Object.entries(MODEL_CAPABILITIES)) { | ||||||
|  |             this.capabilitiesCache.set(model, { | ||||||
|  |                 ...DEFAULT_MODEL_CAPABILITIES, | ||||||
|  |                 ...capabilities | ||||||
|  |             }); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Get model capabilities, fetching from provider if needed | ||||||
|  |      * | ||||||
|  |      * @param modelName Full model name (with or without provider prefix) | ||||||
|  |      * @returns Model capabilities | ||||||
|  |      */ | ||||||
|  |     async getModelCapabilities(modelName: string): Promise<ModelCapabilities> { | ||||||
|  |         // Handle provider-prefixed model names (e.g., "openai:gpt-4") | ||||||
|  |         let provider = 'default'; | ||||||
|  |         let baseModelName = modelName; | ||||||
|  |  | ||||||
|  |         if (modelName.includes(':')) { | ||||||
|  |             const parts = modelName.split(':'); | ||||||
|  |             provider = parts[0]; | ||||||
|  |             baseModelName = parts[1]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Check cache first | ||||||
|  |         const cacheKey = baseModelName; | ||||||
|  |         if (this.capabilitiesCache.has(cacheKey)) { | ||||||
|  |             return this.capabilitiesCache.get(cacheKey)!; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Fetch from provider if possible | ||||||
|  |         try { | ||||||
|  |             // Get provider service | ||||||
|  |             const providerService = aiServiceManager.getService(provider); | ||||||
|  |  | ||||||
|  |             if (providerService && typeof (providerService as any).getModelCapabilities === 'function') { | ||||||
|  |                 // If provider supports direct capability fetching, use it | ||||||
|  |                 const capabilities = await (providerService as any).getModelCapabilities(baseModelName); | ||||||
|  |  | ||||||
|  |                 if (capabilities) { | ||||||
|  |                     // Merge with defaults and cache | ||||||
|  |                     const fullCapabilities = { | ||||||
|  |                         ...DEFAULT_MODEL_CAPABILITIES, | ||||||
|  |                         ...capabilities | ||||||
|  |                     }; | ||||||
|  |  | ||||||
|  |                     this.capabilitiesCache.set(cacheKey, fullCapabilities); | ||||||
|  |                     log.info(`Fetched capabilities for ${modelName}: context window ${fullCapabilities.contextWindowTokens} tokens`); | ||||||
|  |  | ||||||
|  |                     return fullCapabilities; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // Try to fetch from embedding provider if available | ||||||
|  |             const embeddingProvider = getEmbeddingProvider(provider); | ||||||
|  |  | ||||||
|  |             if (embeddingProvider) { | ||||||
|  |                 try { | ||||||
|  |                     // Cast to a type that might have getModelInfo method | ||||||
|  |                     const providerWithModelInfo = embeddingProvider as unknown as EmbeddingProviderWithModelInfo; | ||||||
|  |  | ||||||
|  |                     if (providerWithModelInfo.getModelInfo) { | ||||||
|  |                         const modelInfo = await providerWithModelInfo.getModelInfo(baseModelName); | ||||||
|  |  | ||||||
|  |                         if (modelInfo && modelInfo.contextWidth) { | ||||||
|  |                             // Convert to our capabilities format | ||||||
|  |                             const capabilities: ModelCapabilities = { | ||||||
|  |                                 ...DEFAULT_MODEL_CAPABILITIES, | ||||||
|  |                                 contextWindowTokens: modelInfo.contextWidth, | ||||||
|  |                                 contextWindowChars: modelInfo.contextWidth * 4 // Rough estimate: 4 chars per token | ||||||
|  |                             }; | ||||||
|  |  | ||||||
|  |                             this.capabilitiesCache.set(cacheKey, capabilities); | ||||||
|  |                             log.info(`Derived capabilities for ${modelName} from embedding provider: context window ${capabilities.contextWindowTokens} tokens`); | ||||||
|  |  | ||||||
|  |                             return capabilities; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } catch (error) { | ||||||
|  |                     log.info(`Could not get model info from embedding provider for ${modelName}: ${error}`); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } catch (error) { | ||||||
|  |             log.error(`Error fetching model capabilities for ${modelName}: ${error}`); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // If we get here, try to find a similar model in our predefined list | ||||||
|  |         for (const knownModel of Object.keys(MODEL_CAPABILITIES)) { | ||||||
|  |             // Check if the model name contains this known model (e.g., "gpt-4-1106-preview" contains "gpt-4") | ||||||
|  |             if (baseModelName.includes(knownModel)) { | ||||||
|  |                 const capabilities = { | ||||||
|  |                     ...DEFAULT_MODEL_CAPABILITIES, | ||||||
|  |                     ...MODEL_CAPABILITIES[knownModel] | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 this.capabilitiesCache.set(cacheKey, capabilities); | ||||||
|  |                 log.info(`Using similar model (${knownModel}) capabilities for ${modelName}`); | ||||||
|  |  | ||||||
|  |                 return capabilities; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Fall back to defaults if nothing else works | ||||||
|  |         log.info(`Using default capabilities for unknown model ${modelName}`); | ||||||
|  |         this.capabilitiesCache.set(cacheKey, DEFAULT_MODEL_CAPABILITIES); | ||||||
|  |  | ||||||
|  |         return DEFAULT_MODEL_CAPABILITIES; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Update model capabilities in the cache | ||||||
|  |      * | ||||||
|  |      * @param modelName Model name | ||||||
|  |      * @param capabilities Capabilities to update | ||||||
|  |      */ | ||||||
|  |     updateModelCapabilities(modelName: string, capabilities: Partial<ModelCapabilities>) { | ||||||
|  |         const currentCapabilities = this.capabilitiesCache.get(modelName) || DEFAULT_MODEL_CAPABILITIES; | ||||||
|  |  | ||||||
|  |         this.capabilitiesCache.set(modelName, { | ||||||
|  |             ...currentCapabilities, | ||||||
|  |             ...capabilities | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Create and export singleton instance | ||||||
|  | const modelCapabilitiesService = new ModelCapabilitiesService(); | ||||||
|  | export default modelCapabilitiesService; | ||||||
| @@ -106,7 +106,8 @@ export class ChatPipeline { | |||||||
|                     // Get semantic context for regular queries |                     // Get semantic context for regular queries | ||||||
|                     const semanticContext = await this.stages.semanticContextExtraction.execute({ |                     const semanticContext = await this.stages.semanticContextExtraction.execute({ | ||||||
|                         noteId: input.noteId, |                         noteId: input.noteId, | ||||||
|                         query: input.query |                         query: input.query, | ||||||
|  |                         messages: input.messages | ||||||
|                     }); |                     }); | ||||||
|                     context = semanticContext.context; |                     context = semanticContext.context; | ||||||
|                     this.updateStageMetrics('semanticContextExtraction', contextStartTime); |                     this.updateStageMetrics('semanticContextExtraction', contextStartTime); | ||||||
| @@ -136,10 +137,10 @@ export class ChatPipeline { | |||||||
|             const llmStartTime = Date.now(); |             const llmStartTime = Date.now(); | ||||||
|  |  | ||||||
|             // Setup streaming handler if streaming is enabled and callback provided |             // Setup streaming handler if streaming is enabled and callback provided | ||||||
|             const enableStreaming = this.config.enableStreaming &&  |             const enableStreaming = this.config.enableStreaming && | ||||||
|                                   modelSelection.options.stream !== false && |                                   modelSelection.options.stream !== false && | ||||||
|                                   typeof streamCallback === 'function'; |                                   typeof streamCallback === 'function'; | ||||||
|              |  | ||||||
|             if (enableStreaming) { |             if (enableStreaming) { | ||||||
|                 // Make sure stream is enabled in options |                 // Make sure stream is enabled in options | ||||||
|                 modelSelection.options.stream = true; |                 modelSelection.options.stream = true; | ||||||
| @@ -157,10 +158,10 @@ export class ChatPipeline { | |||||||
|                 await completion.response.stream(async (chunk: StreamChunk) => { |                 await completion.response.stream(async (chunk: StreamChunk) => { | ||||||
|                     // Process the chunk text |                     // Process the chunk text | ||||||
|                     const processedChunk = await this.processStreamChunk(chunk, input.options); |                     const processedChunk = await this.processStreamChunk(chunk, input.options); | ||||||
|                      |  | ||||||
|                     // Accumulate text for final response |                     // Accumulate text for final response | ||||||
|                     accumulatedText += processedChunk.text; |                     accumulatedText += processedChunk.text; | ||||||
|                      |  | ||||||
|                     // Forward to callback |                     // Forward to callback | ||||||
|                     await streamCallback!(processedChunk.text, processedChunk.done); |                     await streamCallback!(processedChunk.text, processedChunk.done); | ||||||
|                 }); |                 }); | ||||||
| @@ -182,12 +183,12 @@ export class ChatPipeline { | |||||||
|  |  | ||||||
|             const endTime = Date.now(); |             const endTime = Date.now(); | ||||||
|             const executionTime = endTime - startTime; |             const executionTime = endTime - startTime; | ||||||
|              |  | ||||||
|             // Update overall average execution time |             // Update overall average execution time | ||||||
|             this.metrics.averageExecutionTime =  |             this.metrics.averageExecutionTime = | ||||||
|                 (this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) / |                 (this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) / | ||||||
|                 this.metrics.totalExecutions; |                 this.metrics.totalExecutions; | ||||||
|                  |  | ||||||
|             log.info(`Chat pipeline completed in ${executionTime}ms`); |             log.info(`Chat pipeline completed in ${executionTime}ms`); | ||||||
|  |  | ||||||
|             return finalResponse; |             return finalResponse; | ||||||
| @@ -235,12 +236,12 @@ export class ChatPipeline { | |||||||
|      */ |      */ | ||||||
|     private updateStageMetrics(stageName: string, startTime: number) { |     private updateStageMetrics(stageName: string, startTime: number) { | ||||||
|         if (!this.config.enableMetrics) return; |         if (!this.config.enableMetrics) return; | ||||||
|          |  | ||||||
|         const executionTime = Date.now() - startTime; |         const executionTime = Date.now() - startTime; | ||||||
|         const metrics = this.metrics.stageMetrics[stageName]; |         const metrics = this.metrics.stageMetrics[stageName]; | ||||||
|          |  | ||||||
|         metrics.totalExecutions++; |         metrics.totalExecutions++; | ||||||
|         metrics.averageExecutionTime =  |         metrics.averageExecutionTime = | ||||||
|             (metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) / |             (metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) / | ||||||
|             metrics.totalExecutions; |             metrics.totalExecutions; | ||||||
|     } |     } | ||||||
| @@ -258,7 +259,7 @@ export class ChatPipeline { | |||||||
|     resetMetrics(): void { |     resetMetrics(): void { | ||||||
|         this.metrics.totalExecutions = 0; |         this.metrics.totalExecutions = 0; | ||||||
|         this.metrics.averageExecutionTime = 0; |         this.metrics.averageExecutionTime = 0; | ||||||
|          |  | ||||||
|         Object.keys(this.metrics.stageMetrics).forEach(stageName => { |         Object.keys(this.metrics.stageMetrics).forEach(stageName => { | ||||||
|             this.metrics.stageMetrics[stageName] = { |             this.metrics.stageMetrics[stageName] = { | ||||||
|                 totalExecutions: 0, |                 totalExecutions: 0, | ||||||
|   | |||||||
| @@ -15,12 +15,12 @@ export interface ChatPipelineConfig { | |||||||
|      * Whether to enable streaming support |      * Whether to enable streaming support | ||||||
|      */ |      */ | ||||||
|     enableStreaming: boolean; |     enableStreaming: boolean; | ||||||
|      |  | ||||||
|     /** |     /** | ||||||
|      * Whether to enable performance metrics |      * Whether to enable performance metrics | ||||||
|      */ |      */ | ||||||
|     enableMetrics: boolean; |     enableMetrics: boolean; | ||||||
|      |  | ||||||
|     /** |     /** | ||||||
|      * Maximum number of tool call iterations |      * Maximum number of tool call iterations | ||||||
|      */ |      */ | ||||||
| @@ -84,6 +84,7 @@ export interface SemanticContextExtractionInput extends PipelineInput { | |||||||
|     noteId: string; |     noteId: string; | ||||||
|     query: string; |     query: string; | ||||||
|     maxResults?: number; |     maxResults?: number; | ||||||
|  |     messages?: Message[]; | ||||||
| } | } | ||||||
|  |  | ||||||
| /** | /** | ||||||
|   | |||||||
| @@ -15,11 +15,11 @@ export class SemanticContextExtractionStage extends BasePipelineStage<SemanticCo | |||||||
|      * Extract semantic context based on a query |      * Extract semantic context based on a query | ||||||
|      */ |      */ | ||||||
|     protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> { |     protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> { | ||||||
|         const { noteId, query, maxResults = 5 } = input; |         const { noteId, query, maxResults = 5, messages = [] } = input; | ||||||
|         log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`); |         log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`); | ||||||
|  |  | ||||||
|         const contextService = aiServiceManager.getContextService(); |         const contextService = aiServiceManager.getContextService(); | ||||||
|         const context = await contextService.getSemanticContext(noteId, query, maxResults); |         const context = await contextService.getSemanticContext(noteId, query, maxResults, messages); | ||||||
|  |  | ||||||
|         return { context }; |         return { context }; | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user