mirror of
				https://github.com/zadam/trilium.git
				synced 2025-10-31 10:26:08 +01:00 
			
		
		
		
	tool calling is close to working
getting closer to calling tools... we definitely need this closer to tool execution... agentic tool calling is...kind of working?
This commit is contained in:
		| @@ -7,6 +7,10 @@ import { MessagePreparationStage } from './stages/message_preparation_stage.js'; | ||||
| import { ModelSelectionStage } from './stages/model_selection_stage.js'; | ||||
| import { LLMCompletionStage } from './stages/llm_completion_stage.js'; | ||||
| import { ResponseProcessingStage } from './stages/response_processing_stage.js'; | ||||
| import { ToolCallingStage } from './stages/tool_calling_stage.js'; | ||||
| import { VectorSearchStage } from './stages/vector_search_stage.js'; | ||||
| import toolRegistry from '../tools/tool_registry.js'; | ||||
| import toolInitializer from '../tools/tool_initializer.js'; | ||||
| import log from '../../log.js'; | ||||
|  | ||||
| /** | ||||
| @@ -22,6 +26,8 @@ export class ChatPipeline { | ||||
|         modelSelection: ModelSelectionStage; | ||||
|         llmCompletion: LLMCompletionStage; | ||||
|         responseProcessing: ResponseProcessingStage; | ||||
|         toolCalling: ToolCallingStage; | ||||
|         vectorSearch: VectorSearchStage; | ||||
|     }; | ||||
|  | ||||
|     config: ChatPipelineConfig; | ||||
| @@ -40,7 +46,9 @@ export class ChatPipeline { | ||||
|             messagePreparation: new MessagePreparationStage(), | ||||
|             modelSelection: new ModelSelectionStage(), | ||||
|             llmCompletion: new LLMCompletionStage(), | ||||
|             responseProcessing: new ResponseProcessingStage() | ||||
|             responseProcessing: new ResponseProcessingStage(), | ||||
|             toolCalling: new ToolCallingStage(), | ||||
|             vectorSearch: new VectorSearchStage() | ||||
|         }; | ||||
|  | ||||
|         // Set default configuration values | ||||
| @@ -87,6 +95,34 @@ export class ChatPipeline { | ||||
|                 contentLength += message.content.length; | ||||
|             } | ||||
|  | ||||
|             // Initialize tools if needed | ||||
|             try { | ||||
|                 const toolCount = toolRegistry.getAllTools().length; | ||||
|  | ||||
|                 // If there are no tools registered, initialize them | ||||
|                 if (toolCount === 0) { | ||||
|                     log.info('No tools found in registry, initializing tools...'); | ||||
|                     await toolInitializer.initializeTools(); | ||||
|                     log.info(`Tools initialized, now have ${toolRegistry.getAllTools().length} tools`); | ||||
|                 } else { | ||||
|                     log.info(`Found ${toolCount} tools already registered`); | ||||
|                 } | ||||
|             } catch (error: any) { | ||||
|                 log.error(`Error checking/initializing tools: ${error.message || String(error)}`); | ||||
|             } | ||||
|  | ||||
|             // First, select the appropriate model based on query complexity and content length | ||||
|             const modelSelectionStartTime = Date.now(); | ||||
|             const modelSelection = await this.stages.modelSelection.execute({ | ||||
|                 options: input.options, | ||||
|                 query: input.query, | ||||
|                 contentLength | ||||
|             }); | ||||
|             this.updateStageMetrics('modelSelection', modelSelectionStartTime); | ||||
|  | ||||
|             // Determine if we should use tools or semantic context | ||||
|             const useTools = modelSelection.options.enableTools === true; | ||||
|  | ||||
|             // Determine which pipeline flow to use | ||||
|             let context: string | undefined; | ||||
|  | ||||
| @@ -102,27 +138,63 @@ export class ChatPipeline { | ||||
|                     }); | ||||
|                     context = agentContext.context; | ||||
|                     this.updateStageMetrics('agentToolsContext', contextStartTime); | ||||
|                 } else { | ||||
|                     // Get semantic context for regular queries | ||||
|                 } else if (!useTools) { | ||||
|                     // Only get semantic context if tools are NOT enabled | ||||
|                     // When tools are enabled, we'll let the LLM request context via tools instead | ||||
|                     log.info('Getting semantic context for note using pipeline stages'); | ||||
|                      | ||||
|                     // First use the vector search stage to find relevant notes | ||||
|                     const vectorSearchStartTime = Date.now(); | ||||
|                     log.info(`Executing vector search stage for query: "${input.query?.substring(0, 50)}..."`); | ||||
|                      | ||||
|                     const vectorSearchResult = await this.stages.vectorSearch.execute({ | ||||
|                         query: input.query || '', | ||||
|                         noteId: input.noteId, | ||||
|                         options: { | ||||
|                             maxResults: 10, | ||||
|                             useEnhancedQueries: true, | ||||
|                             threshold: 0.6 | ||||
|                         } | ||||
|                     }); | ||||
|                      | ||||
|                     this.updateStageMetrics('vectorSearch', vectorSearchStartTime); | ||||
|                      | ||||
|                     log.info(`Vector search found ${vectorSearchResult.searchResults.length} relevant notes`); | ||||
|                      | ||||
|                     // Then pass to the semantic context stage to build the formatted context | ||||
|                     const semanticContext = await this.stages.semanticContextExtraction.execute({ | ||||
|                         noteId: input.noteId, | ||||
|                         query: input.query, | ||||
|                         messages: input.messages | ||||
|                     }); | ||||
|                      | ||||
|                     context = semanticContext.context; | ||||
|                     this.updateStageMetrics('semanticContextExtraction', contextStartTime); | ||||
|                 } else { | ||||
|                     log.info('Tools are enabled - using minimal direct context to avoid race conditions'); | ||||
|                     // Get context from current note directly without semantic search | ||||
|                     if (input.noteId) { | ||||
|                         try { | ||||
|                             const contextExtractor = new (await import('../../llm/context/index.js')).ContextExtractor(); | ||||
|                             // Just get the direct content of the current note | ||||
|                             context = await contextExtractor.extractContext(input.noteId, { | ||||
|                                 includeContent: true, | ||||
|                                 includeParents: true, | ||||
|                                 includeChildren: true, | ||||
|                                 includeLinks: true, | ||||
|                                 includeSimilar: false // Skip semantic search to avoid race conditions | ||||
|                             }); | ||||
|                             log.info(`Direct context extracted (${context.length} chars) without semantic search`); | ||||
|                         } catch (error: any) { | ||||
|                             log.error(`Error extracting direct context: ${error.message}`); | ||||
|                             context = ""; // Fallback to empty context if extraction fails | ||||
|                         } | ||||
|                     } else { | ||||
|                         context = ""; // No note ID, so no context | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Select the appropriate model based on query complexity and content length | ||||
|             const modelSelectionStartTime = Date.now(); | ||||
|             const modelSelection = await this.stages.modelSelection.execute({ | ||||
|                 options: input.options, | ||||
|                 query: input.query, | ||||
|                 contentLength | ||||
|             }); | ||||
|             this.updateStageMetrics('modelSelection', modelSelectionStartTime); | ||||
|  | ||||
|             // Prepare messages with context and system prompt | ||||
|             const messagePreparationStartTime = Date.now(); | ||||
|             const preparedMessages = await this.stages.messagePreparation.execute({ | ||||
| @@ -167,17 +239,106 @@ export class ChatPipeline { | ||||
|                 }); | ||||
|             } | ||||
|  | ||||
|             // For non-streaming responses, process the full response | ||||
|             // Process any tool calls in the response | ||||
|             let currentMessages = preparedMessages.messages; | ||||
|             let currentResponse = completion.response; | ||||
|             let needsFollowUp = false; | ||||
|             let toolCallIterations = 0; | ||||
|             const maxToolCallIterations = this.config.maxToolCallIterations; | ||||
|  | ||||
|             // Check if tools were enabled in the options | ||||
|             const toolsEnabled = modelSelection.options.enableTools !== false; | ||||
|              | ||||
|             log.info(`========== TOOL CALL PROCESSING ==========`); | ||||
|             log.info(`Tools enabled: ${toolsEnabled}`); | ||||
|             log.info(`Tool calls in response: ${currentResponse.tool_calls ? currentResponse.tool_calls.length : 0}`); | ||||
|             log.info(`Current response format: ${typeof currentResponse}`); | ||||
|             log.info(`Response keys: ${Object.keys(currentResponse).join(', ')}`); | ||||
|              | ||||
|             // Detailed tool call inspection | ||||
|             if (currentResponse.tool_calls) { | ||||
|                 currentResponse.tool_calls.forEach((tool, idx) => { | ||||
|                     log.info(`Tool call ${idx+1}: ${JSON.stringify(tool)}`); | ||||
|                 }); | ||||
|             } | ||||
|  | ||||
|             // Process tool calls if present and tools are enabled | ||||
|             if (toolsEnabled && currentResponse.tool_calls && currentResponse.tool_calls.length > 0) { | ||||
|                 log.info(`Response contains ${currentResponse.tool_calls.length} tool calls, processing...`); | ||||
|  | ||||
|                 // Start tool calling loop | ||||
|                 log.info(`Starting tool calling loop with max ${maxToolCallIterations} iterations`); | ||||
|  | ||||
|                 do { | ||||
|                     log.info(`Tool calling iteration ${toolCallIterations + 1}`); | ||||
|  | ||||
|                     // Execute tool calling stage | ||||
|                     const toolCallingStartTime = Date.now(); | ||||
|                     const toolCallingResult = await this.stages.toolCalling.execute({ | ||||
|                         response: currentResponse, | ||||
|                         messages: currentMessages, | ||||
|                         options: modelSelection.options | ||||
|                     }); | ||||
|                     this.updateStageMetrics('toolCalling', toolCallingStartTime); | ||||
|  | ||||
|                     // Update state for next iteration | ||||
|                     currentMessages = toolCallingResult.messages; | ||||
|                     needsFollowUp = toolCallingResult.needsFollowUp; | ||||
|  | ||||
|                     // Make another call to the LLM if needed | ||||
|                     if (needsFollowUp) { | ||||
|                         log.info(`Tool execution completed, making follow-up LLM call (iteration ${toolCallIterations + 1})...`); | ||||
|  | ||||
|                         // Generate a new LLM response with the updated messages | ||||
|                         const followUpStartTime = Date.now(); | ||||
|                         log.info(`Sending follow-up request to LLM with ${currentMessages.length} messages (including tool results)`); | ||||
|  | ||||
|                         const followUpCompletion = await this.stages.llmCompletion.execute({ | ||||
|                             messages: currentMessages, | ||||
|                             options: modelSelection.options | ||||
|                         }); | ||||
|                         this.updateStageMetrics('llmCompletion', followUpStartTime); | ||||
|  | ||||
|                         // Update current response for next iteration | ||||
|                         currentResponse = followUpCompletion.response; | ||||
|  | ||||
|                         // Check for more tool calls | ||||
|                         const hasMoreToolCalls = !!(currentResponse.tool_calls && currentResponse.tool_calls.length > 0); | ||||
|  | ||||
|                         if (hasMoreToolCalls) { | ||||
|                             log.info(`Follow-up response contains ${currentResponse.tool_calls?.length || 0} more tool calls`); | ||||
|                         } else { | ||||
|                             log.info(`Follow-up response contains no more tool calls - completing tool loop`); | ||||
|                         } | ||||
|  | ||||
|                         // Continue loop if there are more tool calls | ||||
|                         needsFollowUp = hasMoreToolCalls; | ||||
|                     } | ||||
|  | ||||
|                     // Increment iteration counter | ||||
|                     toolCallIterations++; | ||||
|  | ||||
|                 } while (needsFollowUp && toolCallIterations < maxToolCallIterations); | ||||
|  | ||||
|                 // If we hit max iterations but still have tool calls, log a warning | ||||
|                 if (toolCallIterations >= maxToolCallIterations && needsFollowUp) { | ||||
|                     log.error(`Reached maximum tool call iterations (${maxToolCallIterations}), stopping`); | ||||
|                 } | ||||
|  | ||||
|                 log.info(`Completed ${toolCallIterations} tool call iterations`); | ||||
|             } | ||||
|  | ||||
|             // For non-streaming responses, process the final response | ||||
|             const processStartTime = Date.now(); | ||||
|             const processed = await this.stages.responseProcessing.execute({ | ||||
|                 response: completion.response, | ||||
|                 response: currentResponse, | ||||
|                 options: input.options | ||||
|             }); | ||||
|             this.updateStageMetrics('responseProcessing', processStartTime); | ||||
|  | ||||
|             // Combine response with processed text, using accumulated text if streamed | ||||
|             const finalResponse: ChatResponse = { | ||||
|                 ...completion.response, | ||||
|                 ...currentResponse, | ||||
|                 text: accumulatedText || processed.text | ||||
|             }; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user