mirror of
https://github.com/zadam/trilium.git
synced 2025-10-26 07:46:30 +01:00
feat(llm): implement error recovery stage and implement better tool calling
This commit is contained in:
@@ -648,6 +648,90 @@ async function handleStreamingProcess(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /api/llm/interactions/{interactionId}/respond:
|
||||
* post:
|
||||
* summary: Respond to a user interaction request (confirm/cancel tool execution)
|
||||
* operationId: llm-interaction-respond
|
||||
* parameters:
|
||||
* - name: interactionId
|
||||
* in: path
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: The ID of the interaction to respond to
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* response:
|
||||
* type: string
|
||||
* enum: [confirm, cancel]
|
||||
* description: User's response to the interaction
|
||||
* responses:
|
||||
* '200':
|
||||
* description: Response processed successfully
|
||||
* '404':
|
||||
* description: Interaction not found
|
||||
* '400':
|
||||
* description: Invalid response
|
||||
* security:
|
||||
* - session: []
|
||||
* tags: ["llm"]
|
||||
*/
|
||||
async function respondToInteraction(req: Request, res: Response): Promise<void> {
|
||||
try {
|
||||
const interactionId = req.params.interactionId;
|
||||
const { response } = req.body;
|
||||
|
||||
if (!interactionId || !response) {
|
||||
res.status(400).json({
|
||||
success: false,
|
||||
error: 'Missing interactionId or response'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (response !== 'confirm' && response !== 'cancel') {
|
||||
res.status(400).json({
|
||||
success: false,
|
||||
error: 'Response must be either "confirm" or "cancel"'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Import the pipeline to access user interaction stage
|
||||
// Note: In a real implementation, you'd maintain a registry of active pipelines
|
||||
// For now, we'll send this via WebSocket to be handled by the active pipeline
|
||||
|
||||
const wsService = (await import('../../services/ws.js')).default;
|
||||
|
||||
// Send the user response via WebSocket to be picked up by the active pipeline
|
||||
wsService.sendMessageToAllClients({
|
||||
type: 'user-interaction-response',
|
||||
interactionId,
|
||||
response,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
res.status(200).json({
|
||||
success: true,
|
||||
message: `User response "${response}" recorded for interaction ${interactionId}`
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
log.error(`Error handling user interaction response: ${error}`);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Debug endpoint to check tool recognition and registry status
|
||||
*/
|
||||
@@ -748,6 +832,9 @@ export default {
|
||||
sendMessage,
|
||||
streamMessage,
|
||||
|
||||
// User interaction
|
||||
respondToInteraction,
|
||||
|
||||
// Debug endpoints
|
||||
debugTools
|
||||
};
|
||||
|
||||
@@ -31,12 +31,24 @@ export interface ToolData {
|
||||
}
|
||||
|
||||
export interface ToolExecutionInfo {
|
||||
type: 'start' | 'update' | 'complete' | 'error';
|
||||
type: 'start' | 'update' | 'complete' | 'error' | 'progress' | 'retry';
|
||||
action?: string;
|
||||
tool: {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
};
|
||||
result?: string | Record<string, unknown>;
|
||||
progress?: {
|
||||
current: number;
|
||||
total: number;
|
||||
status: string;
|
||||
message: string;
|
||||
startTime?: number;
|
||||
executionTime?: number;
|
||||
resultSummary?: string;
|
||||
errorType?: string;
|
||||
estimatedDuration?: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -80,6 +92,12 @@ export interface StreamChunk {
|
||||
* Includes tool name, args, and execution status
|
||||
*/
|
||||
toolExecution?: ToolExecutionInfo;
|
||||
|
||||
/**
|
||||
* User interaction data (for confirmation/cancellation requests)
|
||||
* Contains interaction ID, tool info, and response options
|
||||
*/
|
||||
userInteraction?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -211,6 +229,21 @@ export interface ChatResponse {
|
||||
|
||||
/** Tool calls from the LLM (if tools were used and the model supports them) */
|
||||
tool_calls?: ToolCall[] | null;
|
||||
|
||||
/** Recovery metadata for advanced error recovery */
|
||||
recovery_metadata?: {
|
||||
total_attempts: number;
|
||||
successful_recoveries: number;
|
||||
failed_permanently: number;
|
||||
};
|
||||
|
||||
/** User interaction metadata for confirmation/cancellation features */
|
||||
interaction_metadata?: {
|
||||
total_interactions: number;
|
||||
confirmed: number;
|
||||
cancelled: number;
|
||||
timedout: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface AIService {
|
||||
|
||||
@@ -8,6 +8,7 @@ import { ModelSelectionStage } from './stages/model_selection_stage.js';
|
||||
import { LLMCompletionStage } from './stages/llm_completion_stage.js';
|
||||
import { ResponseProcessingStage } from './stages/response_processing_stage.js';
|
||||
import { ToolCallingStage } from './stages/tool_calling_stage.js';
|
||||
import { ErrorRecoveryStage } from './stages/error_recovery_stage.js';
|
||||
// Traditional search is used instead of vector search
|
||||
import toolRegistry from '../tools/tool_registry.js';
|
||||
import toolInitializer from '../tools/tool_initializer.js';
|
||||
@@ -29,6 +30,7 @@ export class ChatPipeline {
|
||||
llmCompletion: LLMCompletionStage;
|
||||
responseProcessing: ResponseProcessingStage;
|
||||
toolCalling: ToolCallingStage;
|
||||
errorRecovery: ErrorRecoveryStage;
|
||||
// traditional search is used instead of vector search
|
||||
};
|
||||
|
||||
@@ -50,6 +52,7 @@ export class ChatPipeline {
|
||||
llmCompletion: new LLMCompletionStage(),
|
||||
responseProcessing: new ResponseProcessingStage(),
|
||||
toolCalling: new ToolCallingStage(),
|
||||
errorRecovery: new ErrorRecoveryStage(),
|
||||
// traditional search is used instead of vector search
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,561 @@
|
||||
import { BasePipelineStage } from '../pipeline_stage.js';
|
||||
import type { ToolExecutionInput, StreamCallback } from '../interfaces.js';
|
||||
import type { ChatResponse, Message } from '../../ai_interface.js';
|
||||
import toolRegistry from '../../tools/tool_registry.js';
|
||||
import log from '../../../log.js';
|
||||
|
||||
interface RetryStrategy {
|
||||
maxRetries: number;
|
||||
baseDelay: number;
|
||||
maxDelay: number;
|
||||
backoffMultiplier: number;
|
||||
jitter: boolean;
|
||||
}
|
||||
|
||||
interface ToolRetryContext {
|
||||
toolName: string;
|
||||
attempt: number;
|
||||
lastError: string;
|
||||
alternativeApproaches: string[];
|
||||
usedApproaches: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Advanced Error Recovery Pipeline Stage
|
||||
* Implements sophisticated retry strategies with exponential backoff,
|
||||
* alternative tool selection, and intelligent fallback mechanisms
|
||||
*/
|
||||
export class ErrorRecoveryStage extends BasePipelineStage<ToolExecutionInput, { response: ChatResponse, needsFollowUp: boolean, messages: Message[] }> {
|
||||
|
||||
private retryStrategies: Map<string, RetryStrategy> = new Map();
|
||||
private activeRetries: Map<string, ToolRetryContext> = new Map();
|
||||
|
||||
constructor() {
|
||||
super('ErrorRecovery');
|
||||
this.initializeRetryStrategies();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize retry strategies for different tool types
|
||||
*/
|
||||
private initializeRetryStrategies(): void {
|
||||
// Search tools - more aggressive retries since they're critical
|
||||
this.retryStrategies.set('search_notes', {
|
||||
maxRetries: 3,
|
||||
baseDelay: 1000,
|
||||
maxDelay: 8000,
|
||||
backoffMultiplier: 2,
|
||||
jitter: true
|
||||
});
|
||||
|
||||
this.retryStrategies.set('keyword_search', {
|
||||
maxRetries: 3,
|
||||
baseDelay: 800,
|
||||
maxDelay: 6000,
|
||||
backoffMultiplier: 2,
|
||||
jitter: true
|
||||
});
|
||||
|
||||
// Read operations - moderate retries
|
||||
this.retryStrategies.set('read_note', {
|
||||
maxRetries: 2,
|
||||
baseDelay: 500,
|
||||
maxDelay: 3000,
|
||||
backoffMultiplier: 2,
|
||||
jitter: false
|
||||
});
|
||||
|
||||
// Attribute operations - conservative retries
|
||||
this.retryStrategies.set('attribute_search', {
|
||||
maxRetries: 2,
|
||||
baseDelay: 1200,
|
||||
maxDelay: 5000,
|
||||
backoffMultiplier: 1.8,
|
||||
jitter: true
|
||||
});
|
||||
|
||||
// Default strategy for unknown tools
|
||||
this.retryStrategies.set('default', {
|
||||
maxRetries: 2,
|
||||
baseDelay: 1000,
|
||||
maxDelay: 4000,
|
||||
backoffMultiplier: 2,
|
||||
jitter: true
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Process tool execution with advanced error recovery
|
||||
*/
|
||||
protected async process(input: ToolExecutionInput): Promise<{ response: ChatResponse, needsFollowUp: boolean, messages: Message[] }> {
|
||||
const { response } = input;
|
||||
|
||||
// If no tool calls, pass through
|
||||
if (!response.tool_calls || response.tool_calls.length === 0) {
|
||||
return { response, needsFollowUp: false, messages: input.messages };
|
||||
}
|
||||
|
||||
log.info(`========== ERROR RECOVERY STAGE PROCESSING ==========`);
|
||||
log.info(`Processing ${response.tool_calls.length} tool calls with advanced error recovery`);
|
||||
|
||||
const recoveredToolCalls = [];
|
||||
const updatedMessages = [...input.messages];
|
||||
|
||||
// Process each tool call with recovery
|
||||
for (let i = 0; i < response.tool_calls.length; i++) {
|
||||
const toolCall = response.tool_calls[i];
|
||||
const recoveredResult = await this.executeToolWithRecovery(toolCall, input, i);
|
||||
|
||||
if (recoveredResult) {
|
||||
recoveredToolCalls.push(recoveredResult);
|
||||
updatedMessages.push(recoveredResult.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Create enhanced response with recovery information
|
||||
const enhancedResponse: ChatResponse = {
|
||||
...response,
|
||||
tool_calls: recoveredToolCalls.map(r => r.toolCall),
|
||||
recovery_metadata: {
|
||||
total_attempts: recoveredToolCalls.reduce((sum, r) => sum + r.attempts, 0),
|
||||
successful_recoveries: recoveredToolCalls.filter(r => r.recovered).length,
|
||||
failed_permanently: recoveredToolCalls.filter(r => !r.recovered).length
|
||||
}
|
||||
};
|
||||
|
||||
const needsFollowUp = recoveredToolCalls.length > 0;
|
||||
|
||||
log.info(`Recovery complete: ${recoveredToolCalls.filter(r => r.recovered).length}/${recoveredToolCalls.length} tools recovered`);
|
||||
|
||||
return {
|
||||
response: enhancedResponse,
|
||||
needsFollowUp,
|
||||
messages: updatedMessages
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a tool call with comprehensive error recovery
|
||||
*/
|
||||
private async executeToolWithRecovery(
|
||||
toolCall: any,
|
||||
input: ToolExecutionInput,
|
||||
index: number
|
||||
): Promise<{ toolCall: any, message: Message, attempts: number, recovered: boolean } | null> {
|
||||
|
||||
const toolName = toolCall.function.name;
|
||||
const strategy = this.retryStrategies.get(toolName) || this.retryStrategies.get('default')!;
|
||||
|
||||
let lastError = '';
|
||||
let attempts = 0;
|
||||
let recovered = false;
|
||||
|
||||
// Initialize retry context
|
||||
const retryContext: ToolRetryContext = {
|
||||
toolName,
|
||||
attempt: 0,
|
||||
lastError: '',
|
||||
alternativeApproaches: this.getAlternativeApproaches(toolName),
|
||||
usedApproaches: []
|
||||
};
|
||||
|
||||
log.info(`Starting error recovery for tool: ${toolName} (max retries: ${strategy.maxRetries})`);
|
||||
|
||||
// Primary execution attempts
|
||||
for (attempts = 1; attempts <= strategy.maxRetries + 1; attempts++) {
|
||||
try {
|
||||
retryContext.attempt = attempts;
|
||||
|
||||
// Add delay for retry attempts (not first attempt)
|
||||
if (attempts > 1) {
|
||||
const delay = this.calculateDelay(strategy, attempts - 1);
|
||||
log.info(`Retry attempt ${attempts - 1} for ${toolName} after ${delay}ms delay`);
|
||||
await this.sleep(delay);
|
||||
|
||||
// Send retry notification if streaming
|
||||
if (input.streamCallback) {
|
||||
this.sendRetryNotification(input.streamCallback, toolName, attempts - 1, strategy.maxRetries);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
const tool = toolRegistry.getTool(toolName);
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolName}`);
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
// Modify arguments for retry if needed
|
||||
const modifiedArgs = this.modifyArgsForRetry(args, retryContext);
|
||||
|
||||
log.info(`Executing ${toolName} (attempt ${attempts}) with args: ${JSON.stringify(modifiedArgs)}`);
|
||||
|
||||
const result = await tool.execute(modifiedArgs);
|
||||
|
||||
// Success!
|
||||
recovered = true;
|
||||
log.info(`✓ Tool ${toolName} succeeded on attempt ${attempts}`);
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: typeof result === 'string' ? result : JSON.stringify(result, null, 2),
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
},
|
||||
attempts,
|
||||
recovered: true
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
lastError = error instanceof Error ? error.message : String(error);
|
||||
retryContext.lastError = lastError;
|
||||
|
||||
log.info(`✗ Tool ${toolName} failed on attempt ${attempts}: ${lastError}`);
|
||||
|
||||
// If this was the last allowed attempt, break
|
||||
if (attempts > strategy.maxRetries) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Primary attempts failed, try alternative approaches
|
||||
log.info(`Primary attempts failed for ${toolName}, trying alternative approaches`);
|
||||
|
||||
for (const alternative of retryContext.alternativeApproaches) {
|
||||
if (retryContext.usedApproaches.includes(alternative)) {
|
||||
continue; // Skip already used approaches
|
||||
}
|
||||
|
||||
try {
|
||||
log.info(`Trying alternative approach: ${alternative} for ${toolName}`);
|
||||
retryContext.usedApproaches.push(alternative);
|
||||
|
||||
const alternativeResult = await this.executeAlternativeApproach(alternative, toolCall, retryContext);
|
||||
|
||||
if (alternativeResult) {
|
||||
log.info(`✓ Alternative approach ${alternative} succeeded for ${toolName}`);
|
||||
recovered = true;
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: `ALTERNATIVE_SUCCESS: ${alternative} succeeded where ${toolName} failed. Result: ${alternativeResult}`,
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
},
|
||||
attempts: attempts + 1,
|
||||
recovered: true
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
const altError = error instanceof Error ? error.message : String(error);
|
||||
log.info(`✗ Alternative approach ${alternative} failed: ${altError}`);
|
||||
}
|
||||
}
|
||||
|
||||
// All attempts failed
|
||||
log.error(`All recovery attempts failed for ${toolName} after ${attempts} attempts and ${retryContext.usedApproaches.length} alternatives`);
|
||||
|
||||
// Return failure message with guidance
|
||||
const failureGuidance = this.generateFailureGuidance(toolName, lastError, retryContext);
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: `RECOVERY_FAILED: Tool ${toolName} failed after ${attempts} attempts and ${retryContext.usedApproaches.length} alternative approaches. Last error: ${lastError}\n\n${failureGuidance}`,
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
},
|
||||
attempts,
|
||||
recovered: false
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate retry delay with exponential backoff and optional jitter
|
||||
*/
|
||||
private calculateDelay(strategy: RetryStrategy, retryNumber: number): number {
|
||||
let delay = strategy.baseDelay * Math.pow(strategy.backoffMultiplier, retryNumber - 1);
|
||||
|
||||
// Apply maximum delay limit
|
||||
delay = Math.min(delay, strategy.maxDelay);
|
||||
|
||||
// Add jitter if enabled (±25% random variation)
|
||||
if (strategy.jitter) {
|
||||
const jitterRange = delay * 0.25;
|
||||
const jitter = (Math.random() - 0.5) * 2 * jitterRange;
|
||||
delay += jitter;
|
||||
}
|
||||
|
||||
return Math.round(Math.max(delay, 100)); // Minimum 100ms delay
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for specified milliseconds
|
||||
*/
|
||||
private sleep(ms: number): Promise<void> {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get alternative approaches for a tool
|
||||
*/
|
||||
private getAlternativeApproaches(toolName: string): string[] {
|
||||
const alternatives: Record<string, string[]> = {
|
||||
'search_notes': ['keyword_search', 'broader_search_terms', 'attribute_search'],
|
||||
'keyword_search': ['search_notes', 'simplified_query', 'attribute_search'],
|
||||
'attribute_search': ['search_notes', 'keyword_search', 'different_attribute_type'],
|
||||
'read_note': ['note_by_path', 'search_and_read', 'template_search'],
|
||||
'note_by_path': ['read_note', 'search_notes', 'keyword_search']
|
||||
};
|
||||
|
||||
return alternatives[toolName] || ['search_notes', 'keyword_search'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Modify arguments for retry attempts
|
||||
*/
|
||||
private modifyArgsForRetry(args: Record<string, unknown>, context: ToolRetryContext): Record<string, unknown> {
|
||||
const modified = { ...args };
|
||||
|
||||
// For search tools, broaden the query on retries
|
||||
if (context.toolName.includes('search') && context.attempt > 1) {
|
||||
if (modified.query && typeof modified.query === 'string') {
|
||||
// Remove quotes and qualifiers to broaden the search
|
||||
modified.query = (modified.query as string)
|
||||
.replace(/['"]/g, '') // Remove quotes
|
||||
.replace(/\b(exactly|specific|precise)\b/gi, '') // Remove limiting words
|
||||
.trim();
|
||||
|
||||
log.info(`Modified query for retry: "${modified.query}"`);
|
||||
}
|
||||
}
|
||||
|
||||
// For attribute search, try different attribute types
|
||||
if (context.toolName === 'attribute_search' && context.attempt > 1) {
|
||||
if (modified.attributeType === 'label') {
|
||||
modified.attributeType = 'relation';
|
||||
} else if (modified.attributeType === 'relation') {
|
||||
modified.attributeType = 'label';
|
||||
}
|
||||
|
||||
log.info(`Modified attributeType for retry: ${modified.attributeType}`);
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute alternative approach
|
||||
*/
|
||||
private async executeAlternativeApproach(
|
||||
approach: string,
|
||||
originalToolCall: any,
|
||||
context: ToolRetryContext
|
||||
): Promise<string | null> {
|
||||
|
||||
switch (approach) {
|
||||
case 'broader_search_terms':
|
||||
return await this.executeBroaderSearch(originalToolCall);
|
||||
|
||||
case 'simplified_query':
|
||||
return await this.executeSimplifiedSearch(originalToolCall);
|
||||
|
||||
case 'different_attribute_type':
|
||||
return await this.executeDifferentAttributeSearch(originalToolCall);
|
||||
|
||||
case 'search_and_read':
|
||||
return await this.executeSearchAndRead(originalToolCall);
|
||||
|
||||
default:
|
||||
// Try to execute the alternative tool directly
|
||||
return await this.executeAlternativeTool(approach, originalToolCall);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute broader search approach
|
||||
*/
|
||||
private async executeBroaderSearch(toolCall: any): Promise<string | null> {
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
if (args.query && typeof args.query === 'string') {
|
||||
// Extract the main keywords and search more broadly
|
||||
const keywords = (args.query as string)
|
||||
.split(' ')
|
||||
.filter(word => word.length > 3)
|
||||
.slice(0, 3) // Take only first 3 main keywords
|
||||
.join(' ');
|
||||
|
||||
const broadArgs = { ...args, query: keywords };
|
||||
|
||||
const tool = toolRegistry.getTool('search_notes');
|
||||
if (tool) {
|
||||
const result = await tool.execute(broadArgs);
|
||||
return typeof result === 'string' ? result : JSON.stringify(result);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute simplified search approach
|
||||
*/
|
||||
private async executeSimplifiedSearch(toolCall: any): Promise<string | null> {
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
if (args.query && typeof args.query === 'string') {
|
||||
// Use only the first word as a very simple search
|
||||
const firstWord = (args.query as string).split(' ')[0];
|
||||
const simpleArgs = { ...args, query: firstWord };
|
||||
|
||||
const tool = toolRegistry.getTool('keyword_search');
|
||||
if (tool) {
|
||||
const result = await tool.execute(simpleArgs);
|
||||
return typeof result === 'string' ? result : JSON.stringify(result);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute different attribute search
|
||||
*/
|
||||
private async executeDifferentAttributeSearch(toolCall: any): Promise<string | null> {
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
if (args.attributeType) {
|
||||
const newType = args.attributeType === 'label' ? 'relation' : 'label';
|
||||
const newArgs = { ...args, attributeType: newType };
|
||||
|
||||
const tool = toolRegistry.getTool('attribute_search');
|
||||
if (tool) {
|
||||
const result = await tool.execute(newArgs);
|
||||
return typeof result === 'string' ? result : JSON.stringify(result);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute search and read approach
|
||||
*/
|
||||
private async executeSearchAndRead(toolCall: any): Promise<string | null> {
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
// First search for notes
|
||||
const searchTool = toolRegistry.getTool('search_notes');
|
||||
if (searchTool && args.query) {
|
||||
try {
|
||||
const searchResult = await searchTool.execute({ query: args.query });
|
||||
|
||||
// Try to extract note IDs and read the first one
|
||||
const searchText = typeof searchResult === 'string' ? searchResult : JSON.stringify(searchResult);
|
||||
const noteIdMatch = searchText.match(/note[:\s]+([a-zA-Z0-9]+)/i);
|
||||
|
||||
if (noteIdMatch && noteIdMatch[1]) {
|
||||
const readTool = toolRegistry.getTool('read_note');
|
||||
if (readTool) {
|
||||
const readResult = await readTool.execute({ noteId: noteIdMatch[1] });
|
||||
return `SEARCH_AND_READ: Found and read note ${noteIdMatch[1]}. Content: ${readResult}`;
|
||||
}
|
||||
}
|
||||
|
||||
return `SEARCH_ONLY: ${searchText}`;
|
||||
} catch (error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute alternative tool
|
||||
*/
|
||||
private async executeAlternativeTool(toolName: string, originalToolCall: any): Promise<string | null> {
|
||||
const tool = toolRegistry.getTool(toolName);
|
||||
if (!tool) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const args = this.parseToolArguments(originalToolCall.function.arguments);
|
||||
|
||||
try {
|
||||
const result = await tool.execute(args);
|
||||
return typeof result === 'string' ? result : JSON.stringify(result);
|
||||
} catch (error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse tool arguments safely
|
||||
*/
|
||||
private parseToolArguments(args: string | Record<string, unknown>): Record<string, unknown> {
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
return JSON.parse(args);
|
||||
} catch {
|
||||
return { query: args };
|
||||
}
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send retry notification via streaming
|
||||
*/
|
||||
private sendRetryNotification(
|
||||
streamCallback: StreamCallback,
|
||||
toolName: string,
|
||||
retryNumber: number,
|
||||
maxRetries: number
|
||||
): void {
|
||||
streamCallback('', false, {
|
||||
text: '',
|
||||
done: false,
|
||||
toolExecution: {
|
||||
type: 'retry',
|
||||
action: 'retry',
|
||||
tool: { name: toolName, arguments: {} },
|
||||
progress: {
|
||||
current: retryNumber,
|
||||
total: maxRetries,
|
||||
status: 'retrying',
|
||||
message: `Retrying ${toolName} (attempt ${retryNumber}/${maxRetries})...`
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate failure guidance
|
||||
*/
|
||||
private generateFailureGuidance(toolName: string, lastError: string, context: ToolRetryContext): string {
|
||||
const guidance = [
|
||||
`RECOVERY ANALYSIS for ${toolName}:`,
|
||||
`- Primary attempts: ${context.attempt}`,
|
||||
`- Alternative approaches tried: ${context.usedApproaches.join(', ') || 'none'}`,
|
||||
`- Last error: ${lastError}`,
|
||||
'',
|
||||
'SUGGESTED NEXT STEPS:',
|
||||
'- Try manual search with broader terms',
|
||||
'- Check if the requested information exists',
|
||||
'- Use discover_tools to find alternative tools',
|
||||
'- Reformulate the query with different keywords'
|
||||
];
|
||||
|
||||
return guidance.join('\n');
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,11 @@ export class MessagePreparationStage extends BasePipelineStage<MessagePreparatio
|
||||
const toolsEnabled = options?.enableTools === true;
|
||||
|
||||
log.info(`Preparing messages for provider: ${provider}, context: ${!!context}, system prompt: ${!!systemPrompt}, tools: ${toolsEnabled}`);
|
||||
log.info(`Input message count: ${messages.length}`);
|
||||
|
||||
// Apply intelligent context management for long conversations
|
||||
const managedMessages = await this.applyContextManagement(messages, provider, options);
|
||||
log.info(`After context management: ${managedMessages.length} messages (reduced by ${messages.length - managedMessages.length})`);
|
||||
|
||||
// Get appropriate formatter for this provider
|
||||
const formatter = MessageFormatterFactory.getFormatter(provider);
|
||||
@@ -69,13 +74,171 @@ Remember: Tool usage should be continuous and iterative until you have thoroughl
|
||||
|
||||
// Format messages using provider-specific approach
|
||||
const formattedMessages = formatter.formatMessages(
|
||||
messages,
|
||||
managedMessages,
|
||||
finalSystemPrompt,
|
||||
context
|
||||
);
|
||||
|
||||
log.info(`Formatted ${messages.length} messages into ${formattedMessages.length} messages for provider: ${provider}`);
|
||||
log.info(`Formatted ${managedMessages.length} messages into ${formattedMessages.length} messages for provider: ${provider}`);
|
||||
|
||||
return { messages: formattedMessages };
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply intelligent context management to handle long conversations
|
||||
* Implements various strategies like sliding window, summarization, and importance-based pruning
|
||||
*/
|
||||
private async applyContextManagement(messages: Message[], provider: string, options?: any): Promise<Message[]> {
|
||||
const maxMessages = this.getMaxMessagesForProvider(provider);
|
||||
|
||||
// If we're under the limit, return as-is
|
||||
if (messages.length <= maxMessages) {
|
||||
log.info(`Message count (${messages.length}) within limit (${maxMessages}), no context management needed`);
|
||||
return messages;
|
||||
}
|
||||
|
||||
log.info(`Message count (${messages.length}) exceeds limit (${maxMessages}), applying context management`);
|
||||
|
||||
// Strategy 1: Preserve recent messages and important system/tool messages
|
||||
const managedMessages = await this.applySlidingWindowWithImportanceFiltering(messages, maxMessages);
|
||||
|
||||
// Strategy 2: If still too many, apply summarization to older messages
|
||||
if (managedMessages.length > maxMessages) {
|
||||
return await this.applySummarizationToOlderMessages(managedMessages, maxMessages);
|
||||
}
|
||||
|
||||
return managedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get maximum message count for different providers based on their context windows
|
||||
*/
|
||||
private getMaxMessagesForProvider(provider: string): number {
|
||||
const limits = {
|
||||
'anthropic': 50, // Conservative for Claude's context window management
|
||||
'openai': 40, // Conservative for GPT models
|
||||
'ollama': 30, // More conservative for local models
|
||||
'default': 35 // Safe default
|
||||
};
|
||||
|
||||
return limits[provider as keyof typeof limits] || limits.default;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply sliding window with importance filtering
|
||||
* Keeps recent messages and important system/tool messages
|
||||
*/
|
||||
private async applySlidingWindowWithImportanceFiltering(messages: Message[], maxMessages: number): Promise<Message[]> {
|
||||
if (messages.length <= maxMessages) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Always preserve the first system message if it exists
|
||||
const systemMessages = messages.filter(msg => msg.role === 'system').slice(0, 1);
|
||||
|
||||
// Find tool-related messages that are important to preserve
|
||||
const toolMessages = messages.filter(msg =>
|
||||
msg.role === 'tool' ||
|
||||
(msg.role === 'assistant' && msg.tool_calls && msg.tool_calls.length > 0)
|
||||
);
|
||||
|
||||
// Calculate how many recent messages we can keep
|
||||
const preservedCount = systemMessages.length;
|
||||
const recentMessageCount = Math.min(maxMessages - preservedCount, messages.length);
|
||||
|
||||
// Get the most recent messages
|
||||
const recentMessages = messages.slice(-recentMessageCount);
|
||||
|
||||
// Combine system messages + recent messages, avoiding duplicates
|
||||
const result: Message[] = [];
|
||||
|
||||
// Add system messages first
|
||||
systemMessages.forEach(msg => {
|
||||
if (!result.some(existing => existing === msg)) {
|
||||
result.push(msg);
|
||||
}
|
||||
});
|
||||
|
||||
// Add recent messages
|
||||
recentMessages.forEach(msg => {
|
||||
if (!result.some(existing => existing === msg)) {
|
||||
result.push(msg);
|
||||
}
|
||||
});
|
||||
|
||||
log.info(`Sliding window filtering: preserved ${preservedCount} system messages, kept ${recentMessages.length} recent messages`);
|
||||
|
||||
return result.slice(0, maxMessages); // Ensure we don't exceed the limit
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply summarization to older messages when needed
|
||||
* Summarizes conversation segments to reduce token count while preserving context
|
||||
*/
|
||||
private async applySummarizationToOlderMessages(messages: Message[], maxMessages: number): Promise<Message[]> {
|
||||
if (messages.length <= maxMessages) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Keep recent messages (last 60% of limit)
|
||||
const recentCount = Math.floor(maxMessages * 0.6);
|
||||
const recentMessages = messages.slice(-recentCount);
|
||||
|
||||
// Get older messages to summarize
|
||||
const olderMessages = messages.slice(0, messages.length - recentCount);
|
||||
|
||||
// Create a summary of older messages
|
||||
const summary = this.createConversationSummary(olderMessages);
|
||||
|
||||
// Create a summary message
|
||||
const summaryMessage: Message = {
|
||||
role: 'system',
|
||||
content: `CONVERSATION SUMMARY: Previous conversation included ${olderMessages.length} messages. Key points: ${summary}`
|
||||
};
|
||||
|
||||
log.info(`Applied summarization: summarized ${olderMessages.length} older messages, kept ${recentMessages.length} recent messages`);
|
||||
|
||||
return [summaryMessage, ...recentMessages];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a concise summary of conversation messages
|
||||
*/
|
||||
private createConversationSummary(messages: Message[]): string {
|
||||
const userQueries: string[] = [];
|
||||
const assistantActions: string[] = [];
|
||||
const toolUsage: string[] = [];
|
||||
|
||||
messages.forEach(msg => {
|
||||
if (msg.role === 'user') {
|
||||
// Extract key topics from user messages
|
||||
const content = msg.content?.substring(0, 100) || '';
|
||||
if (content.trim()) {
|
||||
userQueries.push(content.trim());
|
||||
}
|
||||
} else if (msg.role === 'assistant') {
|
||||
// Track tool usage
|
||||
if (msg.tool_calls && msg.tool_calls.length > 0) {
|
||||
msg.tool_calls.forEach(tool => {
|
||||
if (tool.function?.name) {
|
||||
toolUsage.push(tool.function.name);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const summary: string[] = [];
|
||||
|
||||
if (userQueries.length > 0) {
|
||||
summary.push(`User asked about: ${userQueries.slice(0, 3).join(', ')}`);
|
||||
}
|
||||
|
||||
if (toolUsage.length > 0) {
|
||||
const uniqueTools = [...new Set(toolUsage)];
|
||||
summary.push(`Tools used: ${uniqueTools.slice(0, 5).join(', ')}`);
|
||||
}
|
||||
|
||||
return summary.join('. ') || 'General conversation about notes and information retrieval';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +256,14 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
|
||||
name: toolCall.function.name,
|
||||
arguments: args
|
||||
},
|
||||
type: 'start' as const
|
||||
type: 'start' as const,
|
||||
progress: {
|
||||
current: index + 1,
|
||||
total: response.tool_calls?.length || 1,
|
||||
status: 'initializing',
|
||||
message: `Starting ${toolCall.function.name} execution...`,
|
||||
estimatedDuration: this.getEstimatedDuration(toolCall.function.name)
|
||||
}
|
||||
};
|
||||
|
||||
// Don't wait for this to complete, but log any errors
|
||||
@@ -274,6 +281,35 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
|
||||
let result;
|
||||
try {
|
||||
log.info(`Starting tool execution for ${toolCall.function.name}...`);
|
||||
|
||||
// Send progress update during execution
|
||||
if (streamCallback) {
|
||||
const progressData = {
|
||||
action: 'progress',
|
||||
tool: {
|
||||
name: toolCall.function.name,
|
||||
arguments: args
|
||||
},
|
||||
type: 'progress' as const,
|
||||
progress: {
|
||||
current: index + 1,
|
||||
total: response.tool_calls?.length || 1,
|
||||
status: 'executing',
|
||||
message: `Executing ${toolCall.function.name}...`,
|
||||
startTime: executionStart
|
||||
}
|
||||
};
|
||||
|
||||
const progressResult = streamCallback('', false, {
|
||||
text: '',
|
||||
done: false,
|
||||
toolExecution: progressData
|
||||
});
|
||||
if (progressResult instanceof Promise) {
|
||||
progressResult.catch((e: Error) => log.error(`Error sending tool execution progress event: ${e.message}`));
|
||||
}
|
||||
}
|
||||
|
||||
result = await tool.execute(args);
|
||||
const executionTime = Date.now() - executionStart;
|
||||
log.info(`================ TOOL EXECUTION COMPLETED in ${executionTime}ms ================`);
|
||||
@@ -296,6 +332,10 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
|
||||
|
||||
// Emit tool completion event if streaming is enabled
|
||||
if (streamCallback) {
|
||||
const resultSummary = typeof result === 'string'
|
||||
? result.substring(0, 200) + (result.length > 200 ? '...' : '')
|
||||
: `Object with ${Object.keys(result).length} properties`;
|
||||
|
||||
const toolExecutionData = {
|
||||
action: 'complete',
|
||||
tool: {
|
||||
@@ -303,7 +343,15 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
|
||||
arguments: {} as Record<string, unknown>
|
||||
},
|
||||
result: typeof result === 'string' ? result : result as Record<string, unknown>,
|
||||
type: 'complete' as const
|
||||
type: 'complete' as const,
|
||||
progress: {
|
||||
current: index + 1,
|
||||
total: response.tool_calls?.length || 1,
|
||||
status: 'completed',
|
||||
message: `${toolCall.function.name} completed successfully`,
|
||||
executionTime: executionTime,
|
||||
resultSummary: resultSummary
|
||||
}
|
||||
};
|
||||
|
||||
// Don't wait for this to complete, but log any errors
|
||||
@@ -352,7 +400,15 @@ export class ToolCallingStage extends BasePipelineStage<ToolExecutionInput, { re
|
||||
arguments: {} as Record<string, unknown>
|
||||
},
|
||||
error: enhancedErrorMessage, // Include guidance in the error message
|
||||
type: 'error' as const
|
||||
type: 'error' as const,
|
||||
progress: {
|
||||
current: index + 1,
|
||||
total: response.tool_calls?.length || 1,
|
||||
status: 'failed',
|
||||
message: `${toolCall.function.name} failed: ${errorMessage.substring(0, 100)}...`,
|
||||
executionTime: executionTime,
|
||||
errorType: execError instanceof Error ? execError.constructor.name : 'UnknownError'
|
||||
}
|
||||
};
|
||||
|
||||
// Don't wait for this to complete, but log any errors
|
||||
@@ -631,6 +687,26 @@ Continue your systematic investigation now.`;
|
||||
return guidance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get estimated duration for a tool execution (in milliseconds)
|
||||
* @param toolName The name of the tool
|
||||
* @returns Estimated duration in milliseconds
|
||||
*/
|
||||
private getEstimatedDuration(toolName: string): number {
|
||||
// Tool-specific duration estimates based on typical execution times
|
||||
const estimations = {
|
||||
'search_notes': 2000,
|
||||
'read_note': 1000,
|
||||
'keyword_search': 1500,
|
||||
'attribute_search': 1200,
|
||||
'discover_tools': 500,
|
||||
'note_by_path': 800,
|
||||
'template_search': 1000
|
||||
};
|
||||
|
||||
return estimations[toolName as keyof typeof estimations] || 1500; // Default 1.5 seconds
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a tool result is effectively empty or unhelpful
|
||||
* @param result The result from the tool execution
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
import { BasePipelineStage } from '../pipeline_stage.js';
|
||||
import type { ToolExecutionInput, StreamCallback } from '../interfaces.js';
|
||||
import type { ChatResponse, Message } from '../../ai_interface.js';
|
||||
import log from '../../../log.js';
|
||||
|
||||
interface UserInteractionConfig {
|
||||
enableConfirmation: boolean;
|
||||
enableCancellation: boolean;
|
||||
confirmationTimeout: number; // milliseconds
|
||||
autoConfirmLowRisk: boolean;
|
||||
requiredConfirmationTools: string[];
|
||||
}
|
||||
|
||||
interface PendingInteraction {
|
||||
id: string;
|
||||
toolCall: any;
|
||||
timestamp: number;
|
||||
timeoutHandle?: NodeJS.Timeout;
|
||||
resolved: boolean;
|
||||
}
|
||||
|
||||
type InteractionResponse = 'confirm' | 'cancel' | 'timeout';
|
||||
|
||||
/**
|
||||
* Enhanced User Interaction Pipeline Stage
|
||||
* Provides real-time confirmation/cancellation capabilities for tool execution
|
||||
*/
|
||||
export class UserInteractionStage extends BasePipelineStage<ToolExecutionInput, { response: ChatResponse, needsFollowUp: boolean, messages: Message[], userInteractions?: any[] }> {
|
||||
|
||||
private config: UserInteractionConfig;
|
||||
private pendingInteractions: Map<string, PendingInteraction> = new Map();
|
||||
private interactionCallbacks: Map<string, (response: InteractionResponse) => void> = new Map();
|
||||
|
||||
constructor(config?: Partial<UserInteractionConfig>) {
|
||||
super('UserInteraction');
|
||||
|
||||
this.config = {
|
||||
enableConfirmation: true,
|
||||
enableCancellation: true,
|
||||
confirmationTimeout: 15000, // 15 seconds
|
||||
autoConfirmLowRisk: true,
|
||||
requiredConfirmationTools: ['attribute_search', 'read_note'],
|
||||
...config
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Process tool execution with user interaction capabilities
|
||||
*/
|
||||
protected async process(input: ToolExecutionInput): Promise<{ response: ChatResponse, needsFollowUp: boolean, messages: Message[], userInteractions?: any[] }> {
|
||||
const { response } = input;
|
||||
|
||||
// If no tool calls or interactions disabled, pass through
|
||||
if (!response.tool_calls || response.tool_calls.length === 0 || !this.config.enableConfirmation) {
|
||||
return {
|
||||
response,
|
||||
needsFollowUp: false,
|
||||
messages: input.messages,
|
||||
userInteractions: []
|
||||
};
|
||||
}
|
||||
|
||||
log.info(`========== USER INTERACTION STAGE PROCESSING ==========`);
|
||||
log.info(`Processing ${response.tool_calls.length} tool calls with user interaction controls`);
|
||||
|
||||
const processedToolCalls: any[] = [];
|
||||
const userInteractions: any[] = [];
|
||||
const updatedMessages = [...input.messages];
|
||||
|
||||
// Process each tool call with interaction controls
|
||||
for (let i = 0; i < response.tool_calls.length; i++) {
|
||||
const toolCall = response.tool_calls[i];
|
||||
|
||||
const interactionResult = await this.processToolCallWithInteraction(toolCall, input, i);
|
||||
|
||||
if (interactionResult) {
|
||||
processedToolCalls.push(interactionResult.toolCall);
|
||||
updatedMessages.push(interactionResult.message);
|
||||
|
||||
if (interactionResult.interaction) {
|
||||
userInteractions.push(interactionResult.interaction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create enhanced response with interaction metadata
|
||||
const enhancedResponse: ChatResponse = {
|
||||
...response,
|
||||
tool_calls: processedToolCalls,
|
||||
interaction_metadata: {
|
||||
total_interactions: userInteractions.length,
|
||||
confirmed: userInteractions.filter((i: any) => i.response === 'confirm').length,
|
||||
cancelled: userInteractions.filter((i: any) => i.response === 'cancel').length,
|
||||
timedout: userInteractions.filter((i: any) => i.response === 'timeout').length
|
||||
}
|
||||
};
|
||||
|
||||
const needsFollowUp = processedToolCalls.length > 0;
|
||||
|
||||
log.info(`User interaction complete: ${userInteractions.length} interactions processed`);
|
||||
|
||||
return {
|
||||
response: enhancedResponse,
|
||||
needsFollowUp,
|
||||
messages: updatedMessages,
|
||||
userInteractions
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a tool call with user interaction controls
|
||||
*/
|
||||
private async processToolCallWithInteraction(
|
||||
toolCall: any,
|
||||
input: ToolExecutionInput,
|
||||
index: number
|
||||
): Promise<{ toolCall: any, message: Message, interaction?: any } | null> {
|
||||
|
||||
const toolName = toolCall.function.name;
|
||||
const riskLevel = this.assessToolRiskLevel(toolName);
|
||||
|
||||
// Determine if confirmation is required
|
||||
const requiresConfirmation = this.shouldRequireConfirmation(toolName, riskLevel);
|
||||
|
||||
if (!requiresConfirmation) {
|
||||
// Execute immediately for low-risk tools
|
||||
log.info(`Tool ${toolName} is low-risk, executing immediately`);
|
||||
return await this.executeToolDirectly(toolCall, input);
|
||||
}
|
||||
|
||||
// Request user confirmation
|
||||
log.info(`Tool ${toolName} requires user confirmation (risk level: ${riskLevel})`);
|
||||
|
||||
const interactionId = this.generateInteractionId();
|
||||
const interaction = await this.requestUserConfirmation(toolCall, interactionId, input.streamCallback);
|
||||
|
||||
if (interaction.response === 'confirm') {
|
||||
log.info(`User confirmed execution of ${toolName}`);
|
||||
const result = await this.executeToolDirectly(toolCall, input);
|
||||
return {
|
||||
...result!,
|
||||
interaction
|
||||
};
|
||||
} else if (interaction.response === 'cancel') {
|
||||
log.info(`User cancelled execution of ${toolName}`);
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: `USER_CANCELLED: Execution of ${toolName} was cancelled by user request.`,
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
},
|
||||
interaction
|
||||
};
|
||||
} else {
|
||||
// Timeout
|
||||
log.info(`User confirmation timeout for ${toolName}, executing with default action`);
|
||||
const result = await this.executeToolDirectly(toolCall, input);
|
||||
return {
|
||||
...result!,
|
||||
interaction: { ...interaction, response: 'timeout_executed' }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Assess the risk level of a tool
|
||||
*/
|
||||
private assessToolRiskLevel(toolName: string): 'low' | 'medium' | 'high' {
|
||||
const riskLevels = {
|
||||
// Low risk - read-only operations
|
||||
'search_notes': 'low',
|
||||
'keyword_search': 'low',
|
||||
'discover_tools': 'low',
|
||||
'template_search': 'low',
|
||||
|
||||
// Medium risk - specific data access
|
||||
'read_note': 'medium',
|
||||
'note_by_path': 'medium',
|
||||
|
||||
// High risk - complex queries or potential data modification
|
||||
'attribute_search': 'high'
|
||||
};
|
||||
|
||||
return (riskLevels as any)[toolName] || 'medium';
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if a tool requires user confirmation
|
||||
*/
|
||||
private shouldRequireConfirmation(toolName: string, riskLevel: string): boolean {
|
||||
// Always require confirmation for high-risk tools
|
||||
if (riskLevel === 'high') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if tool is in the required confirmation list
|
||||
if (this.config.requiredConfirmationTools.includes(toolName)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Auto-confirm low-risk tools if enabled
|
||||
if (riskLevel === 'low' && this.config.autoConfirmLowRisk) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Default to requiring confirmation for medium-risk tools
|
||||
return riskLevel === 'medium';
|
||||
}
|
||||
|
||||
/**
|
||||
* Request user confirmation for tool execution
|
||||
*/
|
||||
private async requestUserConfirmation(
|
||||
toolCall: any,
|
||||
interactionId: string,
|
||||
streamCallback?: StreamCallback
|
||||
): Promise<any> {
|
||||
|
||||
const toolName = toolCall.function.name;
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
|
||||
// Create pending interaction
|
||||
const pendingInteraction: PendingInteraction = {
|
||||
id: interactionId,
|
||||
toolCall,
|
||||
timestamp: Date.now(),
|
||||
resolved: false
|
||||
};
|
||||
|
||||
this.pendingInteractions.set(interactionId, pendingInteraction);
|
||||
|
||||
// Send confirmation request via streaming
|
||||
if (streamCallback) {
|
||||
this.sendConfirmationRequest(streamCallback, toolCall, interactionId, args);
|
||||
}
|
||||
|
||||
// Wait for user response or timeout
|
||||
return new Promise<any>((resolve) => {
|
||||
// Set up timeout
|
||||
const timeoutHandle = setTimeout(() => {
|
||||
if (!pendingInteraction.resolved) {
|
||||
pendingInteraction.resolved = true;
|
||||
this.pendingInteractions.delete(interactionId);
|
||||
this.interactionCallbacks.delete(interactionId);
|
||||
|
||||
resolve({
|
||||
id: interactionId,
|
||||
toolName,
|
||||
response: 'timeout',
|
||||
timestamp: Date.now(),
|
||||
duration: Date.now() - pendingInteraction.timestamp
|
||||
});
|
||||
}
|
||||
}, this.config.confirmationTimeout);
|
||||
|
||||
pendingInteraction.timeoutHandle = timeoutHandle;
|
||||
|
||||
// Set up response callback
|
||||
this.interactionCallbacks.set(interactionId, (response: InteractionResponse) => {
|
||||
if (!pendingInteraction.resolved) {
|
||||
pendingInteraction.resolved = true;
|
||||
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
this.pendingInteractions.delete(interactionId);
|
||||
this.interactionCallbacks.delete(interactionId);
|
||||
|
||||
resolve({
|
||||
id: interactionId,
|
||||
toolName,
|
||||
response,
|
||||
timestamp: Date.now(),
|
||||
duration: Date.now() - pendingInteraction.timestamp
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send confirmation request via streaming
|
||||
*/
|
||||
private sendConfirmationRequest(
|
||||
streamCallback: StreamCallback,
|
||||
toolCall: any,
|
||||
interactionId: string,
|
||||
args: Record<string, unknown>
|
||||
): void {
|
||||
|
||||
const toolName = toolCall.function.name;
|
||||
const riskLevel = this.assessToolRiskLevel(toolName);
|
||||
|
||||
// Create user-friendly description of the tool action
|
||||
const actionDescription = this.createActionDescription(toolName, args);
|
||||
|
||||
const confirmationData = {
|
||||
type: 'user_confirmation',
|
||||
action: 'request',
|
||||
interactionId,
|
||||
tool: {
|
||||
name: toolName,
|
||||
description: actionDescription,
|
||||
arguments: args,
|
||||
riskLevel
|
||||
},
|
||||
options: {
|
||||
confirm: {
|
||||
label: 'Execute',
|
||||
description: `Proceed with ${toolName}`,
|
||||
style: riskLevel === 'high' ? 'warning' : 'primary'
|
||||
},
|
||||
cancel: {
|
||||
label: 'Cancel',
|
||||
description: 'Skip this tool execution',
|
||||
style: 'secondary'
|
||||
}
|
||||
},
|
||||
timeout: this.config.confirmationTimeout,
|
||||
message: `Do you want to execute ${toolName}? ${actionDescription}`
|
||||
};
|
||||
|
||||
streamCallback('', false, {
|
||||
text: '',
|
||||
done: false,
|
||||
userInteraction: confirmationData
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create user-friendly action description
|
||||
*/
|
||||
private createActionDescription(toolName: string, args: Record<string, unknown>): string {
|
||||
switch (toolName) {
|
||||
case 'search_notes':
|
||||
return `Search your notes for: "${args.query || 'unknown query'}"`;
|
||||
|
||||
case 'read_note':
|
||||
return `Read note with ID: ${args.noteId || 'unknown'}`;
|
||||
|
||||
case 'keyword_search':
|
||||
return `Search for keyword: "${args.query || 'unknown query'}"`;
|
||||
|
||||
case 'attribute_search':
|
||||
return `Search for ${args.attributeType || 'attribute'}: "${args.attributeName || 'unknown'}"`;
|
||||
|
||||
case 'note_by_path':
|
||||
return `Find note at path: "${args.path || 'unknown path'}"`;
|
||||
|
||||
case 'discover_tools':
|
||||
return `Discover available tools`;
|
||||
|
||||
default:
|
||||
return `Execute ${toolName} with provided parameters`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool directly without confirmation
|
||||
*/
|
||||
private async executeToolDirectly(
|
||||
toolCall: any,
|
||||
input: ToolExecutionInput
|
||||
): Promise<{ toolCall: any, message: Message }> {
|
||||
|
||||
const toolName = toolCall.function.name;
|
||||
|
||||
try {
|
||||
// Import and use tool registry
|
||||
const toolRegistry = (await import('../../tools/tool_registry.js')).default;
|
||||
const tool = toolRegistry.getTool(toolName);
|
||||
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolName}`);
|
||||
}
|
||||
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
const result = await tool.execute(args);
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: typeof result === 'string' ? result : JSON.stringify(result, null, 2),
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
}
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
log.error(`Error executing tool ${toolName}: ${errorMessage}`);
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
message: {
|
||||
role: 'tool',
|
||||
content: `Error: ${errorMessage}`,
|
||||
name: toolName,
|
||||
tool_call_id: toolCall.id
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse tool arguments safely
|
||||
*/
|
||||
private parseToolArguments(args: string | Record<string, unknown>): Record<string, unknown> {
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
return JSON.parse(args);
|
||||
} catch {
|
||||
return { query: args };
|
||||
}
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate unique interaction ID
|
||||
*/
|
||||
private generateInteractionId(): string {
|
||||
return `interaction_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle user response to confirmation request
|
||||
* This method would be called by the frontend/WebSocket handler
|
||||
*/
|
||||
public handleUserResponse(interactionId: string, response: 'confirm' | 'cancel'): boolean {
|
||||
const callback = this.interactionCallbacks.get(interactionId);
|
||||
|
||||
if (callback) {
|
||||
log.info(`Received user response for interaction ${interactionId}: ${response}`);
|
||||
callback(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
log.error(`No callback found for interaction ${interactionId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel all pending interactions
|
||||
*/
|
||||
public cancelAllPendingInteractions(): void {
|
||||
log.info(`Cancelling ${this.pendingInteractions.size} pending interactions`);
|
||||
|
||||
for (const [id, interaction] of this.pendingInteractions.entries()) {
|
||||
if (interaction.timeoutHandle) {
|
||||
clearTimeout(interaction.timeoutHandle);
|
||||
}
|
||||
|
||||
const callback = this.interactionCallbacks.get(id);
|
||||
if (callback && !interaction.resolved) {
|
||||
callback('cancel');
|
||||
}
|
||||
}
|
||||
|
||||
this.pendingInteractions.clear();
|
||||
this.interactionCallbacks.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pending interactions (for status monitoring)
|
||||
*/
|
||||
public getPendingInteractions(): Array<{ id: string, toolName: string, timestamp: number }> {
|
||||
return Array.from(this.pendingInteractions.values()).map(interaction => ({
|
||||
id: interaction.id,
|
||||
toolName: interaction.toolCall.function.name,
|
||||
timestamp: interaction.timestamp
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update configuration
|
||||
*/
|
||||
public updateConfig(newConfig: Partial<UserInteractionConfig>): void {
|
||||
this.config = { ...this.config, ...newConfig };
|
||||
log.info(`User interaction configuration updated: ${JSON.stringify(newConfig)}`);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user