feat(llm): add tests for streaming

This commit is contained in:
perf3ct
2025-06-08 20:30:33 +00:00
parent c1bcb73337
commit c6f2124e9d
7 changed files with 2586 additions and 9 deletions

View File

@@ -1,8 +1,9 @@
import { Application } from "express";
import { beforeAll, describe, expect, it, vi, beforeEach } from "vitest";
import { beforeAll, describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import supertest from "supertest";
import config from "../../services/config.js";
import { refreshAuth } from "../../services/auth.js";
import type { WebSocket } from 'ws';
// Mock the CSRF protection middleware to allow tests to pass
vi.mock("../csrf_protection.js", () => ({
@@ -10,6 +11,64 @@ vi.mock("../csrf_protection.js", () => ({
generateToken: () => "mock-csrf-token"
}));
// Mock WebSocket service
vi.mock("../../services/ws.js", () => ({
default: {
sendMessageToAllClients: vi.fn()
}
}));
// Mock log service
vi.mock("../../services/log.js", () => ({
default: {
info: vi.fn(),
error: vi.fn(),
warn: vi.fn()
}
}));
// Mock chat storage service
const mockChatStorage = {
createChat: vi.fn(),
getChat: vi.fn(),
updateChat: vi.fn(),
getAllChats: vi.fn(),
deleteChat: vi.fn()
};
vi.mock("../../services/llm/storage/chat_storage_service.js", () => ({
default: mockChatStorage
}));
// Mock AI service manager
const mockAiServiceManager = {
getOrCreateAnyService: vi.fn()
};
vi.mock("../../services/llm/ai_service_manager.js", () => ({
default: mockAiServiceManager
}));
// Mock chat pipeline
const mockChatPipelineExecute = vi.fn();
const MockChatPipeline = vi.fn().mockImplementation(() => ({
execute: mockChatPipelineExecute
}));
vi.mock("../../services/llm/pipeline/chat_pipeline.js", () => ({
ChatPipeline: MockChatPipeline
}));
// Mock configuration helpers
const mockGetSelectedModelConfig = vi.fn();
vi.mock("../../services/llm/config/configuration_helpers.js", () => ({
getSelectedModelConfig: mockGetSelectedModelConfig
}));
// Mock options service
vi.mock("../../services/options.js", () => ({
default: {
getOptionBool: vi.fn()
}
}));
// Session-based login that properly establishes req.session.loggedIn
async function loginWithSession(app: Application) {
const response = await supertest(app)
@@ -257,7 +316,30 @@ describe("LLM API Tests", () => {
let testChatId: string;
beforeEach(async () => {
// Reset all mocks
vi.clearAllMocks();
// Import options service to access mock
const options = (await import("../../services/options.js")).default;
// Setup default mock behaviors
options.getOptionBool.mockReturnValue(true); // AI enabled
mockAiServiceManager.getOrCreateAnyService.mockResolvedValue({});
mockGetSelectedModelConfig.mockResolvedValue({
model: 'test-model',
provider: 'test-provider'
});
// Create a fresh chat for each test
const mockChat = {
id: 'streaming-test-chat',
title: 'Streaming Test Chat',
messages: [],
createdAt: new Date().toISOString()
};
mockChatStorage.createChat.mockResolvedValue(mockChat);
mockChatStorage.getChat.mockResolvedValue(mockChat);
const createResponse = await supertest(app)
.post("/api/llm/chat")
.set("Cookie", sessionCookie)
@@ -269,7 +351,19 @@ describe("LLM API Tests", () => {
testChatId = createResponse.body.id;
});
afterEach(() => {
vi.clearAllMocks();
});
it("should initiate streaming for a chat message", async () => {
// Setup streaming simulation
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate streaming chunks
await callback('Hello', false, {});
await callback(' world!', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
@@ -286,6 +380,31 @@ describe("LLM API Tests", () => {
success: true,
message: "Streaming initiated successfully"
});
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify WebSocket messages were sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: undefined
});
// Verify streaming chunks were sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
content: 'Hello',
done: false
});
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
content: ' world!',
done: true
});
});
it("should handle empty content for streaming", async () => {
@@ -338,6 +457,29 @@ describe("LLM API Tests", () => {
});
it("should handle streaming with note mentions", async () => {
// Mock becca for note content retrieval
const mockBecca = {
getNote: vi.fn().mockReturnValue({
noteId: 'root',
title: 'Root Note',
getBlob: () => ({
getContent: () => 'Root note content for testing'
})
})
};
vi.mocked(await import('../../becca/becca.js')).default = mockBecca;
// Setup streaming with mention context
mockChatPipelineExecute.mockImplementation(async (input) => {
// Verify mention content is included
expect(input.query).toContain('Tell me about this note');
expect(input.query).toContain('Root note content for testing');
const callback = input.streamCallback;
await callback('The root note contains', false, {});
await callback(' important information.', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
@@ -358,6 +500,250 @@ describe("LLM API Tests", () => {
success: true,
message: "Streaming initiated successfully"
});
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify thinking message was sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Initializing streaming LLM response...'
});
});
it("should handle streaming with thinking states", async () => {
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate thinking states
await callback('', false, { thinking: 'Analyzing the question...' });
await callback('', false, { thinking: 'Formulating response...' });
await callback('The answer is', false, {});
await callback(' 42.', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "What is the meaning of life?",
useAdvancedContext: false,
showThinking: true
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify thinking messages
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Analyzing the question...',
done: false
});
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Formulating response...',
done: false
});
});
it("should handle streaming with tool executions", async () => {
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate tool execution
await callback('Let me calculate that', false, {});
await callback('', false, {
toolExecution: {
tool: 'calculator',
arguments: { expression: '2 + 2' },
result: '4',
toolCallId: 'call_123',
action: 'execute'
}
});
await callback('The result is 4', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "What is 2 + 2?",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify tool execution message
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
toolExecution: {
tool: 'calculator',
args: { expression: '2 + 2' },
result: '4',
toolCallId: 'call_123',
action: 'execute',
error: undefined
},
done: false
});
});
it("should handle streaming errors gracefully", async () => {
mockChatPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "This will fail",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200); // Still returns 200
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify error message was sent via WebSocket
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
error: 'Error during streaming: Pipeline error',
done: true
});
});
it("should handle AI disabled state", async () => {
// Import options service to access mock
const options = (await import("../../services/options.js")).default;
options.getOptionBool.mockReturnValue(false); // AI disabled
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Hello AI",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify error message about AI being disabled
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
error: 'Error during streaming: AI features are disabled. Please enable them in the settings.',
done: true
});
});
it("should save chat messages after streaming completion", async () => {
const completeResponse = 'This is the complete response';
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
await callback(completeResponse, true, {});
});
await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Save this response",
useAdvancedContext: false,
showThinking: false
});
// Wait for async operations
await new Promise(resolve => setTimeout(resolve, 100));
// Verify chat was updated with the complete response
expect(mockChatStorage.updateChat).toHaveBeenCalledWith(
testChatId,
expect.arrayContaining([
{ role: 'assistant', content: completeResponse }
]),
'Streaming Test Chat'
);
});
it("should handle rapid consecutive streaming requests", async () => {
let callCount = 0;
mockChatPipelineExecute.mockImplementation(async (input) => {
callCount++;
const callback = input.streamCallback;
await callback(`Response ${callCount}`, true, {});
});
// Send multiple requests rapidly
const promises = Array.from({ length: 3 }, (_, i) =>
supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: `Request ${i + 1}`,
useAdvancedContext: false,
showThinking: false
})
);
const responses = await Promise.all(promises);
// All should succeed
responses.forEach(response => {
expect(response.status).toBe(200);
expect(response.body.success).toBe(true);
});
// Verify all were processed
expect(mockChatPipelineExecute).toHaveBeenCalledTimes(3);
});
it("should handle large streaming responses", async () => {
const largeContent = 'x'.repeat(10000); // 10KB of content
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate chunked delivery of large content
for (let i = 0; i < 10; i++) {
await callback(largeContent.slice(i * 1000, (i + 1) * 1000), false, {});
}
await callback('', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Generate large response",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify multiple chunks were sent
const streamCalls = ws.sendMessageToAllClients.mock.calls.filter(
call => call[0].type === 'llm-stream' && call[0].content
);
expect(streamCalls.length).toBeGreaterThan(5);
});
});

View File

@@ -537,11 +537,11 @@ async function handleStreamingProcess(
}
// Get AI service
const aiServiceManager = await import('../ai_service_manager.js');
const aiServiceManager = await import('../../services/llm/ai_service_manager.js');
await aiServiceManager.default.getOrCreateAnyService();
// Use the chat pipeline directly for streaming
const { ChatPipeline } = await import('../pipeline/chat_pipeline.js');
const { ChatPipeline } = await import('../../services/llm/pipeline/chat_pipeline.js');
const pipeline = new ChatPipeline({
enableStreaming: true,
enableMetrics: true,
@@ -549,7 +549,7 @@ async function handleStreamingProcess(
});
// Get selected model
const { getSelectedModelConfig } = await import('../config/configuration_helpers.js');
const { getSelectedModelConfig } = await import('../../services/llm/config/configuration_helpers.js');
const modelConfig = await getSelectedModelConfig();
if (!modelConfig) {