mirror of
https://github.com/zadam/trilium.git
synced 2026-05-13 00:55:55 +02:00
feat(llm): implement circuitbreaker to prevent going haywire
This commit is contained in:
185
apps/server/src/routes/api/llm_metrics.ts
Normal file
185
apps/server/src/routes/api/llm_metrics.ts
Normal file
@@ -0,0 +1,185 @@
|
||||
/**
|
||||
* LLM Metrics API Endpoint
|
||||
*
|
||||
* Provides metrics export endpoints for monitoring systems
|
||||
*/
|
||||
|
||||
import { Router, Request, Response } from 'express';
|
||||
import { getProviderFactory } from '../../services/llm/providers/provider_factory.js';
|
||||
import log from '../../services/log.js';
|
||||
|
||||
const router = Router();
|
||||
|
||||
/**
|
||||
* GET /api/llm/metrics
|
||||
* Returns metrics in Prometheus format by default
|
||||
*/
|
||||
router.get('/llm/metrics', (req: Request, res: Response) => {
|
||||
try {
|
||||
const format = req.query.format as string || 'prometheus';
|
||||
const factory = getProviderFactory();
|
||||
|
||||
if (!factory) {
|
||||
return res.status(503).json({ error: 'LLM service not initialized' });
|
||||
}
|
||||
|
||||
const metrics = factory.exportMetrics(format as any);
|
||||
|
||||
if (!metrics) {
|
||||
return res.status(503).json({ error: 'Metrics not available' });
|
||||
}
|
||||
|
||||
// Set appropriate content type based on format
|
||||
switch (format) {
|
||||
case 'prometheus':
|
||||
res.set('Content-Type', 'text/plain; version=0.0.4');
|
||||
res.send(metrics);
|
||||
break;
|
||||
case 'json':
|
||||
res.json(metrics);
|
||||
break;
|
||||
case 'opentelemetry':
|
||||
res.json(metrics);
|
||||
break;
|
||||
case 'statsd':
|
||||
res.set('Content-Type', 'text/plain');
|
||||
res.send(Array.isArray(metrics) ? metrics.join('\n') : metrics);
|
||||
break;
|
||||
default:
|
||||
res.status(400).json({ error: `Unknown format: ${format}` });
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.error(`[LLM Metrics API] Error exporting metrics: ${error.message}`);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* GET /api/llm/metrics/summary
|
||||
* Returns a summary of metrics in JSON format
|
||||
*/
|
||||
router.get('/llm/metrics/summary', (req: Request, res: Response) => {
|
||||
try {
|
||||
const factory = getProviderFactory();
|
||||
|
||||
if (!factory) {
|
||||
return res.status(503).json({ error: 'LLM service not initialized' });
|
||||
}
|
||||
|
||||
const summary = factory.getMetricsSummary();
|
||||
|
||||
if (!summary) {
|
||||
return res.status(503).json({ error: 'Metrics not available' });
|
||||
}
|
||||
|
||||
res.json(summary);
|
||||
} catch (error: any) {
|
||||
log.error(`[LLM Metrics API] Error getting metrics summary: ${error.message}`);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* GET /api/llm/circuit-breaker/status
|
||||
* Returns circuit breaker status for all providers
|
||||
*/
|
||||
router.get('/llm/circuit-breaker/status', (req: Request, res: Response) => {
|
||||
try {
|
||||
const factory = getProviderFactory();
|
||||
|
||||
if (!factory) {
|
||||
return res.status(503).json({ error: 'LLM service not initialized' });
|
||||
}
|
||||
|
||||
const status = factory.getCircuitBreakerStatus();
|
||||
|
||||
if (!status) {
|
||||
return res.status(503).json({ error: 'Circuit breaker not enabled' });
|
||||
}
|
||||
|
||||
res.json(status);
|
||||
} catch (error: any) {
|
||||
log.error(`[LLM Metrics API] Error getting circuit breaker status: ${error.message}`);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* POST /api/llm/circuit-breaker/reset/:provider
|
||||
* Reset circuit breaker for a specific provider
|
||||
*/
|
||||
router.post('/llm/circuit-breaker/reset/:provider', (req: Request, res: Response) => {
|
||||
try {
|
||||
const { provider } = req.params;
|
||||
const factory = getProviderFactory();
|
||||
|
||||
if (!factory) {
|
||||
return res.status(503).json({ error: 'LLM service not initialized' });
|
||||
}
|
||||
|
||||
factory.resetCircuitBreaker(provider as any);
|
||||
res.json({ message: `Circuit breaker reset for provider: ${provider}` });
|
||||
} catch (error: any) {
|
||||
log.error(`[LLM Metrics API] Error resetting circuit breaker: ${error.message}`);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* GET /api/llm/health
|
||||
* Returns overall health status of LLM service
|
||||
*/
|
||||
router.get('/llm/health', (req: Request, res: Response) => {
|
||||
try {
|
||||
const factory = getProviderFactory();
|
||||
|
||||
if (!factory) {
|
||||
return res.status(503).json({
|
||||
status: 'unhealthy',
|
||||
error: 'LLM service not initialized'
|
||||
});
|
||||
}
|
||||
|
||||
const circuitStatus = factory.getCircuitBreakerStatus();
|
||||
const metrics = factory.getMetricsSummary();
|
||||
const statistics = factory.getStatistics();
|
||||
|
||||
const health = {
|
||||
status: 'healthy',
|
||||
timestamp: new Date().toISOString(),
|
||||
providers: {
|
||||
available: circuitStatus?.summary?.availableProviders || [],
|
||||
unavailable: circuitStatus?.summary?.unavailableProviders || [],
|
||||
cached: statistics?.cachedProviders || 0,
|
||||
healthy: statistics?.healthyProviders || 0,
|
||||
unhealthy: statistics?.unhealthyProviders || 0
|
||||
},
|
||||
metrics: {
|
||||
totalRequests: metrics?.system?.totalRequests || 0,
|
||||
totalFailures: metrics?.system?.totalFailures || 0,
|
||||
uptime: metrics?.system?.uptime || 0
|
||||
},
|
||||
circuitBreakers: circuitStatus?.summary || {}
|
||||
};
|
||||
|
||||
// Determine overall health
|
||||
if (health.providers.available.length === 0) {
|
||||
health.status = 'unhealthy';
|
||||
} else if (health.providers.unavailable.length > 0) {
|
||||
health.status = 'degraded';
|
||||
}
|
||||
|
||||
const statusCode = health.status === 'healthy' ? 200 :
|
||||
health.status === 'degraded' ? 200 : 503;
|
||||
|
||||
res.status(statusCode).json(health);
|
||||
} catch (error: any) {
|
||||
log.error(`[LLM Metrics API] Error getting health status: ${error.message}`);
|
||||
res.status(500).json({
|
||||
status: 'unhealthy',
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
export default router;
|
||||
@@ -8,6 +8,7 @@ import contextService from './context/services/context_service.js';
|
||||
import log from '../log.js';
|
||||
import { OllamaService } from './providers/ollama_service.js';
|
||||
import { OpenAIService } from './providers/openai_service.js';
|
||||
import { ProviderFactory, ProviderType, getProviderFactory } from './providers/provider_factory.js';
|
||||
|
||||
// Import interfaces
|
||||
import type {
|
||||
@@ -26,7 +27,6 @@ import {
|
||||
clearConfigurationCache,
|
||||
validateConfiguration
|
||||
} from './config/configuration_helpers.js';
|
||||
import type { ProviderType } from './interfaces/configuration_interfaces.js';
|
||||
|
||||
/**
|
||||
* Interface representing relevant note context
|
||||
@@ -59,8 +59,19 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
private cleanupTimer: NodeJS.Timeout | null = null;
|
||||
private initialized = false;
|
||||
private disposed = false;
|
||||
private providerFactory: ProviderFactory | null = null;
|
||||
|
||||
constructor() {
|
||||
// Initialize provider factory
|
||||
this.providerFactory = getProviderFactory({
|
||||
enableHealthChecks: true,
|
||||
healthCheckInterval: 60000,
|
||||
enableFallback: true,
|
||||
enableCaching: true,
|
||||
cacheTimeout: this.SERVICE_TTL_MS,
|
||||
enableMetrics: true
|
||||
});
|
||||
|
||||
// Initialize tools immediately
|
||||
this.initializeTools().catch(error => {
|
||||
log.error(`Error initializing LLM tools during AIServiceManager construction: ${error.message || String(error)}`);
|
||||
@@ -456,7 +467,12 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
* Clear all cached providers (forces recreation on next access)
|
||||
*/
|
||||
public clearCurrentProvider(): void {
|
||||
// Clear all cached services
|
||||
// Clear provider factory cache
|
||||
if (this.providerFactory) {
|
||||
this.providerFactory.clearCache();
|
||||
}
|
||||
|
||||
// Clear local cache
|
||||
for (const provider of this.serviceCache.keys()) {
|
||||
this.disposeService(provider);
|
||||
}
|
||||
@@ -464,87 +480,66 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create a provider instance with proper caching and TTL
|
||||
* Get or create a provider instance using the provider factory
|
||||
*/
|
||||
private async getOrCreateChatProvider(providerName: ServiceProviders): Promise<AIService | null> {
|
||||
if (this.disposed) {
|
||||
throw new Error('AIServiceManager has been disposed');
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
const cached = this.serviceCache.get(providerName);
|
||||
if (cached && cached.service.isAvailable()) {
|
||||
// Update last used time
|
||||
cached.lastUsed = Date.now();
|
||||
|
||||
// Check if service is still within TTL
|
||||
if (Date.now() - cached.createdAt <= this.SERVICE_TTL_MS) {
|
||||
log.info(`Using cached ${providerName} service (age: ${Math.round((Date.now() - cached.createdAt) / 1000)}s)`);
|
||||
return cached.service;
|
||||
} else {
|
||||
// Service is stale, dispose and recreate
|
||||
log.info(`Cached ${providerName} service is stale, recreating`);
|
||||
this.disposeService(providerName);
|
||||
}
|
||||
if (!this.providerFactory) {
|
||||
throw new Error('Provider factory not initialized');
|
||||
}
|
||||
|
||||
// Create new service for the requested provider
|
||||
try {
|
||||
let service: AIService | null = null;
|
||||
// Map ServiceProviders to ProviderType
|
||||
const providerTypeMap: Record<ServiceProviders, ProviderType> = {
|
||||
'openai': ProviderType.OPENAI,
|
||||
'anthropic': ProviderType.ANTHROPIC,
|
||||
'ollama': ProviderType.OLLAMA
|
||||
};
|
||||
|
||||
const providerType = providerTypeMap[providerName];
|
||||
if (!providerType) {
|
||||
log.error(`Unknown provider name: ${providerName}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if provider is configured
|
||||
switch (providerName) {
|
||||
case 'openai': {
|
||||
const apiKey = options.getOption('openaiApiKey');
|
||||
const baseUrl = options.getOption('openaiBaseUrl');
|
||||
if (!apiKey && !baseUrl) return null;
|
||||
|
||||
service = new OpenAIService();
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('OpenAI service not available');
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'anthropic': {
|
||||
const apiKey = options.getOption('anthropicApiKey');
|
||||
if (!apiKey) return null;
|
||||
|
||||
service = new AnthropicService();
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('Anthropic service not available');
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'ollama': {
|
||||
const baseUrl = options.getOption('ollamaBaseUrl');
|
||||
if (!baseUrl) return null;
|
||||
|
||||
service = new OllamaService();
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('Ollama service not available');
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (service) {
|
||||
// Cache the new service with metadata
|
||||
const now = Date.now();
|
||||
this.serviceCache.set(providerName, {
|
||||
service,
|
||||
provider: providerName,
|
||||
createdAt: now,
|
||||
lastUsed: now
|
||||
});
|
||||
log.info(`Created and cached new ${providerName} service`);
|
||||
// Use provider factory to create the service
|
||||
const service = await this.providerFactory.createProvider(providerType);
|
||||
|
||||
if (service && service.isAvailable()) {
|
||||
log.info(`Created ${providerName} service via provider factory`);
|
||||
return service;
|
||||
}
|
||||
|
||||
throw new Error(`${providerName} service not available`);
|
||||
} catch (error: any) {
|
||||
log.error(`Failed to create ${providerName} chat provider: ${error.message || 'Unknown error'}`);
|
||||
|
||||
// Provider factory handles fallback internally if configured
|
||||
return null;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -559,6 +554,12 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
// Stop cleanup timer
|
||||
this.stopCleanupTimer();
|
||||
|
||||
// Dispose provider factory
|
||||
if (this.providerFactory) {
|
||||
this.providerFactory.dispose();
|
||||
this.providerFactory = null;
|
||||
}
|
||||
|
||||
// Dispose all cached services
|
||||
for (const provider of this.serviceCache.keys()) {
|
||||
this.disposeService(provider);
|
||||
@@ -766,13 +767,24 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
* Check if a specific provider is available
|
||||
*/
|
||||
isProviderAvailable(provider: string): boolean {
|
||||
// Check if we have a cached service for this provider
|
||||
const cachedEntry = this.serviceCache.get(provider as ServiceProviders);
|
||||
if (cachedEntry && !this.isServiceStale(cachedEntry)) {
|
||||
return cachedEntry.service.isAvailable();
|
||||
// Check health status from provider factory
|
||||
if (this.providerFactory) {
|
||||
const providerTypeMap: Record<string, ProviderType> = {
|
||||
'openai': ProviderType.OPENAI,
|
||||
'anthropic': ProviderType.ANTHROPIC,
|
||||
'ollama': ProviderType.OLLAMA
|
||||
};
|
||||
|
||||
const providerType = providerTypeMap[provider];
|
||||
if (providerType) {
|
||||
const healthStatus = this.providerFactory.getHealthStatus(providerType);
|
||||
if (healthStatus) {
|
||||
return healthStatus.healthy;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For other providers, check configuration
|
||||
// Fallback to configuration check
|
||||
try {
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
@@ -793,22 +805,43 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
|
||||
* Get metadata about a provider
|
||||
*/
|
||||
getProviderMetadata(provider: string): ProviderMetadata | null {
|
||||
// Check if we have a cached service for this provider
|
||||
const cachedEntry = this.serviceCache.get(provider as ServiceProviders);
|
||||
if (cachedEntry && !this.isServiceStale(cachedEntry)) {
|
||||
return {
|
||||
name: provider,
|
||||
capabilities: {
|
||||
chat: true,
|
||||
streaming: true,
|
||||
functionCalling: provider === 'openai' // Only OpenAI has function calling
|
||||
},
|
||||
models: ['default'], // Placeholder, could be populated from the service
|
||||
defaultModel: 'default'
|
||||
// Get capabilities from provider factory
|
||||
if (this.providerFactory) {
|
||||
const providerTypeMap: Record<string, ProviderType> = {
|
||||
'openai': ProviderType.OPENAI,
|
||||
'anthropic': ProviderType.ANTHROPIC,
|
||||
'ollama': ProviderType.OLLAMA
|
||||
};
|
||||
|
||||
const providerType = providerTypeMap[provider];
|
||||
if (providerType) {
|
||||
const capabilities = this.providerFactory.getCapabilities(providerType);
|
||||
if (capabilities) {
|
||||
return {
|
||||
name: provider,
|
||||
capabilities: {
|
||||
chat: true,
|
||||
streaming: capabilities.streaming,
|
||||
functionCalling: capabilities.functionCalling
|
||||
},
|
||||
models: ['default'], // Could be enhanced to get actual models
|
||||
defaultModel: 'default'
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
// Fallback
|
||||
return {
|
||||
name: provider,
|
||||
capabilities: {
|
||||
chat: true,
|
||||
streaming: true,
|
||||
functionCalling: provider === 'openai'
|
||||
},
|
||||
models: ['default'],
|
||||
defaultModel: 'default'
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
290
apps/server/src/services/llm/config/llm_options.ts
Normal file
290
apps/server/src/services/llm/config/llm_options.ts
Normal file
@@ -0,0 +1,290 @@
|
||||
/**
|
||||
* LLM Service Configuration Options
|
||||
*
|
||||
* Defines all configurable options for the LLM service that can be
|
||||
* managed through Trilium's options system.
|
||||
*/
|
||||
|
||||
import optionService from '../../options.js';
|
||||
import type { OptionNames, FilterOptionsByType } from '@triliumnext/commons';
|
||||
import { ExportFormat } from '../metrics/metrics_exporter.js';
|
||||
|
||||
/**
|
||||
* LLM configuration options
|
||||
*/
|
||||
export interface LLMOptions {
|
||||
// Circuit Breaker Configuration
|
||||
circuitBreakerEnabled: boolean;
|
||||
circuitBreakerFailureThreshold: number;
|
||||
circuitBreakerFailureWindow: number;
|
||||
circuitBreakerCooldownPeriod: number;
|
||||
circuitBreakerSuccessThreshold: number;
|
||||
|
||||
// Metrics Configuration
|
||||
metricsEnabled: boolean;
|
||||
metricsExportFormat: ExportFormat;
|
||||
metricsExportEndpoint?: string;
|
||||
metricsExportInterval: number;
|
||||
metricsPrometheusEnabled: boolean;
|
||||
metricsStatsdHost?: string;
|
||||
metricsStatsdPort?: number;
|
||||
metricsStatsdPrefix: string;
|
||||
|
||||
// Provider Configuration
|
||||
providerHealthCheckEnabled: boolean;
|
||||
providerHealthCheckInterval: number;
|
||||
providerCachingEnabled: boolean;
|
||||
providerCacheTimeout: number;
|
||||
providerFallbackEnabled: boolean;
|
||||
providerFallbackList: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Default LLM options
|
||||
*/
|
||||
const DEFAULT_OPTIONS: LLMOptions = {
|
||||
// Circuit Breaker Defaults
|
||||
circuitBreakerEnabled: true,
|
||||
circuitBreakerFailureThreshold: 5,
|
||||
circuitBreakerFailureWindow: 60000, // 1 minute
|
||||
circuitBreakerCooldownPeriod: 30000, // 30 seconds
|
||||
circuitBreakerSuccessThreshold: 2,
|
||||
|
||||
// Metrics Defaults
|
||||
metricsEnabled: true,
|
||||
metricsExportFormat: 'prometheus' as ExportFormat,
|
||||
metricsExportInterval: 60000, // 1 minute
|
||||
metricsPrometheusEnabled: true,
|
||||
metricsStatsdPrefix: 'trilium.llm',
|
||||
|
||||
// Provider Defaults
|
||||
providerHealthCheckEnabled: true,
|
||||
providerHealthCheckInterval: 60000, // 1 minute
|
||||
providerCachingEnabled: true,
|
||||
providerCacheTimeout: 300000, // 5 minutes
|
||||
providerFallbackEnabled: true,
|
||||
providerFallbackList: ['ollama']
|
||||
};
|
||||
|
||||
/**
|
||||
* Option keys in Trilium's option system
|
||||
*/
|
||||
export const LLM_OPTION_KEYS = {
|
||||
// Circuit Breaker
|
||||
CIRCUIT_BREAKER_ENABLED: 'llmCircuitBreakerEnabled' as const,
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD: 'llmCircuitBreakerFailureThreshold' as const,
|
||||
CIRCUIT_BREAKER_FAILURE_WINDOW: 'llmCircuitBreakerFailureWindow' as const,
|
||||
CIRCUIT_BREAKER_COOLDOWN_PERIOD: 'llmCircuitBreakerCooldownPeriod' as const,
|
||||
CIRCUIT_BREAKER_SUCCESS_THRESHOLD: 'llmCircuitBreakerSuccessThreshold' as const,
|
||||
|
||||
// Metrics
|
||||
METRICS_ENABLED: 'llmMetricsEnabled' as const,
|
||||
METRICS_EXPORT_FORMAT: 'llmMetricsExportFormat' as const,
|
||||
METRICS_EXPORT_ENDPOINT: 'llmMetricsExportEndpoint' as const,
|
||||
METRICS_EXPORT_INTERVAL: 'llmMetricsExportInterval' as const,
|
||||
METRICS_PROMETHEUS_ENABLED: 'llmMetricsPrometheusEnabled' as const,
|
||||
METRICS_STATSD_HOST: 'llmMetricsStatsdHost' as const,
|
||||
METRICS_STATSD_PORT: 'llmMetricsStatsdPort' as const,
|
||||
METRICS_STATSD_PREFIX: 'llmMetricsStatsdPrefix' as const,
|
||||
|
||||
// Provider
|
||||
PROVIDER_HEALTH_CHECK_ENABLED: 'llmProviderHealthCheckEnabled' as const,
|
||||
PROVIDER_HEALTH_CHECK_INTERVAL: 'llmProviderHealthCheckInterval' as const,
|
||||
PROVIDER_CACHING_ENABLED: 'llmProviderCachingEnabled' as const,
|
||||
PROVIDER_CACHE_TIMEOUT: 'llmProviderCacheTimeout' as const,
|
||||
PROVIDER_FALLBACK_ENABLED: 'llmProviderFallbackEnabled' as const,
|
||||
PROVIDER_FALLBACK_LIST: 'llmProviderFallbackList' as const
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Get LLM options from Trilium's option service
|
||||
*/
|
||||
export function getLLMOptions(): LLMOptions {
|
||||
// Helper function to safely get option with fallback
|
||||
function getOptionSafe<T>(getter: () => T, defaultValue: T): T {
|
||||
try {
|
||||
return getter() ?? defaultValue;
|
||||
} catch {
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
// Circuit Breaker
|
||||
circuitBreakerEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.CIRCUIT_BREAKER_ENABLED),
|
||||
DEFAULT_OPTIONS.circuitBreakerEnabled
|
||||
),
|
||||
circuitBreakerFailureThreshold: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_FAILURE_THRESHOLD),
|
||||
DEFAULT_OPTIONS.circuitBreakerFailureThreshold
|
||||
),
|
||||
circuitBreakerFailureWindow: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_FAILURE_WINDOW),
|
||||
DEFAULT_OPTIONS.circuitBreakerFailureWindow
|
||||
),
|
||||
circuitBreakerCooldownPeriod: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_COOLDOWN_PERIOD),
|
||||
DEFAULT_OPTIONS.circuitBreakerCooldownPeriod
|
||||
),
|
||||
circuitBreakerSuccessThreshold: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_SUCCESS_THRESHOLD),
|
||||
DEFAULT_OPTIONS.circuitBreakerSuccessThreshold
|
||||
),
|
||||
|
||||
// Metrics
|
||||
metricsEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.METRICS_ENABLED),
|
||||
DEFAULT_OPTIONS.metricsEnabled
|
||||
),
|
||||
metricsExportFormat: getOptionSafe(
|
||||
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_EXPORT_FORMAT) as ExportFormat,
|
||||
DEFAULT_OPTIONS.metricsExportFormat
|
||||
),
|
||||
metricsExportEndpoint: getOptionSafe(
|
||||
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_EXPORT_ENDPOINT),
|
||||
undefined
|
||||
),
|
||||
metricsExportInterval: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.METRICS_EXPORT_INTERVAL),
|
||||
DEFAULT_OPTIONS.metricsExportInterval
|
||||
),
|
||||
metricsPrometheusEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.METRICS_PROMETHEUS_ENABLED),
|
||||
DEFAULT_OPTIONS.metricsPrometheusEnabled
|
||||
),
|
||||
metricsStatsdHost: getOptionSafe(
|
||||
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_STATSD_HOST),
|
||||
undefined
|
||||
),
|
||||
metricsStatsdPort: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.METRICS_STATSD_PORT),
|
||||
undefined
|
||||
),
|
||||
metricsStatsdPrefix: getOptionSafe(
|
||||
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_STATSD_PREFIX),
|
||||
DEFAULT_OPTIONS.metricsStatsdPrefix
|
||||
),
|
||||
|
||||
// Provider
|
||||
providerHealthCheckEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_HEALTH_CHECK_ENABLED),
|
||||
DEFAULT_OPTIONS.providerHealthCheckEnabled
|
||||
),
|
||||
providerHealthCheckInterval: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.PROVIDER_HEALTH_CHECK_INTERVAL),
|
||||
DEFAULT_OPTIONS.providerHealthCheckInterval
|
||||
),
|
||||
providerCachingEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_CACHING_ENABLED),
|
||||
DEFAULT_OPTIONS.providerCachingEnabled
|
||||
),
|
||||
providerCacheTimeout: getOptionSafe(
|
||||
() => optionService.getOptionInt(LLM_OPTION_KEYS.PROVIDER_CACHE_TIMEOUT),
|
||||
DEFAULT_OPTIONS.providerCacheTimeout
|
||||
),
|
||||
providerFallbackEnabled: getOptionSafe(
|
||||
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_FALLBACK_ENABLED),
|
||||
DEFAULT_OPTIONS.providerFallbackEnabled
|
||||
),
|
||||
providerFallbackList: getOptionSafe(
|
||||
() => optionService.getOption(LLM_OPTION_KEYS.PROVIDER_FALLBACK_LIST).split(',').map((s: string) => s.trim()).filter(Boolean),
|
||||
DEFAULT_OPTIONS.providerFallbackList
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Set an LLM option
|
||||
*/
|
||||
export async function setLLMOption(key: OptionNames, value: any): Promise<void> {
|
||||
await optionService.setOption(key, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize LLM options with defaults if not set
|
||||
*/
|
||||
export async function initializeLLMOptions(): Promise<void> {
|
||||
// Set defaults for any unset options
|
||||
const keysToCheck = Object.values(LLM_OPTION_KEYS) as OptionNames[];
|
||||
|
||||
for (const key of keysToCheck) {
|
||||
try {
|
||||
const currentValue = optionService.getOption(key);
|
||||
|
||||
if (currentValue === null || currentValue === undefined) {
|
||||
// Set default based on key
|
||||
const defaultKey = Object.entries(LLM_OPTION_KEYS)
|
||||
.find(([_, v]) => v === key)?.[0];
|
||||
|
||||
if (defaultKey) {
|
||||
const defaultPath = defaultKey
|
||||
.replace(/_([a-z])/g, (_, char) => char.toUpperCase())
|
||||
.replace(/^[A-Z]/, char => char.toLowerCase())
|
||||
.replace(/_/g, '');
|
||||
|
||||
const defaultValue = (DEFAULT_OPTIONS as any)[defaultPath];
|
||||
|
||||
if (defaultValue !== undefined) {
|
||||
await setLLMOption(key,
|
||||
Array.isArray(defaultValue) ? defaultValue.join(',') : defaultValue
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Option doesn't exist yet, create it with default
|
||||
const defaultKey = Object.entries(LLM_OPTION_KEYS)
|
||||
.find(([_, v]) => v === key)?.[0];
|
||||
|
||||
if (defaultKey) {
|
||||
const defaultPath = defaultKey
|
||||
.replace(/_([a-z])/g, (_, char) => char.toUpperCase())
|
||||
.replace(/^[A-Z]/, char => char.toLowerCase())
|
||||
.replace(/_/g, '');
|
||||
|
||||
const defaultValue = (DEFAULT_OPTIONS as any)[defaultPath];
|
||||
|
||||
if (defaultValue !== undefined) {
|
||||
await setLLMOption(key,
|
||||
Array.isArray(defaultValue) ? defaultValue.join(',') : defaultValue
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create provider factory options from LLM options
|
||||
*/
|
||||
export function createProviderFactoryOptions() {
|
||||
const options = getLLMOptions();
|
||||
|
||||
return {
|
||||
enableHealthChecks: options.providerHealthCheckEnabled,
|
||||
healthCheckInterval: options.providerHealthCheckInterval,
|
||||
enableFallback: options.providerFallbackEnabled,
|
||||
fallbackProviders: options.providerFallbackList as any[],
|
||||
enableCaching: options.providerCachingEnabled,
|
||||
cacheTimeout: options.providerCacheTimeout,
|
||||
enableMetrics: options.metricsEnabled,
|
||||
enableCircuitBreaker: options.circuitBreakerEnabled,
|
||||
circuitBreakerConfig: {
|
||||
failureThreshold: options.circuitBreakerFailureThreshold,
|
||||
failureWindow: options.circuitBreakerFailureWindow,
|
||||
cooldownPeriod: options.circuitBreakerCooldownPeriod,
|
||||
successThreshold: options.circuitBreakerSuccessThreshold,
|
||||
enableLogging: true
|
||||
},
|
||||
metricsExporterConfig: {
|
||||
enabled: options.metricsEnabled,
|
||||
format: options.metricsExportFormat,
|
||||
endpoint: options.metricsExportEndpoint,
|
||||
interval: options.metricsExportInterval,
|
||||
statsdHost: options.metricsStatsdHost,
|
||||
statsdPort: options.metricsStatsdPort,
|
||||
prefix: options.metricsStatsdPrefix
|
||||
}
|
||||
};
|
||||
}
|
||||
796
apps/server/src/services/llm/metrics/metrics_exporter.ts
Normal file
796
apps/server/src/services/llm/metrics/metrics_exporter.ts
Normal file
@@ -0,0 +1,796 @@
|
||||
/**
|
||||
* Metrics Export System for LLM Service
|
||||
*
|
||||
* Provides unified metrics collection and export to various monitoring systems:
|
||||
* - Prometheus format endpoint
|
||||
* - StatsD/DataDog format
|
||||
* - OpenTelemetry format
|
||||
*/
|
||||
|
||||
import log from '../../log.js';
|
||||
import { EventEmitter } from 'events';
|
||||
import type { ProviderType } from '../providers/provider_factory.js';
|
||||
|
||||
/**
|
||||
* Metric types
|
||||
*/
|
||||
export enum MetricType {
|
||||
COUNTER = 'counter',
|
||||
GAUGE = 'gauge',
|
||||
HISTOGRAM = 'histogram',
|
||||
SUMMARY = 'summary'
|
||||
}
|
||||
|
||||
/**
|
||||
* Metric data point
|
||||
*/
|
||||
export interface MetricDataPoint {
|
||||
name: string;
|
||||
type: MetricType;
|
||||
value: number;
|
||||
timestamp: Date;
|
||||
labels: Record<string, string>;
|
||||
unit?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider metrics
|
||||
*/
|
||||
export interface ProviderMetrics {
|
||||
provider: string;
|
||||
requests: number;
|
||||
failures: number;
|
||||
successRate: number;
|
||||
averageLatency: number;
|
||||
p50Latency: number;
|
||||
p95Latency: number;
|
||||
p99Latency: number;
|
||||
totalTokens: number;
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
averageTokensPerRequest: number;
|
||||
errorRate: number;
|
||||
lastError?: string;
|
||||
lastUpdated: Date;
|
||||
}
|
||||
|
||||
/**
|
||||
* System metrics
|
||||
*/
|
||||
export interface SystemMetrics {
|
||||
totalRequests: number;
|
||||
totalFailures: number;
|
||||
averageLatency: number;
|
||||
activePipelines: number;
|
||||
cacheHitRate: number;
|
||||
memoryUsage: number;
|
||||
uptime: number;
|
||||
timestamp: Date;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export format types
|
||||
*/
|
||||
export enum ExportFormat {
|
||||
PROMETHEUS = 'prometheus',
|
||||
STATSD = 'statsd',
|
||||
OPENTELEMETRY = 'opentelemetry',
|
||||
JSON = 'json'
|
||||
}
|
||||
|
||||
/**
|
||||
* Exporter configuration
|
||||
*/
|
||||
export interface ExporterConfig {
|
||||
enabled: boolean;
|
||||
format: ExportFormat;
|
||||
endpoint?: string;
|
||||
interval?: number;
|
||||
prefix?: string;
|
||||
labels?: Record<string, string>;
|
||||
includeHistograms?: boolean;
|
||||
histogramBuckets?: number[];
|
||||
statsdHost?: string;
|
||||
statsdPort?: number;
|
||||
statsdPrefix?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base metrics collector
|
||||
*/
|
||||
export class MetricsCollector extends EventEmitter {
|
||||
private metrics: Map<string, MetricDataPoint[]> = new Map();
|
||||
private providerMetrics: Map<string, ProviderMetrics> = new Map();
|
||||
private systemMetrics: SystemMetrics;
|
||||
private startTime: Date;
|
||||
private latencyHistogram: Map<string, number[]> = new Map();
|
||||
private readonly maxDataPoints = 10000;
|
||||
private readonly maxHistogramSize = 1000;
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.startTime = new Date();
|
||||
this.systemMetrics = this.createDefaultSystemMetrics();
|
||||
}
|
||||
|
||||
/**
|
||||
* Record a metric
|
||||
*/
|
||||
public record(metric: MetricDataPoint): void {
|
||||
const key = this.getMetricKey(metric);
|
||||
|
||||
if (!this.metrics.has(key)) {
|
||||
this.metrics.set(key, []);
|
||||
}
|
||||
|
||||
const dataPoints = this.metrics.get(key)!;
|
||||
dataPoints.push(metric);
|
||||
|
||||
// Limit stored data points
|
||||
if (dataPoints.length > this.maxDataPoints) {
|
||||
dataPoints.shift();
|
||||
}
|
||||
|
||||
// Update provider metrics if applicable
|
||||
if (metric.labels.provider) {
|
||||
this.updateProviderMetrics(metric);
|
||||
}
|
||||
|
||||
// Update system metrics
|
||||
this.updateSystemMetrics(metric);
|
||||
|
||||
// Emit metric event
|
||||
this.emit('metric', metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Record latency
|
||||
*/
|
||||
public recordLatency(provider: string, latency: number): void {
|
||||
this.record({
|
||||
name: 'llm_request_latency',
|
||||
type: MetricType.HISTOGRAM,
|
||||
value: latency,
|
||||
timestamp: new Date(),
|
||||
labels: { provider },
|
||||
unit: 'ms',
|
||||
description: 'LLM request latency'
|
||||
});
|
||||
|
||||
// Update latency histogram
|
||||
if (!this.latencyHistogram.has(provider)) {
|
||||
this.latencyHistogram.set(provider, []);
|
||||
}
|
||||
|
||||
const histogram = this.latencyHistogram.get(provider)!;
|
||||
histogram.push(latency);
|
||||
|
||||
if (histogram.length > this.maxHistogramSize) {
|
||||
histogram.shift();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Record token usage
|
||||
*/
|
||||
public recordTokenUsage(
|
||||
provider: string,
|
||||
inputTokens: number,
|
||||
outputTokens: number
|
||||
): void {
|
||||
this.record({
|
||||
name: 'llm_tokens_used',
|
||||
type: MetricType.COUNTER,
|
||||
value: inputTokens + outputTokens,
|
||||
timestamp: new Date(),
|
||||
labels: { provider, type: 'total' },
|
||||
unit: 'tokens',
|
||||
description: 'Total tokens used'
|
||||
});
|
||||
|
||||
this.record({
|
||||
name: 'llm_input_tokens',
|
||||
type: MetricType.COUNTER,
|
||||
value: inputTokens,
|
||||
timestamp: new Date(),
|
||||
labels: { provider },
|
||||
unit: 'tokens',
|
||||
description: 'Input tokens used'
|
||||
});
|
||||
|
||||
this.record({
|
||||
name: 'llm_output_tokens',
|
||||
type: MetricType.COUNTER,
|
||||
value: outputTokens,
|
||||
timestamp: new Date(),
|
||||
labels: { provider },
|
||||
unit: 'tokens',
|
||||
description: 'Output tokens generated'
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Record error
|
||||
*/
|
||||
public recordError(provider: string, error: string): void {
|
||||
this.record({
|
||||
name: 'llm_errors',
|
||||
type: MetricType.COUNTER,
|
||||
value: 1,
|
||||
timestamp: new Date(),
|
||||
labels: { provider, error_type: this.classifyError(error) },
|
||||
description: 'LLM request errors'
|
||||
});
|
||||
|
||||
// Update provider metrics
|
||||
const metrics = this.getProviderMetrics(provider);
|
||||
metrics.failures++;
|
||||
metrics.lastError = error;
|
||||
metrics.errorRate = metrics.failures / metrics.requests;
|
||||
}
|
||||
|
||||
/**
|
||||
* Record request
|
||||
*/
|
||||
public recordRequest(provider: string, success: boolean): void {
|
||||
this.record({
|
||||
name: 'llm_requests',
|
||||
type: MetricType.COUNTER,
|
||||
value: 1,
|
||||
timestamp: new Date(),
|
||||
labels: { provider, status: success ? 'success' : 'failure' },
|
||||
description: 'LLM requests'
|
||||
});
|
||||
|
||||
const metrics = this.getProviderMetrics(provider);
|
||||
metrics.requests++;
|
||||
if (!success) {
|
||||
metrics.failures++;
|
||||
}
|
||||
metrics.successRate = (metrics.requests - metrics.failures) / metrics.requests;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create provider metrics
|
||||
*/
|
||||
private getProviderMetrics(provider: string): ProviderMetrics {
|
||||
if (!this.providerMetrics.has(provider)) {
|
||||
this.providerMetrics.set(provider, {
|
||||
provider,
|
||||
requests: 0,
|
||||
failures: 0,
|
||||
successRate: 1,
|
||||
averageLatency: 0,
|
||||
p50Latency: 0,
|
||||
p95Latency: 0,
|
||||
p99Latency: 0,
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
averageTokensPerRequest: 0,
|
||||
errorRate: 0,
|
||||
lastUpdated: new Date()
|
||||
});
|
||||
}
|
||||
return this.providerMetrics.get(provider)!;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update provider metrics
|
||||
*/
|
||||
private updateProviderMetrics(metric: MetricDataPoint): void {
|
||||
const provider = metric.labels.provider;
|
||||
if (!provider) return;
|
||||
|
||||
const metrics = this.getProviderMetrics(provider);
|
||||
metrics.lastUpdated = new Date();
|
||||
|
||||
// Update token metrics
|
||||
if (metric.name.includes('tokens')) {
|
||||
if (metric.name === 'llm_input_tokens') {
|
||||
metrics.inputTokens += metric.value;
|
||||
} else if (metric.name === 'llm_output_tokens') {
|
||||
metrics.outputTokens += metric.value;
|
||||
}
|
||||
metrics.totalTokens = metrics.inputTokens + metrics.outputTokens;
|
||||
if (metrics.requests > 0) {
|
||||
metrics.averageTokensPerRequest = metrics.totalTokens / metrics.requests;
|
||||
}
|
||||
}
|
||||
|
||||
// Update latency metrics
|
||||
if (metric.name === 'llm_request_latency') {
|
||||
const histogram = this.latencyHistogram.get(provider);
|
||||
if (histogram && histogram.length > 0) {
|
||||
const sorted = [...histogram].sort((a, b) => a - b);
|
||||
metrics.averageLatency = sorted.reduce((a, b) => a + b, 0) / sorted.length;
|
||||
metrics.p50Latency = this.percentile(sorted, 50);
|
||||
metrics.p95Latency = this.percentile(sorted, 95);
|
||||
metrics.p99Latency = this.percentile(sorted, 99);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update system metrics
|
||||
*/
|
||||
private updateSystemMetrics(metric: MetricDataPoint): void {
|
||||
if (metric.name === 'llm_requests') {
|
||||
this.systemMetrics.totalRequests++;
|
||||
if (metric.labels.status === 'failure') {
|
||||
this.systemMetrics.totalFailures++;
|
||||
}
|
||||
}
|
||||
|
||||
this.systemMetrics.uptime = Date.now() - this.startTime.getTime();
|
||||
this.systemMetrics.timestamp = new Date();
|
||||
this.systemMetrics.memoryUsage = process.memoryUsage().heapUsed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate percentile
|
||||
*/
|
||||
private percentile(sorted: number[], p: number): number {
|
||||
const index = Math.ceil((p / 100) * sorted.length) - 1;
|
||||
return sorted[Math.max(0, index)];
|
||||
}
|
||||
|
||||
/**
|
||||
* Classify error type
|
||||
*/
|
||||
private classifyError(error: string): string {
|
||||
const errorLower = error.toLowerCase();
|
||||
if (errorLower.includes('timeout')) return 'timeout';
|
||||
if (errorLower.includes('rate')) return 'rate_limit';
|
||||
if (errorLower.includes('auth')) return 'authentication';
|
||||
if (errorLower.includes('network')) return 'network';
|
||||
if (errorLower.includes('circuit')) return 'circuit_breaker';
|
||||
return 'unknown';
|
||||
}
|
||||
|
||||
/**
|
||||
* Get metric key
|
||||
*/
|
||||
private getMetricKey(metric: MetricDataPoint): string {
|
||||
const labelStr = Object.entries(metric.labels)
|
||||
.sort(([a], [b]) => a.localeCompare(b))
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join(',');
|
||||
return `${metric.name}{${labelStr}}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create default system metrics
|
||||
*/
|
||||
private createDefaultSystemMetrics(): SystemMetrics {
|
||||
return {
|
||||
totalRequests: 0,
|
||||
totalFailures: 0,
|
||||
averageLatency: 0,
|
||||
activePipelines: 0,
|
||||
cacheHitRate: 0,
|
||||
memoryUsage: 0,
|
||||
uptime: 0,
|
||||
timestamp: new Date()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all metrics
|
||||
*/
|
||||
public getAllMetrics(): Map<string, MetricDataPoint[]> {
|
||||
return new Map(this.metrics);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get provider metrics
|
||||
*/
|
||||
public getProviderMetricsMap(): Map<string, ProviderMetrics> {
|
||||
return new Map(this.providerMetrics);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get system metrics
|
||||
*/
|
||||
public getSystemMetrics(): SystemMetrics {
|
||||
return { ...this.systemMetrics };
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear metrics
|
||||
*/
|
||||
public clear(): void {
|
||||
this.metrics.clear();
|
||||
this.providerMetrics.clear();
|
||||
this.latencyHistogram.clear();
|
||||
this.systemMetrics = this.createDefaultSystemMetrics();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prometheus format exporter
|
||||
*/
|
||||
export class PrometheusExporter {
|
||||
constructor(private collector: MetricsCollector) {}
|
||||
|
||||
/**
|
||||
* Export metrics in Prometheus format
|
||||
*/
|
||||
public export(): string {
|
||||
const lines: string[] = [];
|
||||
const metrics = this.collector.getAllMetrics();
|
||||
|
||||
// Add help and type comments
|
||||
const metricTypes = new Map<string, MetricType>();
|
||||
const metricDescriptions = new Map<string, string>();
|
||||
|
||||
for (const [key, dataPoints] of metrics) {
|
||||
if (dataPoints.length === 0) continue;
|
||||
|
||||
const latest = dataPoints[dataPoints.length - 1];
|
||||
const metricName = latest.name;
|
||||
|
||||
if (!metricTypes.has(metricName)) {
|
||||
metricTypes.set(metricName, latest.type);
|
||||
metricDescriptions.set(metricName, latest.description || '');
|
||||
|
||||
lines.push(`# HELP ${metricName} ${latest.description || ''}`);
|
||||
lines.push(`# TYPE ${metricName} ${this.mapMetricType(latest.type)}`);
|
||||
}
|
||||
|
||||
// Add metric value
|
||||
const labelStr = Object.entries(latest.labels)
|
||||
.map(([k, v]) => `${k}="${v}"`)
|
||||
.join(',');
|
||||
|
||||
const metricLine = labelStr
|
||||
? `${metricName}{${labelStr}} ${latest.value}`
|
||||
: `${metricName} ${latest.value}`;
|
||||
|
||||
lines.push(metricLine);
|
||||
}
|
||||
|
||||
// Add provider-specific metrics
|
||||
for (const [provider, metrics] of this.collector.getProviderMetricsMap()) {
|
||||
lines.push(`# HELP llm_provider_success_rate Success rate by provider`);
|
||||
lines.push(`# TYPE llm_provider_success_rate gauge`);
|
||||
lines.push(`llm_provider_success_rate{provider="${provider}"} ${metrics.successRate}`);
|
||||
|
||||
lines.push(`# HELP llm_provider_avg_latency Average latency by provider`);
|
||||
lines.push(`# TYPE llm_provider_avg_latency gauge`);
|
||||
lines.push(`llm_provider_avg_latency{provider="${provider}"} ${metrics.averageLatency}`);
|
||||
}
|
||||
|
||||
// Add system metrics
|
||||
const systemMetrics = this.collector.getSystemMetrics();
|
||||
lines.push(`# HELP llm_system_uptime System uptime in milliseconds`);
|
||||
lines.push(`# TYPE llm_system_uptime counter`);
|
||||
lines.push(`llm_system_uptime ${systemMetrics.uptime}`);
|
||||
|
||||
lines.push(`# HELP llm_system_memory_usage Memory usage in bytes`);
|
||||
lines.push(`# TYPE llm_system_memory_usage gauge`);
|
||||
lines.push(`llm_system_memory_usage ${systemMetrics.memoryUsage}`);
|
||||
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Map internal metric type to Prometheus type
|
||||
*/
|
||||
private mapMetricType(type: MetricType): string {
|
||||
switch (type) {
|
||||
case MetricType.COUNTER:
|
||||
return 'counter';
|
||||
case MetricType.GAUGE:
|
||||
return 'gauge';
|
||||
case MetricType.HISTOGRAM:
|
||||
return 'histogram';
|
||||
case MetricType.SUMMARY:
|
||||
return 'summary';
|
||||
default:
|
||||
return 'gauge';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* StatsD format exporter
|
||||
*/
|
||||
export class StatsDExporter {
|
||||
constructor(
|
||||
private collector: MetricsCollector,
|
||||
private prefix: string = 'llm'
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Export metrics in StatsD format
|
||||
*/
|
||||
public export(): string[] {
|
||||
const lines: string[] = [];
|
||||
const metrics = this.collector.getAllMetrics();
|
||||
|
||||
for (const [_, dataPoints] of metrics) {
|
||||
if (dataPoints.length === 0) continue;
|
||||
|
||||
const latest = dataPoints[dataPoints.length - 1];
|
||||
const metricName = this.formatMetricName(latest.name, latest.labels);
|
||||
|
||||
switch (latest.type) {
|
||||
case MetricType.COUNTER:
|
||||
lines.push(`${metricName}:${latest.value}|c`);
|
||||
break;
|
||||
case MetricType.GAUGE:
|
||||
lines.push(`${metricName}:${latest.value}|g`);
|
||||
break;
|
||||
case MetricType.HISTOGRAM:
|
||||
lines.push(`${metricName}:${latest.value}|h`);
|
||||
break;
|
||||
default:
|
||||
lines.push(`${metricName}:${latest.value}|g`);
|
||||
}
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format metric name for StatsD
|
||||
*/
|
||||
private formatMetricName(name: string, labels: Record<string, string>): string {
|
||||
const parts = [this.prefix, name];
|
||||
|
||||
// Add important labels to the metric name
|
||||
if (labels.provider) {
|
||||
parts.push(labels.provider);
|
||||
}
|
||||
|
||||
return parts.join('.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenTelemetry format exporter
|
||||
*/
|
||||
export class OpenTelemetryExporter {
|
||||
constructor(private collector: MetricsCollector) {}
|
||||
|
||||
/**
|
||||
* Export metrics in OpenTelemetry format
|
||||
*/
|
||||
public export(): object {
|
||||
const metrics = this.collector.getAllMetrics();
|
||||
const providerMetrics = this.collector.getProviderMetricsMap();
|
||||
const systemMetrics = this.collector.getSystemMetrics();
|
||||
|
||||
const resource = {
|
||||
attributes: {
|
||||
'service.name': 'trilium-llm',
|
||||
'service.version': '1.0.0'
|
||||
}
|
||||
};
|
||||
|
||||
const scopeMetrics = {
|
||||
scope: {
|
||||
name: 'trilium.llm.metrics',
|
||||
version: '1.0.0'
|
||||
},
|
||||
metrics: [] as any[]
|
||||
};
|
||||
|
||||
// Convert internal metrics to OTLP format
|
||||
for (const [key, dataPoints] of metrics) {
|
||||
if (dataPoints.length === 0) continue;
|
||||
|
||||
const latest = dataPoints[dataPoints.length - 1];
|
||||
const metric = {
|
||||
name: latest.name,
|
||||
description: latest.description,
|
||||
unit: latest.unit,
|
||||
data: {
|
||||
dataPoints: dataPoints.map(dp => ({
|
||||
attributes: dp.labels,
|
||||
timeUnixNano: dp.timestamp.getTime() * 1000000,
|
||||
value: dp.value
|
||||
}))
|
||||
}
|
||||
};
|
||||
|
||||
scopeMetrics.metrics.push(metric);
|
||||
}
|
||||
|
||||
return {
|
||||
resourceMetrics: [{
|
||||
resource,
|
||||
scopeMetrics: [scopeMetrics]
|
||||
}]
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Metrics Exporter Manager
|
||||
*/
|
||||
export class MetricsExporter {
|
||||
private static instance: MetricsExporter | null = null;
|
||||
private collector: MetricsCollector;
|
||||
private exporters: Map<ExportFormat, any> = new Map();
|
||||
private exportTimer?: NodeJS.Timeout;
|
||||
private config: ExporterConfig;
|
||||
|
||||
constructor(config?: Partial<ExporterConfig>) {
|
||||
this.collector = new MetricsCollector();
|
||||
this.config = {
|
||||
enabled: config?.enabled ?? false,
|
||||
format: config?.format ?? ExportFormat.PROMETHEUS,
|
||||
interval: config?.interval ?? 60000, // 1 minute
|
||||
prefix: config?.prefix ?? 'llm',
|
||||
includeHistograms: config?.includeHistograms ?? true,
|
||||
histogramBuckets: config?.histogramBuckets ?? [10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000],
|
||||
...config
|
||||
};
|
||||
|
||||
this.initializeExporters();
|
||||
|
||||
if (this.config.enabled && this.config.interval) {
|
||||
this.startAutoExport();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get singleton instance
|
||||
*/
|
||||
public static getInstance(config?: Partial<ExporterConfig>): MetricsExporter {
|
||||
if (!MetricsExporter.instance) {
|
||||
MetricsExporter.instance = new MetricsExporter(config);
|
||||
}
|
||||
return MetricsExporter.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize exporters
|
||||
*/
|
||||
private initializeExporters(): void {
|
||||
this.exporters.set(
|
||||
ExportFormat.PROMETHEUS,
|
||||
new PrometheusExporter(this.collector)
|
||||
);
|
||||
|
||||
this.exporters.set(
|
||||
ExportFormat.STATSD,
|
||||
new StatsDExporter(this.collector, this.config.prefix)
|
||||
);
|
||||
|
||||
this.exporters.set(
|
||||
ExportFormat.OPENTELEMETRY,
|
||||
new OpenTelemetryExporter(this.collector)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start auto export
|
||||
*/
|
||||
private startAutoExport(): void {
|
||||
if (this.exportTimer) {
|
||||
clearInterval(this.exportTimer);
|
||||
}
|
||||
|
||||
this.exportTimer = setInterval(() => {
|
||||
this.export();
|
||||
}, this.config.interval);
|
||||
}
|
||||
|
||||
/**
|
||||
* Export metrics
|
||||
*/
|
||||
public export(format?: ExportFormat): any {
|
||||
const exportFormat = format || this.config.format;
|
||||
const exporter = this.exporters.get(exportFormat);
|
||||
|
||||
if (!exporter) {
|
||||
log.error(`[MetricsExporter] Unknown export format: ${exportFormat}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const data = exporter.export();
|
||||
|
||||
if (this.config.endpoint) {
|
||||
this.sendToEndpoint(data, exportFormat);
|
||||
}
|
||||
|
||||
return data;
|
||||
} catch (error) {
|
||||
log.error(`[MetricsExporter] Export failed: ${error}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send metrics to endpoint
|
||||
*/
|
||||
private async sendToEndpoint(data: any, format: ExportFormat): Promise<void> {
|
||||
if (!this.config.endpoint) return;
|
||||
|
||||
try {
|
||||
const contentType = this.getContentType(format);
|
||||
const body = typeof data === 'string' ? data : JSON.stringify(data);
|
||||
|
||||
// This would be replaced with actual HTTP client
|
||||
log.info(`[MetricsExporter] Would send metrics to ${this.config.endpoint}`);
|
||||
// await fetch(this.config.endpoint, {
|
||||
// method: 'POST',
|
||||
// headers: { 'Content-Type': contentType },
|
||||
// body
|
||||
// });
|
||||
} catch (error) {
|
||||
log.error(`[MetricsExporter] Failed to send metrics: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get content type for format
|
||||
*/
|
||||
private getContentType(format: ExportFormat): string {
|
||||
switch (format) {
|
||||
case ExportFormat.PROMETHEUS:
|
||||
return 'text/plain; version=0.0.4';
|
||||
case ExportFormat.STATSD:
|
||||
return 'text/plain';
|
||||
case ExportFormat.OPENTELEMETRY:
|
||||
return 'application/json';
|
||||
default:
|
||||
return 'application/json';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get metrics collector
|
||||
*/
|
||||
public getCollector(): MetricsCollector {
|
||||
return this.collector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable/disable exporter
|
||||
*/
|
||||
public setEnabled(enabled: boolean): void {
|
||||
this.config.enabled = enabled;
|
||||
|
||||
if (enabled && this.config.interval && !this.exportTimer) {
|
||||
this.startAutoExport();
|
||||
} else if (!enabled && this.exportTimer) {
|
||||
clearInterval(this.exportTimer);
|
||||
this.exportTimer = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update configuration
|
||||
*/
|
||||
public updateConfig(config: Partial<ExporterConfig>): void {
|
||||
this.config = { ...this.config, ...config };
|
||||
|
||||
if (this.config.enabled && this.config.interval) {
|
||||
this.startAutoExport();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose exporter
|
||||
*/
|
||||
public dispose(): void {
|
||||
if (this.exportTimer) {
|
||||
clearInterval(this.exportTimer);
|
||||
this.exportTimer = undefined;
|
||||
}
|
||||
|
||||
this.collector.clear();
|
||||
MetricsExporter.instance = null;
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton getter
|
||||
export const getMetricsExporter = (config?: Partial<ExporterConfig>): MetricsExporter => {
|
||||
return MetricsExporter.getInstance(config);
|
||||
};
|
||||
@@ -0,0 +1,319 @@
|
||||
/**
|
||||
* Circuit Breaker Tests
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import {
|
||||
CircuitBreaker,
|
||||
CircuitBreakerManager,
|
||||
CircuitState,
|
||||
CircuitOpenError
|
||||
} from '../circuit_breaker.js';
|
||||
|
||||
describe('CircuitBreaker', () => {
|
||||
let breaker: CircuitBreaker;
|
||||
|
||||
beforeEach(() => {
|
||||
breaker = new CircuitBreaker('test-provider', {
|
||||
failureThreshold: 3,
|
||||
failureWindow: 1000,
|
||||
cooldownPeriod: 500,
|
||||
successThreshold: 2,
|
||||
enableLogging: false
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
breaker.dispose();
|
||||
});
|
||||
|
||||
describe('State Transitions', () => {
|
||||
it('should start in CLOSED state', () => {
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
});
|
||||
|
||||
it('should open after failure threshold', async () => {
|
||||
const failingFn = () => Promise.reject(new Error('Test failure'));
|
||||
|
||||
// First two failures - should remain closed
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
|
||||
// Third failure - should open
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
|
||||
expect(breaker.getState()).toBe(CircuitState.OPEN);
|
||||
});
|
||||
|
||||
it('should reject requests when open', async () => {
|
||||
// Force open
|
||||
breaker.forceOpen('Test');
|
||||
|
||||
const fn = () => Promise.resolve('success');
|
||||
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
|
||||
});
|
||||
|
||||
it('should transition to HALF_OPEN after cooldown', async () => {
|
||||
// Force open
|
||||
breaker.forceOpen('Test');
|
||||
expect(breaker.getState()).toBe(CircuitState.OPEN);
|
||||
|
||||
// Wait for cooldown
|
||||
await new Promise(resolve => setTimeout(resolve, 600));
|
||||
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
|
||||
});
|
||||
|
||||
it('should close after success threshold in HALF_OPEN', async () => {
|
||||
// Force to half-open
|
||||
breaker.forceOpen('Test');
|
||||
await new Promise(resolve => setTimeout(resolve, 600));
|
||||
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
|
||||
|
||||
const successFn = () => Promise.resolve('success');
|
||||
|
||||
// First success
|
||||
await breaker.execute(successFn);
|
||||
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
|
||||
|
||||
// Second success - should close
|
||||
await breaker.execute(successFn);
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
});
|
||||
|
||||
it('should reopen on failure in HALF_OPEN', async () => {
|
||||
// Force to half-open
|
||||
breaker.forceOpen('Test');
|
||||
await new Promise(resolve => setTimeout(resolve, 600));
|
||||
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
|
||||
|
||||
const failingFn = () => Promise.reject(new Error('Test failure'));
|
||||
|
||||
// Failure in half-open should immediately open
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
|
||||
expect(breaker.getState()).toBe(CircuitState.OPEN);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Failure Window', () => {
|
||||
it('should reset failures outside window', async () => {
|
||||
const failingFn = () => Promise.reject(new Error('Test failure'));
|
||||
|
||||
// Two failures
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow();
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow();
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
|
||||
// Wait for window to expire
|
||||
await new Promise(resolve => setTimeout(resolve, 1100));
|
||||
|
||||
// Two more failures - should still be closed (window reset)
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow();
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow();
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Statistics', () => {
|
||||
it('should track statistics', async () => {
|
||||
const successFn = () => Promise.resolve('success');
|
||||
const failingFn = () => Promise.reject(new Error('Test failure'));
|
||||
|
||||
await breaker.execute(successFn);
|
||||
await expect(breaker.execute(failingFn)).rejects.toThrow();
|
||||
|
||||
const stats = breaker.getStats();
|
||||
expect(stats.totalRequests).toBe(2);
|
||||
expect(stats.successes).toBe(1);
|
||||
expect(stats.failures).toBe(1);
|
||||
expect(stats.rejectedRequests).toBe(0);
|
||||
});
|
||||
|
||||
it('should track rejected requests', async () => {
|
||||
breaker.forceOpen('Test');
|
||||
|
||||
const fn = () => Promise.resolve('success');
|
||||
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
|
||||
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
|
||||
|
||||
const stats = breaker.getStats();
|
||||
expect(stats.rejectedRequests).toBe(2);
|
||||
});
|
||||
|
||||
it('should track state history', async () => {
|
||||
breaker.forceOpen('Test open');
|
||||
breaker.forceClose('Test close');
|
||||
|
||||
const stats = breaker.getStats();
|
||||
expect(stats.stateHistory).toHaveLength(2);
|
||||
expect(stats.stateHistory[0].state).toBe(CircuitState.OPEN);
|
||||
expect(stats.stateHistory[1].state).toBe(CircuitState.CLOSED);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Timeout', () => {
|
||||
it('should apply timeout in HALF_OPEN state', async () => {
|
||||
// Force to half-open
|
||||
breaker.forceOpen('Test');
|
||||
await new Promise(resolve => setTimeout(resolve, 600));
|
||||
|
||||
const slowFn = () => new Promise(resolve =>
|
||||
setTimeout(() => resolve('success'), 10000)
|
||||
);
|
||||
|
||||
// Should timeout (half-open timeout is 5000ms in config)
|
||||
await expect(breaker.execute(slowFn)).rejects.toThrow(/timed out/);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('CircuitBreakerManager', () => {
|
||||
let manager: CircuitBreakerManager;
|
||||
|
||||
beforeEach(() => {
|
||||
manager = new CircuitBreakerManager({
|
||||
failureThreshold: 2,
|
||||
failureWindow: 1000,
|
||||
cooldownPeriod: 500,
|
||||
enableLogging: false
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
manager.dispose();
|
||||
});
|
||||
|
||||
describe('Breaker Management', () => {
|
||||
it('should create breakers on demand', () => {
|
||||
const breaker1 = manager.getBreaker('provider1');
|
||||
const breaker2 = manager.getBreaker('provider2');
|
||||
|
||||
expect(breaker1).toBeDefined();
|
||||
expect(breaker2).toBeDefined();
|
||||
expect(breaker1).not.toBe(breaker2);
|
||||
});
|
||||
|
||||
it('should return same breaker for same provider', () => {
|
||||
const breaker1 = manager.getBreaker('provider1');
|
||||
const breaker2 = manager.getBreaker('provider1');
|
||||
|
||||
expect(breaker1).toBe(breaker2);
|
||||
});
|
||||
|
||||
it('should execute with breaker protection', async () => {
|
||||
const fn = () => Promise.resolve('success');
|
||||
const result = await manager.execute('provider1', fn);
|
||||
|
||||
expect(result).toBe('success');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Health Summary', () => {
|
||||
it('should provide health summary', async () => {
|
||||
const failingFn = () => Promise.reject(new Error('Test'));
|
||||
const successFn = () => Promise.resolve('success');
|
||||
|
||||
// Create some breakers in different states
|
||||
await manager.execute('provider1', successFn);
|
||||
|
||||
// Force provider2 to open
|
||||
const breaker2 = manager.getBreaker('provider2');
|
||||
breaker2.forceOpen('Test');
|
||||
|
||||
const summary = manager.getHealthSummary();
|
||||
expect(summary.total).toBe(2);
|
||||
expect(summary.closed).toBe(1);
|
||||
expect(summary.open).toBe(1);
|
||||
expect(summary.availableProviders).toContain('provider1');
|
||||
expect(summary.unavailableProviders).toContain('provider2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Global Operations', () => {
|
||||
it('should reset all breakers', () => {
|
||||
const breaker1 = manager.getBreaker('provider1');
|
||||
const breaker2 = manager.getBreaker('provider2');
|
||||
|
||||
breaker1.forceOpen('Test');
|
||||
breaker2.forceOpen('Test');
|
||||
|
||||
manager.resetAll();
|
||||
|
||||
expect(breaker1.getState()).toBe(CircuitState.CLOSED);
|
||||
expect(breaker2.getState()).toBe(CircuitState.CLOSED);
|
||||
});
|
||||
|
||||
it('should get all stats', async () => {
|
||||
await manager.execute('provider1', () => Promise.resolve('success'));
|
||||
await manager.execute('provider2', () => Promise.resolve('success'));
|
||||
|
||||
const allStats = manager.getAllStats();
|
||||
expect(allStats.size).toBe(2);
|
||||
expect(allStats.has('provider1')).toBe(true);
|
||||
expect(allStats.has('provider2')).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Circuit Breaker Integration', () => {
|
||||
it('should handle rapid failures gracefully', async () => {
|
||||
const breaker = new CircuitBreaker('rapid-test', {
|
||||
failureThreshold: 3,
|
||||
failureWindow: 1000,
|
||||
cooldownPeriod: 100,
|
||||
enableLogging: false
|
||||
});
|
||||
|
||||
const failingFn = () => Promise.reject(new Error('Rapid failure'));
|
||||
const promises: Promise<any>[] = [];
|
||||
|
||||
// Fire 10 rapid requests
|
||||
for (let i = 0; i < 10; i++) {
|
||||
promises.push(
|
||||
breaker.execute(failingFn).catch(err => err.message)
|
||||
);
|
||||
}
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
// First 3 should be actual failures
|
||||
expect(results.slice(0, 3)).toEqual([
|
||||
'Rapid failure',
|
||||
'Rapid failure',
|
||||
'Rapid failure'
|
||||
]);
|
||||
|
||||
// Rest should be circuit open errors
|
||||
const openErrors = results.slice(3).filter((msg: string) =>
|
||||
msg.includes('Circuit breaker is OPEN')
|
||||
);
|
||||
expect(openErrors.length).toBeGreaterThan(0);
|
||||
|
||||
breaker.dispose();
|
||||
});
|
||||
|
||||
it('should handle concurrent successes correctly', async () => {
|
||||
const breaker = new CircuitBreaker('concurrent-test', {
|
||||
failureThreshold: 3,
|
||||
enableLogging: false
|
||||
});
|
||||
|
||||
let counter = 0;
|
||||
const successFn = () => {
|
||||
counter++;
|
||||
return Promise.resolve(counter);
|
||||
};
|
||||
|
||||
const promises: Promise<number>[] = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
promises.push(breaker.execute(successFn));
|
||||
}
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
expect(results).toEqual([1, 2, 3, 4, 5]);
|
||||
expect(breaker.getState()).toBe(CircuitState.CLOSED);
|
||||
|
||||
breaker.dispose();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,482 @@
|
||||
/**
|
||||
* Mock Providers for Testing
|
||||
*
|
||||
* Provides mock implementations of AI service providers for testing purposes
|
||||
*/
|
||||
|
||||
import type { AIService, ChatCompletionOptions, ChatResponse, Message } from '../../ai_interface.js';
|
||||
import type { UnifiedStreamChunk } from '../unified_stream_handler.js';
|
||||
|
||||
/**
|
||||
* Mock provider configuration
|
||||
*/
|
||||
export interface MockProviderConfig {
|
||||
name: string;
|
||||
available: boolean;
|
||||
responseDelay?: number;
|
||||
errorRate?: number;
|
||||
streamingSupported?: boolean;
|
||||
toolsSupported?: boolean;
|
||||
defaultResponse?: string;
|
||||
throwError?: Error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base mock provider implementation
|
||||
*/
|
||||
export class MockProvider implements AIService {
|
||||
protected config: MockProviderConfig;
|
||||
private callCount: number = 0;
|
||||
private streamCallCount: number = 0;
|
||||
|
||||
constructor(config: Partial<MockProviderConfig> = {}) {
|
||||
this.config = {
|
||||
name: config.name || 'mock',
|
||||
available: config.available !== false,
|
||||
responseDelay: config.responseDelay || 0,
|
||||
errorRate: config.errorRate || 0,
|
||||
streamingSupported: config.streamingSupported !== false,
|
||||
toolsSupported: config.toolsSupported !== false,
|
||||
defaultResponse: config.defaultResponse || 'Mock response',
|
||||
throwError: config.throwError
|
||||
};
|
||||
}
|
||||
|
||||
isAvailable(): boolean {
|
||||
return this.config.available;
|
||||
}
|
||||
|
||||
getName(): string {
|
||||
return this.config.name;
|
||||
}
|
||||
|
||||
async generateChatCompletion(
|
||||
messages: Message[],
|
||||
options: ChatCompletionOptions = {}
|
||||
): Promise<ChatResponse> {
|
||||
this.callCount++;
|
||||
|
||||
// Simulate delay
|
||||
if (this.config.responseDelay) {
|
||||
await new Promise(resolve => setTimeout(resolve, this.config.responseDelay));
|
||||
}
|
||||
|
||||
// Simulate errors
|
||||
if (this.config.throwError) {
|
||||
throw this.config.throwError;
|
||||
}
|
||||
|
||||
if (this.config.errorRate && Math.random() < this.config.errorRate) {
|
||||
throw new Error(`Mock provider error (${this.config.name})`);
|
||||
}
|
||||
|
||||
// Handle streaming
|
||||
if (options.stream && options.streamCallback) {
|
||||
return this.generateStreamingResponse(messages, options);
|
||||
}
|
||||
|
||||
// Generate response based on options
|
||||
const response: ChatResponse = {
|
||||
text: this.generateContent(messages, options),
|
||||
model: `${this.config.name}-model`,
|
||||
provider: this.config.name,
|
||||
usage: {
|
||||
promptTokens: this.calculateTokens(messages),
|
||||
completionTokens: 10,
|
||||
totalTokens: this.calculateTokens(messages) + 10
|
||||
}
|
||||
};
|
||||
|
||||
// Add tool calls if requested
|
||||
if (options.tools && this.config.toolsSupported) {
|
||||
response.tool_calls = this.generateToolCalls(options.tools);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
protected async generateStreamingResponse(
|
||||
messages: Message[],
|
||||
options: ChatCompletionOptions
|
||||
): Promise<ChatResponse> {
|
||||
this.streamCallCount++;
|
||||
|
||||
const content = this.generateContent(messages, options);
|
||||
const chunks = this.splitIntoChunks(content, 5);
|
||||
|
||||
let fullContent = '';
|
||||
|
||||
for (const chunk of chunks) {
|
||||
fullContent += chunk;
|
||||
|
||||
// Call stream callback
|
||||
if (options.streamCallback) {
|
||||
await options.streamCallback(chunk, false);
|
||||
|
||||
// Simulate delay between chunks
|
||||
if (this.config.responseDelay) {
|
||||
await new Promise(resolve =>
|
||||
setTimeout(resolve, this.config.responseDelay! / chunks.length)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send final callback
|
||||
if (options.streamCallback) {
|
||||
await options.streamCallback('', true);
|
||||
}
|
||||
|
||||
return {
|
||||
text: fullContent,
|
||||
model: `${this.config.name}-model`,
|
||||
provider: this.config.name,
|
||||
usage: {
|
||||
promptTokens: this.calculateTokens(messages),
|
||||
completionTokens: Math.floor(fullContent.length / 4),
|
||||
totalTokens: this.calculateTokens(messages) + Math.floor(fullContent.length / 4)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
protected generateContent(messages: Message[], options: ChatCompletionOptions): string {
|
||||
// Return JSON if requested
|
||||
if (options.expectsJsonResponse) {
|
||||
return JSON.stringify({
|
||||
type: 'mock_response',
|
||||
provider: this.config.name,
|
||||
messageCount: messages.length
|
||||
});
|
||||
}
|
||||
|
||||
// Use custom response if provided
|
||||
if (this.config.defaultResponse) {
|
||||
return this.config.defaultResponse;
|
||||
}
|
||||
|
||||
// Generate response based on last message
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
return `Mock ${this.config.name} response to: ${lastMessage.content}`;
|
||||
}
|
||||
|
||||
protected generateToolCalls(tools: any[]): any[] {
|
||||
return tools.slice(0, 1).map((tool, index) => ({
|
||||
id: `call_mock_${index}`,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.function?.name || 'mock_tool',
|
||||
arguments: JSON.stringify({ mock: true })
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
protected calculateTokens(messages: Message[]): number {
|
||||
return messages.reduce((sum, msg) => {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
return sum + Math.floor(content.length / 4);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
protected splitIntoChunks(text: string, chunkCount: number): string[] {
|
||||
const chunkSize = Math.ceil(text.length / chunkCount);
|
||||
const chunks: string[] = [];
|
||||
|
||||
for (let i = 0; i < text.length; i += chunkSize) {
|
||||
chunks.push(text.slice(i, i + chunkSize));
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
// Test helper methods
|
||||
getCallCount(): number {
|
||||
return this.callCount;
|
||||
}
|
||||
|
||||
getStreamCallCount(): number {
|
||||
return this.streamCallCount;
|
||||
}
|
||||
|
||||
resetCallCounts(): void {
|
||||
this.callCount = 0;
|
||||
this.streamCallCount = 0;
|
||||
}
|
||||
|
||||
setAvailable(available: boolean): void {
|
||||
this.config.available = available;
|
||||
}
|
||||
|
||||
setErrorRate(rate: number): void {
|
||||
this.config.errorRate = rate;
|
||||
}
|
||||
|
||||
setResponseDelay(delay: number): void {
|
||||
this.config.responseDelay = delay;
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
// Cleanup mock resources
|
||||
this.resetCallCounts();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock OpenAI provider
|
||||
*/
|
||||
export class MockOpenAIProvider extends MockProvider {
|
||||
constructor(config: Partial<MockProviderConfig> = {}) {
|
||||
super({
|
||||
name: 'openai',
|
||||
...config
|
||||
});
|
||||
}
|
||||
|
||||
supportsStreaming(): boolean {
|
||||
return this.config.streamingSupported!;
|
||||
}
|
||||
|
||||
supportsTools(): boolean {
|
||||
return this.config.toolsSupported!;
|
||||
}
|
||||
|
||||
async *streamCompletion(
|
||||
messages: Message[],
|
||||
options: ChatCompletionOptions = {}
|
||||
): AsyncGenerator<any> {
|
||||
const content = this.generateContent(messages, options);
|
||||
const chunks = this.splitIntoChunks(content, 5);
|
||||
|
||||
for (const chunk of chunks) {
|
||||
yield {
|
||||
choices: [{
|
||||
delta: { content: chunk },
|
||||
index: 0
|
||||
}],
|
||||
model: 'gpt-4-mock'
|
||||
};
|
||||
|
||||
if (this.config.responseDelay) {
|
||||
await new Promise(resolve =>
|
||||
setTimeout(resolve, this.config.responseDelay! / chunks.length)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
choices: [{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
index: 0
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: this.calculateTokens(messages),
|
||||
completion_tokens: Math.floor(content.length / 4),
|
||||
total_tokens: this.calculateTokens(messages) + Math.floor(content.length / 4)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock Anthropic provider
|
||||
*/
|
||||
export class MockAnthropicProvider extends MockProvider {
|
||||
constructor(config: Partial<MockProviderConfig> = {}) {
|
||||
super({
|
||||
name: 'anthropic',
|
||||
...config
|
||||
});
|
||||
}
|
||||
|
||||
async *streamCompletion(
|
||||
messages: Message[],
|
||||
options: ChatCompletionOptions = {}
|
||||
): AsyncGenerator<any> {
|
||||
const content = this.generateContent(messages, options);
|
||||
const chunks = this.splitIntoChunks(content, 5);
|
||||
|
||||
// Message start
|
||||
yield {
|
||||
type: 'message_start',
|
||||
message: { id: 'msg_mock_123' }
|
||||
};
|
||||
|
||||
// Content blocks
|
||||
for (const chunk of chunks) {
|
||||
yield {
|
||||
type: 'content_block_delta',
|
||||
delta: {
|
||||
type: 'text_delta',
|
||||
text: chunk
|
||||
}
|
||||
};
|
||||
|
||||
if (this.config.responseDelay) {
|
||||
await new Promise(resolve =>
|
||||
setTimeout(resolve, this.config.responseDelay! / chunks.length)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Message end
|
||||
yield {
|
||||
type: 'message_delta',
|
||||
delta: { stop_reason: 'end_turn' },
|
||||
usage: {
|
||||
input_tokens: this.calculateTokens(messages),
|
||||
output_tokens: Math.floor(content.length / 4)
|
||||
}
|
||||
};
|
||||
|
||||
yield {
|
||||
type: 'message_stop'
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock Ollama provider
|
||||
*/
|
||||
export class MockOllamaProvider extends MockProvider {
|
||||
constructor(config: Partial<MockProviderConfig> = {}) {
|
||||
super({
|
||||
name: 'ollama',
|
||||
...config
|
||||
});
|
||||
}
|
||||
|
||||
async *streamCompletion(
|
||||
messages: Message[],
|
||||
options: ChatCompletionOptions = {}
|
||||
): AsyncGenerator<any> {
|
||||
const content = this.generateContent(messages, options);
|
||||
const chunks = this.splitIntoChunks(content, 5);
|
||||
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
yield {
|
||||
message: { content: chunks[i] },
|
||||
model: 'llama2-mock',
|
||||
done: false
|
||||
};
|
||||
|
||||
if (this.config.responseDelay) {
|
||||
await new Promise(resolve =>
|
||||
setTimeout(resolve, this.config.responseDelay! / chunks.length)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Final chunk with usage
|
||||
yield {
|
||||
message: { content: '' },
|
||||
model: 'llama2-mock',
|
||||
done: true,
|
||||
prompt_eval_count: this.calculateTokens(messages),
|
||||
eval_count: Math.floor(content.length / 4)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory for creating mock providers
|
||||
*/
|
||||
export class MockProviderFactory {
|
||||
private providers: Map<string, MockProvider> = new Map();
|
||||
|
||||
createProvider(type: 'openai' | 'anthropic' | 'ollama', config?: Partial<MockProviderConfig>): MockProvider {
|
||||
let provider: MockProvider;
|
||||
|
||||
switch (type) {
|
||||
case 'openai':
|
||||
provider = new MockOpenAIProvider(config);
|
||||
break;
|
||||
case 'anthropic':
|
||||
provider = new MockAnthropicProvider(config);
|
||||
break;
|
||||
case 'ollama':
|
||||
provider = new MockOllamaProvider(config);
|
||||
break;
|
||||
default:
|
||||
provider = new MockProvider({ name: type, ...config });
|
||||
}
|
||||
|
||||
this.providers.set(type, provider);
|
||||
return provider;
|
||||
}
|
||||
|
||||
getProvider(type: string): MockProvider | undefined {
|
||||
return this.providers.get(type);
|
||||
}
|
||||
|
||||
getAllProviders(): MockProvider[] {
|
||||
return Array.from(this.providers.values());
|
||||
}
|
||||
|
||||
resetAll(): void {
|
||||
for (const provider of this.providers.values()) {
|
||||
provider.resetCallCounts();
|
||||
}
|
||||
}
|
||||
|
||||
disposeAll(): void {
|
||||
for (const provider of this.providers.values()) {
|
||||
provider.dispose();
|
||||
}
|
||||
this.providers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a mock provider with predefined behaviors
|
||||
*/
|
||||
export function createMockProvider(behavior: 'success' | 'error' | 'slow' | 'flaky'): MockProvider {
|
||||
const configs: Record<string, Partial<MockProviderConfig>> = {
|
||||
success: {
|
||||
available: true,
|
||||
responseDelay: 10
|
||||
},
|
||||
error: {
|
||||
available: true,
|
||||
throwError: new Error('Mock provider error')
|
||||
},
|
||||
slow: {
|
||||
available: true,
|
||||
responseDelay: 1000
|
||||
},
|
||||
flaky: {
|
||||
available: true,
|
||||
errorRate: 0.5,
|
||||
responseDelay: 100
|
||||
}
|
||||
};
|
||||
|
||||
return new MockProvider(configs[behavior] || configs.success);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a mock streaming response for testing
|
||||
*/
|
||||
export async function* createMockStream(
|
||||
chunks: string[],
|
||||
delay: number = 10
|
||||
): AsyncGenerator<UnifiedStreamChunk> {
|
||||
for (const chunk of chunks) {
|
||||
yield {
|
||||
type: 'content',
|
||||
content: chunk,
|
||||
metadata: {
|
||||
provider: 'mock'
|
||||
}
|
||||
};
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
metadata: {
|
||||
provider: 'mock',
|
||||
finishReason: 'stop'
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
/**
|
||||
* Provider Performance Benchmarks
|
||||
*
|
||||
* Performance benchmark suite for AI service providers
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { performance } from 'perf_hooks';
|
||||
import { ProviderFactory, ProviderType } from '../provider_factory.js';
|
||||
import {
|
||||
MockProviderFactory,
|
||||
createMockProvider
|
||||
} from './mock_providers.js';
|
||||
import {
|
||||
StreamAggregator,
|
||||
createStreamHandler
|
||||
} from '../unified_stream_handler.js';
|
||||
import type { AIService, Message } from '../../ai_interface.js';
|
||||
|
||||
// Mock providers
|
||||
vi.mock('../openai_service.js');
|
||||
vi.mock('../anthropic_service.js');
|
||||
vi.mock('../ollama_service.js');
|
||||
|
||||
import { OpenAIService } from '../openai_service.js';
|
||||
import { AnthropicService } from '../anthropic_service.js';
|
||||
import { OllamaService } from '../ollama_service.js';
|
||||
|
||||
/**
|
||||
* Performance metrics interface
|
||||
*/
|
||||
interface PerformanceMetrics {
|
||||
operation: string;
|
||||
provider: string;
|
||||
duration: number;
|
||||
throughput?: number;
|
||||
latency?: number;
|
||||
memoryUsed?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Benchmark runner class
|
||||
*/
|
||||
class BenchmarkRunner {
|
||||
private metrics: PerformanceMetrics[] = [];
|
||||
|
||||
async runBenchmark(
|
||||
name: string,
|
||||
provider: string,
|
||||
fn: () => Promise<void>,
|
||||
iterations: number = 100
|
||||
): Promise<PerformanceMetrics> {
|
||||
const startMemory = process.memoryUsage().heapUsed;
|
||||
const startTime = performance.now();
|
||||
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
await fn();
|
||||
}
|
||||
|
||||
const endTime = performance.now();
|
||||
const endMemory = process.memoryUsage().heapUsed;
|
||||
|
||||
const metrics: PerformanceMetrics = {
|
||||
operation: name,
|
||||
provider,
|
||||
duration: endTime - startTime,
|
||||
throughput: iterations / ((endTime - startTime) / 1000),
|
||||
latency: (endTime - startTime) / iterations,
|
||||
memoryUsed: endMemory - startMemory
|
||||
};
|
||||
|
||||
this.metrics.push(metrics);
|
||||
return metrics;
|
||||
}
|
||||
|
||||
getMetrics(): PerformanceMetrics[] {
|
||||
return [...this.metrics];
|
||||
}
|
||||
|
||||
printSummary(): void {
|
||||
console.table(this.metrics.map(m => ({
|
||||
Operation: m.operation,
|
||||
Provider: m.provider,
|
||||
'Avg Latency (ms)': m.latency?.toFixed(2),
|
||||
'Throughput (ops/s)': m.throughput?.toFixed(2),
|
||||
'Memory (MB)': ((m.memoryUsed || 0) / 1024 / 1024).toFixed(2)
|
||||
})));
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
this.metrics = [];
|
||||
}
|
||||
}
|
||||
|
||||
describe('Provider Performance Benchmarks', () => {
|
||||
let factory: ProviderFactory;
|
||||
let mockFactory: MockProviderFactory;
|
||||
let runner: BenchmarkRunner;
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear singleton
|
||||
const existing = ProviderFactory.getInstance();
|
||||
if (existing) {
|
||||
existing.dispose();
|
||||
}
|
||||
|
||||
factory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableMetrics: false,
|
||||
enableCaching: true,
|
||||
cacheTimeout: 60000
|
||||
});
|
||||
|
||||
mockFactory = new MockProviderFactory();
|
||||
runner = new BenchmarkRunner();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
factory.dispose();
|
||||
mockFactory.disposeAll();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Provider Creation Performance', () => {
|
||||
it('should benchmark provider creation speed', async () => {
|
||||
const providers = ['openai', 'anthropic', 'ollama'] as const;
|
||||
|
||||
for (const providerName of providers) {
|
||||
const mock = createMockProvider('success');
|
||||
mock.setResponseDelay(0); // No delay for creation benchmarks
|
||||
|
||||
switch (providerName) {
|
||||
case 'openai':
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case 'anthropic':
|
||||
(AnthropicService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case 'ollama':
|
||||
(OllamaService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
}
|
||||
|
||||
const metrics = await runner.runBenchmark(
|
||||
'Provider Creation',
|
||||
providerName,
|
||||
async () => {
|
||||
const provider = await factory.createProvider(
|
||||
ProviderType[providerName.toUpperCase() as keyof typeof ProviderType]
|
||||
);
|
||||
},
|
||||
100
|
||||
);
|
||||
|
||||
expect(metrics.latency).toBeLessThan(10); // Should be fast (< 10ms per creation)
|
||||
expect(metrics.throughput).toBeGreaterThan(100); // > 100 ops/sec
|
||||
}
|
||||
|
||||
if (process.env.SHOW_BENCHMARKS) {
|
||||
runner.printSummary();
|
||||
}
|
||||
});
|
||||
|
||||
it('should benchmark cached vs uncached provider creation', async () => {
|
||||
const mock = createMockProvider('success');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
// Benchmark uncached (first creation)
|
||||
const uncachedFactory = new ProviderFactory({
|
||||
enableCaching: false,
|
||||
enableHealthChecks: false
|
||||
});
|
||||
|
||||
const uncachedMetrics = await runner.runBenchmark(
|
||||
'Uncached Creation',
|
||||
'openai',
|
||||
async () => {
|
||||
await uncachedFactory.createProvider(ProviderType.OPENAI);
|
||||
},
|
||||
50
|
||||
);
|
||||
|
||||
// Benchmark cached
|
||||
runner.reset();
|
||||
const cachedMetrics = await runner.runBenchmark(
|
||||
'Cached Creation',
|
||||
'openai',
|
||||
async () => {
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
},
|
||||
50
|
||||
);
|
||||
|
||||
// Cached should be significantly faster
|
||||
expect(cachedMetrics.latency).toBeLessThan(uncachedMetrics.latency! * 0.5);
|
||||
|
||||
uncachedFactory.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Chat Completion Performance', () => {
|
||||
it('should benchmark chat completion latency', async () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'Hello, how are you?' }
|
||||
];
|
||||
|
||||
const providers = ['openai', 'anthropic', 'ollama'] as const;
|
||||
|
||||
for (const providerName of providers) {
|
||||
const mock = createMockProvider('success');
|
||||
mock.setResponseDelay(10); // Simulate 10ms response time
|
||||
|
||||
switch (providerName) {
|
||||
case 'openai':
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case 'anthropic':
|
||||
(AnthropicService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case 'ollama':
|
||||
(OllamaService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
}
|
||||
|
||||
const provider = await factory.createProvider(
|
||||
ProviderType[providerName.toUpperCase() as keyof typeof ProviderType]
|
||||
);
|
||||
|
||||
const metrics = await runner.runBenchmark(
|
||||
'Chat Completion',
|
||||
providerName,
|
||||
async () => {
|
||||
await provider.generateChatCompletion(messages);
|
||||
},
|
||||
20
|
||||
);
|
||||
|
||||
expect(metrics.latency).toBeGreaterThan(10); // At least the mock delay
|
||||
expect(metrics.latency).toBeLessThan(50); // But not too slow
|
||||
}
|
||||
|
||||
if (process.env.SHOW_BENCHMARKS) {
|
||||
runner.printSummary();
|
||||
}
|
||||
});
|
||||
|
||||
it('should benchmark streaming vs non-streaming performance', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
mock.setResponseDelay(50);
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const provider = await factory.createProvider(ProviderType.OPENAI);
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'Tell me a story' }
|
||||
];
|
||||
|
||||
// Benchmark non-streaming
|
||||
const nonStreamMetrics = await runner.runBenchmark(
|
||||
'Non-Streaming',
|
||||
'openai',
|
||||
async () => {
|
||||
await provider.generateChatCompletion(messages, {
|
||||
stream: false
|
||||
});
|
||||
},
|
||||
10
|
||||
);
|
||||
|
||||
// Benchmark streaming
|
||||
runner.reset();
|
||||
const streamMetrics = await runner.runBenchmark(
|
||||
'Streaming',
|
||||
'openai',
|
||||
async () => {
|
||||
const chunks: string[] = [];
|
||||
await provider.generateChatCompletion(messages, {
|
||||
stream: true,
|
||||
streamCallback: async (chunk) => {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
});
|
||||
},
|
||||
10
|
||||
);
|
||||
|
||||
// Streaming might have different characteristics
|
||||
expect(streamMetrics.latency).toBeDefined();
|
||||
expect(nonStreamMetrics.latency).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent Operations Performance', () => {
|
||||
it('should benchmark concurrent provider operations', async () => {
|
||||
const mock = createMockProvider('success');
|
||||
mock.setResponseDelay(5);
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const provider = await factory.createProvider(ProviderType.OPENAI);
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'Test' }
|
||||
];
|
||||
|
||||
// Sequential benchmark
|
||||
const sequentialStart = performance.now();
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await provider.generateChatCompletion(messages);
|
||||
}
|
||||
const sequentialDuration = performance.now() - sequentialStart;
|
||||
|
||||
// Concurrent benchmark
|
||||
const concurrentStart = performance.now();
|
||||
await Promise.all(
|
||||
Array(10).fill(null).map(() =>
|
||||
provider.generateChatCompletion(messages)
|
||||
)
|
||||
);
|
||||
const concurrentDuration = performance.now() - concurrentStart;
|
||||
|
||||
// Concurrent should be faster
|
||||
expect(concurrentDuration).toBeLessThan(sequentialDuration);
|
||||
|
||||
const speedup = sequentialDuration / concurrentDuration;
|
||||
expect(speedup).toBeGreaterThan(1.5); // At least 1.5x speedup
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memory Performance', () => {
|
||||
it('should benchmark memory usage with cache management', async () => {
|
||||
const mock = createMockProvider('success');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
(AnthropicService as any).mockImplementation(() => mock);
|
||||
(OllamaService as any).mockImplementation(() => mock);
|
||||
|
||||
const startMemory = process.memoryUsage().heapUsed;
|
||||
|
||||
// Create many providers
|
||||
for (let i = 0; i < 100; i++) {
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
await factory.createProvider(ProviderType.ANTHROPIC);
|
||||
await factory.createProvider(ProviderType.OLLAMA);
|
||||
}
|
||||
|
||||
const midMemory = process.memoryUsage().heapUsed;
|
||||
const memoryGrowth = midMemory - startMemory;
|
||||
|
||||
// Clear cache
|
||||
factory.clearCache();
|
||||
|
||||
const endMemory = process.memoryUsage().heapUsed;
|
||||
const memoryReclaimed = midMemory - endMemory;
|
||||
|
||||
// Should reclaim some memory
|
||||
expect(memoryReclaimed).toBeGreaterThan(0);
|
||||
|
||||
// Memory growth should be reasonable (< 50MB for 300 operations)
|
||||
expect(memoryGrowth).toBeLessThan(50 * 1024 * 1024);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Stream Processing Performance', () => {
|
||||
it('should benchmark stream chunk processing speed', async () => {
|
||||
const aggregator = new StreamAggregator();
|
||||
const handler = createStreamHandler({
|
||||
provider: 'openai',
|
||||
onChunk: (chunk) => aggregator.addChunk(chunk)
|
||||
});
|
||||
|
||||
const chunks = Array(100).fill(null).map((_, i) => ({
|
||||
choices: [{
|
||||
delta: { content: `Chunk ${i}` },
|
||||
index: 0
|
||||
}]
|
||||
}));
|
||||
|
||||
const metrics = await runner.runBenchmark(
|
||||
'Stream Processing',
|
||||
'openai',
|
||||
async () => {
|
||||
aggregator.reset();
|
||||
for (const chunk of chunks) {
|
||||
await handler.processChunk(chunk);
|
||||
}
|
||||
},
|
||||
10
|
||||
);
|
||||
|
||||
// Should process chunks quickly
|
||||
const chunksPerSecond = (chunks.length * 10) / (metrics.duration / 1000);
|
||||
expect(chunksPerSecond).toBeGreaterThan(1000); // > 1000 chunks/sec
|
||||
});
|
||||
});
|
||||
|
||||
describe('Health Check Performance', () => {
|
||||
it('should benchmark health check operations', async () => {
|
||||
const providers = [ProviderType.OPENAI, ProviderType.ANTHROPIC, ProviderType.OLLAMA];
|
||||
|
||||
for (const providerType of providers) {
|
||||
const mock = createMockProvider('success');
|
||||
mock.setResponseDelay(20); // Simulate network latency
|
||||
|
||||
switch (providerType) {
|
||||
case ProviderType.OPENAI:
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case ProviderType.ANTHROPIC:
|
||||
(AnthropicService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
case ProviderType.OLLAMA:
|
||||
(OllamaService as any).mockImplementation(() => mock);
|
||||
break;
|
||||
}
|
||||
|
||||
const metrics = await runner.runBenchmark(
|
||||
'Health Check',
|
||||
providerType,
|
||||
async () => {
|
||||
await factory.checkProviderHealth(providerType);
|
||||
},
|
||||
10
|
||||
);
|
||||
|
||||
// Health checks should complete reasonably quickly
|
||||
expect(metrics.latency).toBeLessThan(100); // < 100ms per check
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Performance', () => {
|
||||
it('should benchmark fallback provider switching', async () => {
|
||||
const fallbackFactory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableFallback: true,
|
||||
fallbackProviders: [ProviderType.ANTHROPIC, ProviderType.OLLAMA],
|
||||
enableCaching: false
|
||||
});
|
||||
|
||||
let attemptCount = 0;
|
||||
|
||||
// OpenAI fails first 2 times
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
attemptCount++;
|
||||
if (attemptCount <= 2) {
|
||||
throw new Error('OpenAI unavailable');
|
||||
}
|
||||
return createMockProvider('success');
|
||||
});
|
||||
|
||||
// Anthropic always fails
|
||||
(AnthropicService as any).mockImplementation(() => {
|
||||
throw new Error('Anthropic unavailable');
|
||||
});
|
||||
|
||||
// Ollama succeeds
|
||||
const ollamaMock = createMockProvider('success');
|
||||
(OllamaService as any).mockImplementation(() => ollamaMock);
|
||||
|
||||
const metrics = await runner.runBenchmark(
|
||||
'Fallback Switch',
|
||||
'multi',
|
||||
async () => {
|
||||
attemptCount = 0;
|
||||
await fallbackFactory.createProvider(ProviderType.OPENAI);
|
||||
},
|
||||
10
|
||||
);
|
||||
|
||||
// Fallback should add some overhead but still be reasonable
|
||||
expect(metrics.latency).toBeLessThan(50); // < 50ms including fallback
|
||||
|
||||
fallbackFactory.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
// Only run this in CI or when explicitly requested
|
||||
if (process.env.RUN_FULL_BENCHMARKS) {
|
||||
describe('Load Testing', () => {
|
||||
it('should handle high load scenarios', async () => {
|
||||
const mock = createMockProvider('success');
|
||||
mock.setResponseDelay(1);
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const provider = await factory.createProvider(ProviderType.OPENAI);
|
||||
const messages: Message[] = [{ role: 'user', content: 'Load test' }];
|
||||
|
||||
const loadTestStart = performance.now();
|
||||
const promises = Array(1000).fill(null).map(() =>
|
||||
provider.generateChatCompletion(messages)
|
||||
);
|
||||
|
||||
await Promise.all(promises);
|
||||
const loadTestDuration = performance.now() - loadTestStart;
|
||||
|
||||
const requestsPerSecond = 1000 / (loadTestDuration / 1000);
|
||||
|
||||
// Should handle at least 100 requests per second
|
||||
expect(requestsPerSecond).toBeGreaterThan(100);
|
||||
|
||||
console.log(`Load test: ${requestsPerSecond.toFixed(2)} requests/second`);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,434 @@
|
||||
/**
|
||||
* Provider Factory Tests
|
||||
*
|
||||
* Comprehensive test suite for the provider factory pattern implementation
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import {
|
||||
ProviderFactory,
|
||||
ProviderType,
|
||||
type ProviderCapabilities,
|
||||
type ProviderHealthStatus,
|
||||
getProviderFactory
|
||||
} from '../provider_factory.js';
|
||||
import { OpenAIService } from '../openai_service.js';
|
||||
import { AnthropicService } from '../anthropic_service.js';
|
||||
import { OllamaService } from '../ollama_service.js';
|
||||
import type { AIService, ChatResponse } from '../../ai_interface.js';
|
||||
|
||||
// Mock the services
|
||||
vi.mock('../openai_service.js');
|
||||
vi.mock('../anthropic_service.js');
|
||||
vi.mock('../ollama_service.js');
|
||||
vi.mock('../../log.js', () => ({
|
||||
default: {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
describe('ProviderFactory', () => {
|
||||
let factory: ProviderFactory;
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear any existing singleton
|
||||
const existingFactory = ProviderFactory.getInstance();
|
||||
if (existingFactory) {
|
||||
existingFactory.dispose();
|
||||
}
|
||||
|
||||
// Create new factory instance for testing
|
||||
factory = new ProviderFactory({
|
||||
enableHealthChecks: false, // Disable for tests
|
||||
enableMetrics: false,
|
||||
cacheTimeout: 1000 // Short timeout for tests
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Cleanup
|
||||
factory.dispose();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Provider Creation', () => {
|
||||
it('should create OpenAI provider', async () => {
|
||||
// Mock OpenAI service
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockResolvedValue({
|
||||
content: 'test response',
|
||||
role: 'assistant'
|
||||
})
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
const service = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
expect(service).toBeDefined();
|
||||
expect(mockService.isAvailable).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create Anthropic provider', async () => {
|
||||
// Mock Anthropic service
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockResolvedValue({
|
||||
content: 'test response',
|
||||
role: 'assistant'
|
||||
})
|
||||
};
|
||||
|
||||
(AnthropicService as any).mockImplementation(() => mockService);
|
||||
|
||||
const service = await factory.createProvider(ProviderType.ANTHROPIC);
|
||||
|
||||
expect(service).toBeDefined();
|
||||
expect(mockService.isAvailable).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create Ollama provider', async () => {
|
||||
// Mock Ollama service
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockResolvedValue({
|
||||
content: 'test response',
|
||||
role: 'assistant'
|
||||
})
|
||||
};
|
||||
|
||||
(OllamaService as any).mockImplementation(() => mockService);
|
||||
|
||||
const service = await factory.createProvider(ProviderType.OLLAMA);
|
||||
|
||||
expect(service).toBeDefined();
|
||||
expect(mockService.isAvailable).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw error for unavailable provider', async () => {
|
||||
// Mock service as unavailable
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(false)
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
await expect(factory.createProvider(ProviderType.OPENAI))
|
||||
.rejects.toThrow('OpenAI service is not available');
|
||||
});
|
||||
|
||||
it('should throw error for custom provider (not implemented)', async () => {
|
||||
await expect(factory.createProvider(ProviderType.CUSTOM))
|
||||
.rejects.toThrow('Custom providers not yet implemented');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Provider Caching', () => {
|
||||
it('should cache created providers', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn()
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
const service1 = await factory.createProvider(ProviderType.OPENAI);
|
||||
const service2 = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
// Should return same instance
|
||||
expect(service1).toBe(service2);
|
||||
|
||||
// Constructor should only be called once
|
||||
expect(OpenAIService).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should respect cache timeout', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn()
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
const service1 = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
// Wait for cache to expire
|
||||
await new Promise(resolve => setTimeout(resolve, 1100));
|
||||
|
||||
const service2 = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
// Should create new instance after timeout
|
||||
expect(service1).not.toBe(service2);
|
||||
expect(OpenAIService).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should cache providers with different configurations separately', async () => {
|
||||
const mockService1 = {
|
||||
isAvailable: vi.fn().mockReturnValue(true)
|
||||
};
|
||||
const mockService2 = {
|
||||
isAvailable: vi.fn().mockReturnValue(true)
|
||||
};
|
||||
|
||||
let callCount = 0;
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
callCount++;
|
||||
return callCount === 1 ? mockService1 : mockService2;
|
||||
});
|
||||
|
||||
const service1 = await factory.createProvider(ProviderType.OPENAI, { baseUrl: 'url1' });
|
||||
const service2 = await factory.createProvider(ProviderType.OPENAI, { baseUrl: 'url2' });
|
||||
|
||||
expect(service1).not.toBe(service2);
|
||||
expect(OpenAIService).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Capabilities Detection', () => {
|
||||
it('should return default capabilities for providers', () => {
|
||||
const openAICaps = factory.getCapabilities(ProviderType.OPENAI);
|
||||
|
||||
expect(openAICaps).toBeDefined();
|
||||
expect(openAICaps?.streaming).toBe(true);
|
||||
expect(openAICaps?.functionCalling).toBe(true);
|
||||
expect(openAICaps?.vision).toBe(true);
|
||||
expect(openAICaps?.contextWindow).toBe(128000);
|
||||
});
|
||||
|
||||
it('should allow registering custom capabilities', () => {
|
||||
const customCaps: ProviderCapabilities = {
|
||||
streaming: false,
|
||||
functionCalling: false,
|
||||
vision: false,
|
||||
contextWindow: 2048,
|
||||
maxOutputTokens: 512,
|
||||
supportsSystemPrompt: false,
|
||||
supportsTools: false,
|
||||
supportedModalities: ['text'],
|
||||
customEndpoints: true,
|
||||
batchProcessing: false
|
||||
};
|
||||
|
||||
factory.registerCapabilities(ProviderType.CUSTOM, customCaps);
|
||||
|
||||
const retrieved = factory.getCapabilities(ProviderType.CUSTOM);
|
||||
expect(retrieved).toEqual(customCaps);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Health Checks', () => {
|
||||
it('should perform health check on provider', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockResolvedValue({
|
||||
content: 'Hi',
|
||||
role: 'assistant'
|
||||
})
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
|
||||
expect(health.provider).toBe(ProviderType.OPENAI);
|
||||
expect(health.healthy).toBe(true);
|
||||
expect(health.lastChecked).toBeInstanceOf(Date);
|
||||
expect(health.latency).toBeDefined();
|
||||
});
|
||||
|
||||
it('should report unhealthy provider on error', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockRejectedValue(new Error('API Error'))
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
|
||||
expect(health.provider).toBe(ProviderType.OPENAI);
|
||||
expect(health.healthy).toBe(false);
|
||||
expect(health.error).toBe('API Error');
|
||||
});
|
||||
|
||||
it('should store health status', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn().mockResolvedValue({
|
||||
content: 'Hi',
|
||||
role: 'assistant'
|
||||
})
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
|
||||
const status = factory.getHealthStatus(ProviderType.OPENAI);
|
||||
expect(status).toBeDefined();
|
||||
expect(status?.healthy).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Mechanism', () => {
|
||||
it('should fallback to alternative provider on failure', async () => {
|
||||
// Create factory with fallback enabled
|
||||
const fallbackFactory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableFallback: true,
|
||||
fallbackProviders: [ProviderType.OLLAMA],
|
||||
enableCaching: false
|
||||
});
|
||||
|
||||
// Mock OpenAI to fail
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
throw new Error('OpenAI unavailable');
|
||||
});
|
||||
|
||||
// Mock Ollama to succeed
|
||||
const mockOllamaService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn()
|
||||
};
|
||||
(OllamaService as any).mockImplementation(() => mockOllamaService);
|
||||
|
||||
// Should fallback to Ollama
|
||||
const service = await fallbackFactory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
expect(service).toBeDefined();
|
||||
expect(OllamaService).toHaveBeenCalled();
|
||||
|
||||
fallbackFactory.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Statistics', () => {
|
||||
it('should track usage statistics', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
generateChatCompletion: vi.fn()
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
// Create providers
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
await factory.createProvider(ProviderType.OPENAI); // Uses cache
|
||||
|
||||
const stats = factory.getStatistics();
|
||||
|
||||
expect(stats.cachedProviders).toBe(1);
|
||||
expect(stats.totalUsage).toBe(2); // Created once, used twice
|
||||
expect(stats.providerUsage['openai']).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Cache Management', () => {
|
||||
it('should clear all cached providers', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
dispose: vi.fn()
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
(AnthropicService as any).mockImplementation(() => mockService);
|
||||
|
||||
// Create multiple providers
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
await factory.createProvider(ProviderType.ANTHROPIC);
|
||||
|
||||
const statsBefore = factory.getStatistics();
|
||||
expect(statsBefore.cachedProviders).toBe(2);
|
||||
|
||||
factory.clearCache();
|
||||
|
||||
const statsAfter = factory.getStatistics();
|
||||
expect(statsAfter.cachedProviders).toBe(0);
|
||||
expect(mockService.dispose).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should cleanup expired cache entries', async () => {
|
||||
const mockService = {
|
||||
isAvailable: vi.fn().mockReturnValue(true),
|
||||
dispose: vi.fn()
|
||||
};
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mockService);
|
||||
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
// Wait for cache to expire
|
||||
await new Promise(resolve => setTimeout(resolve, 1100));
|
||||
|
||||
factory.cleanupExpiredCache();
|
||||
|
||||
const stats = factory.getStatistics();
|
||||
expect(stats.cachedProviders).toBe(0);
|
||||
expect(mockService.dispose).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Singleton Pattern', () => {
|
||||
it('should return same instance via getInstance', () => {
|
||||
const instance1 = ProviderFactory.getInstance();
|
||||
const instance2 = ProviderFactory.getInstance();
|
||||
|
||||
expect(instance1).toBe(instance2);
|
||||
|
||||
instance1.dispose();
|
||||
});
|
||||
|
||||
it('should create new instance after disposal', () => {
|
||||
const instance1 = ProviderFactory.getInstance();
|
||||
instance1.dispose();
|
||||
|
||||
const instance2 = ProviderFactory.getInstance();
|
||||
|
||||
expect(instance1).not.toBe(instance2);
|
||||
|
||||
instance2.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle provider creation errors gracefully', async () => {
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
throw new Error('Constructor error');
|
||||
});
|
||||
|
||||
await expect(factory.createProvider(ProviderType.OPENAI))
|
||||
.rejects.toThrow('Constructor error');
|
||||
});
|
||||
|
||||
it('should throw error when factory is disposed', async () => {
|
||||
factory.dispose();
|
||||
|
||||
await expect(factory.createProvider(ProviderType.OPENAI))
|
||||
.rejects.toThrow('ProviderFactory has been disposed');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getProviderFactory Helper', () => {
|
||||
it('should return factory instance', () => {
|
||||
const factory = getProviderFactory();
|
||||
|
||||
expect(factory).toBeInstanceOf(ProviderFactory);
|
||||
|
||||
factory.dispose();
|
||||
});
|
||||
|
||||
it('should pass options to factory', () => {
|
||||
const factory = getProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableMetrics: false
|
||||
});
|
||||
|
||||
expect(factory).toBeInstanceOf(ProviderFactory);
|
||||
|
||||
factory.dispose();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,418 @@
|
||||
/**
|
||||
* Provider Integration Tests
|
||||
*
|
||||
* Integration tests for provider factory with AI Service Manager
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { ProviderFactory, ProviderType } from '../provider_factory.js';
|
||||
import {
|
||||
MockProviderFactory,
|
||||
MockProvider,
|
||||
createMockProvider,
|
||||
createMockStream
|
||||
} from './mock_providers.js';
|
||||
import type { AIService, ChatCompletionOptions } from '../../ai_interface.js';
|
||||
import {
|
||||
UnifiedStreamChunk,
|
||||
StreamAggregator,
|
||||
createStreamHandler
|
||||
} from '../unified_stream_handler.js';
|
||||
|
||||
// Mock the actual provider imports
|
||||
vi.mock('../openai_service.js', () => ({
|
||||
OpenAIService: vi.fn()
|
||||
}));
|
||||
vi.mock('../anthropic_service.js', () => ({
|
||||
AnthropicService: vi.fn()
|
||||
}));
|
||||
vi.mock('../ollama_service.js', () => ({
|
||||
OllamaService: vi.fn()
|
||||
}));
|
||||
|
||||
// Import mocked modules
|
||||
import { OpenAIService } from '../openai_service.js';
|
||||
import { AnthropicService } from '../anthropic_service.js';
|
||||
import { OllamaService } from '../ollama_service.js';
|
||||
|
||||
describe('Provider Factory Integration', () => {
|
||||
let factory: ProviderFactory;
|
||||
let mockFactory: MockProviderFactory;
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear singleton
|
||||
const existing = ProviderFactory.getInstance();
|
||||
if (existing) {
|
||||
existing.dispose();
|
||||
}
|
||||
|
||||
factory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableMetrics: true,
|
||||
cacheTimeout: 5000
|
||||
});
|
||||
|
||||
mockFactory = new MockProviderFactory();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
factory.dispose();
|
||||
mockFactory.disposeAll();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Multi-Provider Management', () => {
|
||||
it('should manage multiple providers simultaneously', async () => {
|
||||
// Setup mock providers
|
||||
const openaiMock = mockFactory.createProvider('openai');
|
||||
const anthropicMock = mockFactory.createProvider('anthropic');
|
||||
const ollamaMock = mockFactory.createProvider('ollama');
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => openaiMock);
|
||||
(AnthropicService as any).mockImplementation(() => anthropicMock);
|
||||
(OllamaService as any).mockImplementation(() => ollamaMock);
|
||||
|
||||
// Create providers
|
||||
const openai = await factory.createProvider(ProviderType.OPENAI);
|
||||
const anthropic = await factory.createProvider(ProviderType.ANTHROPIC);
|
||||
const ollama = await factory.createProvider(ProviderType.OLLAMA);
|
||||
|
||||
// Test all are available
|
||||
expect(openai.isAvailable()).toBe(true);
|
||||
expect(anthropic.isAvailable()).toBe(true);
|
||||
expect(ollama.isAvailable()).toBe(true);
|
||||
|
||||
// Test statistics
|
||||
const stats = factory.getStatistics();
|
||||
expect(stats.cachedProviders).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle provider-specific configurations', async () => {
|
||||
const customConfig = {
|
||||
baseUrl: 'https://custom.api.endpoint',
|
||||
timeout: 30000
|
||||
};
|
||||
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const provider1 = await factory.createProvider(ProviderType.OPENAI, customConfig);
|
||||
const provider2 = await factory.createProvider(ProviderType.OPENAI); // Different config
|
||||
|
||||
// Should create two separate instances
|
||||
const stats = factory.getStatistics();
|
||||
expect(stats.cachedProviders).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Scenarios', () => {
|
||||
it('should fallback through provider chain on failures', async () => {
|
||||
const failingFactory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableFallback: true,
|
||||
fallbackProviders: [ProviderType.ANTHROPIC, ProviderType.OLLAMA],
|
||||
enableCaching: false
|
||||
});
|
||||
|
||||
// OpenAI fails
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
throw new Error('OpenAI unavailable');
|
||||
});
|
||||
|
||||
// Anthropic fails
|
||||
(AnthropicService as any).mockImplementation(() => {
|
||||
throw new Error('Anthropic unavailable');
|
||||
});
|
||||
|
||||
// Ollama succeeds
|
||||
const ollamaMock = mockFactory.createProvider('ollama');
|
||||
(OllamaService as any).mockImplementation(() => ollamaMock);
|
||||
|
||||
const provider = await failingFactory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider.isAvailable()).toBe(true);
|
||||
expect(OllamaService).toHaveBeenCalled();
|
||||
|
||||
failingFactory.dispose();
|
||||
});
|
||||
|
||||
it('should handle complete fallback failure', async () => {
|
||||
const failingFactory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
enableFallback: true,
|
||||
fallbackProviders: [ProviderType.ANTHROPIC],
|
||||
enableCaching: false
|
||||
});
|
||||
|
||||
// All providers fail
|
||||
(OpenAIService as any).mockImplementation(() => {
|
||||
throw new Error('OpenAI unavailable');
|
||||
});
|
||||
(AnthropicService as any).mockImplementation(() => {
|
||||
throw new Error('Anthropic unavailable');
|
||||
});
|
||||
|
||||
await expect(failingFactory.createProvider(ProviderType.OPENAI))
|
||||
.rejects.toThrow('OpenAI unavailable');
|
||||
|
||||
failingFactory.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Health Monitoring', () => {
|
||||
it('should perform health checks across all providers', async () => {
|
||||
// Setup healthy providers
|
||||
const openaiMock = createMockProvider('success');
|
||||
const anthropicMock = createMockProvider('success');
|
||||
const ollamaMock = createMockProvider('success');
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => openaiMock);
|
||||
(AnthropicService as any).mockImplementation(() => anthropicMock);
|
||||
(OllamaService as any).mockImplementation(() => ollamaMock);
|
||||
|
||||
// Perform health checks
|
||||
const openaiHealth = await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
const anthropicHealth = await factory.checkProviderHealth(ProviderType.ANTHROPIC);
|
||||
const ollamaHealth = await factory.checkProviderHealth(ProviderType.OLLAMA);
|
||||
|
||||
expect(openaiHealth.healthy).toBe(true);
|
||||
expect(anthropicHealth.healthy).toBe(true);
|
||||
expect(ollamaHealth.healthy).toBe(true);
|
||||
|
||||
// Check all statuses
|
||||
const allStatuses = factory.getAllHealthStatuses();
|
||||
expect(allStatuses.size).toBe(3);
|
||||
});
|
||||
|
||||
it('should detect unhealthy providers', async () => {
|
||||
const errorMock = createMockProvider('error');
|
||||
(OpenAIService as any).mockImplementation(() => errorMock);
|
||||
|
||||
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
|
||||
expect(health.healthy).toBe(false);
|
||||
expect(health.error).toBeDefined();
|
||||
});
|
||||
|
||||
it('should measure provider latency', async () => {
|
||||
const slowMock = createMockProvider('slow');
|
||||
slowMock.setResponseDelay(100);
|
||||
(OpenAIService as any).mockImplementation(() => slowMock);
|
||||
|
||||
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
|
||||
|
||||
expect(health.latency).toBeGreaterThan(100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Streaming Integration', () => {
|
||||
it('should handle streaming across providers', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const provider = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
const messages = [{ role: 'user' as const, content: 'Hello' }];
|
||||
const chunks: string[] = [];
|
||||
|
||||
const response = await provider.generateChatCompletion(messages, {
|
||||
stream: true,
|
||||
streamCallback: async (chunk, isDone) => {
|
||||
if (!isDone) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(response.text).toBe(chunks.join(''));
|
||||
});
|
||||
|
||||
it('should unify streaming formats', async () => {
|
||||
const aggregator = new StreamAggregator();
|
||||
|
||||
// Test OpenAI format
|
||||
const openaiHandler = createStreamHandler({
|
||||
provider: 'openai',
|
||||
onChunk: (chunk) => aggregator.addChunk(chunk)
|
||||
});
|
||||
|
||||
await openaiHandler.processChunk({
|
||||
choices: [{
|
||||
delta: { content: 'Hello from OpenAI' }
|
||||
}]
|
||||
});
|
||||
|
||||
// Test Anthropic format
|
||||
aggregator.reset();
|
||||
const anthropicHandler = createStreamHandler({
|
||||
provider: 'anthropic',
|
||||
onChunk: (chunk) => aggregator.addChunk(chunk)
|
||||
});
|
||||
|
||||
await anthropicHandler.processChunk(
|
||||
'event: content_block_delta\ndata: {"delta":{"type":"text_delta","text":"Hello from Anthropic"}}'
|
||||
);
|
||||
|
||||
// Test Ollama format
|
||||
aggregator.reset();
|
||||
const ollamaHandler = createStreamHandler({
|
||||
provider: 'ollama',
|
||||
onChunk: (chunk) => aggregator.addChunk(chunk)
|
||||
});
|
||||
|
||||
await ollamaHandler.processChunk({
|
||||
message: { content: 'Hello from Ollama' },
|
||||
done: false
|
||||
});
|
||||
|
||||
// All should produce similar unified format
|
||||
const response = aggregator.getResponse();
|
||||
expect(response.text).toContain('Hello from Ollama');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Performance and Caching', () => {
|
||||
it('should cache providers efficiently', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
const startTime = Date.now();
|
||||
|
||||
// First call - creates provider
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
const firstCallTime = Date.now() - startTime;
|
||||
|
||||
// Second call - uses cache
|
||||
const cachedStartTime = Date.now();
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
const cachedCallTime = Date.now() - cachedStartTime;
|
||||
|
||||
// Cached call should be much faster
|
||||
expect(cachedCallTime).toBeLessThan(firstCallTime);
|
||||
expect(OpenAIService).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should track usage statistics', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
// Create and use provider multiple times
|
||||
for (let i = 0; i < 5; i++) {
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
}
|
||||
|
||||
const stats = factory.getStatistics();
|
||||
expect(stats.totalUsage).toBe(5);
|
||||
expect(stats.providerUsage['openai']).toBe(5);
|
||||
});
|
||||
|
||||
it('should cleanup expired cache automatically', async () => {
|
||||
const shortCacheFactory = new ProviderFactory({
|
||||
enableHealthChecks: false,
|
||||
cacheTimeout: 100
|
||||
});
|
||||
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
await shortCacheFactory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
let stats = shortCacheFactory.getStatistics();
|
||||
expect(stats.cachedProviders).toBe(1);
|
||||
|
||||
// Wait for cache to expire
|
||||
await new Promise(resolve => setTimeout(resolve, 150));
|
||||
|
||||
shortCacheFactory.cleanupExpiredCache();
|
||||
|
||||
stats = shortCacheFactory.getStatistics();
|
||||
expect(stats.cachedProviders).toBe(0);
|
||||
|
||||
shortCacheFactory.dispose();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Recovery', () => {
|
||||
it('should recover from transient errors', async () => {
|
||||
const flakyMock = createMockProvider('flaky');
|
||||
(OpenAIService as any).mockImplementation(() => flakyMock);
|
||||
|
||||
const provider = await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
let successCount = 0;
|
||||
let errorCount = 0;
|
||||
|
||||
// Try multiple requests
|
||||
for (let i = 0; i < 10; i++) {
|
||||
try {
|
||||
await provider.generateChatCompletion([
|
||||
{ role: 'user', content: 'Test' }
|
||||
]);
|
||||
successCount++;
|
||||
} catch (error) {
|
||||
errorCount++;
|
||||
}
|
||||
}
|
||||
|
||||
// Should have some successes and some failures
|
||||
expect(successCount).toBeGreaterThan(0);
|
||||
expect(errorCount).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle provider disposal gracefully', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
mock.dispose = vi.fn();
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
await factory.createProvider(ProviderType.OPENAI);
|
||||
|
||||
factory.clearCache();
|
||||
|
||||
expect(mock.dispose).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent Operations', () => {
|
||||
it('should handle concurrent provider creation', async () => {
|
||||
const mock = mockFactory.createProvider('openai');
|
||||
(OpenAIService as any).mockImplementation(() => mock);
|
||||
|
||||
// Create multiple providers concurrently
|
||||
const promises = Array(10).fill(null).map(() =>
|
||||
factory.createProvider(ProviderType.OPENAI)
|
||||
);
|
||||
|
||||
const providers = await Promise.all(promises);
|
||||
|
||||
// All should get the same cached instance
|
||||
const firstProvider = providers[0];
|
||||
expect(providers.every(p => p === firstProvider)).toBe(true);
|
||||
|
||||
// Constructor should only be called once
|
||||
expect(OpenAIService).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle concurrent health checks', async () => {
|
||||
const openaiMock = createMockProvider('success');
|
||||
const anthropicMock = createMockProvider('success');
|
||||
const ollamaMock = createMockProvider('success');
|
||||
|
||||
(OpenAIService as any).mockImplementation(() => openaiMock);
|
||||
(AnthropicService as any).mockImplementation(() => anthropicMock);
|
||||
(OllamaService as any).mockImplementation(() => ollamaMock);
|
||||
|
||||
// Perform health checks concurrently
|
||||
const healthChecks = await Promise.all([
|
||||
factory.checkProviderHealth(ProviderType.OPENAI),
|
||||
factory.checkProviderHealth(ProviderType.ANTHROPIC),
|
||||
factory.checkProviderHealth(ProviderType.OLLAMA)
|
||||
]);
|
||||
|
||||
expect(healthChecks).toHaveLength(3);
|
||||
expect(healthChecks.every(h => h.healthy)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,577 @@
|
||||
/**
|
||||
* Unified Stream Handler Tests
|
||||
*
|
||||
* Test suite for the unified streaming interface
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import {
|
||||
UnifiedStreamChunk,
|
||||
StreamHandlerConfig,
|
||||
OpenAIStreamHandler,
|
||||
AnthropicStreamHandler,
|
||||
OllamaStreamHandler,
|
||||
createStreamHandler,
|
||||
StreamAggregator,
|
||||
unifiedStream
|
||||
} from '../unified_stream_handler.js';
|
||||
import type { ChatResponse } from '../../ai_interface.js';
|
||||
|
||||
vi.mock('../../log.js', () => ({
|
||||
default: {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
describe('OpenAIStreamHandler', () => {
|
||||
let handler: OpenAIStreamHandler;
|
||||
let chunks: UnifiedStreamChunk[];
|
||||
let config: StreamHandlerConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
chunks = [];
|
||||
config = {
|
||||
provider: 'openai',
|
||||
onChunk: (chunk) => { chunks.push(chunk); },
|
||||
onError: vi.fn(),
|
||||
onComplete: vi.fn()
|
||||
};
|
||||
handler = new OpenAIStreamHandler(config);
|
||||
});
|
||||
|
||||
describe('Content Streaming', () => {
|
||||
it('should process content chunks', async () => {
|
||||
const chunk = {
|
||||
choices: [{
|
||||
delta: { content: 'Hello' },
|
||||
index: 0
|
||||
}],
|
||||
model: 'gpt-4'
|
||||
};
|
||||
|
||||
await handler.processChunk(JSON.stringify(chunk));
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'content',
|
||||
content: 'Hello',
|
||||
metadata: {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4'
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple content chunks', async () => {
|
||||
const chunk1 = {
|
||||
choices: [{
|
||||
delta: { content: 'Hello' }
|
||||
}]
|
||||
};
|
||||
const chunk2 = {
|
||||
choices: [{
|
||||
delta: { content: ' World' }
|
||||
}]
|
||||
};
|
||||
|
||||
await handler.processChunk(JSON.stringify(chunk1));
|
||||
await handler.processChunk(JSON.stringify(chunk2));
|
||||
|
||||
expect(chunks).toHaveLength(2);
|
||||
expect(chunks[0].content).toBe('Hello');
|
||||
expect(chunks[1].content).toBe(' World');
|
||||
});
|
||||
|
||||
it('should handle SSE format', async () => {
|
||||
const sseChunk = 'data: {"choices":[{"delta":{"content":"Test"}}]}';
|
||||
|
||||
await handler.processChunk(sseChunk);
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0].content).toBe('Test');
|
||||
});
|
||||
|
||||
it('should handle [DONE] marker', async () => {
|
||||
await handler.processChunk('data: [DONE]');
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0].type).toBe('done');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Calls', () => {
|
||||
it('should process tool call chunks', async () => {
|
||||
const chunk = {
|
||||
choices: [{
|
||||
delta: {
|
||||
tool_calls: [{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: '{"location":'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
await handler.processChunk(JSON.stringify(chunk));
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'tool_call',
|
||||
toolCall: {
|
||||
id: 'call_123',
|
||||
name: 'get_weather',
|
||||
arguments: '{"location":'
|
||||
},
|
||||
metadata: {
|
||||
provider: 'openai'
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should accumulate tool call arguments', async () => {
|
||||
const chunk1 = {
|
||||
choices: [{
|
||||
delta: {
|
||||
tool_calls: [{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
function: {
|
||||
name: 'get_weather',
|
||||
arguments: '{"location":'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
const chunk2 = {
|
||||
choices: [{
|
||||
delta: {
|
||||
tool_calls: [{
|
||||
index: 0,
|
||||
function: {
|
||||
arguments: '"New York"}'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
await handler.processChunk(JSON.stringify(chunk1));
|
||||
await handler.processChunk(JSON.stringify(chunk2));
|
||||
|
||||
expect(chunks).toHaveLength(2);
|
||||
expect(chunks[1].toolCall?.arguments).toBe('{"location":"New York"}');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Completion', () => {
|
||||
it('should handle finish reason', async () => {
|
||||
const chunk = {
|
||||
choices: [{
|
||||
delta: { content: 'Done' },
|
||||
finish_reason: 'stop'
|
||||
}],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15
|
||||
}
|
||||
};
|
||||
|
||||
await handler.processChunk(JSON.stringify(chunk));
|
||||
|
||||
const response = await handler.complete();
|
||||
|
||||
expect(response.text).toBe('Done');
|
||||
// finishReason is not directly on ChatResponse anymore
|
||||
expect(response.usage).toEqual({
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
});
|
||||
});
|
||||
|
||||
it('should call onComplete callback', async () => {
|
||||
await handler.processChunk('data: [DONE]');
|
||||
|
||||
const response = await handler.complete();
|
||||
|
||||
expect(config.onComplete).toHaveBeenCalledWith(response);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle parse errors', async () => {
|
||||
await handler.processChunk('invalid json');
|
||||
|
||||
expect(config.onError).toHaveBeenCalled();
|
||||
expect(chunks.find(c => c.type === 'error')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle timeout', async () => {
|
||||
const timeoutConfig = { ...config, timeout: 100 };
|
||||
const timeoutHandler = new OpenAIStreamHandler(timeoutConfig);
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 150));
|
||||
|
||||
expect(config.onError).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.stringContaining('timeout')
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('AnthropicStreamHandler', () => {
|
||||
let handler: AnthropicStreamHandler;
|
||||
let chunks: UnifiedStreamChunk[];
|
||||
let config: StreamHandlerConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
chunks = [];
|
||||
config = {
|
||||
provider: 'anthropic',
|
||||
onChunk: (chunk) => { chunks.push(chunk); },
|
||||
onError: vi.fn(),
|
||||
onComplete: vi.fn()
|
||||
};
|
||||
handler = new AnthropicStreamHandler(config);
|
||||
});
|
||||
|
||||
describe('Content Streaming', () => {
|
||||
it('should process text delta events', async () => {
|
||||
const event = 'event: content_block_delta\ndata: {"delta":{"type":"text_delta","text":"Hello"},"model":"claude-3"}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'content',
|
||||
content: 'Hello',
|
||||
metadata: {
|
||||
provider: 'anthropic',
|
||||
model: 'claude-3'
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle message start event', async () => {
|
||||
const event = 'event: message_start\ndata: {"message":{"id":"msg_123"}}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
|
||||
// Message start doesn't produce chunks
|
||||
expect(chunks).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle message stop event', async () => {
|
||||
const event = 'event: message_stop\ndata: {}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0].type).toBe('done');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Usage Tracking', () => {
|
||||
it('should track token usage', async () => {
|
||||
const event = 'event: message_delta\ndata: {"usage":{"input_tokens":10,"output_tokens":5}}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
const response = await handler.complete();
|
||||
|
||||
expect(response.usage).toEqual({
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle stop reason', async () => {
|
||||
const event = 'event: message_delta\ndata: {"delta":{"stop_reason":"end_turn"}}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
const response = await handler.complete();
|
||||
|
||||
// finishReason is not directly on ChatResponse anymore
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle error events', async () => {
|
||||
const event = 'event: error\ndata: {"error":{"message":"API Error"}}';
|
||||
|
||||
await handler.processChunk(event);
|
||||
|
||||
expect(config.onError).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: 'API Error'
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('OllamaStreamHandler', () => {
|
||||
let handler: OllamaStreamHandler;
|
||||
let chunks: UnifiedStreamChunk[];
|
||||
let config: StreamHandlerConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
chunks = [];
|
||||
config = {
|
||||
provider: 'ollama',
|
||||
onChunk: (chunk) => { chunks.push(chunk); },
|
||||
onError: vi.fn(),
|
||||
onComplete: vi.fn()
|
||||
};
|
||||
handler = new OllamaStreamHandler(config);
|
||||
});
|
||||
|
||||
describe('Content Streaming', () => {
|
||||
it('should process content chunks', async () => {
|
||||
const chunk = {
|
||||
message: { content: 'Hello' },
|
||||
model: 'llama2',
|
||||
done: false
|
||||
};
|
||||
|
||||
await handler.processChunk(chunk);
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'content',
|
||||
content: 'Hello',
|
||||
metadata: {
|
||||
provider: 'ollama',
|
||||
model: 'llama2'
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle completion', async () => {
|
||||
const chunk = {
|
||||
message: { content: 'Final' },
|
||||
done: true,
|
||||
prompt_eval_count: 10,
|
||||
eval_count: 5
|
||||
};
|
||||
|
||||
await handler.processChunk(chunk);
|
||||
|
||||
expect(chunks).toHaveLength(2);
|
||||
expect(chunks[0].type).toBe('content');
|
||||
expect(chunks[1].type).toBe('done');
|
||||
expect(chunks[1].metadata?.usage).toEqual({
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Calls', () => {
|
||||
it('should process tool calls', async () => {
|
||||
const chunk = {
|
||||
message: {
|
||||
tool_calls: [{
|
||||
id: 'tool_1',
|
||||
function: {
|
||||
name: 'search',
|
||||
arguments: { query: 'test' }
|
||||
}
|
||||
}]
|
||||
},
|
||||
done: false
|
||||
};
|
||||
|
||||
await handler.processChunk(chunk);
|
||||
|
||||
expect(chunks).toHaveLength(1);
|
||||
expect(chunks[0]).toEqual({
|
||||
type: 'tool_call',
|
||||
toolCall: {
|
||||
id: 'tool_1',
|
||||
name: 'search',
|
||||
arguments: '{"query":"test"}'
|
||||
},
|
||||
metadata: {
|
||||
provider: 'ollama'
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('createStreamHandler', () => {
|
||||
it('should create OpenAI handler', () => {
|
||||
const handler = createStreamHandler({
|
||||
provider: 'openai',
|
||||
onChunk: vi.fn()
|
||||
});
|
||||
|
||||
expect(handler).toBeInstanceOf(OpenAIStreamHandler);
|
||||
});
|
||||
|
||||
it('should create Anthropic handler', () => {
|
||||
const handler = createStreamHandler({
|
||||
provider: 'anthropic',
|
||||
onChunk: vi.fn()
|
||||
});
|
||||
|
||||
expect(handler).toBeInstanceOf(AnthropicStreamHandler);
|
||||
});
|
||||
|
||||
it('should create Ollama handler', () => {
|
||||
const handler = createStreamHandler({
|
||||
provider: 'ollama',
|
||||
onChunk: vi.fn()
|
||||
});
|
||||
|
||||
expect(handler).toBeInstanceOf(OllamaStreamHandler);
|
||||
});
|
||||
|
||||
it('should throw for unsupported provider', () => {
|
||||
expect(() => createStreamHandler({
|
||||
provider: 'unsupported' as any,
|
||||
onChunk: vi.fn()
|
||||
})).toThrow('Unsupported provider: unsupported');
|
||||
});
|
||||
});
|
||||
|
||||
describe('StreamAggregator', () => {
|
||||
let aggregator: StreamAggregator;
|
||||
|
||||
beforeEach(() => {
|
||||
aggregator = new StreamAggregator();
|
||||
});
|
||||
|
||||
it('should aggregate content chunks', () => {
|
||||
aggregator.addChunk({
|
||||
type: 'content',
|
||||
content: 'Hello'
|
||||
});
|
||||
aggregator.addChunk({
|
||||
type: 'content',
|
||||
content: ' World'
|
||||
});
|
||||
|
||||
const response = aggregator.getResponse();
|
||||
expect(response.text).toBe('Hello World');
|
||||
});
|
||||
|
||||
it('should aggregate tool calls', () => {
|
||||
aggregator.addChunk({
|
||||
type: 'tool_call',
|
||||
toolCall: {
|
||||
id: '1',
|
||||
name: 'search',
|
||||
arguments: '{}'
|
||||
}
|
||||
});
|
||||
|
||||
const response = aggregator.getResponse();
|
||||
expect(response.tool_calls).toHaveLength(1);
|
||||
expect(response.tool_calls?.[0]).toEqual({
|
||||
id: '1',
|
||||
name: 'search',
|
||||
arguments: '{}'
|
||||
});
|
||||
});
|
||||
|
||||
it('should aggregate metadata', () => {
|
||||
aggregator.addChunk({
|
||||
type: 'done',
|
||||
metadata: {
|
||||
provider: 'openai',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const response = aggregator.getResponse();
|
||||
// finishReason is not directly on ChatResponse anymore
|
||||
expect(response.usage).toEqual({
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
});
|
||||
});
|
||||
|
||||
it('should return all chunks', () => {
|
||||
const chunk1: UnifiedStreamChunk = { type: 'content', content: 'Test' };
|
||||
const chunk2: UnifiedStreamChunk = { type: 'done' };
|
||||
|
||||
aggregator.addChunk(chunk1);
|
||||
aggregator.addChunk(chunk2);
|
||||
|
||||
const chunks = aggregator.getChunks();
|
||||
expect(chunks).toHaveLength(2);
|
||||
expect(chunks[0]).toEqual(chunk1);
|
||||
expect(chunks[1]).toEqual(chunk2);
|
||||
});
|
||||
|
||||
it('should reset state', () => {
|
||||
aggregator.addChunk({ type: 'content', content: 'Test' });
|
||||
aggregator.reset();
|
||||
|
||||
const response = aggregator.getResponse();
|
||||
expect(response.text).toBe('');
|
||||
expect(aggregator.getChunks()).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('unifiedStream', () => {
|
||||
it('should convert async iterable to unified stream', async () => {
|
||||
async function* mockStream() {
|
||||
yield JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: 'Hello' }
|
||||
}]
|
||||
});
|
||||
yield JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: ' World' }
|
||||
}]
|
||||
});
|
||||
yield 'data: [DONE]';
|
||||
}
|
||||
|
||||
const chunks: UnifiedStreamChunk[] = [];
|
||||
|
||||
for await (const chunk of unifiedStream(mockStream(), 'openai')) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks.find(c => c.type === 'content')).toBeDefined();
|
||||
expect(chunks.find(c => c.type === 'done')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle errors in stream', async () => {
|
||||
async function* errorStream() {
|
||||
yield 'invalid json that will cause error';
|
||||
}
|
||||
|
||||
const chunks: UnifiedStreamChunk[] = [];
|
||||
|
||||
for await (const chunk of unifiedStream(errorStream(), 'openai')) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks.find(c => c.type === 'error')).toBeDefined();
|
||||
});
|
||||
});
|
||||
457
apps/server/src/services/llm/providers/circuit_breaker.ts
Normal file
457
apps/server/src/services/llm/providers/circuit_breaker.ts
Normal file
@@ -0,0 +1,457 @@
|
||||
/**
|
||||
* Circuit Breaker Pattern Implementation for LLM Providers
|
||||
*
|
||||
* Implements a circuit breaker to prevent hammering failing providers.
|
||||
* States:
|
||||
* - CLOSED: Normal operation, requests pass through
|
||||
* - OPEN: Provider is failing, requests are rejected immediately
|
||||
* - HALF_OPEN: Testing if provider has recovered
|
||||
*/
|
||||
|
||||
import log from '../../log.js';
|
||||
|
||||
/**
|
||||
* Circuit breaker states
|
||||
*/
|
||||
export enum CircuitState {
|
||||
CLOSED = 'CLOSED',
|
||||
OPEN = 'OPEN',
|
||||
HALF_OPEN = 'HALF_OPEN'
|
||||
}
|
||||
|
||||
/**
|
||||
* Circuit breaker configuration
|
||||
*/
|
||||
export interface CircuitBreakerConfig {
|
||||
/** Number of failures before opening circuit */
|
||||
failureThreshold: number;
|
||||
/** Time window for counting failures (ms) */
|
||||
failureWindow: number;
|
||||
/** Cooldown period before attempting half-open (ms) */
|
||||
cooldownPeriod: number;
|
||||
/** Number of successes in half-open to close circuit */
|
||||
successThreshold: number;
|
||||
/** Request timeout for half-open state (ms) */
|
||||
halfOpenTimeout: number;
|
||||
/** Whether to log state transitions */
|
||||
enableLogging: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Circuit breaker statistics
|
||||
*/
|
||||
export interface CircuitBreakerStats {
|
||||
state: CircuitState;
|
||||
failures: number;
|
||||
successes: number;
|
||||
lastFailureTime?: Date;
|
||||
lastSuccessTime?: Date;
|
||||
lastStateChange: Date;
|
||||
totalRequests: number;
|
||||
rejectedRequests: number;
|
||||
stateHistory: Array<{
|
||||
state: CircuitState;
|
||||
timestamp: Date;
|
||||
reason: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Error type for circuit breaker rejections
|
||||
*/
|
||||
export class CircuitOpenError extends Error {
|
||||
constructor(public readonly providerName: string, public readonly nextRetryTime: Date) {
|
||||
super(`Circuit breaker is OPEN for provider ${providerName}. Will retry after ${nextRetryTime.toISOString()}`);
|
||||
this.name = 'CircuitOpenError';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Circuit Breaker implementation
|
||||
*/
|
||||
export class CircuitBreaker {
|
||||
private state: CircuitState = CircuitState.CLOSED;
|
||||
private failures: number = 0;
|
||||
private successes: number = 0;
|
||||
private failureTimestamps: Date[] = [];
|
||||
private lastStateChangeTime: Date = new Date();
|
||||
private cooldownTimer?: NodeJS.Timeout;
|
||||
private stats: CircuitBreakerStats;
|
||||
private readonly config: CircuitBreakerConfig;
|
||||
|
||||
constructor(
|
||||
private readonly name: string,
|
||||
config?: Partial<CircuitBreakerConfig>
|
||||
) {
|
||||
this.config = {
|
||||
failureThreshold: config?.failureThreshold ?? 5,
|
||||
failureWindow: config?.failureWindow ?? 60000, // 1 minute
|
||||
cooldownPeriod: config?.cooldownPeriod ?? 30000, // 30 seconds
|
||||
successThreshold: config?.successThreshold ?? 2,
|
||||
halfOpenTimeout: config?.halfOpenTimeout ?? 5000, // 5 seconds
|
||||
enableLogging: config?.enableLogging ?? true
|
||||
};
|
||||
|
||||
this.stats = {
|
||||
state: this.state,
|
||||
failures: 0,
|
||||
successes: 0,
|
||||
lastStateChange: this.lastStateChangeTime,
|
||||
totalRequests: 0,
|
||||
rejectedRequests: 0,
|
||||
stateHistory: []
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function with circuit breaker protection
|
||||
*/
|
||||
public async execute<T>(
|
||||
fn: () => Promise<T>,
|
||||
timeout?: number
|
||||
): Promise<T> {
|
||||
this.stats.totalRequests++;
|
||||
|
||||
// Check if circuit is open
|
||||
if (this.state === CircuitState.OPEN) {
|
||||
this.stats.rejectedRequests++;
|
||||
const nextRetryTime = new Date(this.lastStateChangeTime.getTime() + this.config.cooldownPeriod);
|
||||
throw new CircuitOpenError(this.name, nextRetryTime);
|
||||
}
|
||||
|
||||
// Apply timeout for half-open state
|
||||
const executionTimeout = this.state === CircuitState.HALF_OPEN
|
||||
? this.config.halfOpenTimeout
|
||||
: timeout;
|
||||
|
||||
try {
|
||||
const result = await this.executeWithTimeout(fn, executionTimeout);
|
||||
this.onSuccess();
|
||||
return result;
|
||||
} catch (error) {
|
||||
this.onFailure(error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute function with timeout
|
||||
*/
|
||||
private async executeWithTimeout<T>(
|
||||
fn: () => Promise<T>,
|
||||
timeout?: number
|
||||
): Promise<T> {
|
||||
if (!timeout) {
|
||||
return fn();
|
||||
}
|
||||
|
||||
return Promise.race([
|
||||
fn(),
|
||||
new Promise<T>((_, reject) =>
|
||||
setTimeout(() => reject(new Error(`Operation timed out after ${timeout}ms`)), timeout)
|
||||
)
|
||||
]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle successful execution
|
||||
*/
|
||||
private onSuccess(): void {
|
||||
this.successes++;
|
||||
this.stats.successes++;
|
||||
this.stats.lastSuccessTime = new Date();
|
||||
|
||||
switch (this.state) {
|
||||
case CircuitState.HALF_OPEN:
|
||||
if (this.successes >= this.config.successThreshold) {
|
||||
this.transitionTo(CircuitState.CLOSED, 'Success threshold reached');
|
||||
this.reset();
|
||||
}
|
||||
break;
|
||||
|
||||
case CircuitState.CLOSED:
|
||||
// Clear old failure timestamps
|
||||
this.cleanupFailureTimestamps();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle failed execution
|
||||
*/
|
||||
private onFailure(error: any): void {
|
||||
const now = new Date();
|
||||
this.failures++;
|
||||
this.stats.failures++;
|
||||
this.stats.lastFailureTime = now;
|
||||
this.failureTimestamps.push(now);
|
||||
|
||||
switch (this.state) {
|
||||
case CircuitState.HALF_OPEN:
|
||||
// Immediately open on failure in half-open state
|
||||
this.transitionTo(CircuitState.OPEN, `Failure in HALF_OPEN state: ${error.message}`);
|
||||
this.scheduleCooldown();
|
||||
break;
|
||||
|
||||
case CircuitState.CLOSED:
|
||||
// Check if we've exceeded failure threshold
|
||||
this.cleanupFailureTimestamps();
|
||||
if (this.failureTimestamps.length >= this.config.failureThreshold) {
|
||||
this.transitionTo(CircuitState.OPEN, `Failure threshold exceeded: ${this.failures} failures in ${this.config.failureWindow}ms`);
|
||||
this.scheduleCooldown();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old failure timestamps outside the window
|
||||
*/
|
||||
private cleanupFailureTimestamps(): void {
|
||||
const now = Date.now();
|
||||
const windowStart = now - this.config.failureWindow;
|
||||
this.failureTimestamps = this.failureTimestamps.filter(
|
||||
timestamp => timestamp.getTime() > windowStart
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Transition to a new state
|
||||
*/
|
||||
private transitionTo(newState: CircuitState, reason: string): void {
|
||||
const oldState = this.state;
|
||||
this.state = newState;
|
||||
this.lastStateChangeTime = new Date();
|
||||
this.stats.state = newState;
|
||||
this.stats.lastStateChange = this.lastStateChangeTime;
|
||||
|
||||
// Add to state history
|
||||
this.stats.stateHistory.push({
|
||||
state: newState,
|
||||
timestamp: this.lastStateChangeTime,
|
||||
reason
|
||||
});
|
||||
|
||||
// Keep only last 100 state transitions
|
||||
if (this.stats.stateHistory.length > 100) {
|
||||
this.stats.stateHistory = this.stats.stateHistory.slice(-100);
|
||||
}
|
||||
|
||||
if (this.config.enableLogging) {
|
||||
log.info(`[CircuitBreaker:${this.name}] State transition: ${oldState} -> ${newState}. Reason: ${reason}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule cooldown period
|
||||
*/
|
||||
private scheduleCooldown(): void {
|
||||
if (this.cooldownTimer) {
|
||||
clearTimeout(this.cooldownTimer);
|
||||
}
|
||||
|
||||
this.cooldownTimer = setTimeout(() => {
|
||||
if (this.state === CircuitState.OPEN) {
|
||||
this.transitionTo(CircuitState.HALF_OPEN, 'Cooldown period expired');
|
||||
this.successes = 0; // Reset success counter for half-open state
|
||||
}
|
||||
}, this.config.cooldownPeriod);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset counters
|
||||
*/
|
||||
private reset(): void {
|
||||
this.failures = 0;
|
||||
this.successes = 0;
|
||||
this.failureTimestamps = [];
|
||||
|
||||
if (this.cooldownTimer) {
|
||||
clearTimeout(this.cooldownTimer);
|
||||
this.cooldownTimer = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current state
|
||||
*/
|
||||
public getState(): CircuitState {
|
||||
return this.state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get statistics
|
||||
*/
|
||||
public getStats(): CircuitBreakerStats {
|
||||
return { ...this.stats };
|
||||
}
|
||||
|
||||
/**
|
||||
* Force open the circuit (for testing or manual intervention)
|
||||
*/
|
||||
public forceOpen(reason: string = 'Manual intervention'): void {
|
||||
this.transitionTo(CircuitState.OPEN, reason);
|
||||
this.scheduleCooldown();
|
||||
}
|
||||
|
||||
/**
|
||||
* Force close the circuit (for testing or manual intervention)
|
||||
*/
|
||||
public forceClose(reason: string = 'Manual intervention'): void {
|
||||
this.transitionTo(CircuitState.CLOSED, reason);
|
||||
this.reset();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if circuit allows requests
|
||||
*/
|
||||
public isAvailable(): boolean {
|
||||
return this.state !== CircuitState.OPEN;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get time until next retry (if circuit is open)
|
||||
*/
|
||||
public getNextRetryTime(): Date | null {
|
||||
if (this.state !== CircuitState.OPEN) {
|
||||
return null;
|
||||
}
|
||||
return new Date(this.lastStateChangeTime.getTime() + this.config.cooldownPeriod);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup resources
|
||||
*/
|
||||
public dispose(): void {
|
||||
if (this.cooldownTimer) {
|
||||
clearTimeout(this.cooldownTimer);
|
||||
this.cooldownTimer = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Circuit Breaker Manager for managing multiple circuit breakers
|
||||
*/
|
||||
export class CircuitBreakerManager {
|
||||
private static instance: CircuitBreakerManager | null = null;
|
||||
private breakers: Map<string, CircuitBreaker> = new Map();
|
||||
private defaultConfig: Partial<CircuitBreakerConfig>;
|
||||
|
||||
constructor(defaultConfig?: Partial<CircuitBreakerConfig>) {
|
||||
this.defaultConfig = defaultConfig || {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get singleton instance
|
||||
*/
|
||||
public static getInstance(defaultConfig?: Partial<CircuitBreakerConfig>): CircuitBreakerManager {
|
||||
if (!CircuitBreakerManager.instance) {
|
||||
CircuitBreakerManager.instance = new CircuitBreakerManager(defaultConfig);
|
||||
}
|
||||
return CircuitBreakerManager.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create a circuit breaker for a provider
|
||||
*/
|
||||
public getBreaker(
|
||||
providerName: string,
|
||||
config?: Partial<CircuitBreakerConfig>
|
||||
): CircuitBreaker {
|
||||
if (!this.breakers.has(providerName)) {
|
||||
const breakerConfig = { ...this.defaultConfig, ...config };
|
||||
this.breakers.set(providerName, new CircuitBreaker(providerName, breakerConfig));
|
||||
}
|
||||
return this.breakers.get(providerName)!;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute with circuit breaker protection
|
||||
*/
|
||||
public async execute<T>(
|
||||
providerName: string,
|
||||
fn: () => Promise<T>,
|
||||
config?: Partial<CircuitBreakerConfig>
|
||||
): Promise<T> {
|
||||
const breaker = this.getBreaker(providerName, config);
|
||||
return breaker.execute(fn);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all circuit breaker stats
|
||||
*/
|
||||
public getAllStats(): Map<string, CircuitBreakerStats> {
|
||||
const stats = new Map<string, CircuitBreakerStats>();
|
||||
for (const [name, breaker] of this.breakers) {
|
||||
stats.set(name, breaker.getStats());
|
||||
}
|
||||
return stats;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get health summary
|
||||
*/
|
||||
public getHealthSummary(): {
|
||||
total: number;
|
||||
closed: number;
|
||||
open: number;
|
||||
halfOpen: number;
|
||||
availableProviders: string[];
|
||||
unavailableProviders: string[];
|
||||
} {
|
||||
const summary = {
|
||||
total: this.breakers.size,
|
||||
closed: 0,
|
||||
open: 0,
|
||||
halfOpen: 0,
|
||||
availableProviders: [] as string[],
|
||||
unavailableProviders: [] as string[]
|
||||
};
|
||||
|
||||
for (const [name, breaker] of this.breakers) {
|
||||
const state = breaker.getState();
|
||||
switch (state) {
|
||||
case CircuitState.CLOSED:
|
||||
summary.closed++;
|
||||
summary.availableProviders.push(name);
|
||||
break;
|
||||
case CircuitState.OPEN:
|
||||
summary.open++;
|
||||
summary.unavailableProviders.push(name);
|
||||
break;
|
||||
case CircuitState.HALF_OPEN:
|
||||
summary.halfOpen++;
|
||||
summary.availableProviders.push(name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return summary;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset all circuit breakers
|
||||
*/
|
||||
public resetAll(): void {
|
||||
for (const breaker of this.breakers.values()) {
|
||||
breaker.forceClose('Global reset');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose all circuit breakers
|
||||
*/
|
||||
public dispose(): void {
|
||||
for (const breaker of this.breakers.values()) {
|
||||
breaker.dispose();
|
||||
}
|
||||
this.breakers.clear();
|
||||
CircuitBreakerManager.instance = null;
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton getter
|
||||
export const getCircuitBreakerManager = (config?: Partial<CircuitBreakerConfig>): CircuitBreakerManager => {
|
||||
return CircuitBreakerManager.getInstance(config);
|
||||
};
|
||||
770
apps/server/src/services/llm/providers/provider_configuration.ts
Normal file
770
apps/server/src/services/llm/providers/provider_configuration.ts
Normal file
@@ -0,0 +1,770 @@
|
||||
/**
|
||||
* Enhanced Provider Configuration
|
||||
*
|
||||
* Provides advanced configuration options for AI service providers,
|
||||
* including custom endpoints, model detection, and optimization settings.
|
||||
*/
|
||||
|
||||
import log from '../../log.js';
|
||||
import options from '../../options.js';
|
||||
import type { ModelMetadata } from './provider_options.js';
|
||||
|
||||
/**
|
||||
* Provider configuration with enhanced settings
|
||||
*/
|
||||
export interface EnhancedProviderConfig {
|
||||
// Basic settings
|
||||
provider: 'openai' | 'anthropic' | 'ollama' | 'custom';
|
||||
apiKey?: string;
|
||||
baseUrl?: string;
|
||||
|
||||
// Advanced settings
|
||||
customHeaders?: Record<string, string>;
|
||||
timeout?: number;
|
||||
maxRetries?: number;
|
||||
retryDelay?: number;
|
||||
proxy?: string;
|
||||
|
||||
// Model settings
|
||||
defaultModel?: string;
|
||||
availableModels?: string[];
|
||||
modelAliases?: Record<string, string>;
|
||||
|
||||
// Performance settings
|
||||
maxConcurrentRequests?: number;
|
||||
requestQueueSize?: number;
|
||||
rateLimitPerMinute?: number;
|
||||
|
||||
// Feature flags
|
||||
enableStreaming?: boolean;
|
||||
enableTools?: boolean;
|
||||
enableVision?: boolean;
|
||||
enableCaching?: boolean;
|
||||
|
||||
// Custom endpoints
|
||||
endpoints?: {
|
||||
chat?: string;
|
||||
completions?: string;
|
||||
embeddings?: string;
|
||||
models?: string;
|
||||
health?: string;
|
||||
};
|
||||
|
||||
// Optimization settings
|
||||
optimization?: {
|
||||
batchSize?: number;
|
||||
cacheTimeout?: number;
|
||||
compressionEnabled?: boolean;
|
||||
connectionPoolSize?: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Model information with detailed capabilities
|
||||
*/
|
||||
export interface ModelInfo {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
contextWindow: number;
|
||||
maxOutputTokens: number;
|
||||
supportedModalities: string[];
|
||||
costPerMillion?: {
|
||||
input: number;
|
||||
output: number;
|
||||
};
|
||||
capabilities: {
|
||||
chat: boolean;
|
||||
completion: boolean;
|
||||
embedding: boolean;
|
||||
functionCalling: boolean;
|
||||
vision: boolean;
|
||||
audio: boolean;
|
||||
streaming: boolean;
|
||||
};
|
||||
performance?: {
|
||||
averageLatency?: number;
|
||||
tokensPerSecond?: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider configuration manager
|
||||
*/
|
||||
export class ProviderConfigurationManager {
|
||||
private configs: Map<string, EnhancedProviderConfig> = new Map();
|
||||
private modelRegistry: Map<string, ModelInfo> = new Map();
|
||||
private modelCache: Map<string, ModelInfo[]> = new Map();
|
||||
private lastModelFetch: Map<string, number> = new Map();
|
||||
private readonly MODEL_CACHE_TTL = 3600000; // 1 hour
|
||||
|
||||
constructor() {
|
||||
this.initializeDefaultConfigs();
|
||||
this.initializeModelRegistry();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize default provider configurations
|
||||
*/
|
||||
private initializeDefaultConfigs(): void {
|
||||
// OpenAI configuration
|
||||
this.configs.set('openai', {
|
||||
provider: 'openai',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 3,
|
||||
retryDelay: 1000,
|
||||
enableStreaming: true,
|
||||
enableTools: true,
|
||||
enableVision: true,
|
||||
enableCaching: true,
|
||||
endpoints: {
|
||||
chat: '/chat/completions',
|
||||
completions: '/completions',
|
||||
embeddings: '/embeddings',
|
||||
models: '/models'
|
||||
},
|
||||
optimization: {
|
||||
batchSize: 10,
|
||||
cacheTimeout: 300000,
|
||||
compressionEnabled: true,
|
||||
connectionPoolSize: 10
|
||||
}
|
||||
});
|
||||
|
||||
// Anthropic configuration
|
||||
this.configs.set('anthropic', {
|
||||
provider: 'anthropic',
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
timeout: 60000,
|
||||
maxRetries: 3,
|
||||
retryDelay: 1000,
|
||||
enableStreaming: true,
|
||||
enableTools: true,
|
||||
enableVision: true,
|
||||
enableCaching: true,
|
||||
endpoints: {
|
||||
chat: '/v1/messages'
|
||||
},
|
||||
optimization: {
|
||||
batchSize: 5,
|
||||
cacheTimeout: 300000,
|
||||
compressionEnabled: true,
|
||||
connectionPoolSize: 5
|
||||
}
|
||||
});
|
||||
|
||||
// Ollama configuration
|
||||
this.configs.set('ollama', {
|
||||
provider: 'ollama',
|
||||
baseUrl: 'http://localhost:11434',
|
||||
timeout: 120000, // Longer timeout for local models
|
||||
maxRetries: 2,
|
||||
retryDelay: 500,
|
||||
enableStreaming: true,
|
||||
enableTools: true,
|
||||
enableVision: false,
|
||||
enableCaching: true,
|
||||
endpoints: {
|
||||
chat: '/api/chat',
|
||||
models: '/api/tags'
|
||||
},
|
||||
optimization: {
|
||||
batchSize: 1, // Local processing, no batching
|
||||
cacheTimeout: 600000,
|
||||
compressionEnabled: false,
|
||||
connectionPoolSize: 2
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize model registry with known models
|
||||
*/
|
||||
private initializeModelRegistry(): void {
|
||||
// OpenAI models
|
||||
this.registerModel({
|
||||
id: 'gpt-4-turbo-preview',
|
||||
name: 'GPT-4 Turbo',
|
||||
provider: 'openai',
|
||||
contextWindow: 128000,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text', 'image'],
|
||||
costPerMillion: {
|
||||
input: 10,
|
||||
output: 30
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
this.registerModel({
|
||||
id: 'gpt-4o',
|
||||
name: 'GPT-4 Omni',
|
||||
provider: 'openai',
|
||||
contextWindow: 128000,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text', 'image', 'audio'],
|
||||
costPerMillion: {
|
||||
input: 5,
|
||||
output: 15
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
audio: true,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
this.registerModel({
|
||||
id: 'gpt-3.5-turbo',
|
||||
name: 'GPT-3.5 Turbo',
|
||||
provider: 'openai',
|
||||
contextWindow: 16385,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text'],
|
||||
costPerMillion: {
|
||||
input: 0.5,
|
||||
output: 1.5
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: false,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
// Anthropic models
|
||||
this.registerModel({
|
||||
id: 'claude-3-opus-20240229',
|
||||
name: 'Claude 3 Opus',
|
||||
provider: 'anthropic',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text', 'image'],
|
||||
costPerMillion: {
|
||||
input: 15,
|
||||
output: 75
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
this.registerModel({
|
||||
id: 'claude-3-sonnet-20240229',
|
||||
name: 'Claude 3 Sonnet',
|
||||
provider: 'anthropic',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text', 'image'],
|
||||
costPerMillion: {
|
||||
input: 3,
|
||||
output: 15
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
this.registerModel({
|
||||
id: 'claude-3-haiku-20240307',
|
||||
name: 'Claude 3 Haiku',
|
||||
provider: 'anthropic',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text', 'image'],
|
||||
costPerMillion: {
|
||||
input: 0.25,
|
||||
output: 1.25
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: false,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
// Common Ollama models (defaults, actual specs depend on local models)
|
||||
this.registerModel({
|
||||
id: 'llama3',
|
||||
name: 'Llama 3',
|
||||
provider: 'ollama',
|
||||
contextWindow: 8192,
|
||||
maxOutputTokens: 2048,
|
||||
supportedModalities: ['text'],
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: true,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: false,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
|
||||
this.registerModel({
|
||||
id: 'mixtral',
|
||||
name: 'Mixtral',
|
||||
provider: 'ollama',
|
||||
contextWindow: 32768,
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text'],
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: true,
|
||||
embedding: false,
|
||||
functionCalling: true,
|
||||
vision: false,
|
||||
audio: false,
|
||||
streaming: true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a model in the registry
|
||||
*/
|
||||
public registerModel(model: ModelInfo): void {
|
||||
this.modelRegistry.set(model.id, model);
|
||||
|
||||
// Also register by provider
|
||||
const providerModels = this.modelCache.get(model.provider) || [];
|
||||
if (!providerModels.some(m => m.id === model.id)) {
|
||||
providerModels.push(model);
|
||||
this.modelCache.set(model.provider, providerModels);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get configuration for a provider
|
||||
*/
|
||||
public getProviderConfig(provider: string): EnhancedProviderConfig | undefined {
|
||||
// First check if we have a stored config
|
||||
let config = this.configs.get(provider);
|
||||
|
||||
if (!config) {
|
||||
// Try to build config from options
|
||||
config = this.buildConfigFromOptions(provider);
|
||||
if (config) {
|
||||
this.configs.set(provider, config);
|
||||
}
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build configuration from Trilium options
|
||||
*/
|
||||
private buildConfigFromOptions(provider: string): EnhancedProviderConfig | undefined {
|
||||
switch (provider) {
|
||||
case 'openai': {
|
||||
const apiKey = options.getOption('openaiApiKey');
|
||||
const baseUrl = options.getOption('openaiBaseUrl');
|
||||
const defaultModel = options.getOption('openaiDefaultModel');
|
||||
|
||||
if (!apiKey && !baseUrl) return undefined;
|
||||
|
||||
return {
|
||||
...this.configs.get('openai')!,
|
||||
apiKey,
|
||||
baseUrl: baseUrl || this.configs.get('openai')!.baseUrl,
|
||||
defaultModel
|
||||
};
|
||||
}
|
||||
|
||||
case 'anthropic': {
|
||||
const apiKey = options.getOption('anthropicApiKey');
|
||||
const baseUrl = options.getOption('anthropicBaseUrl');
|
||||
const defaultModel = options.getOption('anthropicDefaultModel');
|
||||
|
||||
if (!apiKey) return undefined;
|
||||
|
||||
return {
|
||||
...this.configs.get('anthropic')!,
|
||||
apiKey,
|
||||
baseUrl: baseUrl || this.configs.get('anthropic')!.baseUrl,
|
||||
defaultModel
|
||||
};
|
||||
}
|
||||
|
||||
case 'ollama': {
|
||||
const baseUrl = options.getOption('ollamaBaseUrl');
|
||||
const defaultModel = options.getOption('ollamaDefaultModel');
|
||||
|
||||
if (!baseUrl) return undefined;
|
||||
|
||||
return {
|
||||
...this.configs.get('ollama')!,
|
||||
baseUrl,
|
||||
defaultModel
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update provider configuration
|
||||
*/
|
||||
public updateProviderConfig(provider: string, config: Partial<EnhancedProviderConfig>): void {
|
||||
const existing = this.getProviderConfig(provider) || { provider: provider as any };
|
||||
this.configs.set(provider, { ...existing, ...config });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get available models for a provider
|
||||
*/
|
||||
public async getAvailableModels(provider: string): Promise<ModelInfo[]> {
|
||||
// Check cache first
|
||||
const cached = this.modelCache.get(provider);
|
||||
const lastFetch = this.lastModelFetch.get(provider) || 0;
|
||||
|
||||
if (cached && Date.now() - lastFetch < this.MODEL_CACHE_TTL) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
// Try to fetch fresh model list
|
||||
try {
|
||||
const models = await this.fetchProviderModels(provider);
|
||||
this.modelCache.set(provider, models);
|
||||
this.lastModelFetch.set(provider, Date.now());
|
||||
return models;
|
||||
} catch (error) {
|
||||
log.info(`Failed to fetch models for ${provider}: ${error}`);
|
||||
|
||||
// Return cached if available, otherwise registry models
|
||||
return cached || Array.from(this.modelRegistry.values())
|
||||
.filter(m => m.provider === provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch models from provider API
|
||||
*/
|
||||
private async fetchProviderModels(provider: string): Promise<ModelInfo[]> {
|
||||
const config = this.getProviderConfig(provider);
|
||||
if (!config) {
|
||||
throw new Error(`No configuration for provider: ${provider}`);
|
||||
}
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return this.fetchOpenAIModels(config);
|
||||
|
||||
case 'ollama':
|
||||
return this.fetchOllamaModels(config);
|
||||
|
||||
case 'anthropic':
|
||||
// Anthropic doesn't have a models endpoint, use registry
|
||||
return Array.from(this.modelRegistry.values())
|
||||
.filter(m => m.provider === 'anthropic');
|
||||
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch OpenAI models
|
||||
*/
|
||||
private async fetchOpenAIModels(config: EnhancedProviderConfig): Promise<ModelInfo[]> {
|
||||
try {
|
||||
const url = `${config.baseUrl}${config.endpoints?.models || '/models'}`;
|
||||
const response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${config.apiKey}`,
|
||||
...config.customHeaders
|
||||
}
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
return data.data.map((model: any) => {
|
||||
// Check if we have detailed info in registry
|
||||
const registered = this.modelRegistry.get(model.id);
|
||||
if (registered) {
|
||||
return registered;
|
||||
}
|
||||
|
||||
// Create basic model info
|
||||
return {
|
||||
id: model.id,
|
||||
name: model.id,
|
||||
provider: 'openai',
|
||||
contextWindow: 4096, // Default
|
||||
maxOutputTokens: 4096,
|
||||
supportedModalities: ['text'],
|
||||
capabilities: {
|
||||
chat: model.id.includes('gpt'),
|
||||
completion: !model.id.includes('gpt'),
|
||||
embedding: model.id.includes('embedding'),
|
||||
functionCalling: model.id.includes('gpt'),
|
||||
vision: model.id.includes('vision'),
|
||||
audio: model.id.includes('whisper'),
|
||||
streaming: true
|
||||
}
|
||||
} as ModelInfo;
|
||||
});
|
||||
} catch (error) {
|
||||
log.error(`Failed to fetch OpenAI models: ${error}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch Ollama models
|
||||
*/
|
||||
private async fetchOllamaModels(config: EnhancedProviderConfig): Promise<ModelInfo[]> {
|
||||
try {
|
||||
const url = `${config.baseUrl}${config.endpoints?.models || '/api/tags'}`;
|
||||
const response = await fetch(url);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
return data.models.map((model: any) => {
|
||||
// Check if we have detailed info in registry
|
||||
const registered = this.modelRegistry.get(model.name);
|
||||
if (registered) {
|
||||
return registered;
|
||||
}
|
||||
|
||||
// Create basic model info from Ollama data
|
||||
return {
|
||||
id: model.name,
|
||||
name: model.name,
|
||||
provider: 'ollama',
|
||||
contextWindow: model.details?.parameter_size || 4096,
|
||||
maxOutputTokens: 2048,
|
||||
supportedModalities: ['text'],
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completion: true,
|
||||
embedding: model.name.includes('embed'),
|
||||
functionCalling: true,
|
||||
vision: model.name.includes('vision') || model.name.includes('llava'),
|
||||
audio: false,
|
||||
streaming: true
|
||||
},
|
||||
performance: {
|
||||
tokensPerSecond: model.details?.tokens_per_second
|
||||
}
|
||||
} as ModelInfo;
|
||||
});
|
||||
} catch (error) {
|
||||
log.error(`Failed to fetch Ollama models: ${error}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model information
|
||||
*/
|
||||
public getModelInfo(modelId: string): ModelInfo | undefined {
|
||||
return this.modelRegistry.get(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect best model for a use case
|
||||
*/
|
||||
public detectBestModel(
|
||||
provider: string,
|
||||
requirements: {
|
||||
minContextWindow?: number;
|
||||
needsVision?: boolean;
|
||||
needsTools?: boolean;
|
||||
maxCostPerMillion?: number;
|
||||
preferFast?: boolean;
|
||||
}
|
||||
): ModelInfo | undefined {
|
||||
const models = Array.from(this.modelRegistry.values())
|
||||
.filter(m => m.provider === provider);
|
||||
|
||||
// Filter by requirements
|
||||
let candidates = models.filter(m => {
|
||||
if (requirements.minContextWindow && m.contextWindow < requirements.minContextWindow) {
|
||||
return false;
|
||||
}
|
||||
if (requirements.needsVision && !m.capabilities.vision) {
|
||||
return false;
|
||||
}
|
||||
if (requirements.needsTools && !m.capabilities.functionCalling) {
|
||||
return false;
|
||||
}
|
||||
if (requirements.maxCostPerMillion && m.costPerMillion) {
|
||||
const avgCost = (m.costPerMillion.input + m.costPerMillion.output) / 2;
|
||||
if (avgCost > requirements.maxCostPerMillion) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
if (candidates.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Sort by preference
|
||||
if (requirements.preferFast) {
|
||||
// Prefer smaller, faster models
|
||||
candidates.sort((a, b) => {
|
||||
const costA = a.costPerMillion ? (a.costPerMillion.input + a.costPerMillion.output) / 2 : 1000;
|
||||
const costB = b.costPerMillion ? (b.costPerMillion.input + b.costPerMillion.output) / 2 : 1000;
|
||||
return costA - costB;
|
||||
});
|
||||
} else {
|
||||
// Prefer more capable models
|
||||
candidates.sort((a, b) => b.contextWindow - a.contextWindow);
|
||||
}
|
||||
|
||||
return candidates[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate provider configuration
|
||||
*/
|
||||
public validateConfig(config: EnhancedProviderConfig): {
|
||||
valid: boolean;
|
||||
errors: string[];
|
||||
warnings: string[];
|
||||
} {
|
||||
const errors: string[] = [];
|
||||
const warnings: string[] = [];
|
||||
|
||||
// Check required fields
|
||||
if (!config.provider) {
|
||||
errors.push('Provider type is required');
|
||||
}
|
||||
|
||||
// Provider-specific validation
|
||||
switch (config.provider) {
|
||||
case 'openai':
|
||||
case 'anthropic':
|
||||
if (!config.apiKey && !config.baseUrl?.includes('localhost')) {
|
||||
errors.push('API key is required for cloud providers');
|
||||
}
|
||||
break;
|
||||
|
||||
case 'ollama':
|
||||
if (!config.baseUrl) {
|
||||
errors.push('Base URL is required for Ollama');
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Validate URLs
|
||||
if (config.baseUrl) {
|
||||
try {
|
||||
new URL(config.baseUrl);
|
||||
} catch {
|
||||
errors.push('Invalid base URL format');
|
||||
}
|
||||
}
|
||||
|
||||
// Validate timeout
|
||||
if (config.timeout && config.timeout < 1000) {
|
||||
warnings.push('Timeout less than 1 second may cause issues');
|
||||
}
|
||||
|
||||
// Validate rate limits
|
||||
if (config.rateLimitPerMinute && config.rateLimitPerMinute < 1) {
|
||||
errors.push('Rate limit must be at least 1 request per minute');
|
||||
}
|
||||
|
||||
return {
|
||||
valid: errors.length === 0,
|
||||
errors,
|
||||
warnings
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Export configuration as JSON
|
||||
*/
|
||||
public exportConfig(provider: string): string {
|
||||
const config = this.getProviderConfig(provider);
|
||||
if (!config) {
|
||||
throw new Error(`No configuration for provider: ${provider}`);
|
||||
}
|
||||
|
||||
// Remove sensitive data
|
||||
const exported = { ...config };
|
||||
if (exported.apiKey) {
|
||||
exported.apiKey = '***REDACTED***';
|
||||
}
|
||||
|
||||
return JSON.stringify(exported, null, 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Import configuration from JSON
|
||||
*/
|
||||
public importConfig(provider: string, json: string): void {
|
||||
try {
|
||||
const config = JSON.parse(json) as EnhancedProviderConfig;
|
||||
|
||||
// Validate before importing
|
||||
const validation = this.validateConfig(config);
|
||||
if (!validation.valid) {
|
||||
throw new Error(`Invalid configuration: ${validation.errors.join(', ')}`);
|
||||
}
|
||||
|
||||
// Don't import if API key is redacted
|
||||
if (config.apiKey === '***REDACTED***') {
|
||||
delete config.apiKey;
|
||||
}
|
||||
|
||||
this.updateProviderConfig(provider, config);
|
||||
log.info(`Imported configuration for ${provider}`);
|
||||
} catch (error) {
|
||||
log.error(`Failed to import configuration: ${error}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const providerConfigManager = new ProviderConfigurationManager();
|
||||
888
apps/server/src/services/llm/providers/provider_factory.ts
Normal file
888
apps/server/src/services/llm/providers/provider_factory.ts
Normal file
@@ -0,0 +1,888 @@
|
||||
/**
|
||||
* Provider Factory Pattern Implementation
|
||||
*
|
||||
* This module implements a factory pattern for clean provider instantiation,
|
||||
* unified streaming interfaces, capability detection, and provider-specific
|
||||
* feature support.
|
||||
*/
|
||||
|
||||
import log from '../../log.js';
|
||||
import type { AIService, ChatCompletionOptions } from '../ai_interface.js';
|
||||
import { OpenAIService } from './openai_service.js';
|
||||
import { AnthropicService } from './anthropic_service.js';
|
||||
import { OllamaService } from './ollama_service.js';
|
||||
import type {
|
||||
OpenAIOptions,
|
||||
AnthropicOptions,
|
||||
OllamaOptions,
|
||||
ModelMetadata
|
||||
} from './provider_options.js';
|
||||
import {
|
||||
getOpenAIOptions,
|
||||
getAnthropicOptions,
|
||||
getOllamaOptions
|
||||
} from './providers.js';
|
||||
import {
|
||||
CircuitBreakerManager,
|
||||
CircuitOpenError,
|
||||
type CircuitBreakerConfig
|
||||
} from './circuit_breaker.js';
|
||||
import {
|
||||
MetricsExporter,
|
||||
ExportFormat,
|
||||
type ExporterConfig
|
||||
} from '../metrics/metrics_exporter.js';
|
||||
|
||||
/**
|
||||
* Provider type enumeration
|
||||
*/
|
||||
export enum ProviderType {
|
||||
OPENAI = 'openai',
|
||||
ANTHROPIC = 'anthropic',
|
||||
OLLAMA = 'ollama',
|
||||
CUSTOM = 'custom'
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider capabilities interface
|
||||
*/
|
||||
export interface ProviderCapabilities {
|
||||
streaming: boolean;
|
||||
functionCalling: boolean;
|
||||
vision: boolean;
|
||||
contextWindow: number;
|
||||
maxOutputTokens: number;
|
||||
supportsSystemPrompt: boolean;
|
||||
supportsTools: boolean;
|
||||
supportedModalities: string[];
|
||||
customEndpoints: boolean;
|
||||
batchProcessing: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider health status
|
||||
*/
|
||||
export interface ProviderHealthStatus {
|
||||
provider: ProviderType;
|
||||
healthy: boolean;
|
||||
lastChecked: Date;
|
||||
latency?: number;
|
||||
error?: string;
|
||||
version?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider configuration
|
||||
*/
|
||||
export interface ProviderConfig {
|
||||
type: ProviderType;
|
||||
apiKey?: string;
|
||||
baseUrl?: string;
|
||||
timeout?: number;
|
||||
maxRetries?: number;
|
||||
retryDelay?: number;
|
||||
customHeaders?: Record<string, string>;
|
||||
proxy?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory creation options
|
||||
*/
|
||||
export interface ProviderFactoryOptions {
|
||||
enableHealthChecks?: boolean;
|
||||
healthCheckInterval?: number;
|
||||
enableFallback?: boolean;
|
||||
fallbackProviders?: ProviderType[];
|
||||
enableCaching?: boolean;
|
||||
cacheTimeout?: number;
|
||||
enableMetrics?: boolean;
|
||||
enableCircuitBreaker?: boolean;
|
||||
circuitBreakerConfig?: Partial<CircuitBreakerConfig>;
|
||||
metricsExporterConfig?: Partial<ExporterConfig>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider instance with metadata
|
||||
*/
|
||||
interface ProviderInstance {
|
||||
service: AIService;
|
||||
type: ProviderType;
|
||||
capabilities: ProviderCapabilities;
|
||||
config: ProviderConfig;
|
||||
createdAt: Date;
|
||||
lastUsed: Date;
|
||||
usageCount: number;
|
||||
healthStatus?: ProviderHealthStatus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider Factory Class
|
||||
*
|
||||
* Manages creation, caching, and lifecycle of AI service providers
|
||||
*/
|
||||
export class ProviderFactory {
|
||||
private static instance: ProviderFactory | null = null;
|
||||
private providers: Map<string, ProviderInstance> = new Map();
|
||||
private capabilities: Map<ProviderType, ProviderCapabilities> = new Map();
|
||||
private healthStatuses: Map<ProviderType, ProviderHealthStatus> = new Map();
|
||||
private options: ProviderFactoryOptions;
|
||||
private healthCheckTimer?: NodeJS.Timeout;
|
||||
private disposed: boolean = false;
|
||||
private circuitBreakerManager?: CircuitBreakerManager;
|
||||
private metricsExporter?: MetricsExporter;
|
||||
|
||||
constructor(options: ProviderFactoryOptions = {}) {
|
||||
this.options = {
|
||||
enableHealthChecks: options.enableHealthChecks ?? true,
|
||||
healthCheckInterval: options.healthCheckInterval ?? 60000, // 1 minute
|
||||
enableFallback: options.enableFallback ?? true,
|
||||
fallbackProviders: options.fallbackProviders ?? [ProviderType.OLLAMA],
|
||||
enableCaching: options.enableCaching ?? true,
|
||||
cacheTimeout: options.cacheTimeout ?? 300000, // 5 minutes
|
||||
enableMetrics: options.enableMetrics ?? true,
|
||||
enableCircuitBreaker: options.enableCircuitBreaker ?? true,
|
||||
circuitBreakerConfig: options.circuitBreakerConfig,
|
||||
metricsExporterConfig: options.metricsExporterConfig
|
||||
};
|
||||
|
||||
this.initializeCapabilities();
|
||||
|
||||
// Initialize circuit breaker if enabled
|
||||
if (this.options.enableCircuitBreaker) {
|
||||
this.circuitBreakerManager = CircuitBreakerManager.getInstance(
|
||||
this.options.circuitBreakerConfig
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize metrics exporter if enabled
|
||||
if (this.options.enableMetrics) {
|
||||
this.metricsExporter = MetricsExporter.getInstance({
|
||||
enabled: true,
|
||||
...this.options.metricsExporterConfig
|
||||
});
|
||||
}
|
||||
|
||||
if (this.options.enableHealthChecks) {
|
||||
this.startHealthChecks();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get singleton instance
|
||||
*/
|
||||
public static getInstance(options?: ProviderFactoryOptions): ProviderFactory {
|
||||
if (!ProviderFactory.instance) {
|
||||
ProviderFactory.instance = new ProviderFactory(options);
|
||||
}
|
||||
return ProviderFactory.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize provider capabilities registry
|
||||
*/
|
||||
private initializeCapabilities(): void {
|
||||
// OpenAI capabilities
|
||||
this.capabilities.set(ProviderType.OPENAI, {
|
||||
streaming: true,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
contextWindow: 128000, // GPT-4 Turbo
|
||||
maxOutputTokens: 4096,
|
||||
supportsSystemPrompt: true,
|
||||
supportsTools: true,
|
||||
supportedModalities: ['text', 'image'],
|
||||
customEndpoints: true,
|
||||
batchProcessing: true
|
||||
});
|
||||
|
||||
// Anthropic capabilities
|
||||
this.capabilities.set(ProviderType.ANTHROPIC, {
|
||||
streaming: true,
|
||||
functionCalling: true,
|
||||
vision: true,
|
||||
contextWindow: 200000, // Claude 3
|
||||
maxOutputTokens: 4096,
|
||||
supportsSystemPrompt: true,
|
||||
supportsTools: true,
|
||||
supportedModalities: ['text', 'image'],
|
||||
customEndpoints: false,
|
||||
batchProcessing: false
|
||||
});
|
||||
|
||||
// Ollama capabilities (default, can be overridden per model)
|
||||
this.capabilities.set(ProviderType.OLLAMA, {
|
||||
streaming: true,
|
||||
functionCalling: true,
|
||||
vision: false,
|
||||
contextWindow: 8192, // Default, varies by model
|
||||
maxOutputTokens: 2048,
|
||||
supportsSystemPrompt: true,
|
||||
supportsTools: true,
|
||||
supportedModalities: ['text'],
|
||||
customEndpoints: true,
|
||||
batchProcessing: false
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a provider instance
|
||||
*/
|
||||
public async createProvider(
|
||||
type: ProviderType,
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): Promise<AIService> {
|
||||
if (this.disposed) {
|
||||
throw new Error('ProviderFactory has been disposed');
|
||||
}
|
||||
|
||||
const cacheKey = this.getCacheKey(type, config);
|
||||
|
||||
// Check cache if enabled
|
||||
if (this.options.enableCaching) {
|
||||
const cached = this.providers.get(cacheKey);
|
||||
if (cached && this.isInstanceValid(cached)) {
|
||||
cached.lastUsed = new Date();
|
||||
cached.usageCount++;
|
||||
|
||||
if (this.options.enableMetrics) {
|
||||
log.info(`[ProviderFactory] Using cached ${type} provider (usage: ${cached.usageCount})`);
|
||||
}
|
||||
|
||||
return cached.service;
|
||||
}
|
||||
}
|
||||
|
||||
// Create new provider instance
|
||||
const service = await this.instantiateProvider(type, config, options);
|
||||
|
||||
if (!service) {
|
||||
throw new Error(`Failed to create provider of type: ${type}`);
|
||||
}
|
||||
|
||||
// Get capabilities for this provider
|
||||
const capabilities = await this.detectCapabilities(type, service);
|
||||
|
||||
// Create provider instance
|
||||
const instance: ProviderInstance = {
|
||||
service,
|
||||
type,
|
||||
capabilities,
|
||||
config: { type, ...config },
|
||||
createdAt: new Date(),
|
||||
lastUsed: new Date(),
|
||||
usageCount: 1
|
||||
};
|
||||
|
||||
// Cache the instance
|
||||
if (this.options.enableCaching) {
|
||||
this.providers.set(cacheKey, instance);
|
||||
|
||||
// Schedule cache cleanup
|
||||
setTimeout(() => {
|
||||
this.cleanupCache(cacheKey);
|
||||
}, this.options.cacheTimeout);
|
||||
}
|
||||
|
||||
if (this.options.enableMetrics) {
|
||||
log.info(`[ProviderFactory] Created new ${type} provider`);
|
||||
}
|
||||
|
||||
return service;
|
||||
}
|
||||
|
||||
/**
|
||||
* Instantiate a specific provider
|
||||
*/
|
||||
private async instantiateProvider(
|
||||
type: ProviderType,
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): Promise<AIService | null> {
|
||||
const startTime = Date.now();
|
||||
|
||||
try {
|
||||
// Use circuit breaker if enabled
|
||||
if (this.circuitBreakerManager) {
|
||||
const breaker = this.circuitBreakerManager.getBreaker(type);
|
||||
|
||||
// Check if circuit is open
|
||||
if (!breaker.isAvailable()) {
|
||||
const nextRetry = breaker.getNextRetryTime();
|
||||
log.info(`[ProviderFactory] Circuit breaker OPEN for ${type}. Next retry: ${nextRetry?.toISOString()}`);
|
||||
|
||||
// Record metric
|
||||
if (this.metricsExporter) {
|
||||
this.metricsExporter.getCollector().recordError(type, 'Circuit breaker open');
|
||||
}
|
||||
|
||||
// Try fallback immediately
|
||||
if (this.options.enableFallback && this.options.fallbackProviders?.length) {
|
||||
return this.tryFallbackProvider(options);
|
||||
}
|
||||
|
||||
throw new CircuitOpenError(type, nextRetry!);
|
||||
}
|
||||
|
||||
// Execute with circuit breaker protection
|
||||
return await breaker.execute(async () => {
|
||||
const service = await this.createProviderInternal(type, config, options);
|
||||
|
||||
// Record success metric
|
||||
if (this.metricsExporter && service) {
|
||||
const latency = Date.now() - startTime;
|
||||
this.metricsExporter.getCollector().recordLatency(type, latency);
|
||||
this.metricsExporter.getCollector().recordRequest(type, true);
|
||||
}
|
||||
|
||||
return service;
|
||||
});
|
||||
} else {
|
||||
// No circuit breaker, create directly
|
||||
const service = await this.createProviderInternal(type, config, options);
|
||||
|
||||
// Record metrics
|
||||
if (this.metricsExporter && service) {
|
||||
const latency = Date.now() - startTime;
|
||||
this.metricsExporter.getCollector().recordLatency(type, latency);
|
||||
this.metricsExporter.getCollector().recordRequest(type, true);
|
||||
}
|
||||
|
||||
return service;
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.error(`[ProviderFactory] Error creating ${type} provider: ${error.message}`);
|
||||
|
||||
// Record failure metric
|
||||
if (this.metricsExporter) {
|
||||
this.metricsExporter.getCollector().recordRequest(type, false);
|
||||
this.metricsExporter.getCollector().recordError(type, error.message);
|
||||
}
|
||||
|
||||
// Try fallback if enabled and not a circuit breaker error
|
||||
if (!(error instanceof CircuitOpenError) &&
|
||||
this.options.enableFallback &&
|
||||
this.options.fallbackProviders?.length) {
|
||||
return this.tryFallbackProvider(options);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal provider creation logic
|
||||
*/
|
||||
private async createProviderInternal(
|
||||
type: ProviderType,
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): Promise<AIService | null> {
|
||||
switch (type) {
|
||||
case ProviderType.OPENAI:
|
||||
return this.createOpenAIProvider(config, options);
|
||||
|
||||
case ProviderType.ANTHROPIC:
|
||||
return this.createAnthropicProvider(config, options);
|
||||
|
||||
case ProviderType.OLLAMA:
|
||||
return await this.createOllamaProvider(config, options);
|
||||
|
||||
case ProviderType.CUSTOM:
|
||||
return this.createCustomProvider(config, options);
|
||||
|
||||
default:
|
||||
log.error(`[ProviderFactory] Unknown provider type: ${type}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create OpenAI provider
|
||||
*/
|
||||
private createOpenAIProvider(
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): AIService {
|
||||
const service = new OpenAIService();
|
||||
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('OpenAI service is not available');
|
||||
}
|
||||
|
||||
return service;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create Anthropic provider
|
||||
*/
|
||||
private createAnthropicProvider(
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): AIService {
|
||||
const service = new AnthropicService();
|
||||
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('Anthropic service is not available');
|
||||
}
|
||||
|
||||
return service;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create Ollama provider
|
||||
*/
|
||||
private async createOllamaProvider(
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): Promise<AIService> {
|
||||
const service = new OllamaService();
|
||||
|
||||
if (!service.isAvailable()) {
|
||||
throw new Error('Ollama service is not available');
|
||||
}
|
||||
|
||||
// Ollama might need model pulling or other async setup
|
||||
// This is handled internally by the service
|
||||
|
||||
return service;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create custom provider (for future extensibility)
|
||||
*/
|
||||
private createCustomProvider(
|
||||
config?: Partial<ProviderConfig>,
|
||||
options?: ChatCompletionOptions
|
||||
): AIService {
|
||||
throw new Error('Custom providers not yet implemented');
|
||||
}
|
||||
|
||||
/**
|
||||
* Try fallback providers
|
||||
*/
|
||||
private async tryFallbackProvider(options?: ChatCompletionOptions): Promise<AIService | null> {
|
||||
if (!this.options.fallbackProviders) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (const fallbackType of this.options.fallbackProviders) {
|
||||
try {
|
||||
log.info(`[ProviderFactory] Trying fallback provider: ${fallbackType}`);
|
||||
const service = await this.instantiateProvider(fallbackType, undefined, options);
|
||||
|
||||
if (service && service.isAvailable()) {
|
||||
log.info(`[ProviderFactory] Fallback to ${fallbackType} successful`);
|
||||
return service;
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[ProviderFactory] Fallback to ${fallbackType} failed: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect capabilities for a provider
|
||||
*/
|
||||
private async detectCapabilities(
|
||||
type: ProviderType,
|
||||
service: AIService
|
||||
): Promise<ProviderCapabilities> {
|
||||
// Start with default capabilities
|
||||
let capabilities = this.capabilities.get(type) || this.getDefaultCapabilities();
|
||||
|
||||
// Try to detect actual capabilities from the service
|
||||
try {
|
||||
// Check for streaming support
|
||||
if ('supportsStreaming' in service && typeof service.supportsStreaming === 'function') {
|
||||
capabilities.streaming = (service as any).supportsStreaming();
|
||||
}
|
||||
|
||||
// Check for tool support
|
||||
if ('supportsTools' in service && typeof service.supportsTools === 'function') {
|
||||
capabilities.supportsTools = (service as any).supportsTools();
|
||||
}
|
||||
|
||||
// For Ollama, try to get model-specific capabilities
|
||||
if (type === ProviderType.OLLAMA) {
|
||||
capabilities = await this.detectOllamaCapabilities(service, capabilities);
|
||||
}
|
||||
} catch (error) {
|
||||
log.info(`[ProviderFactory] Could not detect capabilities for ${type}: ${error}`);
|
||||
}
|
||||
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect Ollama-specific capabilities
|
||||
*/
|
||||
private async detectOllamaCapabilities(
|
||||
service: AIService,
|
||||
defaultCaps: ProviderCapabilities
|
||||
): Promise<ProviderCapabilities> {
|
||||
// This would query the Ollama API for model info
|
||||
// For now, return defaults
|
||||
return defaultCaps;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default capabilities
|
||||
*/
|
||||
private getDefaultCapabilities(): ProviderCapabilities {
|
||||
return {
|
||||
streaming: true,
|
||||
functionCalling: false,
|
||||
vision: false,
|
||||
contextWindow: 4096,
|
||||
maxOutputTokens: 1024,
|
||||
supportsSystemPrompt: true,
|
||||
supportsTools: false,
|
||||
supportedModalities: ['text'],
|
||||
customEndpoints: false,
|
||||
batchProcessing: false
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform health check on a provider
|
||||
*/
|
||||
public async checkProviderHealth(type: ProviderType): Promise<ProviderHealthStatus> {
|
||||
const startTime = Date.now();
|
||||
|
||||
try {
|
||||
const service = await this.createProvider(type);
|
||||
|
||||
// Try a simple completion to test the service
|
||||
const testMessages = [{ role: 'user' as const, content: 'Hi' }];
|
||||
const response = await service.generateChatCompletion(testMessages, {
|
||||
maxTokens: 1,
|
||||
temperature: 0
|
||||
});
|
||||
|
||||
const latency = Date.now() - startTime;
|
||||
|
||||
const status: ProviderHealthStatus = {
|
||||
provider: type,
|
||||
healthy: true,
|
||||
lastChecked: new Date(),
|
||||
latency
|
||||
};
|
||||
|
||||
this.healthStatuses.set(type, status);
|
||||
return status;
|
||||
} catch (error: any) {
|
||||
const status: ProviderHealthStatus = {
|
||||
provider: type,
|
||||
healthy: false,
|
||||
lastChecked: new Date(),
|
||||
error: error.message || 'Unknown error'
|
||||
};
|
||||
|
||||
this.healthStatuses.set(type, status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start periodic health checks
|
||||
*/
|
||||
private startHealthChecks(): void {
|
||||
if (this.healthCheckTimer) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.healthCheckTimer = setInterval(async () => {
|
||||
if (this.disposed) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const type of this.capabilities.keys()) {
|
||||
try {
|
||||
await this.checkProviderHealth(type);
|
||||
} catch (error) {
|
||||
log.error(`[ProviderFactory] Health check failed for ${type}: ${error}`);
|
||||
}
|
||||
}
|
||||
}, this.options.healthCheckInterval);
|
||||
|
||||
// Perform initial health check
|
||||
this.performInitialHealthCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform initial health check
|
||||
*/
|
||||
private async performInitialHealthCheck(): Promise<void> {
|
||||
for (const type of this.capabilities.keys()) {
|
||||
try {
|
||||
await this.checkProviderHealth(type);
|
||||
} catch (error) {
|
||||
log.error(`[ProviderFactory] Initial health check failed for ${type}: ${error}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get health status for a provider
|
||||
*/
|
||||
public getHealthStatus(type: ProviderType): ProviderHealthStatus | undefined {
|
||||
return this.healthStatuses.get(type);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all health statuses
|
||||
*/
|
||||
public getAllHealthStatuses(): Map<ProviderType, ProviderHealthStatus> {
|
||||
return new Map(this.healthStatuses);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get capabilities for a provider
|
||||
*/
|
||||
public getCapabilities(type: ProviderType): ProviderCapabilities | undefined {
|
||||
return this.capabilities.get(type);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register custom provider capabilities
|
||||
*/
|
||||
public registerCapabilities(type: ProviderType, capabilities: ProviderCapabilities): void {
|
||||
this.capabilities.set(type, capabilities);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cache key for provider
|
||||
*/
|
||||
private getCacheKey(type: ProviderType, config?: Partial<ProviderConfig>): string {
|
||||
const baseKey = type;
|
||||
|
||||
if (config?.baseUrl) {
|
||||
return `${baseKey}:${config.baseUrl}`;
|
||||
}
|
||||
|
||||
return baseKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if cached instance is still valid
|
||||
*/
|
||||
private isInstanceValid(instance: ProviderInstance): boolean {
|
||||
if (!this.options.cacheTimeout) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const age = Date.now() - instance.createdAt.getTime();
|
||||
return age < this.options.cacheTimeout;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup specific cache entry
|
||||
*/
|
||||
private cleanupCache(key: string): void {
|
||||
const instance = this.providers.get(key);
|
||||
|
||||
if (instance && !this.isInstanceValid(instance)) {
|
||||
this.disposeProvider(instance);
|
||||
this.providers.delete(key);
|
||||
|
||||
if (this.options.enableMetrics) {
|
||||
log.info(`[ProviderFactory] Cleaned up cached provider: ${key}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup all expired cache entries
|
||||
*/
|
||||
public cleanupExpiredCache(): void {
|
||||
const keys = Array.from(this.providers.keys());
|
||||
|
||||
for (const key of keys) {
|
||||
this.cleanupCache(key);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose a provider instance
|
||||
*/
|
||||
private disposeProvider(instance: ProviderInstance): void {
|
||||
try {
|
||||
if ('dispose' in instance.service && typeof (instance.service as any).dispose === 'function') {
|
||||
(instance.service as any).dispose();
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[ProviderFactory] Error disposing provider: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get provider statistics
|
||||
*/
|
||||
public getStatistics(): {
|
||||
cachedProviders: number;
|
||||
totalUsage: number;
|
||||
providerUsage: Record<string, number>;
|
||||
healthyProviders: number;
|
||||
unhealthyProviders: number;
|
||||
} {
|
||||
const stats = {
|
||||
cachedProviders: this.providers.size,
|
||||
totalUsage: 0,
|
||||
providerUsage: {} as Record<string, number>,
|
||||
healthyProviders: 0,
|
||||
unhealthyProviders: 0
|
||||
};
|
||||
|
||||
// Calculate usage statistics
|
||||
for (const [key, instance] of this.providers) {
|
||||
stats.totalUsage += instance.usageCount;
|
||||
|
||||
const type = instance.type.toString();
|
||||
stats.providerUsage[type] = (stats.providerUsage[type] || 0) + instance.usageCount;
|
||||
}
|
||||
|
||||
// Calculate health statistics
|
||||
for (const status of this.healthStatuses.values()) {
|
||||
if (status.healthy) {
|
||||
stats.healthyProviders++;
|
||||
} else {
|
||||
stats.unhealthyProviders++;
|
||||
}
|
||||
}
|
||||
|
||||
return stats;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all cached providers
|
||||
*/
|
||||
public clearCache(): void {
|
||||
for (const instance of this.providers.values()) {
|
||||
this.disposeProvider(instance);
|
||||
}
|
||||
|
||||
this.providers.clear();
|
||||
|
||||
if (this.options.enableMetrics) {
|
||||
log.info('[ProviderFactory] Cleared all cached providers');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get circuit breaker status
|
||||
*/
|
||||
public getCircuitBreakerStatus(): any {
|
||||
if (!this.circuitBreakerManager) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
summary: this.circuitBreakerManager.getHealthSummary(),
|
||||
details: Array.from(this.circuitBreakerManager.getAllStats().entries()).map(([name, stats]) => ({
|
||||
provider: name,
|
||||
...stats
|
||||
}))
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get metrics summary
|
||||
*/
|
||||
public getMetricsSummary(): any {
|
||||
if (!this.metricsExporter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const collector = this.metricsExporter.getCollector();
|
||||
return {
|
||||
providers: Array.from(collector.getProviderMetricsMap().values()),
|
||||
system: collector.getSystemMetrics()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Export metrics in specified format
|
||||
*/
|
||||
public exportMetrics(format?: 'prometheus' | 'statsd' | 'opentelemetry' | 'json'): any {
|
||||
if (!this.metricsExporter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const exportFormat = format ? {
|
||||
prometheus: ExportFormat.PROMETHEUS,
|
||||
statsd: ExportFormat.STATSD,
|
||||
opentelemetry: ExportFormat.OPENTELEMETRY,
|
||||
json: ExportFormat.JSON
|
||||
}[format] : undefined;
|
||||
|
||||
return this.metricsExporter.export(exportFormat);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset circuit breaker for a specific provider
|
||||
*/
|
||||
public resetCircuitBreaker(provider: ProviderType): void {
|
||||
if (!this.circuitBreakerManager) {
|
||||
return;
|
||||
}
|
||||
|
||||
const breaker = this.circuitBreakerManager.getBreaker(provider);
|
||||
breaker.forceClose('Manual reset');
|
||||
log.info(`[ProviderFactory] Circuit breaker reset for ${provider}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure metrics export
|
||||
*/
|
||||
public configureMetricsExport(config: Partial<ExporterConfig>): void {
|
||||
if (!this.metricsExporter) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.metricsExporter.updateConfig(config);
|
||||
log.info('[ProviderFactory] Metrics export configuration updated');
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose the factory and cleanup resources
|
||||
*/
|
||||
public dispose(): void {
|
||||
if (this.disposed) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.disposed = true;
|
||||
|
||||
// Stop health checks
|
||||
if (this.healthCheckTimer) {
|
||||
clearInterval(this.healthCheckTimer);
|
||||
this.healthCheckTimer = undefined;
|
||||
}
|
||||
|
||||
// Dispose circuit breaker manager
|
||||
if (this.circuitBreakerManager) {
|
||||
this.circuitBreakerManager.dispose();
|
||||
}
|
||||
|
||||
// Dispose metrics exporter
|
||||
if (this.metricsExporter) {
|
||||
this.metricsExporter.dispose();
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
this.clearCache();
|
||||
|
||||
// Clear singleton instance
|
||||
ProviderFactory.instance = null;
|
||||
|
||||
log.info('[ProviderFactory] Disposed successfully');
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance getter
|
||||
export const getProviderFactory = (options?: ProviderFactoryOptions): ProviderFactory => {
|
||||
return ProviderFactory.getInstance(options);
|
||||
};
|
||||
662
apps/server/src/services/llm/providers/unified_stream_handler.ts
Normal file
662
apps/server/src/services/llm/providers/unified_stream_handler.ts
Normal file
@@ -0,0 +1,662 @@
|
||||
/**
|
||||
* Unified Stream Handler
|
||||
*
|
||||
* Provides a consistent streaming interface across all providers,
|
||||
* handling provider-specific stream formats and normalizing them
|
||||
* into a unified format.
|
||||
*/
|
||||
|
||||
import log from '../../log.js';
|
||||
import type { ChatResponse } from '../ai_interface.js';
|
||||
|
||||
/**
|
||||
* Unified stream chunk format
|
||||
*/
|
||||
export interface UnifiedStreamChunk {
|
||||
type: 'content' | 'tool_call' | 'error' | 'done';
|
||||
content?: string;
|
||||
toolCall?: {
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
error?: string;
|
||||
metadata?: {
|
||||
provider: string;
|
||||
model?: string;
|
||||
finishReason?: string;
|
||||
usage?: {
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
totalTokens?: number;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream handler configuration
|
||||
*/
|
||||
export interface StreamHandlerConfig {
|
||||
provider: 'openai' | 'anthropic' | 'ollama';
|
||||
onChunk: (chunk: UnifiedStreamChunk) => void | Promise<void>;
|
||||
onError?: (error: Error) => void;
|
||||
onComplete?: (response: ChatResponse) => void;
|
||||
bufferSize?: number;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract base class for provider-specific stream handlers
|
||||
*/
|
||||
export abstract class BaseStreamHandler {
|
||||
protected config: StreamHandlerConfig;
|
||||
protected buffer: string = '';
|
||||
protected response: Partial<ChatResponse> = {};
|
||||
protected finishReason?: string;
|
||||
protected isComplete: boolean = false;
|
||||
protected timeoutTimer?: NodeJS.Timeout;
|
||||
|
||||
constructor(config: StreamHandlerConfig) {
|
||||
this.config = config;
|
||||
|
||||
if (config.timeout) {
|
||||
this.setTimeoutTimer(config.timeout);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a stream chunk from the provider
|
||||
*/
|
||||
public abstract processChunk(chunk: any): Promise<void>;
|
||||
|
||||
/**
|
||||
* Complete the stream processing
|
||||
*/
|
||||
public abstract complete(): Promise<ChatResponse>;
|
||||
|
||||
/**
|
||||
* Handle stream error
|
||||
*/
|
||||
public handleError(error: Error): void {
|
||||
this.clearTimeoutTimer();
|
||||
|
||||
if (this.config.onError) {
|
||||
this.config.onError(error);
|
||||
} else {
|
||||
log.error(`[StreamHandler] Stream error: ${error.message}`);
|
||||
}
|
||||
|
||||
// Send error chunk
|
||||
this.sendChunk({
|
||||
type: 'error',
|
||||
error: error.message,
|
||||
metadata: {
|
||||
provider: this.config.provider
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a unified chunk to the consumer
|
||||
*/
|
||||
protected async sendChunk(chunk: UnifiedStreamChunk): Promise<void> {
|
||||
try {
|
||||
await this.config.onChunk(chunk);
|
||||
} catch (error) {
|
||||
log.error(`[StreamHandler] Error in chunk handler: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set timeout timer
|
||||
*/
|
||||
protected setTimeoutTimer(timeout: number): void {
|
||||
this.timeoutTimer = setTimeout(() => {
|
||||
this.handleError(new Error(`Stream timeout after ${timeout}ms`));
|
||||
}, timeout);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear timeout timer
|
||||
*/
|
||||
protected clearTimeoutTimer(): void {
|
||||
if (this.timeoutTimer) {
|
||||
clearTimeout(this.timeoutTimer);
|
||||
this.timeoutTimer = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset timeout timer
|
||||
*/
|
||||
protected resetTimeoutTimer(): void {
|
||||
if (this.config.timeout) {
|
||||
this.clearTimeoutTimer();
|
||||
this.setTimeoutTimer(this.config.timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI stream handler
|
||||
*/
|
||||
export class OpenAIStreamHandler extends BaseStreamHandler {
|
||||
private toolCalls: Map<number, any> = new Map();
|
||||
|
||||
public async processChunk(chunk: any): Promise<void> {
|
||||
this.resetTimeoutTimer();
|
||||
|
||||
try {
|
||||
// Parse SSE format if needed
|
||||
const data = this.parseSSEChunk(chunk);
|
||||
|
||||
if (!data || data === '[DONE]') {
|
||||
await this.sendComplete();
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = typeof data === 'string' ? JSON.parse(data) : data;
|
||||
const choice = parsed.choices?.[0];
|
||||
|
||||
if (!choice) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle content delta
|
||||
if (choice.delta?.content) {
|
||||
this.buffer += choice.delta.content;
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'content',
|
||||
content: choice.delta.content,
|
||||
metadata: {
|
||||
provider: 'openai',
|
||||
model: parsed.model
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (choice.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
await this.processToolCall(toolCall);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stream is done
|
||||
if (choice.finish_reason) {
|
||||
this.finishReason = choice.finish_reason;
|
||||
|
||||
if (parsed.usage) {
|
||||
this.response.usage = {
|
||||
promptTokens: parsed.usage.prompt_tokens,
|
||||
completionTokens: parsed.usage.completion_tokens,
|
||||
totalTokens: parsed.usage.total_tokens
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[OpenAIStreamHandler] Error processing chunk: ${error}`);
|
||||
this.handleError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
private parseSSEChunk(chunk: any): string | null {
|
||||
if (typeof chunk === 'string') {
|
||||
const lines = chunk.split('\n');
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
return line.slice(6);
|
||||
}
|
||||
}
|
||||
}
|
||||
return chunk;
|
||||
}
|
||||
|
||||
private async processToolCall(toolCall: any): Promise<void> {
|
||||
const index = toolCall.index || 0;
|
||||
|
||||
if (!this.toolCalls.has(index)) {
|
||||
this.toolCalls.set(index, {
|
||||
id: toolCall.id || '',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: '',
|
||||
arguments: ''
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const existing = this.toolCalls.get(index)!;
|
||||
|
||||
if (toolCall.id) {
|
||||
existing.id = toolCall.id;
|
||||
}
|
||||
|
||||
if (toolCall.function?.name) {
|
||||
existing.function.name = toolCall.function.name;
|
||||
}
|
||||
|
||||
if (toolCall.function?.arguments) {
|
||||
existing.function.arguments += toolCall.function.arguments;
|
||||
}
|
||||
|
||||
// Send tool call chunk
|
||||
await this.sendChunk({
|
||||
type: 'tool_call',
|
||||
toolCall: {
|
||||
id: existing.id,
|
||||
name: existing.function.name,
|
||||
arguments: existing.function.arguments
|
||||
},
|
||||
metadata: {
|
||||
provider: 'openai'
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async sendComplete(): Promise<void> {
|
||||
this.isComplete = true;
|
||||
this.clearTimeoutTimer();
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'done',
|
||||
metadata: {
|
||||
provider: 'openai',
|
||||
finishReason: this.finishReason,
|
||||
usage: this.response.usage
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public async complete(): Promise<ChatResponse> {
|
||||
if (!this.isComplete) {
|
||||
await this.sendComplete();
|
||||
}
|
||||
|
||||
const response: ChatResponse = {
|
||||
text: this.buffer,
|
||||
model: 'openai-model',
|
||||
provider: 'openai',
|
||||
usage: this.response.usage
|
||||
};
|
||||
|
||||
if (this.toolCalls.size > 0) {
|
||||
response.tool_calls = Array.from(this.toolCalls.values());
|
||||
}
|
||||
|
||||
if (this.config.onComplete) {
|
||||
this.config.onComplete(response);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic stream handler
|
||||
*/
|
||||
export class AnthropicStreamHandler extends BaseStreamHandler {
|
||||
private messageId?: string;
|
||||
private stopReason?: string;
|
||||
|
||||
public async processChunk(chunk: any): Promise<void> {
|
||||
this.resetTimeoutTimer();
|
||||
|
||||
try {
|
||||
const event = this.parseAnthropicEvent(chunk);
|
||||
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
this.messageId = event.message?.id;
|
||||
break;
|
||||
|
||||
case 'content_block_start':
|
||||
// Content block started
|
||||
break;
|
||||
|
||||
case 'content_block_delta':
|
||||
if (event.delta?.type === 'text_delta') {
|
||||
const text = event.delta.text || '';
|
||||
this.buffer += text;
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'content',
|
||||
content: text,
|
||||
metadata: {
|
||||
provider: 'anthropic',
|
||||
model: event.model
|
||||
}
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case 'content_block_stop':
|
||||
// Content block completed
|
||||
break;
|
||||
|
||||
case 'message_delta':
|
||||
if (event.delta?.stop_reason) {
|
||||
this.stopReason = event.delta.stop_reason;
|
||||
}
|
||||
|
||||
if (event.usage) {
|
||||
this.response.usage = {
|
||||
promptTokens: event.usage.input_tokens,
|
||||
completionTokens: event.usage.output_tokens,
|
||||
totalTokens: (event.usage.input_tokens || 0) + (event.usage.output_tokens || 0)
|
||||
};
|
||||
}
|
||||
break;
|
||||
|
||||
case 'message_stop':
|
||||
await this.sendComplete();
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
this.handleError(new Error(event.error?.message || 'Unknown error'));
|
||||
break;
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[AnthropicStreamHandler] Error processing chunk: ${error}`);
|
||||
this.handleError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
private parseAnthropicEvent(chunk: any): any {
|
||||
if (typeof chunk === 'string') {
|
||||
try {
|
||||
// Parse SSE format
|
||||
const lines = chunk.split('\n');
|
||||
let eventType = '';
|
||||
let eventData = '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) {
|
||||
eventType = line.slice(7);
|
||||
} else if (line.startsWith('data: ')) {
|
||||
eventData = line.slice(6);
|
||||
}
|
||||
}
|
||||
|
||||
if (eventType && eventData) {
|
||||
const parsed = JSON.parse(eventData);
|
||||
return { ...parsed, type: eventType };
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[AnthropicStreamHandler] Error parsing event: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
return chunk;
|
||||
}
|
||||
|
||||
private async sendComplete(): Promise<void> {
|
||||
this.isComplete = true;
|
||||
this.clearTimeoutTimer();
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'done',
|
||||
metadata: {
|
||||
provider: 'anthropic',
|
||||
finishReason: this.stopReason,
|
||||
usage: this.response.usage
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public async complete(): Promise<ChatResponse> {
|
||||
if (!this.isComplete) {
|
||||
await this.sendComplete();
|
||||
}
|
||||
|
||||
const response: ChatResponse = {
|
||||
text: this.buffer,
|
||||
model: 'anthropic-model',
|
||||
provider: 'anthropic',
|
||||
usage: this.response.usage
|
||||
};
|
||||
|
||||
if (this.config.onComplete) {
|
||||
this.config.onComplete(response);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ollama stream handler
|
||||
*/
|
||||
export class OllamaStreamHandler extends BaseStreamHandler {
|
||||
private model?: string;
|
||||
private toolCalls: any[] = [];
|
||||
|
||||
public async processChunk(chunk: any): Promise<void> {
|
||||
this.resetTimeoutTimer();
|
||||
|
||||
try {
|
||||
const data = typeof chunk === 'string' ? JSON.parse(chunk) : chunk;
|
||||
|
||||
// Handle content
|
||||
if (data.message?.content) {
|
||||
const content = data.message.content;
|
||||
this.buffer += content;
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'content',
|
||||
content: content,
|
||||
metadata: {
|
||||
provider: 'ollama',
|
||||
model: data.model || this.model
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (data.message?.tool_calls) {
|
||||
this.toolCalls = data.message.tool_calls;
|
||||
|
||||
for (const toolCall of this.toolCalls) {
|
||||
await this.sendChunk({
|
||||
type: 'tool_call',
|
||||
toolCall: {
|
||||
id: toolCall.id || `tool_${Date.now()}`,
|
||||
name: toolCall.function?.name || '',
|
||||
arguments: JSON.stringify(toolCall.function?.arguments || {})
|
||||
},
|
||||
metadata: {
|
||||
provider: 'ollama'
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Store model info
|
||||
if (data.model) {
|
||||
this.model = data.model;
|
||||
}
|
||||
|
||||
// Check if done
|
||||
if (data.done) {
|
||||
// Calculate token usage if available
|
||||
if (data.prompt_eval_count || data.eval_count) {
|
||||
this.response.usage = {
|
||||
promptTokens: data.prompt_eval_count,
|
||||
completionTokens: data.eval_count,
|
||||
totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0)
|
||||
};
|
||||
}
|
||||
|
||||
await this.sendComplete();
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[OllamaStreamHandler] Error processing chunk: ${error}`);
|
||||
this.handleError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
private async sendComplete(): Promise<void> {
|
||||
this.isComplete = true;
|
||||
this.clearTimeoutTimer();
|
||||
|
||||
await this.sendChunk({
|
||||
type: 'done',
|
||||
metadata: {
|
||||
provider: 'ollama',
|
||||
model: this.model,
|
||||
usage: this.response.usage
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public async complete(): Promise<ChatResponse> {
|
||||
if (!this.isComplete) {
|
||||
await this.sendComplete();
|
||||
}
|
||||
|
||||
const response: ChatResponse = {
|
||||
text: this.buffer,
|
||||
model: this.model || 'ollama-model',
|
||||
provider: 'ollama',
|
||||
usage: this.response.usage
|
||||
};
|
||||
|
||||
if (this.toolCalls.length > 0) {
|
||||
response.tool_calls = this.toolCalls;
|
||||
}
|
||||
|
||||
if (this.config.onComplete) {
|
||||
this.config.onComplete(response);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory function to create appropriate stream handler
|
||||
*/
|
||||
export function createStreamHandler(config: StreamHandlerConfig): BaseStreamHandler {
|
||||
switch (config.provider) {
|
||||
case 'openai':
|
||||
return new OpenAIStreamHandler(config);
|
||||
|
||||
case 'anthropic':
|
||||
return new AnthropicStreamHandler(config);
|
||||
|
||||
case 'ollama':
|
||||
return new OllamaStreamHandler(config);
|
||||
|
||||
default:
|
||||
throw new Error(`Unsupported provider: ${config.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to convert async iterable to unified stream
|
||||
*/
|
||||
export async function* unifiedStream(
|
||||
asyncIterable: AsyncIterable<any>,
|
||||
provider: 'openai' | 'anthropic' | 'ollama'
|
||||
): AsyncGenerator<UnifiedStreamChunk> {
|
||||
const chunks: UnifiedStreamChunk[] = [];
|
||||
let handler: BaseStreamHandler | null = null;
|
||||
|
||||
try {
|
||||
handler = createStreamHandler({
|
||||
provider,
|
||||
onChunk: (chunk) => { chunks.push(chunk); }
|
||||
});
|
||||
|
||||
for await (const chunk of asyncIterable) {
|
||||
await handler.processChunk(chunk);
|
||||
|
||||
// Yield accumulated chunks
|
||||
while (chunks.length > 0) {
|
||||
const chunk = chunks.shift()!;
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
|
||||
// Complete the stream
|
||||
await handler.complete();
|
||||
|
||||
// Yield any remaining chunks
|
||||
while (chunks.length > 0) {
|
||||
const chunk = chunks.shift()!;
|
||||
yield chunk;
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`[unifiedStream] Error: ${error}`);
|
||||
yield {
|
||||
type: 'error',
|
||||
error: (error as Error).message,
|
||||
metadata: { provider }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream aggregator for collecting stream chunks into a complete response
|
||||
*/
|
||||
export class StreamAggregator {
|
||||
private chunks: UnifiedStreamChunk[] = [];
|
||||
private content: string = '';
|
||||
private toolCalls: any[] = [];
|
||||
private metadata: any = {};
|
||||
|
||||
public addChunk(chunk: UnifiedStreamChunk): void {
|
||||
this.chunks.push(chunk);
|
||||
|
||||
switch (chunk.type) {
|
||||
case 'content':
|
||||
if (chunk.content) {
|
||||
this.content += chunk.content;
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool_call':
|
||||
if (chunk.toolCall) {
|
||||
this.toolCalls.push(chunk.toolCall);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'done':
|
||||
if (chunk.metadata) {
|
||||
this.metadata = { ...this.metadata, ...chunk.metadata };
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
public getResponse(): ChatResponse {
|
||||
const response: ChatResponse = {
|
||||
text: this.content,
|
||||
model: this.metadata.model || 'unknown-model',
|
||||
provider: this.metadata.provider || 'unknown',
|
||||
usage: this.metadata.usage
|
||||
};
|
||||
|
||||
if (this.toolCalls.length > 0) {
|
||||
response.tool_calls = this.toolCalls;
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
public getChunks(): UnifiedStreamChunk[] {
|
||||
return [...this.chunks];
|
||||
}
|
||||
|
||||
public reset(): void {
|
||||
this.chunks = [];
|
||||
this.content = '';
|
||||
this.toolCalls = [];
|
||||
this.metadata = {};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user