feat(llm): implement error recovery stage and implement better tool calling

This commit is contained in:
perf3ct
2025-07-04 23:16:26 +00:00
parent 5562559b0b
commit 6fbc5b2b14
7 changed files with 1415 additions and 6 deletions

View File

@@ -648,6 +648,90 @@ async function handleStreamingProcess(
}
}
/**
* @swagger
* /api/llm/interactions/{interactionId}/respond:
* post:
* summary: Respond to a user interaction request (confirm/cancel tool execution)
* operationId: llm-interaction-respond
* parameters:
* - name: interactionId
* in: path
* required: true
* schema:
* type: string
* description: The ID of the interaction to respond to
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* properties:
* response:
* type: string
* enum: [confirm, cancel]
* description: User's response to the interaction
* responses:
* '200':
* description: Response processed successfully
* '404':
* description: Interaction not found
* '400':
* description: Invalid response
* security:
* - session: []
* tags: ["llm"]
*/
async function respondToInteraction(req: Request, res: Response): Promise<void> {
try {
const interactionId = req.params.interactionId;
const { response } = req.body;
if (!interactionId || !response) {
res.status(400).json({
success: false,
error: 'Missing interactionId or response'
});
return;
}
if (response !== 'confirm' && response !== 'cancel') {
res.status(400).json({
success: false,
error: 'Response must be either "confirm" or "cancel"'
});
return;
}
// Import the pipeline to access user interaction stage
// Note: In a real implementation, you'd maintain a registry of active pipelines
// For now, we'll send this via WebSocket to be handled by the active pipeline
const wsService = (await import('../../services/ws.js')).default;
// Send the user response via WebSocket to be picked up by the active pipeline
wsService.sendMessageToAllClients({
type: 'user-interaction-response',
interactionId,
response,
timestamp: Date.now()
});
res.status(200).json({
success: true,
message: `User response "${response}" recorded for interaction ${interactionId}`
});
} catch (error) {
log.error(`Error handling user interaction response: ${error}`);
res.status(500).json({
success: false,
error: 'Internal server error'
});
}
}
/**
* Debug endpoint to check tool recognition and registry status
*/
@@ -748,6 +832,9 @@ export default {
sendMessage,
streamMessage,
// User interaction
respondToInteraction,
// Debug endpoints
debugTools
};

View File

@@ -31,12 +31,24 @@ export interface ToolData {
}
export interface ToolExecutionInfo {
type: 'start' | 'update' | 'complete' | 'error';
type: 'start' | 'update' | 'complete' | 'error' | 'progress' | 'retry';
action?: string;
tool: {
name: string;
arguments: Record<string, unknown>;
};
result?: string | Record<string, unknown>;
progress?: {
current: number;
total: number;
status: string;
message: string;
startTime?: number;
executionTime?: number;
resultSummary?: string;
errorType?: string;
estimatedDuration?: number;
};
}
/**
@@ -80,6 +92,12 @@ export interface StreamChunk {
* Includes tool name, args, and execution status
*/
toolExecution?: ToolExecutionInfo;
/**
* User interaction data (for confirmation/cancellation requests)
* Contains interaction ID, tool info, and response options
*/
userInteraction?: Record<string, unknown>;
}
/**
@@ -211,6 +229,21 @@ export interface ChatResponse {
/** Tool calls from the LLM (if tools were used and the model supports them) */
tool_calls?: ToolCall[] | null;
/** Recovery metadata for advanced error recovery */
recovery_metadata?: {
total_attempts: number;
successful_recoveries: number;
failed_permanently: number;
};
/** User interaction metadata for confirmation/cancellation features */
interaction_metadata?: {
total_interactions: number;
confirmed: number;
cancelled: number;
timedout: number;
};
}
export interface AIService {

View File

@@ -8,6 +8,7 @@ import { ModelSelectionStage } from './stages/model_selection_stage.js';
import { LLMCompletionStage } from './stages/llm_completion_stage.js';
import { ResponseProcessingStage } from './stages/response_processing_stage.js';
import { ToolCallingStage } from './stages/tool_calling_stage.js';
import { ErrorRecoveryStage } from './stages/error_recovery_stage.js';
// Traditional search is used instead of vector search
import toolRegistry from '../tools/tool_registry.js';
import toolInitializer from '../tools/tool_initializer.js';
@@ -29,6 +30,7 @@ export class ChatPipeline {
llmCompletion: LLMCompletionStage;
responseProcessing: ResponseProcessingStage;
toolCalling: ToolCallingStage;
errorRecovery: ErrorRecoveryStage;
// traditional search is used instead of vector search
};
@@ -50,6 +52,7 @@ export class ChatPipeline {
llmCompletion: new LLMCompletionStage(),
responseProcessing: new ResponseProcessingStage(),
toolCalling: new ToolCallingStage(),
errorRecovery: new ErrorRecoveryStage(),
// traditional search is used instead of vector search
};

View File

@@ -0,0 +1,561 @@
import { BasePipelineStage } from '../pipeline_stage.js';
import type { ToolExecutionInput, StreamCallback } from '../interfaces.js';
import type { ChatResponse, Message } from '../../ai_interface.js';
import toolRegistry from '../../tools/tool_registry.js';
import log from '../../../log.js';
interface RetryStrategy {
maxRetries: number;
baseDelay: number;
maxDelay: number;
backoffMultiplier: number;
jitter: boolean;
}
interface ToolRetryContext {
toolName: string;
attempt: number;
lastError: string;
alternativeApproaches: string[];
usedApproaches: string[];
}
/**
* Advanced Error Recovery Pipeline Stage
* Implements sophisticated retry strategies with exponential backoff,
* alternative tool selection, and intelligent fallback mechanisms
*/
export class ErrorRecoveryStage extends BasePipelineStage<ToolExecutionInput, { response: ChatResponse, needsFollowUp: boolean, messages: Message[] }> {
private retryStrategies: Map<string, RetryStrategy> = new Map();
private activeRetries: Map<string, ToolRetryContext> = new Map();
constructor() {
super('ErrorRecovery');
this.initializeRetryStrategies();
}
/**
* Initialize retry strategies for different tool types
*/
private initializeRetryStrategies(): void {
// Search tools - more aggressive retries since they're critical
this.retryStrategies.set('search_notes', {
maxRetries: 3,
baseDelay: 1000,
maxDelay: 8000,
backoffMultiplier: 2,
jitter: true
});
this.retryStrategies.set('keyword_search', {
maxRetries: 3,
baseDelay: 800,
maxDelay: 6000,
backoffMultiplier: 2,
jitter: true
});
// Read operations - moderate retries
this.retryStrategies.set('read_note', {
maxRetries: 2,
baseDelay: 500,
maxDelay: 3000,
backoffMultiplier: 2,
jitter: false
});
// Attribute operations - conservative retries
this.retryStrategies.set('attribute_search', {
maxRetries: 2,
baseDelay: 1200,
maxDelay: 5000,
backoffMultiplier: 1.8,
jitter: true
});
// Default strategy for unknown tools
this.retryStrategies.set('default', {
maxRetries: 2,
baseDelay: 1000,
maxDelay: 4000,
backoffMultiplier: 2,
jitter: true
});
}
/**
* Process tool execution with advanced error recovery
*/
protected async process(input: ToolExecutionInput): Promise<{ response: ChatResponse, needsFollowUp: boolean, messages: Message[] }> {
const { response } = input;
// If no tool calls, pass through
if (!response.tool_calls || response.tool_calls.length === 0) {
return { response, needsFollowUp: false, messages: input.messages };
}
log.info(`========== ERROR RECOVERY STAGE PROCESSING ==========`);
log.info(`Processing ${response.tool_calls.length} tool calls with advanced error recovery`);
const recoveredToolCalls = [];
const updatedMessages = [...input.messages];
// Process each tool call with recovery
for (let i = 0; i < response.tool_calls.length; i++) {
const toolCall = response.tool_calls[i];
const recoveredResult = await this.executeToolWithRecovery(toolCall, input, i);
if (recoveredResult) {
recoveredToolCalls.push(recoveredResult);
updatedMessages.push(recoveredResult.message);
}
}
// Create enhanced response with recovery information
const enhancedResponse: ChatResponse = {
...response,
tool_calls: recoveredToolCalls.map(r => r.toolCall),
recovery_metadata: {
total_attempts: recoveredToolCalls.reduce((sum, r) => sum + r.attempts, 0),
successful_recoveries: recoveredToolCalls.filter(r => r.recovered).length,
failed_permanently: recoveredToolCalls.filter(r => !r.recovered).length
}
};
const needsFollowUp = recoveredToolCalls.length > 0;
log.info(`Recovery complete: ${recoveredToolCalls.filter(r => r.recovered).length}/${recoveredToolCalls.length} tools recovered`);
return {
response: enhancedResponse,
needsFollowUp,
messages: updatedMessages
};
}
/**
* Execute a tool call with comprehensive error recovery
*/
private async executeToolWithRecovery(
toolCall: any,
input: ToolExecutionInput,
index: number
): Promise<{ toolCall: any, message: Message, attempts: number, recovered: boolean } | null> {
const toolName = toolCall.function.name;
const strategy = this.retryStrategies.get(toolName) || this.retryStrategies.get('default')!;
let lastError = '';
let attempts = 0;
let recovered = false;
// Initialize retry context
const retryContext: ToolRetryContext = {
toolName,
attempt: 0,
lastError: '',
alternativeApproaches: this.getAlternativeApproaches(toolName),
usedApproaches: []
};
log.info(`Starting error recovery for tool: ${toolName} (max retries: ${strategy.maxRetries})`);
// Primary execution attempts
for (attempts = 1; attempts <= strategy.maxRetries + 1; attempts++) {
try {
retryContext.attempt = attempts;
// Add delay for retry attempts (not first attempt)
if (attempts > 1) {
const delay = this.calculateDelay(strategy, attempts - 1);
log.info(`Retry attempt ${attempts - 1} for ${toolName} after ${delay}ms delay`);
await this.sleep(delay);
// Send retry notification if streaming
if (input.streamCallback) {
this.sendRetryNotification(input.streamCallback, toolName, attempts - 1, strategy.maxRetries);
}
}
// Execute the tool
const tool = toolRegistry.getTool(toolName);
if (!tool) {
throw new Error(`Tool not found: ${toolName}`);
}
// Parse arguments
const args = this.parseToolArguments(toolCall.function.arguments);
// Modify arguments for retry if needed
const modifiedArgs = this.modifyArgsForRetry(args, retryContext);
log.info(`Executing ${toolName} (attempt ${attempts}) with args: ${JSON.stringify(modifiedArgs)}`);
const result = await tool.execute(modifiedArgs);
// Success!
recovered = true;
log.info(`✓ Tool ${toolName} succeeded on attempt ${attempts}`);
return {
toolCall,
message: {
role: 'tool',
content: typeof result === 'string' ? result : JSON.stringify(result, null, 2),
name: toolName,
tool_call_id: toolCall.id
},
attempts,
recovered: true
};
} catch (error) {
lastError = error instanceof Error ? error.message : String(error);
retryContext.lastError = lastError;
log.info(`✗ Tool ${toolName} failed on attempt ${attempts}: ${lastError}`);
// If this was the last allowed attempt, break
if (attempts > strategy.maxRetries) {
break;
}
}
}
// Primary attempts failed, try alternative approaches
log.info(`Primary attempts failed for ${toolName}, trying alternative approaches`);
for (const alternative of retryContext.alternativeApproaches) {
if (retryContext.usedApproaches.includes(alternative)) {
continue; // Skip already used approaches
}
try {
log.info(`Trying alternative approach: ${alternative} for ${toolName}`);
retryContext.usedApproaches.push(alternative);
const alternativeResult = await this.executeAlternativeApproach(alternative, toolCall, retryContext);
if (alternativeResult) {
log.info(`✓ Alternative approach ${alternative} succeeded for ${toolName}`);
recovered = true;
return {
toolCall,
message: {
role: 'tool',
content: `ALTERNATIVE_SUCCESS: ${alternative} succeeded where ${toolName} failed. Result: ${alternativeResult}`,
name: toolName,
tool_call_id: toolCall.id
},
attempts: attempts + 1,
recovered: true
};
}
} catch (error) {
const altError = error instanceof Error ? error.message : String(error);
log.info(`✗ Alternative approach ${alternative} failed: ${altError}`);
}
}
// All attempts failed
log.error(`All recovery attempts failed for ${toolName} after ${attempts} attempts and ${retryContext.usedApproaches.length} alternatives`);
// Return failure message with guidance
const failureGuidance = this.generateFailureGuidance(toolName, lastError, retryContext);
return {
toolCall,
message: {
role: 'tool',
content: `RECOVERY_FAILED: Tool ${toolName} failed after ${attempts} attempts and ${retryContext.usedApproaches.length} alternative approaches. Last error: ${lastError}\n\n${failureGuidance}`,
name: toolName,
tool_call_id: toolCall.id
},
attempts,
recovered: false
};
}
/**
* Calculate retry delay with exponential backoff and optional jitter
*/
private calculateDelay(strategy: RetryStrategy, retryNumber: number): number {
let delay = strategy.baseDelay * Math.pow(strategy.backoffMultiplier, retryNumber - 1);
// Apply maximum delay limit
delay = Math.min(delay, strategy.maxDelay);
// Add jitter if enabled (±25% random variation)
if (strategy.jitter) {
const jitterRange = delay * 0.25;
const jitter = (Math.random() - 0.5) * 2 * jitterRange;
delay += jitter;
}
return Math.round(Math.max(delay, 100)); // Minimum 100ms delay
}
/**
* Sleep for specified milliseconds
*/
private sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
/**
* Get alternative approaches for a tool
*/
private getAlternativeApproaches(toolName: string): string[] {
const alternatives: Record<string, string[]> = {
'search_notes': ['keyword_search', 'broader_search_terms', 'attribute_search'],
'keyword_search': ['search_notes', 'simplified_query', 'attribute_search'],
'attribute_search': ['search_notes', 'keyword_search', 'different_attribute_type'],
'read_note': ['note_by_path', 'search_and_read', 'template_search'],
'note_by_path': ['read_note', 'search_notes', 'keyword_search']
};
return alternatives[toolName] || ['search_notes', 'keyword_search'];
}
/**
* Modify arguments for retry attempts
*/
private modifyArgsForRetry(args: Record<string, unknown>, context: ToolRetryContext): Record<string, unknown> {
const modified = { ...args };
// For search tools, broaden the query on retries
if (context.toolName.includes('search') && context.attempt > 1) {
if (modified.query && typeof modified.query === 'string') {
// Remove quotes and qualifiers to broaden the search
modified.query = (modified.query as string)
.replace(/['"]/g, '') // Remove quotes
.replace(/\b(exactly|specific|precise)\b/gi, '') // Remove limiting words
.trim();
log.info(`Modified query for retry: "${modified.query}"`);
}
}
// For attribute search, try different attribute types
if (context.toolName === 'attribute_search' && context.attempt > 1) {
if (modified.attributeType === 'label') {
modified.attributeType = 'relation';
} else if (modified.attributeType === 'relation') {
modified.attributeType = 'label';
}
log.info(`Modified attributeType for retry: ${modified.attributeType}`);
}
return modified;
}
/**
* Execute alternative approach
*/
private async executeAlternativeApproach(
approach: string,
originalToolCall: any,
context: ToolRetryContext
): Promise<string | null> {
switch (approach) {
case 'broader_search_terms':
return await this.executeBroaderSearch(originalToolCall);
case 'simplified_query':
return await this.executeSimplifiedSearch(originalToolCall);
case 'different_attribute_type':
return await this.executeDifferentAttributeSearch(originalToolCall);
case 'search_and_read':
return await this.executeSearchAndRead(originalToolCall);
default:
// Try to execute the alternative tool directly
return await this.executeAlternativeTool(approach, originalToolCall);
}
}
/**
* Execute broader search approach
*/
private async executeBroaderSearch(toolCall: any): Promise<string | null> {
const args = this.parseToolArguments(toolCall.function.arguments);
if (args.query && typeof args.query === 'string') {
// Extract the main keywords and search more broadly
const keywords = (args.query as string)
.split(' ')
.filter(word => word.length > 3)
.slice(0, 3) // Take only first 3 main keywords
.join(' ');
const broadArgs = { ...args, query: keywords };
const tool = toolRegistry.getTool('search_notes');
if (tool) {
const result = await tool.execute(broadArgs);
return typeof result === 'string' ? result : JSON.stringify(result);
}
}
return null;
}
/**
* Execute simplified search approach
*/
private async executeSimplifiedSearch(toolCall: any): Promise<string | null> {
const args = this.parseToolArguments(toolCall.function.arguments);
if (args.query && typeof args.query === 'string') {
// Use only the first word as a very simple search
const firstWord = (args.query as string).split(' ')[0];
const simpleArgs = { ...args, query: firstWord };
const tool = toolRegistry.getTool('keyword_search');
if (tool) {
const result = await tool.execute(simpleArgs);
return typeof result === 'string' ? result : JSON.stringify(result);
}
}
return null;
}
/**
* Execute different attribute search
*/
private async executeDifferentAttributeSearch(toolCall: any): Promise<string | null> {
const args = this.parseToolArguments(toolCall.function.arguments);
if (args.attributeType) {
const newType = args.attributeType === 'label' ? 'relation' : 'label';
const newArgs = { ...args, attributeType: newType };
const tool = toolRegistry.getTool('attribute_search');
if (tool) {
const result = await tool.execute(newArgs);
return typeof result === 'string' ? result : JSON.stringify(result);
}
}
return null;
}
/**
* Execute search and read approach
*/
private async executeSearchAndRead(toolCall: any): Promise<string | null> {
const args = this.parseToolArguments(toolCall.function.arguments);
// First search for notes
const searchTool = toolRegistry.getTool('search_notes');
if (searchTool && args.query) {
try {
const searchResult = await searchTool.execute({ query: args.query });
// Try to extract note IDs and read the first one
const searchText = typeof searchResult === 'string' ? searchResult : JSON.stringify(searchResult);
const noteIdMatch = searchText.match(/note[:\s]+([a-zA-Z0-9]+)/i);
if (noteIdMatch && noteIdMatch[1]) {
const readTool = toolRegistry.getTool('read_note');
if (readTool) {
const readResult = await readTool.execute({ noteId: noteIdMatch[1] });
return `SEARCH_AND_READ: Found and read note ${noteIdMatch[1]}. Content: ${readResult}`;
}
}
return `SEARCH_ONLY: ${searchText}`;
} catch (error) {
return null;
}
}
return null;
}
/**
* Execute alternative tool
*/
private async executeAlternativeTool(toolName: string, originalToolCall: any): Promise<string | null> {
const tool = toolRegistry.getTool(toolName);
if (!tool) {
return null;
}
const args = this.parseToolArguments(originalToolCall.function.arguments);
try {
const result = await tool.execute(args);
return typeof result === 'string' ? result : JSON.stringify(result);
} catch (error) {
return null;
}
}
/**
* Parse tool arguments safely
*/
private parseToolArguments(args: string | Record<string, unknown>): Record<string, unknown> {
if (typeof args === 'string') {
try {
return JSON.parse(args);
} catch {
return { query: args };
}
}
return args;
}
/**
* Send retry notification via streaming
*/
private sendRetryNotification(
streamCallback: StreamCallback,
toolName: string,
retryNumber: number,
maxRetries: number
): void {
streamCallback('', false, {
text: '',
done: false,
toolExecution: {
type: 'retry',
action: 'retry',
tool: { name: toolName, arguments: {} },
progress: {
current: retryNumber,
total: maxRetries,
status: 'retrying',
message: `Retrying ${toolName} (attempt ${retryNumber}/${maxRetries})...`
}
}
});
}
/**
* Generate failure guidance
*/
private generateFailureGuidance(toolName: string, lastError: string, context: ToolRetryContext): string {
const guidance = [
`RECOVERY ANALYSIS for ${toolName}:`,
`- Primary attempts: ${context.attempt}`,
`- Alternative approaches tried: ${context.usedApproaches.join(', ') || 'none'}`,
`- Last error: ${lastError}`,
'',
'SUGGESTED NEXT STEPS:',
'- Try manual search with broader terms',
'- Check if the requested information exists',
'- Use discover_tools to find alternative tools',
'- Reformulate the query with different keywords'
];
return guidance.join('\n');
}
}

View File

@@ -32,6 +32,11 @@ export class MessagePreparationStage extends BasePipelineStage<MessagePreparatio
const toolsEnabled = options?.enableTools === true;
log.info(`Preparing messages for provider: ${provider}, context: ${!!context}, system prompt: ${!!systemPrompt}, tools: ${toolsEnabled}`);
log.info(`Input message count: ${messages.length}`);
// Apply intelligent context management for long conversations
const managedMessages = await this.applyContextManagement(messages, provider, options);
log.info(`After context management: ${managedMessages.length} messages (reduced by ${messages.length - managedMessages.length})`);
// Get appropriate formatter for this provider
const formatter = MessageFormatterFactory.getFormatter(provider);
@@ -69,13 +74,171 @@ Remember: Tool usage should be continuous and iterative until you have thoroughl
// Format messages using provider-specific approach
const formattedMessages = formatter.formatMessages(
messages,
managedMessages,
finalSystemPrompt,
context
);
log.info(`Formatted ${messages.length} messages into ${formattedMessages.length} messages for provider: ${provider}`);
log.info(`Formatted ${managedMessages.length} messages into ${formattedMessages.length} messages for provider: ${provider}`);
return { messages: formattedMessages };
}
/**
* Apply intelligent context management to handle long conversations
* Implements various strategies like sliding window, summarization, and importance-based pruning
*/
private async applyContextManagement(messages: Message[], provider: string, options?: any): Promise<Message[]> {
const maxMessages = this.getMaxMessagesForProvider(provider);
// If we're under the limit, return as-is
if (messages.length <= maxMessages) {
log.info(`Message count (${messages.length}) within limit (${maxMessages}), no context management needed`);
return messages;
}
log.info(`Message count (${messages.length}) exceeds limit (${maxMessages}), applying context management`);
// Strategy 1: Preserve recent messages and important system/tool messages
const managedMessages = await this.applySlidingWindowWithImportanceFiltering(messages, maxMessages);
// Strategy 2: If still too many, apply summarization to older messages
if (managedMessages.length > maxMessages) {
return await this.applySummarizationToOlderMessages(managedMessages, maxMessages);
}
return managedMessages;
}
/**
* Get maximum message count for different providers based on their context windows
*/
private getMaxMessagesForProvider(provider: string): number {
const limits = {
'anthropic': 50, // Conservative for Claude's context window management
'openai': 40, // Conservative for GPT models
'ollama': 30, // More conservative for local models
'default': 35 // Safe default
};
return limits[provider as keyof typeof limits] || limits.default;
}
/**
* Apply sliding window with importance filtering
* Keeps recent messages and important system/tool messages
*/
private async applySlidingWindowWithImportanceFiltering(messages: Message[], maxMessages: number): Promise<Message[]> {
if (messages.length <= maxMessages) {
return messages;
}
// Always preserve the first system message if it exists
const systemMessages = messages.filter(msg => msg.role === 'system').slice(0, 1);
// Find tool-related messages that are important to preserve
const toolMessages = messages.filter(msg =>
msg.role === 'tool' ||
(msg.role === 'assistant' && msg.tool_calls && msg.tool_calls.length > 0)
);
// Calculate how many recent messages we can keep
const preservedCount = systemMessages.length;
const recentMessageCount = Math.min(maxMessages - preservedCount, messages.length);
// Get the most recent messages
const recentMessages = messages.slice(-recentMessageCount);
// Combine system messages + recent messages, avoiding duplicates
const result: Message[] = [];
// Add system messages first
systemMessages.forEach(msg => {
if (!result.some(existing => existing === msg)) {
result.push(msg);
}
});
// Add recent messages
recentMessages.forEach(msg => {
if (!result.some(existing => existing === msg)) {
result.push(msg);
}
});
log.info(`Sliding window filtering: preserved ${preservedCount} system messages, kept ${recentMessages.length} recent messages`);
return result.slice(0, maxMessages); // Ensure we don't exceed the limit
}
/**
* Apply summarization to older messages when needed
* Summarizes conversation segments to reduce token count while preserving context
*/
private async applySummarizationToOlderMessages(messages: Message[], maxMessages: number): Promise<Message[]> {
if (messages.length <= maxMessages) {
return messages;
}
// Keep recent messages (last 60% of limit)
const recentCount = Math.floor(maxMessages * 0.6);
const recentMessages = messages.slice(-recentCount);
// Get older messages to summarize
const olderMessages = messages.slice(0, messages.length - recentCount);
// Create a summary of older messages
const summary = this.createConversationSummary(olderMessages);
// Create a summary message
const summaryMessage: Message = {
role: 'system',
content: `CONVERSATION SUMMARY: Previous conversation included ${olderMessages.length} messages. Key points: ${summary}`
};
log.info(`Applied summarization: summarized ${olderMessages.length} older messages, kept ${recentMessages.length} recent messages`);
return [summaryMessage, ...recentMessages];
}
/**
* Create a concise summary of conversation messages
*/
private createConversationSummary(messages: Message[]): string {
const userQueries: string[] = [];
const assistantActions: string[] = [];
const toolUsage: string[] = [];
messages.forEach(msg => {
if (msg.role === 'user') {
// Extract key topics from user messages
const content = msg.content?.substring(0, 100) || '';
if (content.trim()) {
userQueries.push(content.trim());
}
} else if (msg.role === 'assistant') {
// Track tool usage
if (msg.tool_calls && msg.tool_calls.length > 0) {
msg.tool_calls.forEach(tool => {
if (tool.function?.name) {
toolUsage.push(tool.function.name);
}
});
}
}
});
const summary: string[] = [];
if (userQueries.length > 0) {
summary.push(`User asked about: ${userQueries.slice(0, 3).join(', ')}`);
}
if (toolUsage.length > 0) {
const uniqueTools = [...new Set(toolUsage)];
summary.push(`Tools used: ${uniqueTools.slice(0, 5).join(', ')}`);
}
return summary.join('. ') || 'General conversation about notes and information retrieval';
}
}

View File

@@ -256,7 +256,14 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
name: toolCall.function.name,
arguments: args
},
type: 'start' as const
type: 'start' as const,
progress: {
current: index + 1,
total: response.tool_calls?.length || 1,
status: 'initializing',
message: `Starting ${toolCall.function.name} execution...`,
estimatedDuration: this.getEstimatedDuration(toolCall.function.name)
}
};
// Don't wait for this to complete, but log any errors
@@ -274,6 +281,35 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
let result;
try {
log.info(`Starting tool execution for ${toolCall.function.name}...`);
// Send progress update during execution
if (streamCallback) {
const progressData = {
action: 'progress',
tool: {
name: toolCall.function.name,
arguments: args
},
type: 'progress' as const,
progress: {
current: index + 1,
total: response.tool_calls?.length || 1,
status: 'executing',
message: `Executing ${toolCall.function.name}...`,
startTime: executionStart
}
};
const progressResult = streamCallback('', false, {
text: '',
done: false,
toolExecution: progressData
});
if (progressResult instanceof Promise) {
progressResult.catch((e: Error) => log.error(`Error sending tool execution progress event: ${e.message}`));
}
}
result = await tool.execute(args);
const executionTime = Date.now() - executionStart;
log.info(`================ TOOL EXECUTION COMPLETED in ${executionTime}ms ================`);
@@ -296,6 +332,10 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
// Emit tool completion event if streaming is enabled
if (streamCallback) {
const resultSummary = typeof result === 'string'
? result.substring(0, 200) + (result.length > 200 ? '...' : '')
: `Object with ${Object.keys(result).length} properties`;
const toolExecutionData = {
action: 'complete',
tool: {
@@ -303,7 +343,15 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
arguments: {} as Record<string, unknown>
},
result: typeof result === 'string' ? result : result as Record<string, unknown>,
type: 'complete' as const
type: 'complete' as const,
progress: {
current: index + 1,
total: response.tool_calls?.length || 1,
status: 'completed',
message: `${toolCall.function.name} completed successfully`,
executionTime: executionTime,
resultSummary: resultSummary
}
};
// Don't wait for this to complete, but log any errors
@@ -352,7 +400,15 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
arguments: {} as Record<string, unknown>
},
error: enhancedErrorMessage, // Include guidance in the error message
type: 'error' as const
type: 'error' as const,
progress: {
current: index + 1,
total: response.tool_calls?.length || 1,
status: 'failed',
message: `${toolCall.function.name} failed: ${errorMessage.substring(0, 100)}...`,
executionTime: executionTime,
errorType: execError instanceof Error ? execError.constructor.name : 'UnknownError'
}
};
// Don't wait for this to complete, but log any errors
@@ -631,6 +687,26 @@ Continue your systematic investigation now.`;
return guidance;
}
/**
* Get estimated duration for a tool execution (in milliseconds)
* @param toolName The name of the tool
* @returns Estimated duration in milliseconds
*/
private getEstimatedDuration(toolName: string): number {
// Tool-specific duration estimates based on typical execution times
const estimations = {
'search_notes': 2000,
'read_note': 1000,
'keyword_search': 1500,
'attribute_search': 1200,
'discover_tools': 500,
'note_by_path': 800,
'template_search': 1000
};
return estimations[toolName as keyof typeof estimations] || 1500; // Default 1.5 seconds
}
/**
* Determines if a tool result is effectively empty or unhelpful
* @param result The result from the tool execution

View File

@@ -0,0 +1,486 @@
import { BasePipelineStage } from '../pipeline_stage.js';
import type { ToolExecutionInput, StreamCallback } from '../interfaces.js';
import type { ChatResponse, Message } from '../../ai_interface.js';
import log from '../../../log.js';
interface UserInteractionConfig {
enableConfirmation: boolean;
enableCancellation: boolean;
confirmationTimeout: number; // milliseconds
autoConfirmLowRisk: boolean;
requiredConfirmationTools: string[];
}
interface PendingInteraction {
id: string;
toolCall: any;
timestamp: number;
timeoutHandle?: NodeJS.Timeout;
resolved: boolean;
}
type InteractionResponse = 'confirm' | 'cancel' | 'timeout';
/**
* Enhanced User Interaction Pipeline Stage
* Provides real-time confirmation/cancellation capabilities for tool execution
*/
export class UserInteractionStage extends BasePipelineStage<ToolExecutionInput, { response: ChatResponse, needsFollowUp: boolean, messages: Message[], userInteractions?: any[] }> {
private config: UserInteractionConfig;
private pendingInteractions: Map<string, PendingInteraction> = new Map();
private interactionCallbacks: Map<string, (response: InteractionResponse) => void> = new Map();
constructor(config?: Partial<UserInteractionConfig>) {
super('UserInteraction');
this.config = {
enableConfirmation: true,
enableCancellation: true,
confirmationTimeout: 15000, // 15 seconds
autoConfirmLowRisk: true,
requiredConfirmationTools: ['attribute_search', 'read_note'],
...config
};
}
/**
* Process tool execution with user interaction capabilities
*/
protected async process(input: ToolExecutionInput): Promise<{ response: ChatResponse, needsFollowUp: boolean, messages: Message[], userInteractions?: any[] }> {
const { response } = input;
// If no tool calls or interactions disabled, pass through
if (!response.tool_calls || response.tool_calls.length === 0 || !this.config.enableConfirmation) {
return {
response,
needsFollowUp: false,
messages: input.messages,
userInteractions: []
};
}
log.info(`========== USER INTERACTION STAGE PROCESSING ==========`);
log.info(`Processing ${response.tool_calls.length} tool calls with user interaction controls`);
const processedToolCalls: any[] = [];
const userInteractions: any[] = [];
const updatedMessages = [...input.messages];
// Process each tool call with interaction controls
for (let i = 0; i < response.tool_calls.length; i++) {
const toolCall = response.tool_calls[i];
const interactionResult = await this.processToolCallWithInteraction(toolCall, input, i);
if (interactionResult) {
processedToolCalls.push(interactionResult.toolCall);
updatedMessages.push(interactionResult.message);
if (interactionResult.interaction) {
userInteractions.push(interactionResult.interaction);
}
}
}
// Create enhanced response with interaction metadata
const enhancedResponse: ChatResponse = {
...response,
tool_calls: processedToolCalls,
interaction_metadata: {
total_interactions: userInteractions.length,
confirmed: userInteractions.filter((i: any) => i.response === 'confirm').length,
cancelled: userInteractions.filter((i: any) => i.response === 'cancel').length,
timedout: userInteractions.filter((i: any) => i.response === 'timeout').length
}
};
const needsFollowUp = processedToolCalls.length > 0;
log.info(`User interaction complete: ${userInteractions.length} interactions processed`);
return {
response: enhancedResponse,
needsFollowUp,
messages: updatedMessages,
userInteractions
};
}
/**
* Process a tool call with user interaction controls
*/
private async processToolCallWithInteraction(
toolCall: any,
input: ToolExecutionInput,
index: number
): Promise<{ toolCall: any, message: Message, interaction?: any } | null> {
const toolName = toolCall.function.name;
const riskLevel = this.assessToolRiskLevel(toolName);
// Determine if confirmation is required
const requiresConfirmation = this.shouldRequireConfirmation(toolName, riskLevel);
if (!requiresConfirmation) {
// Execute immediately for low-risk tools
log.info(`Tool ${toolName} is low-risk, executing immediately`);
return await this.executeToolDirectly(toolCall, input);
}
// Request user confirmation
log.info(`Tool ${toolName} requires user confirmation (risk level: ${riskLevel})`);
const interactionId = this.generateInteractionId();
const interaction = await this.requestUserConfirmation(toolCall, interactionId, input.streamCallback);
if (interaction.response === 'confirm') {
log.info(`User confirmed execution of ${toolName}`);
const result = await this.executeToolDirectly(toolCall, input);
return {
...result!,
interaction
};
} else if (interaction.response === 'cancel') {
log.info(`User cancelled execution of ${toolName}`);
return {
toolCall,
message: {
role: 'tool',
content: `USER_CANCELLED: Execution of ${toolName} was cancelled by user request.`,
name: toolName,
tool_call_id: toolCall.id
},
interaction
};
} else {
// Timeout
log.info(`User confirmation timeout for ${toolName}, executing with default action`);
const result = await this.executeToolDirectly(toolCall, input);
return {
...result!,
interaction: { ...interaction, response: 'timeout_executed' }
};
}
}
/**
* Assess the risk level of a tool
*/
private assessToolRiskLevel(toolName: string): 'low' | 'medium' | 'high' {
const riskLevels = {
// Low risk - read-only operations
'search_notes': 'low',
'keyword_search': 'low',
'discover_tools': 'low',
'template_search': 'low',
// Medium risk - specific data access
'read_note': 'medium',
'note_by_path': 'medium',
// High risk - complex queries or potential data modification
'attribute_search': 'high'
};
return (riskLevels as any)[toolName] || 'medium';
}
/**
* Determine if a tool requires user confirmation
*/
private shouldRequireConfirmation(toolName: string, riskLevel: string): boolean {
// Always require confirmation for high-risk tools
if (riskLevel === 'high') {
return true;
}
// Check if tool is in the required confirmation list
if (this.config.requiredConfirmationTools.includes(toolName)) {
return true;
}
// Auto-confirm low-risk tools if enabled
if (riskLevel === 'low' && this.config.autoConfirmLowRisk) {
return false;
}
// Default to requiring confirmation for medium-risk tools
return riskLevel === 'medium';
}
/**
* Request user confirmation for tool execution
*/
private async requestUserConfirmation(
toolCall: any,
interactionId: string,
streamCallback?: StreamCallback
): Promise<any> {
const toolName = toolCall.function.name;
const args = this.parseToolArguments(toolCall.function.arguments);
// Create pending interaction
const pendingInteraction: PendingInteraction = {
id: interactionId,
toolCall,
timestamp: Date.now(),
resolved: false
};
this.pendingInteractions.set(interactionId, pendingInteraction);
// Send confirmation request via streaming
if (streamCallback) {
this.sendConfirmationRequest(streamCallback, toolCall, interactionId, args);
}
// Wait for user response or timeout
return new Promise<any>((resolve) => {
// Set up timeout
const timeoutHandle = setTimeout(() => {
if (!pendingInteraction.resolved) {
pendingInteraction.resolved = true;
this.pendingInteractions.delete(interactionId);
this.interactionCallbacks.delete(interactionId);
resolve({
id: interactionId,
toolName,
response: 'timeout',
timestamp: Date.now(),
duration: Date.now() - pendingInteraction.timestamp
});
}
}, this.config.confirmationTimeout);
pendingInteraction.timeoutHandle = timeoutHandle;
// Set up response callback
this.interactionCallbacks.set(interactionId, (response: InteractionResponse) => {
if (!pendingInteraction.resolved) {
pendingInteraction.resolved = true;
if (timeoutHandle) {
clearTimeout(timeoutHandle);
}
this.pendingInteractions.delete(interactionId);
this.interactionCallbacks.delete(interactionId);
resolve({
id: interactionId,
toolName,
response,
timestamp: Date.now(),
duration: Date.now() - pendingInteraction.timestamp
});
}
});
});
}
/**
* Send confirmation request via streaming
*/
private sendConfirmationRequest(
streamCallback: StreamCallback,
toolCall: any,
interactionId: string,
args: Record<string, unknown>
): void {
const toolName = toolCall.function.name;
const riskLevel = this.assessToolRiskLevel(toolName);
// Create user-friendly description of the tool action
const actionDescription = this.createActionDescription(toolName, args);
const confirmationData = {
type: 'user_confirmation',
action: 'request',
interactionId,
tool: {
name: toolName,
description: actionDescription,
arguments: args,
riskLevel
},
options: {
confirm: {
label: 'Execute',
description: `Proceed with ${toolName}`,
style: riskLevel === 'high' ? 'warning' : 'primary'
},
cancel: {
label: 'Cancel',
description: 'Skip this tool execution',
style: 'secondary'
}
},
timeout: this.config.confirmationTimeout,
message: `Do you want to execute ${toolName}? ${actionDescription}`
};
streamCallback('', false, {
text: '',
done: false,
userInteraction: confirmationData
});
}
/**
* Create user-friendly action description
*/
private createActionDescription(toolName: string, args: Record<string, unknown>): string {
switch (toolName) {
case 'search_notes':
return `Search your notes for: "${args.query || 'unknown query'}"`;
case 'read_note':
return `Read note with ID: ${args.noteId || 'unknown'}`;
case 'keyword_search':
return `Search for keyword: "${args.query || 'unknown query'}"`;
case 'attribute_search':
return `Search for ${args.attributeType || 'attribute'}: "${args.attributeName || 'unknown'}"`;
case 'note_by_path':
return `Find note at path: "${args.path || 'unknown path'}"`;
case 'discover_tools':
return `Discover available tools`;
default:
return `Execute ${toolName} with provided parameters`;
}
}
/**
* Execute tool directly without confirmation
*/
private async executeToolDirectly(
toolCall: any,
input: ToolExecutionInput
): Promise<{ toolCall: any, message: Message }> {
const toolName = toolCall.function.name;
try {
// Import and use tool registry
const toolRegistry = (await import('../../tools/tool_registry.js')).default;
const tool = toolRegistry.getTool(toolName);
if (!tool) {
throw new Error(`Tool not found: ${toolName}`);
}
const args = this.parseToolArguments(toolCall.function.arguments);
const result = await tool.execute(args);
return {
toolCall,
message: {
role: 'tool',
content: typeof result === 'string' ? result : JSON.stringify(result, null, 2),
name: toolName,
tool_call_id: toolCall.id
}
};
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
log.error(`Error executing tool ${toolName}: ${errorMessage}`);
return {
toolCall,
message: {
role: 'tool',
content: `Error: ${errorMessage}`,
name: toolName,
tool_call_id: toolCall.id
}
};
}
}
/**
* Parse tool arguments safely
*/
private parseToolArguments(args: string | Record<string, unknown>): Record<string, unknown> {
if (typeof args === 'string') {
try {
return JSON.parse(args);
} catch {
return { query: args };
}
}
return args;
}
/**
* Generate unique interaction ID
*/
private generateInteractionId(): string {
return `interaction_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
}
/**
* Handle user response to confirmation request
* This method would be called by the frontend/WebSocket handler
*/
public handleUserResponse(interactionId: string, response: 'confirm' | 'cancel'): boolean {
const callback = this.interactionCallbacks.get(interactionId);
if (callback) {
log.info(`Received user response for interaction ${interactionId}: ${response}`);
callback(response);
return true;
}
log.error(`No callback found for interaction ${interactionId}`);
return false;
}
/**
* Cancel all pending interactions
*/
public cancelAllPendingInteractions(): void {
log.info(`Cancelling ${this.pendingInteractions.size} pending interactions`);
for (const [id, interaction] of this.pendingInteractions.entries()) {
if (interaction.timeoutHandle) {
clearTimeout(interaction.timeoutHandle);
}
const callback = this.interactionCallbacks.get(id);
if (callback && !interaction.resolved) {
callback('cancel');
}
}
this.pendingInteractions.clear();
this.interactionCallbacks.clear();
}
/**
* Get pending interactions (for status monitoring)
*/
public getPendingInteractions(): Array<{ id: string, toolName: string, timestamp: number }> {
return Array.from(this.pendingInteractions.values()).map(interaction => ({
id: interaction.id,
toolName: interaction.toolCall.function.name,
timestamp: interaction.timestamp
}));
}
/**
* Update configuration
*/
public updateConfig(newConfig: Partial<UserInteractionConfig>): void {
this.config = { ...this.config, ...newConfig };
log.info(`User interaction configuration updated: ${JSON.stringify(newConfig)}`);
}
}