feat(llm): implement circuitbreaker to prevent going haywire

This commit is contained in:
perfectra1n
2025-08-09 13:24:53 -07:00
parent f89c202fcc
commit 16622f43e3
14 changed files with 6880 additions and 67 deletions

View File

@@ -0,0 +1,185 @@
/**
* LLM Metrics API Endpoint
*
* Provides metrics export endpoints for monitoring systems
*/
import { Router, Request, Response } from 'express';
import { getProviderFactory } from '../../services/llm/providers/provider_factory.js';
import log from '../../services/log.js';
const router = Router();
/**
* GET /api/llm/metrics
* Returns metrics in Prometheus format by default
*/
router.get('/llm/metrics', (req: Request, res: Response) => {
try {
const format = req.query.format as string || 'prometheus';
const factory = getProviderFactory();
if (!factory) {
return res.status(503).json({ error: 'LLM service not initialized' });
}
const metrics = factory.exportMetrics(format as any);
if (!metrics) {
return res.status(503).json({ error: 'Metrics not available' });
}
// Set appropriate content type based on format
switch (format) {
case 'prometheus':
res.set('Content-Type', 'text/plain; version=0.0.4');
res.send(metrics);
break;
case 'json':
res.json(metrics);
break;
case 'opentelemetry':
res.json(metrics);
break;
case 'statsd':
res.set('Content-Type', 'text/plain');
res.send(Array.isArray(metrics) ? metrics.join('\n') : metrics);
break;
default:
res.status(400).json({ error: `Unknown format: ${format}` });
}
} catch (error: any) {
log.error(`[LLM Metrics API] Error exporting metrics: ${error.message}`);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* GET /api/llm/metrics/summary
* Returns a summary of metrics in JSON format
*/
router.get('/llm/metrics/summary', (req: Request, res: Response) => {
try {
const factory = getProviderFactory();
if (!factory) {
return res.status(503).json({ error: 'LLM service not initialized' });
}
const summary = factory.getMetricsSummary();
if (!summary) {
return res.status(503).json({ error: 'Metrics not available' });
}
res.json(summary);
} catch (error: any) {
log.error(`[LLM Metrics API] Error getting metrics summary: ${error.message}`);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* GET /api/llm/circuit-breaker/status
* Returns circuit breaker status for all providers
*/
router.get('/llm/circuit-breaker/status', (req: Request, res: Response) => {
try {
const factory = getProviderFactory();
if (!factory) {
return res.status(503).json({ error: 'LLM service not initialized' });
}
const status = factory.getCircuitBreakerStatus();
if (!status) {
return res.status(503).json({ error: 'Circuit breaker not enabled' });
}
res.json(status);
} catch (error: any) {
log.error(`[LLM Metrics API] Error getting circuit breaker status: ${error.message}`);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* POST /api/llm/circuit-breaker/reset/:provider
* Reset circuit breaker for a specific provider
*/
router.post('/llm/circuit-breaker/reset/:provider', (req: Request, res: Response) => {
try {
const { provider } = req.params;
const factory = getProviderFactory();
if (!factory) {
return res.status(503).json({ error: 'LLM service not initialized' });
}
factory.resetCircuitBreaker(provider as any);
res.json({ message: `Circuit breaker reset for provider: ${provider}` });
} catch (error: any) {
log.error(`[LLM Metrics API] Error resetting circuit breaker: ${error.message}`);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* GET /api/llm/health
* Returns overall health status of LLM service
*/
router.get('/llm/health', (req: Request, res: Response) => {
try {
const factory = getProviderFactory();
if (!factory) {
return res.status(503).json({
status: 'unhealthy',
error: 'LLM service not initialized'
});
}
const circuitStatus = factory.getCircuitBreakerStatus();
const metrics = factory.getMetricsSummary();
const statistics = factory.getStatistics();
const health = {
status: 'healthy',
timestamp: new Date().toISOString(),
providers: {
available: circuitStatus?.summary?.availableProviders || [],
unavailable: circuitStatus?.summary?.unavailableProviders || [],
cached: statistics?.cachedProviders || 0,
healthy: statistics?.healthyProviders || 0,
unhealthy: statistics?.unhealthyProviders || 0
},
metrics: {
totalRequests: metrics?.system?.totalRequests || 0,
totalFailures: metrics?.system?.totalFailures || 0,
uptime: metrics?.system?.uptime || 0
},
circuitBreakers: circuitStatus?.summary || {}
};
// Determine overall health
if (health.providers.available.length === 0) {
health.status = 'unhealthy';
} else if (health.providers.unavailable.length > 0) {
health.status = 'degraded';
}
const statusCode = health.status === 'healthy' ? 200 :
health.status === 'degraded' ? 200 : 503;
res.status(statusCode).json(health);
} catch (error: any) {
log.error(`[LLM Metrics API] Error getting health status: ${error.message}`);
res.status(500).json({
status: 'unhealthy',
error: 'Internal server error'
});
}
});
export default router;

View File

@@ -8,6 +8,7 @@ import contextService from './context/services/context_service.js';
import log from '../log.js';
import { OllamaService } from './providers/ollama_service.js';
import { OpenAIService } from './providers/openai_service.js';
import { ProviderFactory, ProviderType, getProviderFactory } from './providers/provider_factory.js';
// Import interfaces
import type {
@@ -26,7 +27,6 @@ import {
clearConfigurationCache,
validateConfiguration
} from './config/configuration_helpers.js';
import type { ProviderType } from './interfaces/configuration_interfaces.js';
/**
* Interface representing relevant note context
@@ -59,8 +59,19 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
private cleanupTimer: NodeJS.Timeout | null = null;
private initialized = false;
private disposed = false;
private providerFactory: ProviderFactory | null = null;
constructor() {
// Initialize provider factory
this.providerFactory = getProviderFactory({
enableHealthChecks: true,
healthCheckInterval: 60000,
enableFallback: true,
enableCaching: true,
cacheTimeout: this.SERVICE_TTL_MS,
enableMetrics: true
});
// Initialize tools immediately
this.initializeTools().catch(error => {
log.error(`Error initializing LLM tools during AIServiceManager construction: ${error.message || String(error)}`);
@@ -456,7 +467,12 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
* Clear all cached providers (forces recreation on next access)
*/
public clearCurrentProvider(): void {
// Clear all cached services
// Clear provider factory cache
if (this.providerFactory) {
this.providerFactory.clearCache();
}
// Clear local cache
for (const provider of this.serviceCache.keys()) {
this.disposeService(provider);
}
@@ -464,87 +480,66 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
}
/**
* Get or create a provider instance with proper caching and TTL
* Get or create a provider instance using the provider factory
*/
private async getOrCreateChatProvider(providerName: ServiceProviders): Promise<AIService | null> {
if (this.disposed) {
throw new Error('AIServiceManager has been disposed');
}
// Check cache first
const cached = this.serviceCache.get(providerName);
if (cached && cached.service.isAvailable()) {
// Update last used time
cached.lastUsed = Date.now();
// Check if service is still within TTL
if (Date.now() - cached.createdAt <= this.SERVICE_TTL_MS) {
log.info(`Using cached ${providerName} service (age: ${Math.round((Date.now() - cached.createdAt) / 1000)}s)`);
return cached.service;
} else {
// Service is stale, dispose and recreate
log.info(`Cached ${providerName} service is stale, recreating`);
this.disposeService(providerName);
}
if (!this.providerFactory) {
throw new Error('Provider factory not initialized');
}
// Create new service for the requested provider
try {
let service: AIService | null = null;
// Map ServiceProviders to ProviderType
const providerTypeMap: Record<ServiceProviders, ProviderType> = {
'openai': ProviderType.OPENAI,
'anthropic': ProviderType.ANTHROPIC,
'ollama': ProviderType.OLLAMA
};
const providerType = providerTypeMap[providerName];
if (!providerType) {
log.error(`Unknown provider name: ${providerName}`);
return null;
}
// Check if provider is configured
switch (providerName) {
case 'openai': {
const apiKey = options.getOption('openaiApiKey');
const baseUrl = options.getOption('openaiBaseUrl');
if (!apiKey && !baseUrl) return null;
service = new OpenAIService();
if (!service.isAvailable()) {
throw new Error('OpenAI service not available');
}
break;
}
case 'anthropic': {
const apiKey = options.getOption('anthropicApiKey');
if (!apiKey) return null;
service = new AnthropicService();
if (!service.isAvailable()) {
throw new Error('Anthropic service not available');
}
break;
}
case 'ollama': {
const baseUrl = options.getOption('ollamaBaseUrl');
if (!baseUrl) return null;
service = new OllamaService();
if (!service.isAvailable()) {
throw new Error('Ollama service not available');
}
break;
}
}
if (service) {
// Cache the new service with metadata
const now = Date.now();
this.serviceCache.set(providerName, {
service,
provider: providerName,
createdAt: now,
lastUsed: now
});
log.info(`Created and cached new ${providerName} service`);
// Use provider factory to create the service
const service = await this.providerFactory.createProvider(providerType);
if (service && service.isAvailable()) {
log.info(`Created ${providerName} service via provider factory`);
return service;
}
throw new Error(`${providerName} service not available`);
} catch (error: any) {
log.error(`Failed to create ${providerName} chat provider: ${error.message || 'Unknown error'}`);
// Provider factory handles fallback internally if configured
return null;
}
return null;
}
/**
@@ -559,6 +554,12 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
// Stop cleanup timer
this.stopCleanupTimer();
// Dispose provider factory
if (this.providerFactory) {
this.providerFactory.dispose();
this.providerFactory = null;
}
// Dispose all cached services
for (const provider of this.serviceCache.keys()) {
this.disposeService(provider);
@@ -766,13 +767,24 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
* Check if a specific provider is available
*/
isProviderAvailable(provider: string): boolean {
// Check if we have a cached service for this provider
const cachedEntry = this.serviceCache.get(provider as ServiceProviders);
if (cachedEntry && !this.isServiceStale(cachedEntry)) {
return cachedEntry.service.isAvailable();
// Check health status from provider factory
if (this.providerFactory) {
const providerTypeMap: Record<string, ProviderType> = {
'openai': ProviderType.OPENAI,
'anthropic': ProviderType.ANTHROPIC,
'ollama': ProviderType.OLLAMA
};
const providerType = providerTypeMap[provider];
if (providerType) {
const healthStatus = this.providerFactory.getHealthStatus(providerType);
if (healthStatus) {
return healthStatus.healthy;
}
}
}
// For other providers, check configuration
// Fallback to configuration check
try {
switch (provider) {
case 'openai':
@@ -793,22 +805,43 @@ export class AIServiceManager implements IAIServiceManager, Disposable {
* Get metadata about a provider
*/
getProviderMetadata(provider: string): ProviderMetadata | null {
// Check if we have a cached service for this provider
const cachedEntry = this.serviceCache.get(provider as ServiceProviders);
if (cachedEntry && !this.isServiceStale(cachedEntry)) {
return {
name: provider,
capabilities: {
chat: true,
streaming: true,
functionCalling: provider === 'openai' // Only OpenAI has function calling
},
models: ['default'], // Placeholder, could be populated from the service
defaultModel: 'default'
// Get capabilities from provider factory
if (this.providerFactory) {
const providerTypeMap: Record<string, ProviderType> = {
'openai': ProviderType.OPENAI,
'anthropic': ProviderType.ANTHROPIC,
'ollama': ProviderType.OLLAMA
};
const providerType = providerTypeMap[provider];
if (providerType) {
const capabilities = this.providerFactory.getCapabilities(providerType);
if (capabilities) {
return {
name: provider,
capabilities: {
chat: true,
streaming: capabilities.streaming,
functionCalling: capabilities.functionCalling
},
models: ['default'], // Could be enhanced to get actual models
defaultModel: 'default'
};
}
}
}
return null;
// Fallback
return {
name: provider,
capabilities: {
chat: true,
streaming: true,
functionCalling: provider === 'openai'
},
models: ['default'],
defaultModel: 'default'
};
}

View File

@@ -0,0 +1,290 @@
/**
* LLM Service Configuration Options
*
* Defines all configurable options for the LLM service that can be
* managed through Trilium's options system.
*/
import optionService from '../../options.js';
import type { OptionNames, FilterOptionsByType } from '@triliumnext/commons';
import { ExportFormat } from '../metrics/metrics_exporter.js';
/**
* LLM configuration options
*/
export interface LLMOptions {
// Circuit Breaker Configuration
circuitBreakerEnabled: boolean;
circuitBreakerFailureThreshold: number;
circuitBreakerFailureWindow: number;
circuitBreakerCooldownPeriod: number;
circuitBreakerSuccessThreshold: number;
// Metrics Configuration
metricsEnabled: boolean;
metricsExportFormat: ExportFormat;
metricsExportEndpoint?: string;
metricsExportInterval: number;
metricsPrometheusEnabled: boolean;
metricsStatsdHost?: string;
metricsStatsdPort?: number;
metricsStatsdPrefix: string;
// Provider Configuration
providerHealthCheckEnabled: boolean;
providerHealthCheckInterval: number;
providerCachingEnabled: boolean;
providerCacheTimeout: number;
providerFallbackEnabled: boolean;
providerFallbackList: string[];
}
/**
* Default LLM options
*/
const DEFAULT_OPTIONS: LLMOptions = {
// Circuit Breaker Defaults
circuitBreakerEnabled: true,
circuitBreakerFailureThreshold: 5,
circuitBreakerFailureWindow: 60000, // 1 minute
circuitBreakerCooldownPeriod: 30000, // 30 seconds
circuitBreakerSuccessThreshold: 2,
// Metrics Defaults
metricsEnabled: true,
metricsExportFormat: 'prometheus' as ExportFormat,
metricsExportInterval: 60000, // 1 minute
metricsPrometheusEnabled: true,
metricsStatsdPrefix: 'trilium.llm',
// Provider Defaults
providerHealthCheckEnabled: true,
providerHealthCheckInterval: 60000, // 1 minute
providerCachingEnabled: true,
providerCacheTimeout: 300000, // 5 minutes
providerFallbackEnabled: true,
providerFallbackList: ['ollama']
};
/**
* Option keys in Trilium's option system
*/
export const LLM_OPTION_KEYS = {
// Circuit Breaker
CIRCUIT_BREAKER_ENABLED: 'llmCircuitBreakerEnabled' as const,
CIRCUIT_BREAKER_FAILURE_THRESHOLD: 'llmCircuitBreakerFailureThreshold' as const,
CIRCUIT_BREAKER_FAILURE_WINDOW: 'llmCircuitBreakerFailureWindow' as const,
CIRCUIT_BREAKER_COOLDOWN_PERIOD: 'llmCircuitBreakerCooldownPeriod' as const,
CIRCUIT_BREAKER_SUCCESS_THRESHOLD: 'llmCircuitBreakerSuccessThreshold' as const,
// Metrics
METRICS_ENABLED: 'llmMetricsEnabled' as const,
METRICS_EXPORT_FORMAT: 'llmMetricsExportFormat' as const,
METRICS_EXPORT_ENDPOINT: 'llmMetricsExportEndpoint' as const,
METRICS_EXPORT_INTERVAL: 'llmMetricsExportInterval' as const,
METRICS_PROMETHEUS_ENABLED: 'llmMetricsPrometheusEnabled' as const,
METRICS_STATSD_HOST: 'llmMetricsStatsdHost' as const,
METRICS_STATSD_PORT: 'llmMetricsStatsdPort' as const,
METRICS_STATSD_PREFIX: 'llmMetricsStatsdPrefix' as const,
// Provider
PROVIDER_HEALTH_CHECK_ENABLED: 'llmProviderHealthCheckEnabled' as const,
PROVIDER_HEALTH_CHECK_INTERVAL: 'llmProviderHealthCheckInterval' as const,
PROVIDER_CACHING_ENABLED: 'llmProviderCachingEnabled' as const,
PROVIDER_CACHE_TIMEOUT: 'llmProviderCacheTimeout' as const,
PROVIDER_FALLBACK_ENABLED: 'llmProviderFallbackEnabled' as const,
PROVIDER_FALLBACK_LIST: 'llmProviderFallbackList' as const
} as const;
/**
* Get LLM options from Trilium's option service
*/
export function getLLMOptions(): LLMOptions {
// Helper function to safely get option with fallback
function getOptionSafe<T>(getter: () => T, defaultValue: T): T {
try {
return getter() ?? defaultValue;
} catch {
return defaultValue;
}
}
return {
// Circuit Breaker
circuitBreakerEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.CIRCUIT_BREAKER_ENABLED),
DEFAULT_OPTIONS.circuitBreakerEnabled
),
circuitBreakerFailureThreshold: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_FAILURE_THRESHOLD),
DEFAULT_OPTIONS.circuitBreakerFailureThreshold
),
circuitBreakerFailureWindow: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_FAILURE_WINDOW),
DEFAULT_OPTIONS.circuitBreakerFailureWindow
),
circuitBreakerCooldownPeriod: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_COOLDOWN_PERIOD),
DEFAULT_OPTIONS.circuitBreakerCooldownPeriod
),
circuitBreakerSuccessThreshold: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.CIRCUIT_BREAKER_SUCCESS_THRESHOLD),
DEFAULT_OPTIONS.circuitBreakerSuccessThreshold
),
// Metrics
metricsEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.METRICS_ENABLED),
DEFAULT_OPTIONS.metricsEnabled
),
metricsExportFormat: getOptionSafe(
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_EXPORT_FORMAT) as ExportFormat,
DEFAULT_OPTIONS.metricsExportFormat
),
metricsExportEndpoint: getOptionSafe(
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_EXPORT_ENDPOINT),
undefined
),
metricsExportInterval: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.METRICS_EXPORT_INTERVAL),
DEFAULT_OPTIONS.metricsExportInterval
),
metricsPrometheusEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.METRICS_PROMETHEUS_ENABLED),
DEFAULT_OPTIONS.metricsPrometheusEnabled
),
metricsStatsdHost: getOptionSafe(
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_STATSD_HOST),
undefined
),
metricsStatsdPort: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.METRICS_STATSD_PORT),
undefined
),
metricsStatsdPrefix: getOptionSafe(
() => optionService.getOption(LLM_OPTION_KEYS.METRICS_STATSD_PREFIX),
DEFAULT_OPTIONS.metricsStatsdPrefix
),
// Provider
providerHealthCheckEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_HEALTH_CHECK_ENABLED),
DEFAULT_OPTIONS.providerHealthCheckEnabled
),
providerHealthCheckInterval: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.PROVIDER_HEALTH_CHECK_INTERVAL),
DEFAULT_OPTIONS.providerHealthCheckInterval
),
providerCachingEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_CACHING_ENABLED),
DEFAULT_OPTIONS.providerCachingEnabled
),
providerCacheTimeout: getOptionSafe(
() => optionService.getOptionInt(LLM_OPTION_KEYS.PROVIDER_CACHE_TIMEOUT),
DEFAULT_OPTIONS.providerCacheTimeout
),
providerFallbackEnabled: getOptionSafe(
() => optionService.getOptionBool(LLM_OPTION_KEYS.PROVIDER_FALLBACK_ENABLED),
DEFAULT_OPTIONS.providerFallbackEnabled
),
providerFallbackList: getOptionSafe(
() => optionService.getOption(LLM_OPTION_KEYS.PROVIDER_FALLBACK_LIST).split(',').map((s: string) => s.trim()).filter(Boolean),
DEFAULT_OPTIONS.providerFallbackList
)
};
}
/**
* Set an LLM option
*/
export async function setLLMOption(key: OptionNames, value: any): Promise<void> {
await optionService.setOption(key, value);
}
/**
* Initialize LLM options with defaults if not set
*/
export async function initializeLLMOptions(): Promise<void> {
// Set defaults for any unset options
const keysToCheck = Object.values(LLM_OPTION_KEYS) as OptionNames[];
for (const key of keysToCheck) {
try {
const currentValue = optionService.getOption(key);
if (currentValue === null || currentValue === undefined) {
// Set default based on key
const defaultKey = Object.entries(LLM_OPTION_KEYS)
.find(([_, v]) => v === key)?.[0];
if (defaultKey) {
const defaultPath = defaultKey
.replace(/_([a-z])/g, (_, char) => char.toUpperCase())
.replace(/^[A-Z]/, char => char.toLowerCase())
.replace(/_/g, '');
const defaultValue = (DEFAULT_OPTIONS as any)[defaultPath];
if (defaultValue !== undefined) {
await setLLMOption(key,
Array.isArray(defaultValue) ? defaultValue.join(',') : defaultValue
);
}
}
}
} catch {
// Option doesn't exist yet, create it with default
const defaultKey = Object.entries(LLM_OPTION_KEYS)
.find(([_, v]) => v === key)?.[0];
if (defaultKey) {
const defaultPath = defaultKey
.replace(/_([a-z])/g, (_, char) => char.toUpperCase())
.replace(/^[A-Z]/, char => char.toLowerCase())
.replace(/_/g, '');
const defaultValue = (DEFAULT_OPTIONS as any)[defaultPath];
if (defaultValue !== undefined) {
await setLLMOption(key,
Array.isArray(defaultValue) ? defaultValue.join(',') : defaultValue
);
}
}
}
}
}
/**
* Create provider factory options from LLM options
*/
export function createProviderFactoryOptions() {
const options = getLLMOptions();
return {
enableHealthChecks: options.providerHealthCheckEnabled,
healthCheckInterval: options.providerHealthCheckInterval,
enableFallback: options.providerFallbackEnabled,
fallbackProviders: options.providerFallbackList as any[],
enableCaching: options.providerCachingEnabled,
cacheTimeout: options.providerCacheTimeout,
enableMetrics: options.metricsEnabled,
enableCircuitBreaker: options.circuitBreakerEnabled,
circuitBreakerConfig: {
failureThreshold: options.circuitBreakerFailureThreshold,
failureWindow: options.circuitBreakerFailureWindow,
cooldownPeriod: options.circuitBreakerCooldownPeriod,
successThreshold: options.circuitBreakerSuccessThreshold,
enableLogging: true
},
metricsExporterConfig: {
enabled: options.metricsEnabled,
format: options.metricsExportFormat,
endpoint: options.metricsExportEndpoint,
interval: options.metricsExportInterval,
statsdHost: options.metricsStatsdHost,
statsdPort: options.metricsStatsdPort,
prefix: options.metricsStatsdPrefix
}
};
}

View File

@@ -0,0 +1,796 @@
/**
* Metrics Export System for LLM Service
*
* Provides unified metrics collection and export to various monitoring systems:
* - Prometheus format endpoint
* - StatsD/DataDog format
* - OpenTelemetry format
*/
import log from '../../log.js';
import { EventEmitter } from 'events';
import type { ProviderType } from '../providers/provider_factory.js';
/**
* Metric types
*/
export enum MetricType {
COUNTER = 'counter',
GAUGE = 'gauge',
HISTOGRAM = 'histogram',
SUMMARY = 'summary'
}
/**
* Metric data point
*/
export interface MetricDataPoint {
name: string;
type: MetricType;
value: number;
timestamp: Date;
labels: Record<string, string>;
unit?: string;
description?: string;
}
/**
* Provider metrics
*/
export interface ProviderMetrics {
provider: string;
requests: number;
failures: number;
successRate: number;
averageLatency: number;
p50Latency: number;
p95Latency: number;
p99Latency: number;
totalTokens: number;
inputTokens: number;
outputTokens: number;
averageTokensPerRequest: number;
errorRate: number;
lastError?: string;
lastUpdated: Date;
}
/**
* System metrics
*/
export interface SystemMetrics {
totalRequests: number;
totalFailures: number;
averageLatency: number;
activePipelines: number;
cacheHitRate: number;
memoryUsage: number;
uptime: number;
timestamp: Date;
}
/**
* Export format types
*/
export enum ExportFormat {
PROMETHEUS = 'prometheus',
STATSD = 'statsd',
OPENTELEMETRY = 'opentelemetry',
JSON = 'json'
}
/**
* Exporter configuration
*/
export interface ExporterConfig {
enabled: boolean;
format: ExportFormat;
endpoint?: string;
interval?: number;
prefix?: string;
labels?: Record<string, string>;
includeHistograms?: boolean;
histogramBuckets?: number[];
statsdHost?: string;
statsdPort?: number;
statsdPrefix?: string;
}
/**
* Base metrics collector
*/
export class MetricsCollector extends EventEmitter {
private metrics: Map<string, MetricDataPoint[]> = new Map();
private providerMetrics: Map<string, ProviderMetrics> = new Map();
private systemMetrics: SystemMetrics;
private startTime: Date;
private latencyHistogram: Map<string, number[]> = new Map();
private readonly maxDataPoints = 10000;
private readonly maxHistogramSize = 1000;
constructor() {
super();
this.startTime = new Date();
this.systemMetrics = this.createDefaultSystemMetrics();
}
/**
* Record a metric
*/
public record(metric: MetricDataPoint): void {
const key = this.getMetricKey(metric);
if (!this.metrics.has(key)) {
this.metrics.set(key, []);
}
const dataPoints = this.metrics.get(key)!;
dataPoints.push(metric);
// Limit stored data points
if (dataPoints.length > this.maxDataPoints) {
dataPoints.shift();
}
// Update provider metrics if applicable
if (metric.labels.provider) {
this.updateProviderMetrics(metric);
}
// Update system metrics
this.updateSystemMetrics(metric);
// Emit metric event
this.emit('metric', metric);
}
/**
* Record latency
*/
public recordLatency(provider: string, latency: number): void {
this.record({
name: 'llm_request_latency',
type: MetricType.HISTOGRAM,
value: latency,
timestamp: new Date(),
labels: { provider },
unit: 'ms',
description: 'LLM request latency'
});
// Update latency histogram
if (!this.latencyHistogram.has(provider)) {
this.latencyHistogram.set(provider, []);
}
const histogram = this.latencyHistogram.get(provider)!;
histogram.push(latency);
if (histogram.length > this.maxHistogramSize) {
histogram.shift();
}
}
/**
* Record token usage
*/
public recordTokenUsage(
provider: string,
inputTokens: number,
outputTokens: number
): void {
this.record({
name: 'llm_tokens_used',
type: MetricType.COUNTER,
value: inputTokens + outputTokens,
timestamp: new Date(),
labels: { provider, type: 'total' },
unit: 'tokens',
description: 'Total tokens used'
});
this.record({
name: 'llm_input_tokens',
type: MetricType.COUNTER,
value: inputTokens,
timestamp: new Date(),
labels: { provider },
unit: 'tokens',
description: 'Input tokens used'
});
this.record({
name: 'llm_output_tokens',
type: MetricType.COUNTER,
value: outputTokens,
timestamp: new Date(),
labels: { provider },
unit: 'tokens',
description: 'Output tokens generated'
});
}
/**
* Record error
*/
public recordError(provider: string, error: string): void {
this.record({
name: 'llm_errors',
type: MetricType.COUNTER,
value: 1,
timestamp: new Date(),
labels: { provider, error_type: this.classifyError(error) },
description: 'LLM request errors'
});
// Update provider metrics
const metrics = this.getProviderMetrics(provider);
metrics.failures++;
metrics.lastError = error;
metrics.errorRate = metrics.failures / metrics.requests;
}
/**
* Record request
*/
public recordRequest(provider: string, success: boolean): void {
this.record({
name: 'llm_requests',
type: MetricType.COUNTER,
value: 1,
timestamp: new Date(),
labels: { provider, status: success ? 'success' : 'failure' },
description: 'LLM requests'
});
const metrics = this.getProviderMetrics(provider);
metrics.requests++;
if (!success) {
metrics.failures++;
}
metrics.successRate = (metrics.requests - metrics.failures) / metrics.requests;
}
/**
* Get or create provider metrics
*/
private getProviderMetrics(provider: string): ProviderMetrics {
if (!this.providerMetrics.has(provider)) {
this.providerMetrics.set(provider, {
provider,
requests: 0,
failures: 0,
successRate: 1,
averageLatency: 0,
p50Latency: 0,
p95Latency: 0,
p99Latency: 0,
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
averageTokensPerRequest: 0,
errorRate: 0,
lastUpdated: new Date()
});
}
return this.providerMetrics.get(provider)!;
}
/**
* Update provider metrics
*/
private updateProviderMetrics(metric: MetricDataPoint): void {
const provider = metric.labels.provider;
if (!provider) return;
const metrics = this.getProviderMetrics(provider);
metrics.lastUpdated = new Date();
// Update token metrics
if (metric.name.includes('tokens')) {
if (metric.name === 'llm_input_tokens') {
metrics.inputTokens += metric.value;
} else if (metric.name === 'llm_output_tokens') {
metrics.outputTokens += metric.value;
}
metrics.totalTokens = metrics.inputTokens + metrics.outputTokens;
if (metrics.requests > 0) {
metrics.averageTokensPerRequest = metrics.totalTokens / metrics.requests;
}
}
// Update latency metrics
if (metric.name === 'llm_request_latency') {
const histogram = this.latencyHistogram.get(provider);
if (histogram && histogram.length > 0) {
const sorted = [...histogram].sort((a, b) => a - b);
metrics.averageLatency = sorted.reduce((a, b) => a + b, 0) / sorted.length;
metrics.p50Latency = this.percentile(sorted, 50);
metrics.p95Latency = this.percentile(sorted, 95);
metrics.p99Latency = this.percentile(sorted, 99);
}
}
}
/**
* Update system metrics
*/
private updateSystemMetrics(metric: MetricDataPoint): void {
if (metric.name === 'llm_requests') {
this.systemMetrics.totalRequests++;
if (metric.labels.status === 'failure') {
this.systemMetrics.totalFailures++;
}
}
this.systemMetrics.uptime = Date.now() - this.startTime.getTime();
this.systemMetrics.timestamp = new Date();
this.systemMetrics.memoryUsage = process.memoryUsage().heapUsed;
}
/**
* Calculate percentile
*/
private percentile(sorted: number[], p: number): number {
const index = Math.ceil((p / 100) * sorted.length) - 1;
return sorted[Math.max(0, index)];
}
/**
* Classify error type
*/
private classifyError(error: string): string {
const errorLower = error.toLowerCase();
if (errorLower.includes('timeout')) return 'timeout';
if (errorLower.includes('rate')) return 'rate_limit';
if (errorLower.includes('auth')) return 'authentication';
if (errorLower.includes('network')) return 'network';
if (errorLower.includes('circuit')) return 'circuit_breaker';
return 'unknown';
}
/**
* Get metric key
*/
private getMetricKey(metric: MetricDataPoint): string {
const labelStr = Object.entries(metric.labels)
.sort(([a], [b]) => a.localeCompare(b))
.map(([k, v]) => `${k}=${v}`)
.join(',');
return `${metric.name}{${labelStr}}`;
}
/**
* Create default system metrics
*/
private createDefaultSystemMetrics(): SystemMetrics {
return {
totalRequests: 0,
totalFailures: 0,
averageLatency: 0,
activePipelines: 0,
cacheHitRate: 0,
memoryUsage: 0,
uptime: 0,
timestamp: new Date()
};
}
/**
* Get all metrics
*/
public getAllMetrics(): Map<string, MetricDataPoint[]> {
return new Map(this.metrics);
}
/**
* Get provider metrics
*/
public getProviderMetricsMap(): Map<string, ProviderMetrics> {
return new Map(this.providerMetrics);
}
/**
* Get system metrics
*/
public getSystemMetrics(): SystemMetrics {
return { ...this.systemMetrics };
}
/**
* Clear metrics
*/
public clear(): void {
this.metrics.clear();
this.providerMetrics.clear();
this.latencyHistogram.clear();
this.systemMetrics = this.createDefaultSystemMetrics();
}
}
/**
* Prometheus format exporter
*/
export class PrometheusExporter {
constructor(private collector: MetricsCollector) {}
/**
* Export metrics in Prometheus format
*/
public export(): string {
const lines: string[] = [];
const metrics = this.collector.getAllMetrics();
// Add help and type comments
const metricTypes = new Map<string, MetricType>();
const metricDescriptions = new Map<string, string>();
for (const [key, dataPoints] of metrics) {
if (dataPoints.length === 0) continue;
const latest = dataPoints[dataPoints.length - 1];
const metricName = latest.name;
if (!metricTypes.has(metricName)) {
metricTypes.set(metricName, latest.type);
metricDescriptions.set(metricName, latest.description || '');
lines.push(`# HELP ${metricName} ${latest.description || ''}`);
lines.push(`# TYPE ${metricName} ${this.mapMetricType(latest.type)}`);
}
// Add metric value
const labelStr = Object.entries(latest.labels)
.map(([k, v]) => `${k}="${v}"`)
.join(',');
const metricLine = labelStr
? `${metricName}{${labelStr}} ${latest.value}`
: `${metricName} ${latest.value}`;
lines.push(metricLine);
}
// Add provider-specific metrics
for (const [provider, metrics] of this.collector.getProviderMetricsMap()) {
lines.push(`# HELP llm_provider_success_rate Success rate by provider`);
lines.push(`# TYPE llm_provider_success_rate gauge`);
lines.push(`llm_provider_success_rate{provider="${provider}"} ${metrics.successRate}`);
lines.push(`# HELP llm_provider_avg_latency Average latency by provider`);
lines.push(`# TYPE llm_provider_avg_latency gauge`);
lines.push(`llm_provider_avg_latency{provider="${provider}"} ${metrics.averageLatency}`);
}
// Add system metrics
const systemMetrics = this.collector.getSystemMetrics();
lines.push(`# HELP llm_system_uptime System uptime in milliseconds`);
lines.push(`# TYPE llm_system_uptime counter`);
lines.push(`llm_system_uptime ${systemMetrics.uptime}`);
lines.push(`# HELP llm_system_memory_usage Memory usage in bytes`);
lines.push(`# TYPE llm_system_memory_usage gauge`);
lines.push(`llm_system_memory_usage ${systemMetrics.memoryUsage}`);
return lines.join('\n');
}
/**
* Map internal metric type to Prometheus type
*/
private mapMetricType(type: MetricType): string {
switch (type) {
case MetricType.COUNTER:
return 'counter';
case MetricType.GAUGE:
return 'gauge';
case MetricType.HISTOGRAM:
return 'histogram';
case MetricType.SUMMARY:
return 'summary';
default:
return 'gauge';
}
}
}
/**
* StatsD format exporter
*/
export class StatsDExporter {
constructor(
private collector: MetricsCollector,
private prefix: string = 'llm'
) {}
/**
* Export metrics in StatsD format
*/
public export(): string[] {
const lines: string[] = [];
const metrics = this.collector.getAllMetrics();
for (const [_, dataPoints] of metrics) {
if (dataPoints.length === 0) continue;
const latest = dataPoints[dataPoints.length - 1];
const metricName = this.formatMetricName(latest.name, latest.labels);
switch (latest.type) {
case MetricType.COUNTER:
lines.push(`${metricName}:${latest.value}|c`);
break;
case MetricType.GAUGE:
lines.push(`${metricName}:${latest.value}|g`);
break;
case MetricType.HISTOGRAM:
lines.push(`${metricName}:${latest.value}|h`);
break;
default:
lines.push(`${metricName}:${latest.value}|g`);
}
}
return lines;
}
/**
* Format metric name for StatsD
*/
private formatMetricName(name: string, labels: Record<string, string>): string {
const parts = [this.prefix, name];
// Add important labels to the metric name
if (labels.provider) {
parts.push(labels.provider);
}
return parts.join('.');
}
}
/**
* OpenTelemetry format exporter
*/
export class OpenTelemetryExporter {
constructor(private collector: MetricsCollector) {}
/**
* Export metrics in OpenTelemetry format
*/
public export(): object {
const metrics = this.collector.getAllMetrics();
const providerMetrics = this.collector.getProviderMetricsMap();
const systemMetrics = this.collector.getSystemMetrics();
const resource = {
attributes: {
'service.name': 'trilium-llm',
'service.version': '1.0.0'
}
};
const scopeMetrics = {
scope: {
name: 'trilium.llm.metrics',
version: '1.0.0'
},
metrics: [] as any[]
};
// Convert internal metrics to OTLP format
for (const [key, dataPoints] of metrics) {
if (dataPoints.length === 0) continue;
const latest = dataPoints[dataPoints.length - 1];
const metric = {
name: latest.name,
description: latest.description,
unit: latest.unit,
data: {
dataPoints: dataPoints.map(dp => ({
attributes: dp.labels,
timeUnixNano: dp.timestamp.getTime() * 1000000,
value: dp.value
}))
}
};
scopeMetrics.metrics.push(metric);
}
return {
resourceMetrics: [{
resource,
scopeMetrics: [scopeMetrics]
}]
};
}
}
/**
* Metrics Exporter Manager
*/
export class MetricsExporter {
private static instance: MetricsExporter | null = null;
private collector: MetricsCollector;
private exporters: Map<ExportFormat, any> = new Map();
private exportTimer?: NodeJS.Timeout;
private config: ExporterConfig;
constructor(config?: Partial<ExporterConfig>) {
this.collector = new MetricsCollector();
this.config = {
enabled: config?.enabled ?? false,
format: config?.format ?? ExportFormat.PROMETHEUS,
interval: config?.interval ?? 60000, // 1 minute
prefix: config?.prefix ?? 'llm',
includeHistograms: config?.includeHistograms ?? true,
histogramBuckets: config?.histogramBuckets ?? [10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000],
...config
};
this.initializeExporters();
if (this.config.enabled && this.config.interval) {
this.startAutoExport();
}
}
/**
* Get singleton instance
*/
public static getInstance(config?: Partial<ExporterConfig>): MetricsExporter {
if (!MetricsExporter.instance) {
MetricsExporter.instance = new MetricsExporter(config);
}
return MetricsExporter.instance;
}
/**
* Initialize exporters
*/
private initializeExporters(): void {
this.exporters.set(
ExportFormat.PROMETHEUS,
new PrometheusExporter(this.collector)
);
this.exporters.set(
ExportFormat.STATSD,
new StatsDExporter(this.collector, this.config.prefix)
);
this.exporters.set(
ExportFormat.OPENTELEMETRY,
new OpenTelemetryExporter(this.collector)
);
}
/**
* Start auto export
*/
private startAutoExport(): void {
if (this.exportTimer) {
clearInterval(this.exportTimer);
}
this.exportTimer = setInterval(() => {
this.export();
}, this.config.interval);
}
/**
* Export metrics
*/
public export(format?: ExportFormat): any {
const exportFormat = format || this.config.format;
const exporter = this.exporters.get(exportFormat);
if (!exporter) {
log.error(`[MetricsExporter] Unknown export format: ${exportFormat}`);
return null;
}
try {
const data = exporter.export();
if (this.config.endpoint) {
this.sendToEndpoint(data, exportFormat);
}
return data;
} catch (error) {
log.error(`[MetricsExporter] Export failed: ${error}`);
return null;
}
}
/**
* Send metrics to endpoint
*/
private async sendToEndpoint(data: any, format: ExportFormat): Promise<void> {
if (!this.config.endpoint) return;
try {
const contentType = this.getContentType(format);
const body = typeof data === 'string' ? data : JSON.stringify(data);
// This would be replaced with actual HTTP client
log.info(`[MetricsExporter] Would send metrics to ${this.config.endpoint}`);
// await fetch(this.config.endpoint, {
// method: 'POST',
// headers: { 'Content-Type': contentType },
// body
// });
} catch (error) {
log.error(`[MetricsExporter] Failed to send metrics: ${error}`);
}
}
/**
* Get content type for format
*/
private getContentType(format: ExportFormat): string {
switch (format) {
case ExportFormat.PROMETHEUS:
return 'text/plain; version=0.0.4';
case ExportFormat.STATSD:
return 'text/plain';
case ExportFormat.OPENTELEMETRY:
return 'application/json';
default:
return 'application/json';
}
}
/**
* Get metrics collector
*/
public getCollector(): MetricsCollector {
return this.collector;
}
/**
* Enable/disable exporter
*/
public setEnabled(enabled: boolean): void {
this.config.enabled = enabled;
if (enabled && this.config.interval && !this.exportTimer) {
this.startAutoExport();
} else if (!enabled && this.exportTimer) {
clearInterval(this.exportTimer);
this.exportTimer = undefined;
}
}
/**
* Update configuration
*/
public updateConfig(config: Partial<ExporterConfig>): void {
this.config = { ...this.config, ...config };
if (this.config.enabled && this.config.interval) {
this.startAutoExport();
}
}
/**
* Dispose exporter
*/
public dispose(): void {
if (this.exportTimer) {
clearInterval(this.exportTimer);
this.exportTimer = undefined;
}
this.collector.clear();
MetricsExporter.instance = null;
}
}
// Export singleton getter
export const getMetricsExporter = (config?: Partial<ExporterConfig>): MetricsExporter => {
return MetricsExporter.getInstance(config);
};

View File

@@ -0,0 +1,319 @@
/**
* Circuit Breaker Tests
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import {
CircuitBreaker,
CircuitBreakerManager,
CircuitState,
CircuitOpenError
} from '../circuit_breaker.js';
describe('CircuitBreaker', () => {
let breaker: CircuitBreaker;
beforeEach(() => {
breaker = new CircuitBreaker('test-provider', {
failureThreshold: 3,
failureWindow: 1000,
cooldownPeriod: 500,
successThreshold: 2,
enableLogging: false
});
});
afterEach(() => {
breaker.dispose();
});
describe('State Transitions', () => {
it('should start in CLOSED state', () => {
expect(breaker.getState()).toBe(CircuitState.CLOSED);
});
it('should open after failure threshold', async () => {
const failingFn = () => Promise.reject(new Error('Test failure'));
// First two failures - should remain closed
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
expect(breaker.getState()).toBe(CircuitState.CLOSED);
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
expect(breaker.getState()).toBe(CircuitState.CLOSED);
// Third failure - should open
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
expect(breaker.getState()).toBe(CircuitState.OPEN);
});
it('should reject requests when open', async () => {
// Force open
breaker.forceOpen('Test');
const fn = () => Promise.resolve('success');
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
});
it('should transition to HALF_OPEN after cooldown', async () => {
// Force open
breaker.forceOpen('Test');
expect(breaker.getState()).toBe(CircuitState.OPEN);
// Wait for cooldown
await new Promise(resolve => setTimeout(resolve, 600));
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
});
it('should close after success threshold in HALF_OPEN', async () => {
// Force to half-open
breaker.forceOpen('Test');
await new Promise(resolve => setTimeout(resolve, 600));
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
const successFn = () => Promise.resolve('success');
// First success
await breaker.execute(successFn);
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
// Second success - should close
await breaker.execute(successFn);
expect(breaker.getState()).toBe(CircuitState.CLOSED);
});
it('should reopen on failure in HALF_OPEN', async () => {
// Force to half-open
breaker.forceOpen('Test');
await new Promise(resolve => setTimeout(resolve, 600));
expect(breaker.getState()).toBe(CircuitState.HALF_OPEN);
const failingFn = () => Promise.reject(new Error('Test failure'));
// Failure in half-open should immediately open
await expect(breaker.execute(failingFn)).rejects.toThrow('Test failure');
expect(breaker.getState()).toBe(CircuitState.OPEN);
});
});
describe('Failure Window', () => {
it('should reset failures outside window', async () => {
const failingFn = () => Promise.reject(new Error('Test failure'));
// Two failures
await expect(breaker.execute(failingFn)).rejects.toThrow();
await expect(breaker.execute(failingFn)).rejects.toThrow();
expect(breaker.getState()).toBe(CircuitState.CLOSED);
// Wait for window to expire
await new Promise(resolve => setTimeout(resolve, 1100));
// Two more failures - should still be closed (window reset)
await expect(breaker.execute(failingFn)).rejects.toThrow();
await expect(breaker.execute(failingFn)).rejects.toThrow();
expect(breaker.getState()).toBe(CircuitState.CLOSED);
});
});
describe('Statistics', () => {
it('should track statistics', async () => {
const successFn = () => Promise.resolve('success');
const failingFn = () => Promise.reject(new Error('Test failure'));
await breaker.execute(successFn);
await expect(breaker.execute(failingFn)).rejects.toThrow();
const stats = breaker.getStats();
expect(stats.totalRequests).toBe(2);
expect(stats.successes).toBe(1);
expect(stats.failures).toBe(1);
expect(stats.rejectedRequests).toBe(0);
});
it('should track rejected requests', async () => {
breaker.forceOpen('Test');
const fn = () => Promise.resolve('success');
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
await expect(breaker.execute(fn)).rejects.toThrow(CircuitOpenError);
const stats = breaker.getStats();
expect(stats.rejectedRequests).toBe(2);
});
it('should track state history', async () => {
breaker.forceOpen('Test open');
breaker.forceClose('Test close');
const stats = breaker.getStats();
expect(stats.stateHistory).toHaveLength(2);
expect(stats.stateHistory[0].state).toBe(CircuitState.OPEN);
expect(stats.stateHistory[1].state).toBe(CircuitState.CLOSED);
});
});
describe('Timeout', () => {
it('should apply timeout in HALF_OPEN state', async () => {
// Force to half-open
breaker.forceOpen('Test');
await new Promise(resolve => setTimeout(resolve, 600));
const slowFn = () => new Promise(resolve =>
setTimeout(() => resolve('success'), 10000)
);
// Should timeout (half-open timeout is 5000ms in config)
await expect(breaker.execute(slowFn)).rejects.toThrow(/timed out/);
});
});
});
describe('CircuitBreakerManager', () => {
let manager: CircuitBreakerManager;
beforeEach(() => {
manager = new CircuitBreakerManager({
failureThreshold: 2,
failureWindow: 1000,
cooldownPeriod: 500,
enableLogging: false
});
});
afterEach(() => {
manager.dispose();
});
describe('Breaker Management', () => {
it('should create breakers on demand', () => {
const breaker1 = manager.getBreaker('provider1');
const breaker2 = manager.getBreaker('provider2');
expect(breaker1).toBeDefined();
expect(breaker2).toBeDefined();
expect(breaker1).not.toBe(breaker2);
});
it('should return same breaker for same provider', () => {
const breaker1 = manager.getBreaker('provider1');
const breaker2 = manager.getBreaker('provider1');
expect(breaker1).toBe(breaker2);
});
it('should execute with breaker protection', async () => {
const fn = () => Promise.resolve('success');
const result = await manager.execute('provider1', fn);
expect(result).toBe('success');
});
});
describe('Health Summary', () => {
it('should provide health summary', async () => {
const failingFn = () => Promise.reject(new Error('Test'));
const successFn = () => Promise.resolve('success');
// Create some breakers in different states
await manager.execute('provider1', successFn);
// Force provider2 to open
const breaker2 = manager.getBreaker('provider2');
breaker2.forceOpen('Test');
const summary = manager.getHealthSummary();
expect(summary.total).toBe(2);
expect(summary.closed).toBe(1);
expect(summary.open).toBe(1);
expect(summary.availableProviders).toContain('provider1');
expect(summary.unavailableProviders).toContain('provider2');
});
});
describe('Global Operations', () => {
it('should reset all breakers', () => {
const breaker1 = manager.getBreaker('provider1');
const breaker2 = manager.getBreaker('provider2');
breaker1.forceOpen('Test');
breaker2.forceOpen('Test');
manager.resetAll();
expect(breaker1.getState()).toBe(CircuitState.CLOSED);
expect(breaker2.getState()).toBe(CircuitState.CLOSED);
});
it('should get all stats', async () => {
await manager.execute('provider1', () => Promise.resolve('success'));
await manager.execute('provider2', () => Promise.resolve('success'));
const allStats = manager.getAllStats();
expect(allStats.size).toBe(2);
expect(allStats.has('provider1')).toBe(true);
expect(allStats.has('provider2')).toBe(true);
});
});
});
describe('Circuit Breaker Integration', () => {
it('should handle rapid failures gracefully', async () => {
const breaker = new CircuitBreaker('rapid-test', {
failureThreshold: 3,
failureWindow: 1000,
cooldownPeriod: 100,
enableLogging: false
});
const failingFn = () => Promise.reject(new Error('Rapid failure'));
const promises: Promise<any>[] = [];
// Fire 10 rapid requests
for (let i = 0; i < 10; i++) {
promises.push(
breaker.execute(failingFn).catch(err => err.message)
);
}
const results = await Promise.all(promises);
// First 3 should be actual failures
expect(results.slice(0, 3)).toEqual([
'Rapid failure',
'Rapid failure',
'Rapid failure'
]);
// Rest should be circuit open errors
const openErrors = results.slice(3).filter((msg: string) =>
msg.includes('Circuit breaker is OPEN')
);
expect(openErrors.length).toBeGreaterThan(0);
breaker.dispose();
});
it('should handle concurrent successes correctly', async () => {
const breaker = new CircuitBreaker('concurrent-test', {
failureThreshold: 3,
enableLogging: false
});
let counter = 0;
const successFn = () => {
counter++;
return Promise.resolve(counter);
};
const promises: Promise<number>[] = [];
for (let i = 0; i < 5; i++) {
promises.push(breaker.execute(successFn));
}
const results = await Promise.all(promises);
expect(results).toEqual([1, 2, 3, 4, 5]);
expect(breaker.getState()).toBe(CircuitState.CLOSED);
breaker.dispose();
});
});

View File

@@ -0,0 +1,482 @@
/**
* Mock Providers for Testing
*
* Provides mock implementations of AI service providers for testing purposes
*/
import type { AIService, ChatCompletionOptions, ChatResponse, Message } from '../../ai_interface.js';
import type { UnifiedStreamChunk } from '../unified_stream_handler.js';
/**
* Mock provider configuration
*/
export interface MockProviderConfig {
name: string;
available: boolean;
responseDelay?: number;
errorRate?: number;
streamingSupported?: boolean;
toolsSupported?: boolean;
defaultResponse?: string;
throwError?: Error;
}
/**
* Base mock provider implementation
*/
export class MockProvider implements AIService {
protected config: MockProviderConfig;
private callCount: number = 0;
private streamCallCount: number = 0;
constructor(config: Partial<MockProviderConfig> = {}) {
this.config = {
name: config.name || 'mock',
available: config.available !== false,
responseDelay: config.responseDelay || 0,
errorRate: config.errorRate || 0,
streamingSupported: config.streamingSupported !== false,
toolsSupported: config.toolsSupported !== false,
defaultResponse: config.defaultResponse || 'Mock response',
throwError: config.throwError
};
}
isAvailable(): boolean {
return this.config.available;
}
getName(): string {
return this.config.name;
}
async generateChatCompletion(
messages: Message[],
options: ChatCompletionOptions = {}
): Promise<ChatResponse> {
this.callCount++;
// Simulate delay
if (this.config.responseDelay) {
await new Promise(resolve => setTimeout(resolve, this.config.responseDelay));
}
// Simulate errors
if (this.config.throwError) {
throw this.config.throwError;
}
if (this.config.errorRate && Math.random() < this.config.errorRate) {
throw new Error(`Mock provider error (${this.config.name})`);
}
// Handle streaming
if (options.stream && options.streamCallback) {
return this.generateStreamingResponse(messages, options);
}
// Generate response based on options
const response: ChatResponse = {
text: this.generateContent(messages, options),
model: `${this.config.name}-model`,
provider: this.config.name,
usage: {
promptTokens: this.calculateTokens(messages),
completionTokens: 10,
totalTokens: this.calculateTokens(messages) + 10
}
};
// Add tool calls if requested
if (options.tools && this.config.toolsSupported) {
response.tool_calls = this.generateToolCalls(options.tools);
}
return response;
}
protected async generateStreamingResponse(
messages: Message[],
options: ChatCompletionOptions
): Promise<ChatResponse> {
this.streamCallCount++;
const content = this.generateContent(messages, options);
const chunks = this.splitIntoChunks(content, 5);
let fullContent = '';
for (const chunk of chunks) {
fullContent += chunk;
// Call stream callback
if (options.streamCallback) {
await options.streamCallback(chunk, false);
// Simulate delay between chunks
if (this.config.responseDelay) {
await new Promise(resolve =>
setTimeout(resolve, this.config.responseDelay! / chunks.length)
);
}
}
}
// Send final callback
if (options.streamCallback) {
await options.streamCallback('', true);
}
return {
text: fullContent,
model: `${this.config.name}-model`,
provider: this.config.name,
usage: {
promptTokens: this.calculateTokens(messages),
completionTokens: Math.floor(fullContent.length / 4),
totalTokens: this.calculateTokens(messages) + Math.floor(fullContent.length / 4)
}
};
}
protected generateContent(messages: Message[], options: ChatCompletionOptions): string {
// Return JSON if requested
if (options.expectsJsonResponse) {
return JSON.stringify({
type: 'mock_response',
provider: this.config.name,
messageCount: messages.length
});
}
// Use custom response if provided
if (this.config.defaultResponse) {
return this.config.defaultResponse;
}
// Generate response based on last message
const lastMessage = messages[messages.length - 1];
return `Mock ${this.config.name} response to: ${lastMessage.content}`;
}
protected generateToolCalls(tools: any[]): any[] {
return tools.slice(0, 1).map((tool, index) => ({
id: `call_mock_${index}`,
type: 'function',
function: {
name: tool.function?.name || 'mock_tool',
arguments: JSON.stringify({ mock: true })
}
}));
}
protected calculateTokens(messages: Message[]): number {
return messages.reduce((sum, msg) => {
const content = typeof msg.content === 'string' ? msg.content : '';
return sum + Math.floor(content.length / 4);
}, 0);
}
protected splitIntoChunks(text: string, chunkCount: number): string[] {
const chunkSize = Math.ceil(text.length / chunkCount);
const chunks: string[] = [];
for (let i = 0; i < text.length; i += chunkSize) {
chunks.push(text.slice(i, i + chunkSize));
}
return chunks;
}
// Test helper methods
getCallCount(): number {
return this.callCount;
}
getStreamCallCount(): number {
return this.streamCallCount;
}
resetCallCounts(): void {
this.callCount = 0;
this.streamCallCount = 0;
}
setAvailable(available: boolean): void {
this.config.available = available;
}
setErrorRate(rate: number): void {
this.config.errorRate = rate;
}
setResponseDelay(delay: number): void {
this.config.responseDelay = delay;
}
dispose(): void {
// Cleanup mock resources
this.resetCallCounts();
}
}
/**
* Mock OpenAI provider
*/
export class MockOpenAIProvider extends MockProvider {
constructor(config: Partial<MockProviderConfig> = {}) {
super({
name: 'openai',
...config
});
}
supportsStreaming(): boolean {
return this.config.streamingSupported!;
}
supportsTools(): boolean {
return this.config.toolsSupported!;
}
async *streamCompletion(
messages: Message[],
options: ChatCompletionOptions = {}
): AsyncGenerator<any> {
const content = this.generateContent(messages, options);
const chunks = this.splitIntoChunks(content, 5);
for (const chunk of chunks) {
yield {
choices: [{
delta: { content: chunk },
index: 0
}],
model: 'gpt-4-mock'
};
if (this.config.responseDelay) {
await new Promise(resolve =>
setTimeout(resolve, this.config.responseDelay! / chunks.length)
);
}
}
yield {
choices: [{
delta: {},
finish_reason: 'stop',
index: 0
}],
usage: {
prompt_tokens: this.calculateTokens(messages),
completion_tokens: Math.floor(content.length / 4),
total_tokens: this.calculateTokens(messages) + Math.floor(content.length / 4)
}
};
}
}
/**
* Mock Anthropic provider
*/
export class MockAnthropicProvider extends MockProvider {
constructor(config: Partial<MockProviderConfig> = {}) {
super({
name: 'anthropic',
...config
});
}
async *streamCompletion(
messages: Message[],
options: ChatCompletionOptions = {}
): AsyncGenerator<any> {
const content = this.generateContent(messages, options);
const chunks = this.splitIntoChunks(content, 5);
// Message start
yield {
type: 'message_start',
message: { id: 'msg_mock_123' }
};
// Content blocks
for (const chunk of chunks) {
yield {
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: chunk
}
};
if (this.config.responseDelay) {
await new Promise(resolve =>
setTimeout(resolve, this.config.responseDelay! / chunks.length)
);
}
}
// Message end
yield {
type: 'message_delta',
delta: { stop_reason: 'end_turn' },
usage: {
input_tokens: this.calculateTokens(messages),
output_tokens: Math.floor(content.length / 4)
}
};
yield {
type: 'message_stop'
};
}
}
/**
* Mock Ollama provider
*/
export class MockOllamaProvider extends MockProvider {
constructor(config: Partial<MockProviderConfig> = {}) {
super({
name: 'ollama',
...config
});
}
async *streamCompletion(
messages: Message[],
options: ChatCompletionOptions = {}
): AsyncGenerator<any> {
const content = this.generateContent(messages, options);
const chunks = this.splitIntoChunks(content, 5);
for (let i = 0; i < chunks.length; i++) {
yield {
message: { content: chunks[i] },
model: 'llama2-mock',
done: false
};
if (this.config.responseDelay) {
await new Promise(resolve =>
setTimeout(resolve, this.config.responseDelay! / chunks.length)
);
}
}
// Final chunk with usage
yield {
message: { content: '' },
model: 'llama2-mock',
done: true,
prompt_eval_count: this.calculateTokens(messages),
eval_count: Math.floor(content.length / 4)
};
}
}
/**
* Factory for creating mock providers
*/
export class MockProviderFactory {
private providers: Map<string, MockProvider> = new Map();
createProvider(type: 'openai' | 'anthropic' | 'ollama', config?: Partial<MockProviderConfig>): MockProvider {
let provider: MockProvider;
switch (type) {
case 'openai':
provider = new MockOpenAIProvider(config);
break;
case 'anthropic':
provider = new MockAnthropicProvider(config);
break;
case 'ollama':
provider = new MockOllamaProvider(config);
break;
default:
provider = new MockProvider({ name: type, ...config });
}
this.providers.set(type, provider);
return provider;
}
getProvider(type: string): MockProvider | undefined {
return this.providers.get(type);
}
getAllProviders(): MockProvider[] {
return Array.from(this.providers.values());
}
resetAll(): void {
for (const provider of this.providers.values()) {
provider.resetCallCounts();
}
}
disposeAll(): void {
for (const provider of this.providers.values()) {
provider.dispose();
}
this.providers.clear();
}
}
/**
* Create a mock provider with predefined behaviors
*/
export function createMockProvider(behavior: 'success' | 'error' | 'slow' | 'flaky'): MockProvider {
const configs: Record<string, Partial<MockProviderConfig>> = {
success: {
available: true,
responseDelay: 10
},
error: {
available: true,
throwError: new Error('Mock provider error')
},
slow: {
available: true,
responseDelay: 1000
},
flaky: {
available: true,
errorRate: 0.5,
responseDelay: 100
}
};
return new MockProvider(configs[behavior] || configs.success);
}
/**
* Create a mock streaming response for testing
*/
export async function* createMockStream(
chunks: string[],
delay: number = 10
): AsyncGenerator<UnifiedStreamChunk> {
for (const chunk of chunks) {
yield {
type: 'content',
content: chunk,
metadata: {
provider: 'mock'
}
};
await new Promise(resolve => setTimeout(resolve, delay));
}
yield {
type: 'done',
metadata: {
provider: 'mock',
finishReason: 'stop'
}
};
}

View File

@@ -0,0 +1,502 @@
/**
* Provider Performance Benchmarks
*
* Performance benchmark suite for AI service providers
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { performance } from 'perf_hooks';
import { ProviderFactory, ProviderType } from '../provider_factory.js';
import {
MockProviderFactory,
createMockProvider
} from './mock_providers.js';
import {
StreamAggregator,
createStreamHandler
} from '../unified_stream_handler.js';
import type { AIService, Message } from '../../ai_interface.js';
// Mock providers
vi.mock('../openai_service.js');
vi.mock('../anthropic_service.js');
vi.mock('../ollama_service.js');
import { OpenAIService } from '../openai_service.js';
import { AnthropicService } from '../anthropic_service.js';
import { OllamaService } from '../ollama_service.js';
/**
* Performance metrics interface
*/
interface PerformanceMetrics {
operation: string;
provider: string;
duration: number;
throughput?: number;
latency?: number;
memoryUsed?: number;
}
/**
* Benchmark runner class
*/
class BenchmarkRunner {
private metrics: PerformanceMetrics[] = [];
async runBenchmark(
name: string,
provider: string,
fn: () => Promise<void>,
iterations: number = 100
): Promise<PerformanceMetrics> {
const startMemory = process.memoryUsage().heapUsed;
const startTime = performance.now();
for (let i = 0; i < iterations; i++) {
await fn();
}
const endTime = performance.now();
const endMemory = process.memoryUsage().heapUsed;
const metrics: PerformanceMetrics = {
operation: name,
provider,
duration: endTime - startTime,
throughput: iterations / ((endTime - startTime) / 1000),
latency: (endTime - startTime) / iterations,
memoryUsed: endMemory - startMemory
};
this.metrics.push(metrics);
return metrics;
}
getMetrics(): PerformanceMetrics[] {
return [...this.metrics];
}
printSummary(): void {
console.table(this.metrics.map(m => ({
Operation: m.operation,
Provider: m.provider,
'Avg Latency (ms)': m.latency?.toFixed(2),
'Throughput (ops/s)': m.throughput?.toFixed(2),
'Memory (MB)': ((m.memoryUsed || 0) / 1024 / 1024).toFixed(2)
})));
}
reset(): void {
this.metrics = [];
}
}
describe('Provider Performance Benchmarks', () => {
let factory: ProviderFactory;
let mockFactory: MockProviderFactory;
let runner: BenchmarkRunner;
beforeEach(() => {
// Clear singleton
const existing = ProviderFactory.getInstance();
if (existing) {
existing.dispose();
}
factory = new ProviderFactory({
enableHealthChecks: false,
enableMetrics: false,
enableCaching: true,
cacheTimeout: 60000
});
mockFactory = new MockProviderFactory();
runner = new BenchmarkRunner();
});
afterEach(() => {
factory.dispose();
mockFactory.disposeAll();
vi.clearAllMocks();
});
describe('Provider Creation Performance', () => {
it('should benchmark provider creation speed', async () => {
const providers = ['openai', 'anthropic', 'ollama'] as const;
for (const providerName of providers) {
const mock = createMockProvider('success');
mock.setResponseDelay(0); // No delay for creation benchmarks
switch (providerName) {
case 'openai':
(OpenAIService as any).mockImplementation(() => mock);
break;
case 'anthropic':
(AnthropicService as any).mockImplementation(() => mock);
break;
case 'ollama':
(OllamaService as any).mockImplementation(() => mock);
break;
}
const metrics = await runner.runBenchmark(
'Provider Creation',
providerName,
async () => {
const provider = await factory.createProvider(
ProviderType[providerName.toUpperCase() as keyof typeof ProviderType]
);
},
100
);
expect(metrics.latency).toBeLessThan(10); // Should be fast (< 10ms per creation)
expect(metrics.throughput).toBeGreaterThan(100); // > 100 ops/sec
}
if (process.env.SHOW_BENCHMARKS) {
runner.printSummary();
}
});
it('should benchmark cached vs uncached provider creation', async () => {
const mock = createMockProvider('success');
(OpenAIService as any).mockImplementation(() => mock);
// Benchmark uncached (first creation)
const uncachedFactory = new ProviderFactory({
enableCaching: false,
enableHealthChecks: false
});
const uncachedMetrics = await runner.runBenchmark(
'Uncached Creation',
'openai',
async () => {
await uncachedFactory.createProvider(ProviderType.OPENAI);
},
50
);
// Benchmark cached
runner.reset();
const cachedMetrics = await runner.runBenchmark(
'Cached Creation',
'openai',
async () => {
await factory.createProvider(ProviderType.OPENAI);
},
50
);
// Cached should be significantly faster
expect(cachedMetrics.latency).toBeLessThan(uncachedMetrics.latency! * 0.5);
uncachedFactory.dispose();
});
});
describe('Chat Completion Performance', () => {
it('should benchmark chat completion latency', async () => {
const messages: Message[] = [
{ role: 'user', content: 'Hello, how are you?' }
];
const providers = ['openai', 'anthropic', 'ollama'] as const;
for (const providerName of providers) {
const mock = createMockProvider('success');
mock.setResponseDelay(10); // Simulate 10ms response time
switch (providerName) {
case 'openai':
(OpenAIService as any).mockImplementation(() => mock);
break;
case 'anthropic':
(AnthropicService as any).mockImplementation(() => mock);
break;
case 'ollama':
(OllamaService as any).mockImplementation(() => mock);
break;
}
const provider = await factory.createProvider(
ProviderType[providerName.toUpperCase() as keyof typeof ProviderType]
);
const metrics = await runner.runBenchmark(
'Chat Completion',
providerName,
async () => {
await provider.generateChatCompletion(messages);
},
20
);
expect(metrics.latency).toBeGreaterThan(10); // At least the mock delay
expect(metrics.latency).toBeLessThan(50); // But not too slow
}
if (process.env.SHOW_BENCHMARKS) {
runner.printSummary();
}
});
it('should benchmark streaming vs non-streaming performance', async () => {
const mock = mockFactory.createProvider('openai');
mock.setResponseDelay(50);
(OpenAIService as any).mockImplementation(() => mock);
const provider = await factory.createProvider(ProviderType.OPENAI);
const messages: Message[] = [
{ role: 'user', content: 'Tell me a story' }
];
// Benchmark non-streaming
const nonStreamMetrics = await runner.runBenchmark(
'Non-Streaming',
'openai',
async () => {
await provider.generateChatCompletion(messages, {
stream: false
});
},
10
);
// Benchmark streaming
runner.reset();
const streamMetrics = await runner.runBenchmark(
'Streaming',
'openai',
async () => {
const chunks: string[] = [];
await provider.generateChatCompletion(messages, {
stream: true,
streamCallback: async (chunk) => {
chunks.push(chunk);
}
});
},
10
);
// Streaming might have different characteristics
expect(streamMetrics.latency).toBeDefined();
expect(nonStreamMetrics.latency).toBeDefined();
});
});
describe('Concurrent Operations Performance', () => {
it('should benchmark concurrent provider operations', async () => {
const mock = createMockProvider('success');
mock.setResponseDelay(5);
(OpenAIService as any).mockImplementation(() => mock);
const provider = await factory.createProvider(ProviderType.OPENAI);
const messages: Message[] = [
{ role: 'user', content: 'Test' }
];
// Sequential benchmark
const sequentialStart = performance.now();
for (let i = 0; i < 10; i++) {
await provider.generateChatCompletion(messages);
}
const sequentialDuration = performance.now() - sequentialStart;
// Concurrent benchmark
const concurrentStart = performance.now();
await Promise.all(
Array(10).fill(null).map(() =>
provider.generateChatCompletion(messages)
)
);
const concurrentDuration = performance.now() - concurrentStart;
// Concurrent should be faster
expect(concurrentDuration).toBeLessThan(sequentialDuration);
const speedup = sequentialDuration / concurrentDuration;
expect(speedup).toBeGreaterThan(1.5); // At least 1.5x speedup
});
});
describe('Memory Performance', () => {
it('should benchmark memory usage with cache management', async () => {
const mock = createMockProvider('success');
(OpenAIService as any).mockImplementation(() => mock);
(AnthropicService as any).mockImplementation(() => mock);
(OllamaService as any).mockImplementation(() => mock);
const startMemory = process.memoryUsage().heapUsed;
// Create many providers
for (let i = 0; i < 100; i++) {
await factory.createProvider(ProviderType.OPENAI);
await factory.createProvider(ProviderType.ANTHROPIC);
await factory.createProvider(ProviderType.OLLAMA);
}
const midMemory = process.memoryUsage().heapUsed;
const memoryGrowth = midMemory - startMemory;
// Clear cache
factory.clearCache();
const endMemory = process.memoryUsage().heapUsed;
const memoryReclaimed = midMemory - endMemory;
// Should reclaim some memory
expect(memoryReclaimed).toBeGreaterThan(0);
// Memory growth should be reasonable (< 50MB for 300 operations)
expect(memoryGrowth).toBeLessThan(50 * 1024 * 1024);
});
});
describe('Stream Processing Performance', () => {
it('should benchmark stream chunk processing speed', async () => {
const aggregator = new StreamAggregator();
const handler = createStreamHandler({
provider: 'openai',
onChunk: (chunk) => aggregator.addChunk(chunk)
});
const chunks = Array(100).fill(null).map((_, i) => ({
choices: [{
delta: { content: `Chunk ${i}` },
index: 0
}]
}));
const metrics = await runner.runBenchmark(
'Stream Processing',
'openai',
async () => {
aggregator.reset();
for (const chunk of chunks) {
await handler.processChunk(chunk);
}
},
10
);
// Should process chunks quickly
const chunksPerSecond = (chunks.length * 10) / (metrics.duration / 1000);
expect(chunksPerSecond).toBeGreaterThan(1000); // > 1000 chunks/sec
});
});
describe('Health Check Performance', () => {
it('should benchmark health check operations', async () => {
const providers = [ProviderType.OPENAI, ProviderType.ANTHROPIC, ProviderType.OLLAMA];
for (const providerType of providers) {
const mock = createMockProvider('success');
mock.setResponseDelay(20); // Simulate network latency
switch (providerType) {
case ProviderType.OPENAI:
(OpenAIService as any).mockImplementation(() => mock);
break;
case ProviderType.ANTHROPIC:
(AnthropicService as any).mockImplementation(() => mock);
break;
case ProviderType.OLLAMA:
(OllamaService as any).mockImplementation(() => mock);
break;
}
const metrics = await runner.runBenchmark(
'Health Check',
providerType,
async () => {
await factory.checkProviderHealth(providerType);
},
10
);
// Health checks should complete reasonably quickly
expect(metrics.latency).toBeLessThan(100); // < 100ms per check
}
});
});
describe('Fallback Performance', () => {
it('should benchmark fallback provider switching', async () => {
const fallbackFactory = new ProviderFactory({
enableHealthChecks: false,
enableFallback: true,
fallbackProviders: [ProviderType.ANTHROPIC, ProviderType.OLLAMA],
enableCaching: false
});
let attemptCount = 0;
// OpenAI fails first 2 times
(OpenAIService as any).mockImplementation(() => {
attemptCount++;
if (attemptCount <= 2) {
throw new Error('OpenAI unavailable');
}
return createMockProvider('success');
});
// Anthropic always fails
(AnthropicService as any).mockImplementation(() => {
throw new Error('Anthropic unavailable');
});
// Ollama succeeds
const ollamaMock = createMockProvider('success');
(OllamaService as any).mockImplementation(() => ollamaMock);
const metrics = await runner.runBenchmark(
'Fallback Switch',
'multi',
async () => {
attemptCount = 0;
await fallbackFactory.createProvider(ProviderType.OPENAI);
},
10
);
// Fallback should add some overhead but still be reasonable
expect(metrics.latency).toBeLessThan(50); // < 50ms including fallback
fallbackFactory.dispose();
});
});
// Only run this in CI or when explicitly requested
if (process.env.RUN_FULL_BENCHMARKS) {
describe('Load Testing', () => {
it('should handle high load scenarios', async () => {
const mock = createMockProvider('success');
mock.setResponseDelay(1);
(OpenAIService as any).mockImplementation(() => mock);
const provider = await factory.createProvider(ProviderType.OPENAI);
const messages: Message[] = [{ role: 'user', content: 'Load test' }];
const loadTestStart = performance.now();
const promises = Array(1000).fill(null).map(() =>
provider.generateChatCompletion(messages)
);
await Promise.all(promises);
const loadTestDuration = performance.now() - loadTestStart;
const requestsPerSecond = 1000 / (loadTestDuration / 1000);
// Should handle at least 100 requests per second
expect(requestsPerSecond).toBeGreaterThan(100);
console.log(`Load test: ${requestsPerSecond.toFixed(2)} requests/second`);
});
});
}
});

View File

@@ -0,0 +1,434 @@
/**
* Provider Factory Tests
*
* Comprehensive test suite for the provider factory pattern implementation
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import {
ProviderFactory,
ProviderType,
type ProviderCapabilities,
type ProviderHealthStatus,
getProviderFactory
} from '../provider_factory.js';
import { OpenAIService } from '../openai_service.js';
import { AnthropicService } from '../anthropic_service.js';
import { OllamaService } from '../ollama_service.js';
import type { AIService, ChatResponse } from '../../ai_interface.js';
// Mock the services
vi.mock('../openai_service.js');
vi.mock('../anthropic_service.js');
vi.mock('../ollama_service.js');
vi.mock('../../log.js', () => ({
default: {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn()
}
}));
describe('ProviderFactory', () => {
let factory: ProviderFactory;
beforeEach(() => {
// Clear any existing singleton
const existingFactory = ProviderFactory.getInstance();
if (existingFactory) {
existingFactory.dispose();
}
// Create new factory instance for testing
factory = new ProviderFactory({
enableHealthChecks: false, // Disable for tests
enableMetrics: false,
cacheTimeout: 1000 // Short timeout for tests
});
});
afterEach(() => {
// Cleanup
factory.dispose();
vi.clearAllMocks();
});
describe('Provider Creation', () => {
it('should create OpenAI provider', async () => {
// Mock OpenAI service
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockResolvedValue({
content: 'test response',
role: 'assistant'
})
};
(OpenAIService as any).mockImplementation(() => mockService);
const service = await factory.createProvider(ProviderType.OPENAI);
expect(service).toBeDefined();
expect(mockService.isAvailable).toHaveBeenCalled();
});
it('should create Anthropic provider', async () => {
// Mock Anthropic service
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockResolvedValue({
content: 'test response',
role: 'assistant'
})
};
(AnthropicService as any).mockImplementation(() => mockService);
const service = await factory.createProvider(ProviderType.ANTHROPIC);
expect(service).toBeDefined();
expect(mockService.isAvailable).toHaveBeenCalled();
});
it('should create Ollama provider', async () => {
// Mock Ollama service
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockResolvedValue({
content: 'test response',
role: 'assistant'
})
};
(OllamaService as any).mockImplementation(() => mockService);
const service = await factory.createProvider(ProviderType.OLLAMA);
expect(service).toBeDefined();
expect(mockService.isAvailable).toHaveBeenCalled();
});
it('should throw error for unavailable provider', async () => {
// Mock service as unavailable
const mockService = {
isAvailable: vi.fn().mockReturnValue(false)
};
(OpenAIService as any).mockImplementation(() => mockService);
await expect(factory.createProvider(ProviderType.OPENAI))
.rejects.toThrow('OpenAI service is not available');
});
it('should throw error for custom provider (not implemented)', async () => {
await expect(factory.createProvider(ProviderType.CUSTOM))
.rejects.toThrow('Custom providers not yet implemented');
});
});
describe('Provider Caching', () => {
it('should cache created providers', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn()
};
(OpenAIService as any).mockImplementation(() => mockService);
const service1 = await factory.createProvider(ProviderType.OPENAI);
const service2 = await factory.createProvider(ProviderType.OPENAI);
// Should return same instance
expect(service1).toBe(service2);
// Constructor should only be called once
expect(OpenAIService).toHaveBeenCalledTimes(1);
});
it('should respect cache timeout', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn()
};
(OpenAIService as any).mockImplementation(() => mockService);
const service1 = await factory.createProvider(ProviderType.OPENAI);
// Wait for cache to expire
await new Promise(resolve => setTimeout(resolve, 1100));
const service2 = await factory.createProvider(ProviderType.OPENAI);
// Should create new instance after timeout
expect(service1).not.toBe(service2);
expect(OpenAIService).toHaveBeenCalledTimes(2);
});
it('should cache providers with different configurations separately', async () => {
const mockService1 = {
isAvailable: vi.fn().mockReturnValue(true)
};
const mockService2 = {
isAvailable: vi.fn().mockReturnValue(true)
};
let callCount = 0;
(OpenAIService as any).mockImplementation(() => {
callCount++;
return callCount === 1 ? mockService1 : mockService2;
});
const service1 = await factory.createProvider(ProviderType.OPENAI, { baseUrl: 'url1' });
const service2 = await factory.createProvider(ProviderType.OPENAI, { baseUrl: 'url2' });
expect(service1).not.toBe(service2);
expect(OpenAIService).toHaveBeenCalledTimes(2);
});
});
describe('Capabilities Detection', () => {
it('should return default capabilities for providers', () => {
const openAICaps = factory.getCapabilities(ProviderType.OPENAI);
expect(openAICaps).toBeDefined();
expect(openAICaps?.streaming).toBe(true);
expect(openAICaps?.functionCalling).toBe(true);
expect(openAICaps?.vision).toBe(true);
expect(openAICaps?.contextWindow).toBe(128000);
});
it('should allow registering custom capabilities', () => {
const customCaps: ProviderCapabilities = {
streaming: false,
functionCalling: false,
vision: false,
contextWindow: 2048,
maxOutputTokens: 512,
supportsSystemPrompt: false,
supportsTools: false,
supportedModalities: ['text'],
customEndpoints: true,
batchProcessing: false
};
factory.registerCapabilities(ProviderType.CUSTOM, customCaps);
const retrieved = factory.getCapabilities(ProviderType.CUSTOM);
expect(retrieved).toEqual(customCaps);
});
});
describe('Health Checks', () => {
it('should perform health check on provider', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockResolvedValue({
content: 'Hi',
role: 'assistant'
})
};
(OpenAIService as any).mockImplementation(() => mockService);
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
expect(health.provider).toBe(ProviderType.OPENAI);
expect(health.healthy).toBe(true);
expect(health.lastChecked).toBeInstanceOf(Date);
expect(health.latency).toBeDefined();
});
it('should report unhealthy provider on error', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockRejectedValue(new Error('API Error'))
};
(OpenAIService as any).mockImplementation(() => mockService);
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
expect(health.provider).toBe(ProviderType.OPENAI);
expect(health.healthy).toBe(false);
expect(health.error).toBe('API Error');
});
it('should store health status', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn().mockResolvedValue({
content: 'Hi',
role: 'assistant'
})
};
(OpenAIService as any).mockImplementation(() => mockService);
await factory.checkProviderHealth(ProviderType.OPENAI);
const status = factory.getHealthStatus(ProviderType.OPENAI);
expect(status).toBeDefined();
expect(status?.healthy).toBe(true);
});
});
describe('Fallback Mechanism', () => {
it('should fallback to alternative provider on failure', async () => {
// Create factory with fallback enabled
const fallbackFactory = new ProviderFactory({
enableHealthChecks: false,
enableFallback: true,
fallbackProviders: [ProviderType.OLLAMA],
enableCaching: false
});
// Mock OpenAI to fail
(OpenAIService as any).mockImplementation(() => {
throw new Error('OpenAI unavailable');
});
// Mock Ollama to succeed
const mockOllamaService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn()
};
(OllamaService as any).mockImplementation(() => mockOllamaService);
// Should fallback to Ollama
const service = await fallbackFactory.createProvider(ProviderType.OPENAI);
expect(service).toBeDefined();
expect(OllamaService).toHaveBeenCalled();
fallbackFactory.dispose();
});
});
describe('Statistics', () => {
it('should track usage statistics', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
generateChatCompletion: vi.fn()
};
(OpenAIService as any).mockImplementation(() => mockService);
// Create providers
await factory.createProvider(ProviderType.OPENAI);
await factory.createProvider(ProviderType.OPENAI); // Uses cache
const stats = factory.getStatistics();
expect(stats.cachedProviders).toBe(1);
expect(stats.totalUsage).toBe(2); // Created once, used twice
expect(stats.providerUsage['openai']).toBe(2);
});
});
describe('Cache Management', () => {
it('should clear all cached providers', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
dispose: vi.fn()
};
(OpenAIService as any).mockImplementation(() => mockService);
(AnthropicService as any).mockImplementation(() => mockService);
// Create multiple providers
await factory.createProvider(ProviderType.OPENAI);
await factory.createProvider(ProviderType.ANTHROPIC);
const statsBefore = factory.getStatistics();
expect(statsBefore.cachedProviders).toBe(2);
factory.clearCache();
const statsAfter = factory.getStatistics();
expect(statsAfter.cachedProviders).toBe(0);
expect(mockService.dispose).toHaveBeenCalledTimes(2);
});
it('should cleanup expired cache entries', async () => {
const mockService = {
isAvailable: vi.fn().mockReturnValue(true),
dispose: vi.fn()
};
(OpenAIService as any).mockImplementation(() => mockService);
await factory.createProvider(ProviderType.OPENAI);
// Wait for cache to expire
await new Promise(resolve => setTimeout(resolve, 1100));
factory.cleanupExpiredCache();
const stats = factory.getStatistics();
expect(stats.cachedProviders).toBe(0);
expect(mockService.dispose).toHaveBeenCalled();
});
});
describe('Singleton Pattern', () => {
it('should return same instance via getInstance', () => {
const instance1 = ProviderFactory.getInstance();
const instance2 = ProviderFactory.getInstance();
expect(instance1).toBe(instance2);
instance1.dispose();
});
it('should create new instance after disposal', () => {
const instance1 = ProviderFactory.getInstance();
instance1.dispose();
const instance2 = ProviderFactory.getInstance();
expect(instance1).not.toBe(instance2);
instance2.dispose();
});
});
describe('Error Handling', () => {
it('should handle provider creation errors gracefully', async () => {
(OpenAIService as any).mockImplementation(() => {
throw new Error('Constructor error');
});
await expect(factory.createProvider(ProviderType.OPENAI))
.rejects.toThrow('Constructor error');
});
it('should throw error when factory is disposed', async () => {
factory.dispose();
await expect(factory.createProvider(ProviderType.OPENAI))
.rejects.toThrow('ProviderFactory has been disposed');
});
});
});
describe('getProviderFactory Helper', () => {
it('should return factory instance', () => {
const factory = getProviderFactory();
expect(factory).toBeInstanceOf(ProviderFactory);
factory.dispose();
});
it('should pass options to factory', () => {
const factory = getProviderFactory({
enableHealthChecks: false,
enableMetrics: false
});
expect(factory).toBeInstanceOf(ProviderFactory);
factory.dispose();
});
});

View File

@@ -0,0 +1,418 @@
/**
* Provider Integration Tests
*
* Integration tests for provider factory with AI Service Manager
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { ProviderFactory, ProviderType } from '../provider_factory.js';
import {
MockProviderFactory,
MockProvider,
createMockProvider,
createMockStream
} from './mock_providers.js';
import type { AIService, ChatCompletionOptions } from '../../ai_interface.js';
import {
UnifiedStreamChunk,
StreamAggregator,
createStreamHandler
} from '../unified_stream_handler.js';
// Mock the actual provider imports
vi.mock('../openai_service.js', () => ({
OpenAIService: vi.fn()
}));
vi.mock('../anthropic_service.js', () => ({
AnthropicService: vi.fn()
}));
vi.mock('../ollama_service.js', () => ({
OllamaService: vi.fn()
}));
// Import mocked modules
import { OpenAIService } from '../openai_service.js';
import { AnthropicService } from '../anthropic_service.js';
import { OllamaService } from '../ollama_service.js';
describe('Provider Factory Integration', () => {
let factory: ProviderFactory;
let mockFactory: MockProviderFactory;
beforeEach(() => {
// Clear singleton
const existing = ProviderFactory.getInstance();
if (existing) {
existing.dispose();
}
factory = new ProviderFactory({
enableHealthChecks: false,
enableMetrics: true,
cacheTimeout: 5000
});
mockFactory = new MockProviderFactory();
});
afterEach(() => {
factory.dispose();
mockFactory.disposeAll();
vi.clearAllMocks();
});
describe('Multi-Provider Management', () => {
it('should manage multiple providers simultaneously', async () => {
// Setup mock providers
const openaiMock = mockFactory.createProvider('openai');
const anthropicMock = mockFactory.createProvider('anthropic');
const ollamaMock = mockFactory.createProvider('ollama');
(OpenAIService as any).mockImplementation(() => openaiMock);
(AnthropicService as any).mockImplementation(() => anthropicMock);
(OllamaService as any).mockImplementation(() => ollamaMock);
// Create providers
const openai = await factory.createProvider(ProviderType.OPENAI);
const anthropic = await factory.createProvider(ProviderType.ANTHROPIC);
const ollama = await factory.createProvider(ProviderType.OLLAMA);
// Test all are available
expect(openai.isAvailable()).toBe(true);
expect(anthropic.isAvailable()).toBe(true);
expect(ollama.isAvailable()).toBe(true);
// Test statistics
const stats = factory.getStatistics();
expect(stats.cachedProviders).toBe(3);
});
it('should handle provider-specific configurations', async () => {
const customConfig = {
baseUrl: 'https://custom.api.endpoint',
timeout: 30000
};
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
const provider1 = await factory.createProvider(ProviderType.OPENAI, customConfig);
const provider2 = await factory.createProvider(ProviderType.OPENAI); // Different config
// Should create two separate instances
const stats = factory.getStatistics();
expect(stats.cachedProviders).toBe(2);
});
});
describe('Fallback Scenarios', () => {
it('should fallback through provider chain on failures', async () => {
const failingFactory = new ProviderFactory({
enableHealthChecks: false,
enableFallback: true,
fallbackProviders: [ProviderType.ANTHROPIC, ProviderType.OLLAMA],
enableCaching: false
});
// OpenAI fails
(OpenAIService as any).mockImplementation(() => {
throw new Error('OpenAI unavailable');
});
// Anthropic fails
(AnthropicService as any).mockImplementation(() => {
throw new Error('Anthropic unavailable');
});
// Ollama succeeds
const ollamaMock = mockFactory.createProvider('ollama');
(OllamaService as any).mockImplementation(() => ollamaMock);
const provider = await failingFactory.createProvider(ProviderType.OPENAI);
expect(provider).toBeDefined();
expect(provider.isAvailable()).toBe(true);
expect(OllamaService).toHaveBeenCalled();
failingFactory.dispose();
});
it('should handle complete fallback failure', async () => {
const failingFactory = new ProviderFactory({
enableHealthChecks: false,
enableFallback: true,
fallbackProviders: [ProviderType.ANTHROPIC],
enableCaching: false
});
// All providers fail
(OpenAIService as any).mockImplementation(() => {
throw new Error('OpenAI unavailable');
});
(AnthropicService as any).mockImplementation(() => {
throw new Error('Anthropic unavailable');
});
await expect(failingFactory.createProvider(ProviderType.OPENAI))
.rejects.toThrow('OpenAI unavailable');
failingFactory.dispose();
});
});
describe('Health Monitoring', () => {
it('should perform health checks across all providers', async () => {
// Setup healthy providers
const openaiMock = createMockProvider('success');
const anthropicMock = createMockProvider('success');
const ollamaMock = createMockProvider('success');
(OpenAIService as any).mockImplementation(() => openaiMock);
(AnthropicService as any).mockImplementation(() => anthropicMock);
(OllamaService as any).mockImplementation(() => ollamaMock);
// Perform health checks
const openaiHealth = await factory.checkProviderHealth(ProviderType.OPENAI);
const anthropicHealth = await factory.checkProviderHealth(ProviderType.ANTHROPIC);
const ollamaHealth = await factory.checkProviderHealth(ProviderType.OLLAMA);
expect(openaiHealth.healthy).toBe(true);
expect(anthropicHealth.healthy).toBe(true);
expect(ollamaHealth.healthy).toBe(true);
// Check all statuses
const allStatuses = factory.getAllHealthStatuses();
expect(allStatuses.size).toBe(3);
});
it('should detect unhealthy providers', async () => {
const errorMock = createMockProvider('error');
(OpenAIService as any).mockImplementation(() => errorMock);
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
expect(health.healthy).toBe(false);
expect(health.error).toBeDefined();
});
it('should measure provider latency', async () => {
const slowMock = createMockProvider('slow');
slowMock.setResponseDelay(100);
(OpenAIService as any).mockImplementation(() => slowMock);
const health = await factory.checkProviderHealth(ProviderType.OPENAI);
expect(health.latency).toBeGreaterThan(100);
});
});
describe('Streaming Integration', () => {
it('should handle streaming across providers', async () => {
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
const provider = await factory.createProvider(ProviderType.OPENAI);
const messages = [{ role: 'user' as const, content: 'Hello' }];
const chunks: string[] = [];
const response = await provider.generateChatCompletion(messages, {
stream: true,
streamCallback: async (chunk, isDone) => {
if (!isDone) {
chunks.push(chunk);
}
}
});
expect(chunks.length).toBeGreaterThan(0);
expect(response.text).toBe(chunks.join(''));
});
it('should unify streaming formats', async () => {
const aggregator = new StreamAggregator();
// Test OpenAI format
const openaiHandler = createStreamHandler({
provider: 'openai',
onChunk: (chunk) => aggregator.addChunk(chunk)
});
await openaiHandler.processChunk({
choices: [{
delta: { content: 'Hello from OpenAI' }
}]
});
// Test Anthropic format
aggregator.reset();
const anthropicHandler = createStreamHandler({
provider: 'anthropic',
onChunk: (chunk) => aggregator.addChunk(chunk)
});
await anthropicHandler.processChunk(
'event: content_block_delta\ndata: {"delta":{"type":"text_delta","text":"Hello from Anthropic"}}'
);
// Test Ollama format
aggregator.reset();
const ollamaHandler = createStreamHandler({
provider: 'ollama',
onChunk: (chunk) => aggregator.addChunk(chunk)
});
await ollamaHandler.processChunk({
message: { content: 'Hello from Ollama' },
done: false
});
// All should produce similar unified format
const response = aggregator.getResponse();
expect(response.text).toContain('Hello from Ollama');
});
});
describe('Performance and Caching', () => {
it('should cache providers efficiently', async () => {
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
const startTime = Date.now();
// First call - creates provider
await factory.createProvider(ProviderType.OPENAI);
const firstCallTime = Date.now() - startTime;
// Second call - uses cache
const cachedStartTime = Date.now();
await factory.createProvider(ProviderType.OPENAI);
const cachedCallTime = Date.now() - cachedStartTime;
// Cached call should be much faster
expect(cachedCallTime).toBeLessThan(firstCallTime);
expect(OpenAIService).toHaveBeenCalledTimes(1);
});
it('should track usage statistics', async () => {
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
// Create and use provider multiple times
for (let i = 0; i < 5; i++) {
await factory.createProvider(ProviderType.OPENAI);
}
const stats = factory.getStatistics();
expect(stats.totalUsage).toBe(5);
expect(stats.providerUsage['openai']).toBe(5);
});
it('should cleanup expired cache automatically', async () => {
const shortCacheFactory = new ProviderFactory({
enableHealthChecks: false,
cacheTimeout: 100
});
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
await shortCacheFactory.createProvider(ProviderType.OPENAI);
let stats = shortCacheFactory.getStatistics();
expect(stats.cachedProviders).toBe(1);
// Wait for cache to expire
await new Promise(resolve => setTimeout(resolve, 150));
shortCacheFactory.cleanupExpiredCache();
stats = shortCacheFactory.getStatistics();
expect(stats.cachedProviders).toBe(0);
shortCacheFactory.dispose();
});
});
describe('Error Recovery', () => {
it('should recover from transient errors', async () => {
const flakyMock = createMockProvider('flaky');
(OpenAIService as any).mockImplementation(() => flakyMock);
const provider = await factory.createProvider(ProviderType.OPENAI);
let successCount = 0;
let errorCount = 0;
// Try multiple requests
for (let i = 0; i < 10; i++) {
try {
await provider.generateChatCompletion([
{ role: 'user', content: 'Test' }
]);
successCount++;
} catch (error) {
errorCount++;
}
}
// Should have some successes and some failures
expect(successCount).toBeGreaterThan(0);
expect(errorCount).toBeGreaterThan(0);
});
it('should handle provider disposal gracefully', async () => {
const mock = mockFactory.createProvider('openai');
mock.dispose = vi.fn();
(OpenAIService as any).mockImplementation(() => mock);
await factory.createProvider(ProviderType.OPENAI);
factory.clearCache();
expect(mock.dispose).toHaveBeenCalled();
});
});
describe('Concurrent Operations', () => {
it('should handle concurrent provider creation', async () => {
const mock = mockFactory.createProvider('openai');
(OpenAIService as any).mockImplementation(() => mock);
// Create multiple providers concurrently
const promises = Array(10).fill(null).map(() =>
factory.createProvider(ProviderType.OPENAI)
);
const providers = await Promise.all(promises);
// All should get the same cached instance
const firstProvider = providers[0];
expect(providers.every(p => p === firstProvider)).toBe(true);
// Constructor should only be called once
expect(OpenAIService).toHaveBeenCalledTimes(1);
});
it('should handle concurrent health checks', async () => {
const openaiMock = createMockProvider('success');
const anthropicMock = createMockProvider('success');
const ollamaMock = createMockProvider('success');
(OpenAIService as any).mockImplementation(() => openaiMock);
(AnthropicService as any).mockImplementation(() => anthropicMock);
(OllamaService as any).mockImplementation(() => ollamaMock);
// Perform health checks concurrently
const healthChecks = await Promise.all([
factory.checkProviderHealth(ProviderType.OPENAI),
factory.checkProviderHealth(ProviderType.ANTHROPIC),
factory.checkProviderHealth(ProviderType.OLLAMA)
]);
expect(healthChecks).toHaveLength(3);
expect(healthChecks.every(h => h.healthy)).toBe(true);
});
});
});

View File

@@ -0,0 +1,577 @@
/**
* Unified Stream Handler Tests
*
* Test suite for the unified streaming interface
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import {
UnifiedStreamChunk,
StreamHandlerConfig,
OpenAIStreamHandler,
AnthropicStreamHandler,
OllamaStreamHandler,
createStreamHandler,
StreamAggregator,
unifiedStream
} from '../unified_stream_handler.js';
import type { ChatResponse } from '../../ai_interface.js';
vi.mock('../../log.js', () => ({
default: {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn()
}
}));
describe('OpenAIStreamHandler', () => {
let handler: OpenAIStreamHandler;
let chunks: UnifiedStreamChunk[];
let config: StreamHandlerConfig;
beforeEach(() => {
chunks = [];
config = {
provider: 'openai',
onChunk: (chunk) => { chunks.push(chunk); },
onError: vi.fn(),
onComplete: vi.fn()
};
handler = new OpenAIStreamHandler(config);
});
describe('Content Streaming', () => {
it('should process content chunks', async () => {
const chunk = {
choices: [{
delta: { content: 'Hello' },
index: 0
}],
model: 'gpt-4'
};
await handler.processChunk(JSON.stringify(chunk));
expect(chunks).toHaveLength(1);
expect(chunks[0]).toEqual({
type: 'content',
content: 'Hello',
metadata: {
provider: 'openai',
model: 'gpt-4'
}
});
});
it('should handle multiple content chunks', async () => {
const chunk1 = {
choices: [{
delta: { content: 'Hello' }
}]
};
const chunk2 = {
choices: [{
delta: { content: ' World' }
}]
};
await handler.processChunk(JSON.stringify(chunk1));
await handler.processChunk(JSON.stringify(chunk2));
expect(chunks).toHaveLength(2);
expect(chunks[0].content).toBe('Hello');
expect(chunks[1].content).toBe(' World');
});
it('should handle SSE format', async () => {
const sseChunk = 'data: {"choices":[{"delta":{"content":"Test"}}]}';
await handler.processChunk(sseChunk);
expect(chunks).toHaveLength(1);
expect(chunks[0].content).toBe('Test');
});
it('should handle [DONE] marker', async () => {
await handler.processChunk('data: [DONE]');
expect(chunks).toHaveLength(1);
expect(chunks[0].type).toBe('done');
});
});
describe('Tool Calls', () => {
it('should process tool call chunks', async () => {
const chunk = {
choices: [{
delta: {
tool_calls: [{
index: 0,
id: 'call_123',
function: {
name: 'get_weather',
arguments: '{"location":'
}
}]
}
}]
};
await handler.processChunk(JSON.stringify(chunk));
expect(chunks).toHaveLength(1);
expect(chunks[0]).toEqual({
type: 'tool_call',
toolCall: {
id: 'call_123',
name: 'get_weather',
arguments: '{"location":'
},
metadata: {
provider: 'openai'
}
});
});
it('should accumulate tool call arguments', async () => {
const chunk1 = {
choices: [{
delta: {
tool_calls: [{
index: 0,
id: 'call_123',
function: {
name: 'get_weather',
arguments: '{"location":'
}
}]
}
}]
};
const chunk2 = {
choices: [{
delta: {
tool_calls: [{
index: 0,
function: {
arguments: '"New York"}'
}
}]
}
}]
};
await handler.processChunk(JSON.stringify(chunk1));
await handler.processChunk(JSON.stringify(chunk2));
expect(chunks).toHaveLength(2);
expect(chunks[1].toolCall?.arguments).toBe('{"location":"New York"}');
});
});
describe('Completion', () => {
it('should handle finish reason', async () => {
const chunk = {
choices: [{
delta: { content: 'Done' },
finish_reason: 'stop'
}],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15
}
};
await handler.processChunk(JSON.stringify(chunk));
const response = await handler.complete();
expect(response.text).toBe('Done');
// finishReason is not directly on ChatResponse anymore
expect(response.usage).toEqual({
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
});
});
it('should call onComplete callback', async () => {
await handler.processChunk('data: [DONE]');
const response = await handler.complete();
expect(config.onComplete).toHaveBeenCalledWith(response);
});
});
describe('Error Handling', () => {
it('should handle parse errors', async () => {
await handler.processChunk('invalid json');
expect(config.onError).toHaveBeenCalled();
expect(chunks.find(c => c.type === 'error')).toBeDefined();
});
it('should handle timeout', async () => {
const timeoutConfig = { ...config, timeout: 100 };
const timeoutHandler = new OpenAIStreamHandler(timeoutConfig);
await new Promise(resolve => setTimeout(resolve, 150));
expect(config.onError).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringContaining('timeout')
})
);
});
});
});
describe('AnthropicStreamHandler', () => {
let handler: AnthropicStreamHandler;
let chunks: UnifiedStreamChunk[];
let config: StreamHandlerConfig;
beforeEach(() => {
chunks = [];
config = {
provider: 'anthropic',
onChunk: (chunk) => { chunks.push(chunk); },
onError: vi.fn(),
onComplete: vi.fn()
};
handler = new AnthropicStreamHandler(config);
});
describe('Content Streaming', () => {
it('should process text delta events', async () => {
const event = 'event: content_block_delta\ndata: {"delta":{"type":"text_delta","text":"Hello"},"model":"claude-3"}';
await handler.processChunk(event);
expect(chunks).toHaveLength(1);
expect(chunks[0]).toEqual({
type: 'content',
content: 'Hello',
metadata: {
provider: 'anthropic',
model: 'claude-3'
}
});
});
it('should handle message start event', async () => {
const event = 'event: message_start\ndata: {"message":{"id":"msg_123"}}';
await handler.processChunk(event);
// Message start doesn't produce chunks
expect(chunks).toHaveLength(0);
});
it('should handle message stop event', async () => {
const event = 'event: message_stop\ndata: {}';
await handler.processChunk(event);
expect(chunks).toHaveLength(1);
expect(chunks[0].type).toBe('done');
});
});
describe('Usage Tracking', () => {
it('should track token usage', async () => {
const event = 'event: message_delta\ndata: {"usage":{"input_tokens":10,"output_tokens":5}}';
await handler.processChunk(event);
const response = await handler.complete();
expect(response.usage).toEqual({
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
});
});
it('should handle stop reason', async () => {
const event = 'event: message_delta\ndata: {"delta":{"stop_reason":"end_turn"}}';
await handler.processChunk(event);
const response = await handler.complete();
// finishReason is not directly on ChatResponse anymore
});
});
describe('Error Handling', () => {
it('should handle error events', async () => {
const event = 'event: error\ndata: {"error":{"message":"API Error"}}';
await handler.processChunk(event);
expect(config.onError).toHaveBeenCalledWith(
expect.objectContaining({
message: 'API Error'
})
);
});
});
});
describe('OllamaStreamHandler', () => {
let handler: OllamaStreamHandler;
let chunks: UnifiedStreamChunk[];
let config: StreamHandlerConfig;
beforeEach(() => {
chunks = [];
config = {
provider: 'ollama',
onChunk: (chunk) => { chunks.push(chunk); },
onError: vi.fn(),
onComplete: vi.fn()
};
handler = new OllamaStreamHandler(config);
});
describe('Content Streaming', () => {
it('should process content chunks', async () => {
const chunk = {
message: { content: 'Hello' },
model: 'llama2',
done: false
};
await handler.processChunk(chunk);
expect(chunks).toHaveLength(1);
expect(chunks[0]).toEqual({
type: 'content',
content: 'Hello',
metadata: {
provider: 'ollama',
model: 'llama2'
}
});
});
it('should handle completion', async () => {
const chunk = {
message: { content: 'Final' },
done: true,
prompt_eval_count: 10,
eval_count: 5
};
await handler.processChunk(chunk);
expect(chunks).toHaveLength(2);
expect(chunks[0].type).toBe('content');
expect(chunks[1].type).toBe('done');
expect(chunks[1].metadata?.usage).toEqual({
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
});
});
});
describe('Tool Calls', () => {
it('should process tool calls', async () => {
const chunk = {
message: {
tool_calls: [{
id: 'tool_1',
function: {
name: 'search',
arguments: { query: 'test' }
}
}]
},
done: false
};
await handler.processChunk(chunk);
expect(chunks).toHaveLength(1);
expect(chunks[0]).toEqual({
type: 'tool_call',
toolCall: {
id: 'tool_1',
name: 'search',
arguments: '{"query":"test"}'
},
metadata: {
provider: 'ollama'
}
});
});
});
});
describe('createStreamHandler', () => {
it('should create OpenAI handler', () => {
const handler = createStreamHandler({
provider: 'openai',
onChunk: vi.fn()
});
expect(handler).toBeInstanceOf(OpenAIStreamHandler);
});
it('should create Anthropic handler', () => {
const handler = createStreamHandler({
provider: 'anthropic',
onChunk: vi.fn()
});
expect(handler).toBeInstanceOf(AnthropicStreamHandler);
});
it('should create Ollama handler', () => {
const handler = createStreamHandler({
provider: 'ollama',
onChunk: vi.fn()
});
expect(handler).toBeInstanceOf(OllamaStreamHandler);
});
it('should throw for unsupported provider', () => {
expect(() => createStreamHandler({
provider: 'unsupported' as any,
onChunk: vi.fn()
})).toThrow('Unsupported provider: unsupported');
});
});
describe('StreamAggregator', () => {
let aggregator: StreamAggregator;
beforeEach(() => {
aggregator = new StreamAggregator();
});
it('should aggregate content chunks', () => {
aggregator.addChunk({
type: 'content',
content: 'Hello'
});
aggregator.addChunk({
type: 'content',
content: ' World'
});
const response = aggregator.getResponse();
expect(response.text).toBe('Hello World');
});
it('should aggregate tool calls', () => {
aggregator.addChunk({
type: 'tool_call',
toolCall: {
id: '1',
name: 'search',
arguments: '{}'
}
});
const response = aggregator.getResponse();
expect(response.tool_calls).toHaveLength(1);
expect(response.tool_calls?.[0]).toEqual({
id: '1',
name: 'search',
arguments: '{}'
});
});
it('should aggregate metadata', () => {
aggregator.addChunk({
type: 'done',
metadata: {
provider: 'openai',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
});
const response = aggregator.getResponse();
// finishReason is not directly on ChatResponse anymore
expect(response.usage).toEqual({
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
});
});
it('should return all chunks', () => {
const chunk1: UnifiedStreamChunk = { type: 'content', content: 'Test' };
const chunk2: UnifiedStreamChunk = { type: 'done' };
aggregator.addChunk(chunk1);
aggregator.addChunk(chunk2);
const chunks = aggregator.getChunks();
expect(chunks).toHaveLength(2);
expect(chunks[0]).toEqual(chunk1);
expect(chunks[1]).toEqual(chunk2);
});
it('should reset state', () => {
aggregator.addChunk({ type: 'content', content: 'Test' });
aggregator.reset();
const response = aggregator.getResponse();
expect(response.text).toBe('');
expect(aggregator.getChunks()).toHaveLength(0);
});
});
describe('unifiedStream', () => {
it('should convert async iterable to unified stream', async () => {
async function* mockStream() {
yield JSON.stringify({
choices: [{
delta: { content: 'Hello' }
}]
});
yield JSON.stringify({
choices: [{
delta: { content: ' World' }
}]
});
yield 'data: [DONE]';
}
const chunks: UnifiedStreamChunk[] = [];
for await (const chunk of unifiedStream(mockStream(), 'openai')) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(0);
expect(chunks.find(c => c.type === 'content')).toBeDefined();
expect(chunks.find(c => c.type === 'done')).toBeDefined();
});
it('should handle errors in stream', async () => {
async function* errorStream() {
yield 'invalid json that will cause error';
}
const chunks: UnifiedStreamChunk[] = [];
for await (const chunk of unifiedStream(errorStream(), 'openai')) {
chunks.push(chunk);
}
expect(chunks.find(c => c.type === 'error')).toBeDefined();
});
});

View File

@@ -0,0 +1,457 @@
/**
* Circuit Breaker Pattern Implementation for LLM Providers
*
* Implements a circuit breaker to prevent hammering failing providers.
* States:
* - CLOSED: Normal operation, requests pass through
* - OPEN: Provider is failing, requests are rejected immediately
* - HALF_OPEN: Testing if provider has recovered
*/
import log from '../../log.js';
/**
* Circuit breaker states
*/
export enum CircuitState {
CLOSED = 'CLOSED',
OPEN = 'OPEN',
HALF_OPEN = 'HALF_OPEN'
}
/**
* Circuit breaker configuration
*/
export interface CircuitBreakerConfig {
/** Number of failures before opening circuit */
failureThreshold: number;
/** Time window for counting failures (ms) */
failureWindow: number;
/** Cooldown period before attempting half-open (ms) */
cooldownPeriod: number;
/** Number of successes in half-open to close circuit */
successThreshold: number;
/** Request timeout for half-open state (ms) */
halfOpenTimeout: number;
/** Whether to log state transitions */
enableLogging: boolean;
}
/**
* Circuit breaker statistics
*/
export interface CircuitBreakerStats {
state: CircuitState;
failures: number;
successes: number;
lastFailureTime?: Date;
lastSuccessTime?: Date;
lastStateChange: Date;
totalRequests: number;
rejectedRequests: number;
stateHistory: Array<{
state: CircuitState;
timestamp: Date;
reason: string;
}>;
}
/**
* Error type for circuit breaker rejections
*/
export class CircuitOpenError extends Error {
constructor(public readonly providerName: string, public readonly nextRetryTime: Date) {
super(`Circuit breaker is OPEN for provider ${providerName}. Will retry after ${nextRetryTime.toISOString()}`);
this.name = 'CircuitOpenError';
}
}
/**
* Circuit Breaker implementation
*/
export class CircuitBreaker {
private state: CircuitState = CircuitState.CLOSED;
private failures: number = 0;
private successes: number = 0;
private failureTimestamps: Date[] = [];
private lastStateChangeTime: Date = new Date();
private cooldownTimer?: NodeJS.Timeout;
private stats: CircuitBreakerStats;
private readonly config: CircuitBreakerConfig;
constructor(
private readonly name: string,
config?: Partial<CircuitBreakerConfig>
) {
this.config = {
failureThreshold: config?.failureThreshold ?? 5,
failureWindow: config?.failureWindow ?? 60000, // 1 minute
cooldownPeriod: config?.cooldownPeriod ?? 30000, // 30 seconds
successThreshold: config?.successThreshold ?? 2,
halfOpenTimeout: config?.halfOpenTimeout ?? 5000, // 5 seconds
enableLogging: config?.enableLogging ?? true
};
this.stats = {
state: this.state,
failures: 0,
successes: 0,
lastStateChange: this.lastStateChangeTime,
totalRequests: 0,
rejectedRequests: 0,
stateHistory: []
};
}
/**
* Execute a function with circuit breaker protection
*/
public async execute<T>(
fn: () => Promise<T>,
timeout?: number
): Promise<T> {
this.stats.totalRequests++;
// Check if circuit is open
if (this.state === CircuitState.OPEN) {
this.stats.rejectedRequests++;
const nextRetryTime = new Date(this.lastStateChangeTime.getTime() + this.config.cooldownPeriod);
throw new CircuitOpenError(this.name, nextRetryTime);
}
// Apply timeout for half-open state
const executionTimeout = this.state === CircuitState.HALF_OPEN
? this.config.halfOpenTimeout
: timeout;
try {
const result = await this.executeWithTimeout(fn, executionTimeout);
this.onSuccess();
return result;
} catch (error) {
this.onFailure(error);
throw error;
}
}
/**
* Execute function with timeout
*/
private async executeWithTimeout<T>(
fn: () => Promise<T>,
timeout?: number
): Promise<T> {
if (!timeout) {
return fn();
}
return Promise.race([
fn(),
new Promise<T>((_, reject) =>
setTimeout(() => reject(new Error(`Operation timed out after ${timeout}ms`)), timeout)
)
]);
}
/**
* Handle successful execution
*/
private onSuccess(): void {
this.successes++;
this.stats.successes++;
this.stats.lastSuccessTime = new Date();
switch (this.state) {
case CircuitState.HALF_OPEN:
if (this.successes >= this.config.successThreshold) {
this.transitionTo(CircuitState.CLOSED, 'Success threshold reached');
this.reset();
}
break;
case CircuitState.CLOSED:
// Clear old failure timestamps
this.cleanupFailureTimestamps();
break;
}
}
/**
* Handle failed execution
*/
private onFailure(error: any): void {
const now = new Date();
this.failures++;
this.stats.failures++;
this.stats.lastFailureTime = now;
this.failureTimestamps.push(now);
switch (this.state) {
case CircuitState.HALF_OPEN:
// Immediately open on failure in half-open state
this.transitionTo(CircuitState.OPEN, `Failure in HALF_OPEN state: ${error.message}`);
this.scheduleCooldown();
break;
case CircuitState.CLOSED:
// Check if we've exceeded failure threshold
this.cleanupFailureTimestamps();
if (this.failureTimestamps.length >= this.config.failureThreshold) {
this.transitionTo(CircuitState.OPEN, `Failure threshold exceeded: ${this.failures} failures in ${this.config.failureWindow}ms`);
this.scheduleCooldown();
}
break;
}
}
/**
* Clean up old failure timestamps outside the window
*/
private cleanupFailureTimestamps(): void {
const now = Date.now();
const windowStart = now - this.config.failureWindow;
this.failureTimestamps = this.failureTimestamps.filter(
timestamp => timestamp.getTime() > windowStart
);
}
/**
* Transition to a new state
*/
private transitionTo(newState: CircuitState, reason: string): void {
const oldState = this.state;
this.state = newState;
this.lastStateChangeTime = new Date();
this.stats.state = newState;
this.stats.lastStateChange = this.lastStateChangeTime;
// Add to state history
this.stats.stateHistory.push({
state: newState,
timestamp: this.lastStateChangeTime,
reason
});
// Keep only last 100 state transitions
if (this.stats.stateHistory.length > 100) {
this.stats.stateHistory = this.stats.stateHistory.slice(-100);
}
if (this.config.enableLogging) {
log.info(`[CircuitBreaker:${this.name}] State transition: ${oldState} -> ${newState}. Reason: ${reason}`);
}
}
/**
* Schedule cooldown period
*/
private scheduleCooldown(): void {
if (this.cooldownTimer) {
clearTimeout(this.cooldownTimer);
}
this.cooldownTimer = setTimeout(() => {
if (this.state === CircuitState.OPEN) {
this.transitionTo(CircuitState.HALF_OPEN, 'Cooldown period expired');
this.successes = 0; // Reset success counter for half-open state
}
}, this.config.cooldownPeriod);
}
/**
* Reset counters
*/
private reset(): void {
this.failures = 0;
this.successes = 0;
this.failureTimestamps = [];
if (this.cooldownTimer) {
clearTimeout(this.cooldownTimer);
this.cooldownTimer = undefined;
}
}
/**
* Get current state
*/
public getState(): CircuitState {
return this.state;
}
/**
* Get statistics
*/
public getStats(): CircuitBreakerStats {
return { ...this.stats };
}
/**
* Force open the circuit (for testing or manual intervention)
*/
public forceOpen(reason: string = 'Manual intervention'): void {
this.transitionTo(CircuitState.OPEN, reason);
this.scheduleCooldown();
}
/**
* Force close the circuit (for testing or manual intervention)
*/
public forceClose(reason: string = 'Manual intervention'): void {
this.transitionTo(CircuitState.CLOSED, reason);
this.reset();
}
/**
* Check if circuit allows requests
*/
public isAvailable(): boolean {
return this.state !== CircuitState.OPEN;
}
/**
* Get time until next retry (if circuit is open)
*/
public getNextRetryTime(): Date | null {
if (this.state !== CircuitState.OPEN) {
return null;
}
return new Date(this.lastStateChangeTime.getTime() + this.config.cooldownPeriod);
}
/**
* Cleanup resources
*/
public dispose(): void {
if (this.cooldownTimer) {
clearTimeout(this.cooldownTimer);
this.cooldownTimer = undefined;
}
}
}
/**
* Circuit Breaker Manager for managing multiple circuit breakers
*/
export class CircuitBreakerManager {
private static instance: CircuitBreakerManager | null = null;
private breakers: Map<string, CircuitBreaker> = new Map();
private defaultConfig: Partial<CircuitBreakerConfig>;
constructor(defaultConfig?: Partial<CircuitBreakerConfig>) {
this.defaultConfig = defaultConfig || {};
}
/**
* Get singleton instance
*/
public static getInstance(defaultConfig?: Partial<CircuitBreakerConfig>): CircuitBreakerManager {
if (!CircuitBreakerManager.instance) {
CircuitBreakerManager.instance = new CircuitBreakerManager(defaultConfig);
}
return CircuitBreakerManager.instance;
}
/**
* Get or create a circuit breaker for a provider
*/
public getBreaker(
providerName: string,
config?: Partial<CircuitBreakerConfig>
): CircuitBreaker {
if (!this.breakers.has(providerName)) {
const breakerConfig = { ...this.defaultConfig, ...config };
this.breakers.set(providerName, new CircuitBreaker(providerName, breakerConfig));
}
return this.breakers.get(providerName)!;
}
/**
* Execute with circuit breaker protection
*/
public async execute<T>(
providerName: string,
fn: () => Promise<T>,
config?: Partial<CircuitBreakerConfig>
): Promise<T> {
const breaker = this.getBreaker(providerName, config);
return breaker.execute(fn);
}
/**
* Get all circuit breaker stats
*/
public getAllStats(): Map<string, CircuitBreakerStats> {
const stats = new Map<string, CircuitBreakerStats>();
for (const [name, breaker] of this.breakers) {
stats.set(name, breaker.getStats());
}
return stats;
}
/**
* Get health summary
*/
public getHealthSummary(): {
total: number;
closed: number;
open: number;
halfOpen: number;
availableProviders: string[];
unavailableProviders: string[];
} {
const summary = {
total: this.breakers.size,
closed: 0,
open: 0,
halfOpen: 0,
availableProviders: [] as string[],
unavailableProviders: [] as string[]
};
for (const [name, breaker] of this.breakers) {
const state = breaker.getState();
switch (state) {
case CircuitState.CLOSED:
summary.closed++;
summary.availableProviders.push(name);
break;
case CircuitState.OPEN:
summary.open++;
summary.unavailableProviders.push(name);
break;
case CircuitState.HALF_OPEN:
summary.halfOpen++;
summary.availableProviders.push(name);
break;
}
}
return summary;
}
/**
* Reset all circuit breakers
*/
public resetAll(): void {
for (const breaker of this.breakers.values()) {
breaker.forceClose('Global reset');
}
}
/**
* Dispose all circuit breakers
*/
public dispose(): void {
for (const breaker of this.breakers.values()) {
breaker.dispose();
}
this.breakers.clear();
CircuitBreakerManager.instance = null;
}
}
// Export singleton getter
export const getCircuitBreakerManager = (config?: Partial<CircuitBreakerConfig>): CircuitBreakerManager => {
return CircuitBreakerManager.getInstance(config);
};

View File

@@ -0,0 +1,770 @@
/**
* Enhanced Provider Configuration
*
* Provides advanced configuration options for AI service providers,
* including custom endpoints, model detection, and optimization settings.
*/
import log from '../../log.js';
import options from '../../options.js';
import type { ModelMetadata } from './provider_options.js';
/**
* Provider configuration with enhanced settings
*/
export interface EnhancedProviderConfig {
// Basic settings
provider: 'openai' | 'anthropic' | 'ollama' | 'custom';
apiKey?: string;
baseUrl?: string;
// Advanced settings
customHeaders?: Record<string, string>;
timeout?: number;
maxRetries?: number;
retryDelay?: number;
proxy?: string;
// Model settings
defaultModel?: string;
availableModels?: string[];
modelAliases?: Record<string, string>;
// Performance settings
maxConcurrentRequests?: number;
requestQueueSize?: number;
rateLimitPerMinute?: number;
// Feature flags
enableStreaming?: boolean;
enableTools?: boolean;
enableVision?: boolean;
enableCaching?: boolean;
// Custom endpoints
endpoints?: {
chat?: string;
completions?: string;
embeddings?: string;
models?: string;
health?: string;
};
// Optimization settings
optimization?: {
batchSize?: number;
cacheTimeout?: number;
compressionEnabled?: boolean;
connectionPoolSize?: number;
};
}
/**
* Model information with detailed capabilities
*/
export interface ModelInfo {
id: string;
name: string;
provider: string;
contextWindow: number;
maxOutputTokens: number;
supportedModalities: string[];
costPerMillion?: {
input: number;
output: number;
};
capabilities: {
chat: boolean;
completion: boolean;
embedding: boolean;
functionCalling: boolean;
vision: boolean;
audio: boolean;
streaming: boolean;
};
performance?: {
averageLatency?: number;
tokensPerSecond?: number;
};
}
/**
* Provider configuration manager
*/
export class ProviderConfigurationManager {
private configs: Map<string, EnhancedProviderConfig> = new Map();
private modelRegistry: Map<string, ModelInfo> = new Map();
private modelCache: Map<string, ModelInfo[]> = new Map();
private lastModelFetch: Map<string, number> = new Map();
private readonly MODEL_CACHE_TTL = 3600000; // 1 hour
constructor() {
this.initializeDefaultConfigs();
this.initializeModelRegistry();
}
/**
* Initialize default provider configurations
*/
private initializeDefaultConfigs(): void {
// OpenAI configuration
this.configs.set('openai', {
provider: 'openai',
baseUrl: 'https://api.openai.com/v1',
timeout: 60000,
maxRetries: 3,
retryDelay: 1000,
enableStreaming: true,
enableTools: true,
enableVision: true,
enableCaching: true,
endpoints: {
chat: '/chat/completions',
completions: '/completions',
embeddings: '/embeddings',
models: '/models'
},
optimization: {
batchSize: 10,
cacheTimeout: 300000,
compressionEnabled: true,
connectionPoolSize: 10
}
});
// Anthropic configuration
this.configs.set('anthropic', {
provider: 'anthropic',
baseUrl: 'https://api.anthropic.com',
timeout: 60000,
maxRetries: 3,
retryDelay: 1000,
enableStreaming: true,
enableTools: true,
enableVision: true,
enableCaching: true,
endpoints: {
chat: '/v1/messages'
},
optimization: {
batchSize: 5,
cacheTimeout: 300000,
compressionEnabled: true,
connectionPoolSize: 5
}
});
// Ollama configuration
this.configs.set('ollama', {
provider: 'ollama',
baseUrl: 'http://localhost:11434',
timeout: 120000, // Longer timeout for local models
maxRetries: 2,
retryDelay: 500,
enableStreaming: true,
enableTools: true,
enableVision: false,
enableCaching: true,
endpoints: {
chat: '/api/chat',
models: '/api/tags'
},
optimization: {
batchSize: 1, // Local processing, no batching
cacheTimeout: 600000,
compressionEnabled: false,
connectionPoolSize: 2
}
});
}
/**
* Initialize model registry with known models
*/
private initializeModelRegistry(): void {
// OpenAI models
this.registerModel({
id: 'gpt-4-turbo-preview',
name: 'GPT-4 Turbo',
provider: 'openai',
contextWindow: 128000,
maxOutputTokens: 4096,
supportedModalities: ['text', 'image'],
costPerMillion: {
input: 10,
output: 30
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: true,
audio: false,
streaming: true
}
});
this.registerModel({
id: 'gpt-4o',
name: 'GPT-4 Omni',
provider: 'openai',
contextWindow: 128000,
maxOutputTokens: 4096,
supportedModalities: ['text', 'image', 'audio'],
costPerMillion: {
input: 5,
output: 15
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: true,
audio: true,
streaming: true
}
});
this.registerModel({
id: 'gpt-3.5-turbo',
name: 'GPT-3.5 Turbo',
provider: 'openai',
contextWindow: 16385,
maxOutputTokens: 4096,
supportedModalities: ['text'],
costPerMillion: {
input: 0.5,
output: 1.5
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: false,
audio: false,
streaming: true
}
});
// Anthropic models
this.registerModel({
id: 'claude-3-opus-20240229',
name: 'Claude 3 Opus',
provider: 'anthropic',
contextWindow: 200000,
maxOutputTokens: 4096,
supportedModalities: ['text', 'image'],
costPerMillion: {
input: 15,
output: 75
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: true,
audio: false,
streaming: true
}
});
this.registerModel({
id: 'claude-3-sonnet-20240229',
name: 'Claude 3 Sonnet',
provider: 'anthropic',
contextWindow: 200000,
maxOutputTokens: 4096,
supportedModalities: ['text', 'image'],
costPerMillion: {
input: 3,
output: 15
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: true,
audio: false,
streaming: true
}
});
this.registerModel({
id: 'claude-3-haiku-20240307',
name: 'Claude 3 Haiku',
provider: 'anthropic',
contextWindow: 200000,
maxOutputTokens: 4096,
supportedModalities: ['text', 'image'],
costPerMillion: {
input: 0.25,
output: 1.25
},
capabilities: {
chat: true,
completion: false,
embedding: false,
functionCalling: true,
vision: true,
audio: false,
streaming: true
}
});
// Common Ollama models (defaults, actual specs depend on local models)
this.registerModel({
id: 'llama3',
name: 'Llama 3',
provider: 'ollama',
contextWindow: 8192,
maxOutputTokens: 2048,
supportedModalities: ['text'],
capabilities: {
chat: true,
completion: true,
embedding: false,
functionCalling: true,
vision: false,
audio: false,
streaming: true
}
});
this.registerModel({
id: 'mixtral',
name: 'Mixtral',
provider: 'ollama',
contextWindow: 32768,
maxOutputTokens: 4096,
supportedModalities: ['text'],
capabilities: {
chat: true,
completion: true,
embedding: false,
functionCalling: true,
vision: false,
audio: false,
streaming: true
}
});
}
/**
* Register a model in the registry
*/
public registerModel(model: ModelInfo): void {
this.modelRegistry.set(model.id, model);
// Also register by provider
const providerModels = this.modelCache.get(model.provider) || [];
if (!providerModels.some(m => m.id === model.id)) {
providerModels.push(model);
this.modelCache.set(model.provider, providerModels);
}
}
/**
* Get configuration for a provider
*/
public getProviderConfig(provider: string): EnhancedProviderConfig | undefined {
// First check if we have a stored config
let config = this.configs.get(provider);
if (!config) {
// Try to build config from options
config = this.buildConfigFromOptions(provider);
if (config) {
this.configs.set(provider, config);
}
}
return config;
}
/**
* Build configuration from Trilium options
*/
private buildConfigFromOptions(provider: string): EnhancedProviderConfig | undefined {
switch (provider) {
case 'openai': {
const apiKey = options.getOption('openaiApiKey');
const baseUrl = options.getOption('openaiBaseUrl');
const defaultModel = options.getOption('openaiDefaultModel');
if (!apiKey && !baseUrl) return undefined;
return {
...this.configs.get('openai')!,
apiKey,
baseUrl: baseUrl || this.configs.get('openai')!.baseUrl,
defaultModel
};
}
case 'anthropic': {
const apiKey = options.getOption('anthropicApiKey');
const baseUrl = options.getOption('anthropicBaseUrl');
const defaultModel = options.getOption('anthropicDefaultModel');
if (!apiKey) return undefined;
return {
...this.configs.get('anthropic')!,
apiKey,
baseUrl: baseUrl || this.configs.get('anthropic')!.baseUrl,
defaultModel
};
}
case 'ollama': {
const baseUrl = options.getOption('ollamaBaseUrl');
const defaultModel = options.getOption('ollamaDefaultModel');
if (!baseUrl) return undefined;
return {
...this.configs.get('ollama')!,
baseUrl,
defaultModel
};
}
default:
return undefined;
}
}
/**
* Update provider configuration
*/
public updateProviderConfig(provider: string, config: Partial<EnhancedProviderConfig>): void {
const existing = this.getProviderConfig(provider) || { provider: provider as any };
this.configs.set(provider, { ...existing, ...config });
}
/**
* Get available models for a provider
*/
public async getAvailableModels(provider: string): Promise<ModelInfo[]> {
// Check cache first
const cached = this.modelCache.get(provider);
const lastFetch = this.lastModelFetch.get(provider) || 0;
if (cached && Date.now() - lastFetch < this.MODEL_CACHE_TTL) {
return cached;
}
// Try to fetch fresh model list
try {
const models = await this.fetchProviderModels(provider);
this.modelCache.set(provider, models);
this.lastModelFetch.set(provider, Date.now());
return models;
} catch (error) {
log.info(`Failed to fetch models for ${provider}: ${error}`);
// Return cached if available, otherwise registry models
return cached || Array.from(this.modelRegistry.values())
.filter(m => m.provider === provider);
}
}
/**
* Fetch models from provider API
*/
private async fetchProviderModels(provider: string): Promise<ModelInfo[]> {
const config = this.getProviderConfig(provider);
if (!config) {
throw new Error(`No configuration for provider: ${provider}`);
}
switch (provider) {
case 'openai':
return this.fetchOpenAIModels(config);
case 'ollama':
return this.fetchOllamaModels(config);
case 'anthropic':
// Anthropic doesn't have a models endpoint, use registry
return Array.from(this.modelRegistry.values())
.filter(m => m.provider === 'anthropic');
default:
return [];
}
}
/**
* Fetch OpenAI models
*/
private async fetchOpenAIModels(config: EnhancedProviderConfig): Promise<ModelInfo[]> {
try {
const url = `${config.baseUrl}${config.endpoints?.models || '/models'}`;
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${config.apiKey}`,
...config.customHeaders
}
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
const data = await response.json();
return data.data.map((model: any) => {
// Check if we have detailed info in registry
const registered = this.modelRegistry.get(model.id);
if (registered) {
return registered;
}
// Create basic model info
return {
id: model.id,
name: model.id,
provider: 'openai',
contextWindow: 4096, // Default
maxOutputTokens: 4096,
supportedModalities: ['text'],
capabilities: {
chat: model.id.includes('gpt'),
completion: !model.id.includes('gpt'),
embedding: model.id.includes('embedding'),
functionCalling: model.id.includes('gpt'),
vision: model.id.includes('vision'),
audio: model.id.includes('whisper'),
streaming: true
}
} as ModelInfo;
});
} catch (error) {
log.error(`Failed to fetch OpenAI models: ${error}`);
throw error;
}
}
/**
* Fetch Ollama models
*/
private async fetchOllamaModels(config: EnhancedProviderConfig): Promise<ModelInfo[]> {
try {
const url = `${config.baseUrl}${config.endpoints?.models || '/api/tags'}`;
const response = await fetch(url);
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
const data = await response.json();
return data.models.map((model: any) => {
// Check if we have detailed info in registry
const registered = this.modelRegistry.get(model.name);
if (registered) {
return registered;
}
// Create basic model info from Ollama data
return {
id: model.name,
name: model.name,
provider: 'ollama',
contextWindow: model.details?.parameter_size || 4096,
maxOutputTokens: 2048,
supportedModalities: ['text'],
capabilities: {
chat: true,
completion: true,
embedding: model.name.includes('embed'),
functionCalling: true,
vision: model.name.includes('vision') || model.name.includes('llava'),
audio: false,
streaming: true
},
performance: {
tokensPerSecond: model.details?.tokens_per_second
}
} as ModelInfo;
});
} catch (error) {
log.error(`Failed to fetch Ollama models: ${error}`);
throw error;
}
}
/**
* Get model information
*/
public getModelInfo(modelId: string): ModelInfo | undefined {
return this.modelRegistry.get(modelId);
}
/**
* Detect best model for a use case
*/
public detectBestModel(
provider: string,
requirements: {
minContextWindow?: number;
needsVision?: boolean;
needsTools?: boolean;
maxCostPerMillion?: number;
preferFast?: boolean;
}
): ModelInfo | undefined {
const models = Array.from(this.modelRegistry.values())
.filter(m => m.provider === provider);
// Filter by requirements
let candidates = models.filter(m => {
if (requirements.minContextWindow && m.contextWindow < requirements.minContextWindow) {
return false;
}
if (requirements.needsVision && !m.capabilities.vision) {
return false;
}
if (requirements.needsTools && !m.capabilities.functionCalling) {
return false;
}
if (requirements.maxCostPerMillion && m.costPerMillion) {
const avgCost = (m.costPerMillion.input + m.costPerMillion.output) / 2;
if (avgCost > requirements.maxCostPerMillion) {
return false;
}
}
return true;
});
if (candidates.length === 0) {
return undefined;
}
// Sort by preference
if (requirements.preferFast) {
// Prefer smaller, faster models
candidates.sort((a, b) => {
const costA = a.costPerMillion ? (a.costPerMillion.input + a.costPerMillion.output) / 2 : 1000;
const costB = b.costPerMillion ? (b.costPerMillion.input + b.costPerMillion.output) / 2 : 1000;
return costA - costB;
});
} else {
// Prefer more capable models
candidates.sort((a, b) => b.contextWindow - a.contextWindow);
}
return candidates[0];
}
/**
* Validate provider configuration
*/
public validateConfig(config: EnhancedProviderConfig): {
valid: boolean;
errors: string[];
warnings: string[];
} {
const errors: string[] = [];
const warnings: string[] = [];
// Check required fields
if (!config.provider) {
errors.push('Provider type is required');
}
// Provider-specific validation
switch (config.provider) {
case 'openai':
case 'anthropic':
if (!config.apiKey && !config.baseUrl?.includes('localhost')) {
errors.push('API key is required for cloud providers');
}
break;
case 'ollama':
if (!config.baseUrl) {
errors.push('Base URL is required for Ollama');
}
break;
}
// Validate URLs
if (config.baseUrl) {
try {
new URL(config.baseUrl);
} catch {
errors.push('Invalid base URL format');
}
}
// Validate timeout
if (config.timeout && config.timeout < 1000) {
warnings.push('Timeout less than 1 second may cause issues');
}
// Validate rate limits
if (config.rateLimitPerMinute && config.rateLimitPerMinute < 1) {
errors.push('Rate limit must be at least 1 request per minute');
}
return {
valid: errors.length === 0,
errors,
warnings
};
}
/**
* Export configuration as JSON
*/
public exportConfig(provider: string): string {
const config = this.getProviderConfig(provider);
if (!config) {
throw new Error(`No configuration for provider: ${provider}`);
}
// Remove sensitive data
const exported = { ...config };
if (exported.apiKey) {
exported.apiKey = '***REDACTED***';
}
return JSON.stringify(exported, null, 2);
}
/**
* Import configuration from JSON
*/
public importConfig(provider: string, json: string): void {
try {
const config = JSON.parse(json) as EnhancedProviderConfig;
// Validate before importing
const validation = this.validateConfig(config);
if (!validation.valid) {
throw new Error(`Invalid configuration: ${validation.errors.join(', ')}`);
}
// Don't import if API key is redacted
if (config.apiKey === '***REDACTED***') {
delete config.apiKey;
}
this.updateProviderConfig(provider, config);
log.info(`Imported configuration for ${provider}`);
} catch (error) {
log.error(`Failed to import configuration: ${error}`);
throw error;
}
}
}
// Export singleton instance
export const providerConfigManager = new ProviderConfigurationManager();

View File

@@ -0,0 +1,888 @@
/**
* Provider Factory Pattern Implementation
*
* This module implements a factory pattern for clean provider instantiation,
* unified streaming interfaces, capability detection, and provider-specific
* feature support.
*/
import log from '../../log.js';
import type { AIService, ChatCompletionOptions } from '../ai_interface.js';
import { OpenAIService } from './openai_service.js';
import { AnthropicService } from './anthropic_service.js';
import { OllamaService } from './ollama_service.js';
import type {
OpenAIOptions,
AnthropicOptions,
OllamaOptions,
ModelMetadata
} from './provider_options.js';
import {
getOpenAIOptions,
getAnthropicOptions,
getOllamaOptions
} from './providers.js';
import {
CircuitBreakerManager,
CircuitOpenError,
type CircuitBreakerConfig
} from './circuit_breaker.js';
import {
MetricsExporter,
ExportFormat,
type ExporterConfig
} from '../metrics/metrics_exporter.js';
/**
* Provider type enumeration
*/
export enum ProviderType {
OPENAI = 'openai',
ANTHROPIC = 'anthropic',
OLLAMA = 'ollama',
CUSTOM = 'custom'
}
/**
* Provider capabilities interface
*/
export interface ProviderCapabilities {
streaming: boolean;
functionCalling: boolean;
vision: boolean;
contextWindow: number;
maxOutputTokens: number;
supportsSystemPrompt: boolean;
supportsTools: boolean;
supportedModalities: string[];
customEndpoints: boolean;
batchProcessing: boolean;
}
/**
* Provider health status
*/
export interface ProviderHealthStatus {
provider: ProviderType;
healthy: boolean;
lastChecked: Date;
latency?: number;
error?: string;
version?: string;
}
/**
* Provider configuration
*/
export interface ProviderConfig {
type: ProviderType;
apiKey?: string;
baseUrl?: string;
timeout?: number;
maxRetries?: number;
retryDelay?: number;
customHeaders?: Record<string, string>;
proxy?: string;
}
/**
* Factory creation options
*/
export interface ProviderFactoryOptions {
enableHealthChecks?: boolean;
healthCheckInterval?: number;
enableFallback?: boolean;
fallbackProviders?: ProviderType[];
enableCaching?: boolean;
cacheTimeout?: number;
enableMetrics?: boolean;
enableCircuitBreaker?: boolean;
circuitBreakerConfig?: Partial<CircuitBreakerConfig>;
metricsExporterConfig?: Partial<ExporterConfig>;
}
/**
* Provider instance with metadata
*/
interface ProviderInstance {
service: AIService;
type: ProviderType;
capabilities: ProviderCapabilities;
config: ProviderConfig;
createdAt: Date;
lastUsed: Date;
usageCount: number;
healthStatus?: ProviderHealthStatus;
}
/**
* Provider Factory Class
*
* Manages creation, caching, and lifecycle of AI service providers
*/
export class ProviderFactory {
private static instance: ProviderFactory | null = null;
private providers: Map<string, ProviderInstance> = new Map();
private capabilities: Map<ProviderType, ProviderCapabilities> = new Map();
private healthStatuses: Map<ProviderType, ProviderHealthStatus> = new Map();
private options: ProviderFactoryOptions;
private healthCheckTimer?: NodeJS.Timeout;
private disposed: boolean = false;
private circuitBreakerManager?: CircuitBreakerManager;
private metricsExporter?: MetricsExporter;
constructor(options: ProviderFactoryOptions = {}) {
this.options = {
enableHealthChecks: options.enableHealthChecks ?? true,
healthCheckInterval: options.healthCheckInterval ?? 60000, // 1 minute
enableFallback: options.enableFallback ?? true,
fallbackProviders: options.fallbackProviders ?? [ProviderType.OLLAMA],
enableCaching: options.enableCaching ?? true,
cacheTimeout: options.cacheTimeout ?? 300000, // 5 minutes
enableMetrics: options.enableMetrics ?? true,
enableCircuitBreaker: options.enableCircuitBreaker ?? true,
circuitBreakerConfig: options.circuitBreakerConfig,
metricsExporterConfig: options.metricsExporterConfig
};
this.initializeCapabilities();
// Initialize circuit breaker if enabled
if (this.options.enableCircuitBreaker) {
this.circuitBreakerManager = CircuitBreakerManager.getInstance(
this.options.circuitBreakerConfig
);
}
// Initialize metrics exporter if enabled
if (this.options.enableMetrics) {
this.metricsExporter = MetricsExporter.getInstance({
enabled: true,
...this.options.metricsExporterConfig
});
}
if (this.options.enableHealthChecks) {
this.startHealthChecks();
}
}
/**
* Get singleton instance
*/
public static getInstance(options?: ProviderFactoryOptions): ProviderFactory {
if (!ProviderFactory.instance) {
ProviderFactory.instance = new ProviderFactory(options);
}
return ProviderFactory.instance;
}
/**
* Initialize provider capabilities registry
*/
private initializeCapabilities(): void {
// OpenAI capabilities
this.capabilities.set(ProviderType.OPENAI, {
streaming: true,
functionCalling: true,
vision: true,
contextWindow: 128000, // GPT-4 Turbo
maxOutputTokens: 4096,
supportsSystemPrompt: true,
supportsTools: true,
supportedModalities: ['text', 'image'],
customEndpoints: true,
batchProcessing: true
});
// Anthropic capabilities
this.capabilities.set(ProviderType.ANTHROPIC, {
streaming: true,
functionCalling: true,
vision: true,
contextWindow: 200000, // Claude 3
maxOutputTokens: 4096,
supportsSystemPrompt: true,
supportsTools: true,
supportedModalities: ['text', 'image'],
customEndpoints: false,
batchProcessing: false
});
// Ollama capabilities (default, can be overridden per model)
this.capabilities.set(ProviderType.OLLAMA, {
streaming: true,
functionCalling: true,
vision: false,
contextWindow: 8192, // Default, varies by model
maxOutputTokens: 2048,
supportsSystemPrompt: true,
supportsTools: true,
supportedModalities: ['text'],
customEndpoints: true,
batchProcessing: false
});
}
/**
* Create a provider instance
*/
public async createProvider(
type: ProviderType,
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): Promise<AIService> {
if (this.disposed) {
throw new Error('ProviderFactory has been disposed');
}
const cacheKey = this.getCacheKey(type, config);
// Check cache if enabled
if (this.options.enableCaching) {
const cached = this.providers.get(cacheKey);
if (cached && this.isInstanceValid(cached)) {
cached.lastUsed = new Date();
cached.usageCount++;
if (this.options.enableMetrics) {
log.info(`[ProviderFactory] Using cached ${type} provider (usage: ${cached.usageCount})`);
}
return cached.service;
}
}
// Create new provider instance
const service = await this.instantiateProvider(type, config, options);
if (!service) {
throw new Error(`Failed to create provider of type: ${type}`);
}
// Get capabilities for this provider
const capabilities = await this.detectCapabilities(type, service);
// Create provider instance
const instance: ProviderInstance = {
service,
type,
capabilities,
config: { type, ...config },
createdAt: new Date(),
lastUsed: new Date(),
usageCount: 1
};
// Cache the instance
if (this.options.enableCaching) {
this.providers.set(cacheKey, instance);
// Schedule cache cleanup
setTimeout(() => {
this.cleanupCache(cacheKey);
}, this.options.cacheTimeout);
}
if (this.options.enableMetrics) {
log.info(`[ProviderFactory] Created new ${type} provider`);
}
return service;
}
/**
* Instantiate a specific provider
*/
private async instantiateProvider(
type: ProviderType,
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): Promise<AIService | null> {
const startTime = Date.now();
try {
// Use circuit breaker if enabled
if (this.circuitBreakerManager) {
const breaker = this.circuitBreakerManager.getBreaker(type);
// Check if circuit is open
if (!breaker.isAvailable()) {
const nextRetry = breaker.getNextRetryTime();
log.info(`[ProviderFactory] Circuit breaker OPEN for ${type}. Next retry: ${nextRetry?.toISOString()}`);
// Record metric
if (this.metricsExporter) {
this.metricsExporter.getCollector().recordError(type, 'Circuit breaker open');
}
// Try fallback immediately
if (this.options.enableFallback && this.options.fallbackProviders?.length) {
return this.tryFallbackProvider(options);
}
throw new CircuitOpenError(type, nextRetry!);
}
// Execute with circuit breaker protection
return await breaker.execute(async () => {
const service = await this.createProviderInternal(type, config, options);
// Record success metric
if (this.metricsExporter && service) {
const latency = Date.now() - startTime;
this.metricsExporter.getCollector().recordLatency(type, latency);
this.metricsExporter.getCollector().recordRequest(type, true);
}
return service;
});
} else {
// No circuit breaker, create directly
const service = await this.createProviderInternal(type, config, options);
// Record metrics
if (this.metricsExporter && service) {
const latency = Date.now() - startTime;
this.metricsExporter.getCollector().recordLatency(type, latency);
this.metricsExporter.getCollector().recordRequest(type, true);
}
return service;
}
} catch (error: any) {
log.error(`[ProviderFactory] Error creating ${type} provider: ${error.message}`);
// Record failure metric
if (this.metricsExporter) {
this.metricsExporter.getCollector().recordRequest(type, false);
this.metricsExporter.getCollector().recordError(type, error.message);
}
// Try fallback if enabled and not a circuit breaker error
if (!(error instanceof CircuitOpenError) &&
this.options.enableFallback &&
this.options.fallbackProviders?.length) {
return this.tryFallbackProvider(options);
}
throw error;
}
}
/**
* Internal provider creation logic
*/
private async createProviderInternal(
type: ProviderType,
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): Promise<AIService | null> {
switch (type) {
case ProviderType.OPENAI:
return this.createOpenAIProvider(config, options);
case ProviderType.ANTHROPIC:
return this.createAnthropicProvider(config, options);
case ProviderType.OLLAMA:
return await this.createOllamaProvider(config, options);
case ProviderType.CUSTOM:
return this.createCustomProvider(config, options);
default:
log.error(`[ProviderFactory] Unknown provider type: ${type}`);
return null;
}
}
/**
* Create OpenAI provider
*/
private createOpenAIProvider(
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): AIService {
const service = new OpenAIService();
if (!service.isAvailable()) {
throw new Error('OpenAI service is not available');
}
return service;
}
/**
* Create Anthropic provider
*/
private createAnthropicProvider(
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): AIService {
const service = new AnthropicService();
if (!service.isAvailable()) {
throw new Error('Anthropic service is not available');
}
return service;
}
/**
* Create Ollama provider
*/
private async createOllamaProvider(
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): Promise<AIService> {
const service = new OllamaService();
if (!service.isAvailable()) {
throw new Error('Ollama service is not available');
}
// Ollama might need model pulling or other async setup
// This is handled internally by the service
return service;
}
/**
* Create custom provider (for future extensibility)
*/
private createCustomProvider(
config?: Partial<ProviderConfig>,
options?: ChatCompletionOptions
): AIService {
throw new Error('Custom providers not yet implemented');
}
/**
* Try fallback providers
*/
private async tryFallbackProvider(options?: ChatCompletionOptions): Promise<AIService | null> {
if (!this.options.fallbackProviders) {
return null;
}
for (const fallbackType of this.options.fallbackProviders) {
try {
log.info(`[ProviderFactory] Trying fallback provider: ${fallbackType}`);
const service = await this.instantiateProvider(fallbackType, undefined, options);
if (service && service.isAvailable()) {
log.info(`[ProviderFactory] Fallback to ${fallbackType} successful`);
return service;
}
} catch (error) {
log.error(`[ProviderFactory] Fallback to ${fallbackType} failed: ${error}`);
}
}
return null;
}
/**
* Detect capabilities for a provider
*/
private async detectCapabilities(
type: ProviderType,
service: AIService
): Promise<ProviderCapabilities> {
// Start with default capabilities
let capabilities = this.capabilities.get(type) || this.getDefaultCapabilities();
// Try to detect actual capabilities from the service
try {
// Check for streaming support
if ('supportsStreaming' in service && typeof service.supportsStreaming === 'function') {
capabilities.streaming = (service as any).supportsStreaming();
}
// Check for tool support
if ('supportsTools' in service && typeof service.supportsTools === 'function') {
capabilities.supportsTools = (service as any).supportsTools();
}
// For Ollama, try to get model-specific capabilities
if (type === ProviderType.OLLAMA) {
capabilities = await this.detectOllamaCapabilities(service, capabilities);
}
} catch (error) {
log.info(`[ProviderFactory] Could not detect capabilities for ${type}: ${error}`);
}
return capabilities;
}
/**
* Detect Ollama-specific capabilities
*/
private async detectOllamaCapabilities(
service: AIService,
defaultCaps: ProviderCapabilities
): Promise<ProviderCapabilities> {
// This would query the Ollama API for model info
// For now, return defaults
return defaultCaps;
}
/**
* Get default capabilities
*/
private getDefaultCapabilities(): ProviderCapabilities {
return {
streaming: true,
functionCalling: false,
vision: false,
contextWindow: 4096,
maxOutputTokens: 1024,
supportsSystemPrompt: true,
supportsTools: false,
supportedModalities: ['text'],
customEndpoints: false,
batchProcessing: false
};
}
/**
* Perform health check on a provider
*/
public async checkProviderHealth(type: ProviderType): Promise<ProviderHealthStatus> {
const startTime = Date.now();
try {
const service = await this.createProvider(type);
// Try a simple completion to test the service
const testMessages = [{ role: 'user' as const, content: 'Hi' }];
const response = await service.generateChatCompletion(testMessages, {
maxTokens: 1,
temperature: 0
});
const latency = Date.now() - startTime;
const status: ProviderHealthStatus = {
provider: type,
healthy: true,
lastChecked: new Date(),
latency
};
this.healthStatuses.set(type, status);
return status;
} catch (error: any) {
const status: ProviderHealthStatus = {
provider: type,
healthy: false,
lastChecked: new Date(),
error: error.message || 'Unknown error'
};
this.healthStatuses.set(type, status);
return status;
}
}
/**
* Start periodic health checks
*/
private startHealthChecks(): void {
if (this.healthCheckTimer) {
return;
}
this.healthCheckTimer = setInterval(async () => {
if (this.disposed) {
return;
}
for (const type of this.capabilities.keys()) {
try {
await this.checkProviderHealth(type);
} catch (error) {
log.error(`[ProviderFactory] Health check failed for ${type}: ${error}`);
}
}
}, this.options.healthCheckInterval);
// Perform initial health check
this.performInitialHealthCheck();
}
/**
* Perform initial health check
*/
private async performInitialHealthCheck(): Promise<void> {
for (const type of this.capabilities.keys()) {
try {
await this.checkProviderHealth(type);
} catch (error) {
log.error(`[ProviderFactory] Initial health check failed for ${type}: ${error}`);
}
}
}
/**
* Get health status for a provider
*/
public getHealthStatus(type: ProviderType): ProviderHealthStatus | undefined {
return this.healthStatuses.get(type);
}
/**
* Get all health statuses
*/
public getAllHealthStatuses(): Map<ProviderType, ProviderHealthStatus> {
return new Map(this.healthStatuses);
}
/**
* Get capabilities for a provider
*/
public getCapabilities(type: ProviderType): ProviderCapabilities | undefined {
return this.capabilities.get(type);
}
/**
* Register custom provider capabilities
*/
public registerCapabilities(type: ProviderType, capabilities: ProviderCapabilities): void {
this.capabilities.set(type, capabilities);
}
/**
* Get cache key for provider
*/
private getCacheKey(type: ProviderType, config?: Partial<ProviderConfig>): string {
const baseKey = type;
if (config?.baseUrl) {
return `${baseKey}:${config.baseUrl}`;
}
return baseKey;
}
/**
* Check if cached instance is still valid
*/
private isInstanceValid(instance: ProviderInstance): boolean {
if (!this.options.cacheTimeout) {
return true;
}
const age = Date.now() - instance.createdAt.getTime();
return age < this.options.cacheTimeout;
}
/**
* Cleanup specific cache entry
*/
private cleanupCache(key: string): void {
const instance = this.providers.get(key);
if (instance && !this.isInstanceValid(instance)) {
this.disposeProvider(instance);
this.providers.delete(key);
if (this.options.enableMetrics) {
log.info(`[ProviderFactory] Cleaned up cached provider: ${key}`);
}
}
}
/**
* Cleanup all expired cache entries
*/
public cleanupExpiredCache(): void {
const keys = Array.from(this.providers.keys());
for (const key of keys) {
this.cleanupCache(key);
}
}
/**
* Dispose a provider instance
*/
private disposeProvider(instance: ProviderInstance): void {
try {
if ('dispose' in instance.service && typeof (instance.service as any).dispose === 'function') {
(instance.service as any).dispose();
}
} catch (error) {
log.error(`[ProviderFactory] Error disposing provider: ${error}`);
}
}
/**
* Get provider statistics
*/
public getStatistics(): {
cachedProviders: number;
totalUsage: number;
providerUsage: Record<string, number>;
healthyProviders: number;
unhealthyProviders: number;
} {
const stats = {
cachedProviders: this.providers.size,
totalUsage: 0,
providerUsage: {} as Record<string, number>,
healthyProviders: 0,
unhealthyProviders: 0
};
// Calculate usage statistics
for (const [key, instance] of this.providers) {
stats.totalUsage += instance.usageCount;
const type = instance.type.toString();
stats.providerUsage[type] = (stats.providerUsage[type] || 0) + instance.usageCount;
}
// Calculate health statistics
for (const status of this.healthStatuses.values()) {
if (status.healthy) {
stats.healthyProviders++;
} else {
stats.unhealthyProviders++;
}
}
return stats;
}
/**
* Clear all cached providers
*/
public clearCache(): void {
for (const instance of this.providers.values()) {
this.disposeProvider(instance);
}
this.providers.clear();
if (this.options.enableMetrics) {
log.info('[ProviderFactory] Cleared all cached providers');
}
}
/**
* Get circuit breaker status
*/
public getCircuitBreakerStatus(): any {
if (!this.circuitBreakerManager) {
return null;
}
return {
summary: this.circuitBreakerManager.getHealthSummary(),
details: Array.from(this.circuitBreakerManager.getAllStats().entries()).map(([name, stats]) => ({
provider: name,
...stats
}))
};
}
/**
* Get metrics summary
*/
public getMetricsSummary(): any {
if (!this.metricsExporter) {
return null;
}
const collector = this.metricsExporter.getCollector();
return {
providers: Array.from(collector.getProviderMetricsMap().values()),
system: collector.getSystemMetrics()
};
}
/**
* Export metrics in specified format
*/
public exportMetrics(format?: 'prometheus' | 'statsd' | 'opentelemetry' | 'json'): any {
if (!this.metricsExporter) {
return null;
}
const exportFormat = format ? {
prometheus: ExportFormat.PROMETHEUS,
statsd: ExportFormat.STATSD,
opentelemetry: ExportFormat.OPENTELEMETRY,
json: ExportFormat.JSON
}[format] : undefined;
return this.metricsExporter.export(exportFormat);
}
/**
* Reset circuit breaker for a specific provider
*/
public resetCircuitBreaker(provider: ProviderType): void {
if (!this.circuitBreakerManager) {
return;
}
const breaker = this.circuitBreakerManager.getBreaker(provider);
breaker.forceClose('Manual reset');
log.info(`[ProviderFactory] Circuit breaker reset for ${provider}`);
}
/**
* Configure metrics export
*/
public configureMetricsExport(config: Partial<ExporterConfig>): void {
if (!this.metricsExporter) {
return;
}
this.metricsExporter.updateConfig(config);
log.info('[ProviderFactory] Metrics export configuration updated');
}
/**
* Dispose the factory and cleanup resources
*/
public dispose(): void {
if (this.disposed) {
return;
}
this.disposed = true;
// Stop health checks
if (this.healthCheckTimer) {
clearInterval(this.healthCheckTimer);
this.healthCheckTimer = undefined;
}
// Dispose circuit breaker manager
if (this.circuitBreakerManager) {
this.circuitBreakerManager.dispose();
}
// Dispose metrics exporter
if (this.metricsExporter) {
this.metricsExporter.dispose();
}
// Clear cache
this.clearCache();
// Clear singleton instance
ProviderFactory.instance = null;
log.info('[ProviderFactory] Disposed successfully');
}
}
// Export singleton instance getter
export const getProviderFactory = (options?: ProviderFactoryOptions): ProviderFactory => {
return ProviderFactory.getInstance(options);
};

View File

@@ -0,0 +1,662 @@
/**
* Unified Stream Handler
*
* Provides a consistent streaming interface across all providers,
* handling provider-specific stream formats and normalizing them
* into a unified format.
*/
import log from '../../log.js';
import type { ChatResponse } from '../ai_interface.js';
/**
* Unified stream chunk format
*/
export interface UnifiedStreamChunk {
type: 'content' | 'tool_call' | 'error' | 'done';
content?: string;
toolCall?: {
id: string;
name: string;
arguments: string;
};
error?: string;
metadata?: {
provider: string;
model?: string;
finishReason?: string;
usage?: {
promptTokens?: number;
completionTokens?: number;
totalTokens?: number;
};
};
}
/**
* Stream handler configuration
*/
export interface StreamHandlerConfig {
provider: 'openai' | 'anthropic' | 'ollama';
onChunk: (chunk: UnifiedStreamChunk) => void | Promise<void>;
onError?: (error: Error) => void;
onComplete?: (response: ChatResponse) => void;
bufferSize?: number;
timeout?: number;
}
/**
* Abstract base class for provider-specific stream handlers
*/
export abstract class BaseStreamHandler {
protected config: StreamHandlerConfig;
protected buffer: string = '';
protected response: Partial<ChatResponse> = {};
protected finishReason?: string;
protected isComplete: boolean = false;
protected timeoutTimer?: NodeJS.Timeout;
constructor(config: StreamHandlerConfig) {
this.config = config;
if (config.timeout) {
this.setTimeoutTimer(config.timeout);
}
}
/**
* Process a stream chunk from the provider
*/
public abstract processChunk(chunk: any): Promise<void>;
/**
* Complete the stream processing
*/
public abstract complete(): Promise<ChatResponse>;
/**
* Handle stream error
*/
public handleError(error: Error): void {
this.clearTimeoutTimer();
if (this.config.onError) {
this.config.onError(error);
} else {
log.error(`[StreamHandler] Stream error: ${error.message}`);
}
// Send error chunk
this.sendChunk({
type: 'error',
error: error.message,
metadata: {
provider: this.config.provider
}
});
}
/**
* Send a unified chunk to the consumer
*/
protected async sendChunk(chunk: UnifiedStreamChunk): Promise<void> {
try {
await this.config.onChunk(chunk);
} catch (error) {
log.error(`[StreamHandler] Error in chunk handler: ${error}`);
}
}
/**
* Set timeout timer
*/
protected setTimeoutTimer(timeout: number): void {
this.timeoutTimer = setTimeout(() => {
this.handleError(new Error(`Stream timeout after ${timeout}ms`));
}, timeout);
}
/**
* Clear timeout timer
*/
protected clearTimeoutTimer(): void {
if (this.timeoutTimer) {
clearTimeout(this.timeoutTimer);
this.timeoutTimer = undefined;
}
}
/**
* Reset timeout timer
*/
protected resetTimeoutTimer(): void {
if (this.config.timeout) {
this.clearTimeoutTimer();
this.setTimeoutTimer(this.config.timeout);
}
}
}
/**
* OpenAI stream handler
*/
export class OpenAIStreamHandler extends BaseStreamHandler {
private toolCalls: Map<number, any> = new Map();
public async processChunk(chunk: any): Promise<void> {
this.resetTimeoutTimer();
try {
// Parse SSE format if needed
const data = this.parseSSEChunk(chunk);
if (!data || data === '[DONE]') {
await this.sendComplete();
return;
}
const parsed = typeof data === 'string' ? JSON.parse(data) : data;
const choice = parsed.choices?.[0];
if (!choice) {
return;
}
// Handle content delta
if (choice.delta?.content) {
this.buffer += choice.delta.content;
await this.sendChunk({
type: 'content',
content: choice.delta.content,
metadata: {
provider: 'openai',
model: parsed.model
}
});
}
// Handle tool calls
if (choice.delta?.tool_calls) {
for (const toolCall of choice.delta.tool_calls) {
await this.processToolCall(toolCall);
}
}
// Check if stream is done
if (choice.finish_reason) {
this.finishReason = choice.finish_reason;
if (parsed.usage) {
this.response.usage = {
promptTokens: parsed.usage.prompt_tokens,
completionTokens: parsed.usage.completion_tokens,
totalTokens: parsed.usage.total_tokens
};
}
}
} catch (error) {
log.error(`[OpenAIStreamHandler] Error processing chunk: ${error}`);
this.handleError(error as Error);
}
}
private parseSSEChunk(chunk: any): string | null {
if (typeof chunk === 'string') {
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
return line.slice(6);
}
}
}
return chunk;
}
private async processToolCall(toolCall: any): Promise<void> {
const index = toolCall.index || 0;
if (!this.toolCalls.has(index)) {
this.toolCalls.set(index, {
id: toolCall.id || '',
type: 'function',
function: {
name: '',
arguments: ''
}
});
}
const existing = this.toolCalls.get(index)!;
if (toolCall.id) {
existing.id = toolCall.id;
}
if (toolCall.function?.name) {
existing.function.name = toolCall.function.name;
}
if (toolCall.function?.arguments) {
existing.function.arguments += toolCall.function.arguments;
}
// Send tool call chunk
await this.sendChunk({
type: 'tool_call',
toolCall: {
id: existing.id,
name: existing.function.name,
arguments: existing.function.arguments
},
metadata: {
provider: 'openai'
}
});
}
private async sendComplete(): Promise<void> {
this.isComplete = true;
this.clearTimeoutTimer();
await this.sendChunk({
type: 'done',
metadata: {
provider: 'openai',
finishReason: this.finishReason,
usage: this.response.usage
}
});
}
public async complete(): Promise<ChatResponse> {
if (!this.isComplete) {
await this.sendComplete();
}
const response: ChatResponse = {
text: this.buffer,
model: 'openai-model',
provider: 'openai',
usage: this.response.usage
};
if (this.toolCalls.size > 0) {
response.tool_calls = Array.from(this.toolCalls.values());
}
if (this.config.onComplete) {
this.config.onComplete(response);
}
return response;
}
}
/**
* Anthropic stream handler
*/
export class AnthropicStreamHandler extends BaseStreamHandler {
private messageId?: string;
private stopReason?: string;
public async processChunk(chunk: any): Promise<void> {
this.resetTimeoutTimer();
try {
const event = this.parseAnthropicEvent(chunk);
if (!event) {
return;
}
switch (event.type) {
case 'message_start':
this.messageId = event.message?.id;
break;
case 'content_block_start':
// Content block started
break;
case 'content_block_delta':
if (event.delta?.type === 'text_delta') {
const text = event.delta.text || '';
this.buffer += text;
await this.sendChunk({
type: 'content',
content: text,
metadata: {
provider: 'anthropic',
model: event.model
}
});
}
break;
case 'content_block_stop':
// Content block completed
break;
case 'message_delta':
if (event.delta?.stop_reason) {
this.stopReason = event.delta.stop_reason;
}
if (event.usage) {
this.response.usage = {
promptTokens: event.usage.input_tokens,
completionTokens: event.usage.output_tokens,
totalTokens: (event.usage.input_tokens || 0) + (event.usage.output_tokens || 0)
};
}
break;
case 'message_stop':
await this.sendComplete();
break;
case 'error':
this.handleError(new Error(event.error?.message || 'Unknown error'));
break;
}
} catch (error) {
log.error(`[AnthropicStreamHandler] Error processing chunk: ${error}`);
this.handleError(error as Error);
}
}
private parseAnthropicEvent(chunk: any): any {
if (typeof chunk === 'string') {
try {
// Parse SSE format
const lines = chunk.split('\n');
let eventType = '';
let eventData = '';
for (const line of lines) {
if (line.startsWith('event: ')) {
eventType = line.slice(7);
} else if (line.startsWith('data: ')) {
eventData = line.slice(6);
}
}
if (eventType && eventData) {
const parsed = JSON.parse(eventData);
return { ...parsed, type: eventType };
}
} catch (error) {
log.error(`[AnthropicStreamHandler] Error parsing event: ${error}`);
}
}
return chunk;
}
private async sendComplete(): Promise<void> {
this.isComplete = true;
this.clearTimeoutTimer();
await this.sendChunk({
type: 'done',
metadata: {
provider: 'anthropic',
finishReason: this.stopReason,
usage: this.response.usage
}
});
}
public async complete(): Promise<ChatResponse> {
if (!this.isComplete) {
await this.sendComplete();
}
const response: ChatResponse = {
text: this.buffer,
model: 'anthropic-model',
provider: 'anthropic',
usage: this.response.usage
};
if (this.config.onComplete) {
this.config.onComplete(response);
}
return response;
}
}
/**
* Ollama stream handler
*/
export class OllamaStreamHandler extends BaseStreamHandler {
private model?: string;
private toolCalls: any[] = [];
public async processChunk(chunk: any): Promise<void> {
this.resetTimeoutTimer();
try {
const data = typeof chunk === 'string' ? JSON.parse(chunk) : chunk;
// Handle content
if (data.message?.content) {
const content = data.message.content;
this.buffer += content;
await this.sendChunk({
type: 'content',
content: content,
metadata: {
provider: 'ollama',
model: data.model || this.model
}
});
}
// Handle tool calls
if (data.message?.tool_calls) {
this.toolCalls = data.message.tool_calls;
for (const toolCall of this.toolCalls) {
await this.sendChunk({
type: 'tool_call',
toolCall: {
id: toolCall.id || `tool_${Date.now()}`,
name: toolCall.function?.name || '',
arguments: JSON.stringify(toolCall.function?.arguments || {})
},
metadata: {
provider: 'ollama'
}
});
}
}
// Store model info
if (data.model) {
this.model = data.model;
}
// Check if done
if (data.done) {
// Calculate token usage if available
if (data.prompt_eval_count || data.eval_count) {
this.response.usage = {
promptTokens: data.prompt_eval_count,
completionTokens: data.eval_count,
totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0)
};
}
await this.sendComplete();
}
} catch (error) {
log.error(`[OllamaStreamHandler] Error processing chunk: ${error}`);
this.handleError(error as Error);
}
}
private async sendComplete(): Promise<void> {
this.isComplete = true;
this.clearTimeoutTimer();
await this.sendChunk({
type: 'done',
metadata: {
provider: 'ollama',
model: this.model,
usage: this.response.usage
}
});
}
public async complete(): Promise<ChatResponse> {
if (!this.isComplete) {
await this.sendComplete();
}
const response: ChatResponse = {
text: this.buffer,
model: this.model || 'ollama-model',
provider: 'ollama',
usage: this.response.usage
};
if (this.toolCalls.length > 0) {
response.tool_calls = this.toolCalls;
}
if (this.config.onComplete) {
this.config.onComplete(response);
}
return response;
}
}
/**
* Factory function to create appropriate stream handler
*/
export function createStreamHandler(config: StreamHandlerConfig): BaseStreamHandler {
switch (config.provider) {
case 'openai':
return new OpenAIStreamHandler(config);
case 'anthropic':
return new AnthropicStreamHandler(config);
case 'ollama':
return new OllamaStreamHandler(config);
default:
throw new Error(`Unsupported provider: ${config.provider}`);
}
}
/**
* Utility to convert async iterable to unified stream
*/
export async function* unifiedStream(
asyncIterable: AsyncIterable<any>,
provider: 'openai' | 'anthropic' | 'ollama'
): AsyncGenerator<UnifiedStreamChunk> {
const chunks: UnifiedStreamChunk[] = [];
let handler: BaseStreamHandler | null = null;
try {
handler = createStreamHandler({
provider,
onChunk: (chunk) => { chunks.push(chunk); }
});
for await (const chunk of asyncIterable) {
await handler.processChunk(chunk);
// Yield accumulated chunks
while (chunks.length > 0) {
const chunk = chunks.shift()!;
yield chunk;
}
}
// Complete the stream
await handler.complete();
// Yield any remaining chunks
while (chunks.length > 0) {
const chunk = chunks.shift()!;
yield chunk;
}
} catch (error) {
log.error(`[unifiedStream] Error: ${error}`);
yield {
type: 'error',
error: (error as Error).message,
metadata: { provider }
};
}
}
/**
* Stream aggregator for collecting stream chunks into a complete response
*/
export class StreamAggregator {
private chunks: UnifiedStreamChunk[] = [];
private content: string = '';
private toolCalls: any[] = [];
private metadata: any = {};
public addChunk(chunk: UnifiedStreamChunk): void {
this.chunks.push(chunk);
switch (chunk.type) {
case 'content':
if (chunk.content) {
this.content += chunk.content;
}
break;
case 'tool_call':
if (chunk.toolCall) {
this.toolCalls.push(chunk.toolCall);
}
break;
case 'done':
if (chunk.metadata) {
this.metadata = { ...this.metadata, ...chunk.metadata };
}
break;
}
}
public getResponse(): ChatResponse {
const response: ChatResponse = {
text: this.content,
model: this.metadata.model || 'unknown-model',
provider: this.metadata.provider || 'unknown',
usage: this.metadata.usage
};
if (this.toolCalls.length > 0) {
response.tool_calls = this.toolCalls;
}
return response;
}
public getChunks(): UnifiedStreamChunk[] {
return [...this.chunks];
}
public reset(): void {
this.chunks = [];
this.content = '';
this.toolCalls = [];
this.metadata = {};
}
}