mirror of
https://github.com/zadam/trilium.git
synced 2025-12-22 16:20:08 +01:00
feat(llm): add tests for streaming
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import { Application } from "express";
|
||||
import { beforeAll, describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { beforeAll, describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import supertest from "supertest";
|
||||
import config from "../../services/config.js";
|
||||
import { refreshAuth } from "../../services/auth.js";
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
// Mock the CSRF protection middleware to allow tests to pass
|
||||
vi.mock("../csrf_protection.js", () => ({
|
||||
@@ -10,6 +11,64 @@ vi.mock("../csrf_protection.js", () => ({
|
||||
generateToken: () => "mock-csrf-token"
|
||||
}));
|
||||
|
||||
// Mock WebSocket service
|
||||
vi.mock("../../services/ws.js", () => ({
|
||||
default: {
|
||||
sendMessageToAllClients: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
// Mock log service
|
||||
vi.mock("../../services/log.js", () => ({
|
||||
default: {
|
||||
info: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
// Mock chat storage service
|
||||
const mockChatStorage = {
|
||||
createChat: vi.fn(),
|
||||
getChat: vi.fn(),
|
||||
updateChat: vi.fn(),
|
||||
getAllChats: vi.fn(),
|
||||
deleteChat: vi.fn()
|
||||
};
|
||||
vi.mock("../../services/llm/storage/chat_storage_service.js", () => ({
|
||||
default: mockChatStorage
|
||||
}));
|
||||
|
||||
// Mock AI service manager
|
||||
const mockAiServiceManager = {
|
||||
getOrCreateAnyService: vi.fn()
|
||||
};
|
||||
vi.mock("../../services/llm/ai_service_manager.js", () => ({
|
||||
default: mockAiServiceManager
|
||||
}));
|
||||
|
||||
// Mock chat pipeline
|
||||
const mockChatPipelineExecute = vi.fn();
|
||||
const MockChatPipeline = vi.fn().mockImplementation(() => ({
|
||||
execute: mockChatPipelineExecute
|
||||
}));
|
||||
vi.mock("../../services/llm/pipeline/chat_pipeline.js", () => ({
|
||||
ChatPipeline: MockChatPipeline
|
||||
}));
|
||||
|
||||
// Mock configuration helpers
|
||||
const mockGetSelectedModelConfig = vi.fn();
|
||||
vi.mock("../../services/llm/config/configuration_helpers.js", () => ({
|
||||
getSelectedModelConfig: mockGetSelectedModelConfig
|
||||
}));
|
||||
|
||||
// Mock options service
|
||||
vi.mock("../../services/options.js", () => ({
|
||||
default: {
|
||||
getOptionBool: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
// Session-based login that properly establishes req.session.loggedIn
|
||||
async function loginWithSession(app: Application) {
|
||||
const response = await supertest(app)
|
||||
@@ -257,7 +316,30 @@ describe("LLM API Tests", () => {
|
||||
let testChatId: string;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Import options service to access mock
|
||||
const options = (await import("../../services/options.js")).default;
|
||||
|
||||
// Setup default mock behaviors
|
||||
options.getOptionBool.mockReturnValue(true); // AI enabled
|
||||
mockAiServiceManager.getOrCreateAnyService.mockResolvedValue({});
|
||||
mockGetSelectedModelConfig.mockResolvedValue({
|
||||
model: 'test-model',
|
||||
provider: 'test-provider'
|
||||
});
|
||||
|
||||
// Create a fresh chat for each test
|
||||
const mockChat = {
|
||||
id: 'streaming-test-chat',
|
||||
title: 'Streaming Test Chat',
|
||||
messages: [],
|
||||
createdAt: new Date().toISOString()
|
||||
};
|
||||
mockChatStorage.createChat.mockResolvedValue(mockChat);
|
||||
mockChatStorage.getChat.mockResolvedValue(mockChat);
|
||||
|
||||
const createResponse = await supertest(app)
|
||||
.post("/api/llm/chat")
|
||||
.set("Cookie", sessionCookie)
|
||||
@@ -269,7 +351,19 @@ describe("LLM API Tests", () => {
|
||||
testChatId = createResponse.body.id;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should initiate streaming for a chat message", async () => {
|
||||
// Setup streaming simulation
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate streaming chunks
|
||||
await callback('Hello', false, {});
|
||||
await callback(' world!', true, {});
|
||||
});
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
@@ -286,6 +380,31 @@ describe("LLM API Tests", () => {
|
||||
success: true,
|
||||
message: "Streaming initiated successfully"
|
||||
});
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify WebSocket messages were sent
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
thinking: undefined
|
||||
});
|
||||
|
||||
// Verify streaming chunks were sent
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
content: 'Hello',
|
||||
done: false
|
||||
});
|
||||
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
content: ' world!',
|
||||
done: true
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty content for streaming", async () => {
|
||||
@@ -338,6 +457,29 @@ describe("LLM API Tests", () => {
|
||||
});
|
||||
|
||||
it("should handle streaming with note mentions", async () => {
|
||||
// Mock becca for note content retrieval
|
||||
const mockBecca = {
|
||||
getNote: vi.fn().mockReturnValue({
|
||||
noteId: 'root',
|
||||
title: 'Root Note',
|
||||
getBlob: () => ({
|
||||
getContent: () => 'Root note content for testing'
|
||||
})
|
||||
})
|
||||
};
|
||||
vi.mocked(await import('../../becca/becca.js')).default = mockBecca;
|
||||
|
||||
// Setup streaming with mention context
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
// Verify mention content is included
|
||||
expect(input.query).toContain('Tell me about this note');
|
||||
expect(input.query).toContain('Root note content for testing');
|
||||
|
||||
const callback = input.streamCallback;
|
||||
await callback('The root note contains', false, {});
|
||||
await callback(' important information.', true, {});
|
||||
});
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
@@ -358,6 +500,250 @@ describe("LLM API Tests", () => {
|
||||
success: true,
|
||||
message: "Streaming initiated successfully"
|
||||
});
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify thinking message was sent
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
thinking: 'Initializing streaming LLM response...'
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle streaming with thinking states", async () => {
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate thinking states
|
||||
await callback('', false, { thinking: 'Analyzing the question...' });
|
||||
await callback('', false, { thinking: 'Formulating response...' });
|
||||
await callback('The answer is', false, {});
|
||||
await callback(' 42.', true, {});
|
||||
});
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "What is the meaning of life?",
|
||||
useAdvancedContext: false,
|
||||
showThinking: true
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify thinking messages
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
thinking: 'Analyzing the question...',
|
||||
done: false
|
||||
});
|
||||
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
thinking: 'Formulating response...',
|
||||
done: false
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle streaming with tool executions", async () => {
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate tool execution
|
||||
await callback('Let me calculate that', false, {});
|
||||
await callback('', false, {
|
||||
toolExecution: {
|
||||
tool: 'calculator',
|
||||
arguments: { expression: '2 + 2' },
|
||||
result: '4',
|
||||
toolCallId: 'call_123',
|
||||
action: 'execute'
|
||||
}
|
||||
});
|
||||
await callback('The result is 4', true, {});
|
||||
});
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "What is 2 + 2?",
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify tool execution message
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
toolExecution: {
|
||||
tool: 'calculator',
|
||||
args: { expression: '2 + 2' },
|
||||
result: '4',
|
||||
toolCallId: 'call_123',
|
||||
action: 'execute',
|
||||
error: undefined
|
||||
},
|
||||
done: false
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle streaming errors gracefully", async () => {
|
||||
mockChatPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "This will fail",
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200); // Still returns 200
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify error message was sent via WebSocket
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
error: 'Error during streaming: Pipeline error',
|
||||
done: true
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle AI disabled state", async () => {
|
||||
// Import options service to access mock
|
||||
const options = (await import("../../services/options.js")).default;
|
||||
options.getOptionBool.mockReturnValue(false); // AI disabled
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "Hello AI",
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify error message about AI being disabled
|
||||
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
|
||||
type: 'llm-stream',
|
||||
chatNoteId: testChatId,
|
||||
error: 'Error during streaming: AI features are disabled. Please enable them in the settings.',
|
||||
done: true
|
||||
});
|
||||
});
|
||||
|
||||
it("should save chat messages after streaming completion", async () => {
|
||||
const completeResponse = 'This is the complete response';
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
await callback(completeResponse, true, {});
|
||||
});
|
||||
|
||||
await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "Save this response",
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
});
|
||||
|
||||
// Wait for async operations
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Verify chat was updated with the complete response
|
||||
expect(mockChatStorage.updateChat).toHaveBeenCalledWith(
|
||||
testChatId,
|
||||
expect.arrayContaining([
|
||||
{ role: 'assistant', content: completeResponse }
|
||||
]),
|
||||
'Streaming Test Chat'
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle rapid consecutive streaming requests", async () => {
|
||||
let callCount = 0;
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
callCount++;
|
||||
const callback = input.streamCallback;
|
||||
await callback(`Response ${callCount}`, true, {});
|
||||
});
|
||||
|
||||
// Send multiple requests rapidly
|
||||
const promises = Array.from({ length: 3 }, (_, i) =>
|
||||
supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: `Request ${i + 1}`,
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
})
|
||||
);
|
||||
|
||||
const responses = await Promise.all(promises);
|
||||
|
||||
// All should succeed
|
||||
responses.forEach(response => {
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.success).toBe(true);
|
||||
});
|
||||
|
||||
// Verify all were processed
|
||||
expect(mockChatPipelineExecute).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("should handle large streaming responses", async () => {
|
||||
const largeContent = 'x'.repeat(10000); // 10KB of content
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate chunked delivery of large content
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await callback(largeContent.slice(i * 1000, (i + 1) * 1000), false, {});
|
||||
}
|
||||
await callback('', true, {});
|
||||
});
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
.set("Cookie", sessionCookie)
|
||||
.send({
|
||||
content: "Generate large response",
|
||||
useAdvancedContext: false,
|
||||
showThinking: false
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Import ws service to access mock
|
||||
const ws = (await import("../../services/ws.js")).default;
|
||||
|
||||
// Verify multiple chunks were sent
|
||||
const streamCalls = ws.sendMessageToAllClients.mock.calls.filter(
|
||||
call => call[0].type === 'llm-stream' && call[0].content
|
||||
);
|
||||
expect(streamCalls.length).toBeGreaterThan(5);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -537,11 +537,11 @@ async function handleStreamingProcess(
|
||||
}
|
||||
|
||||
// Get AI service
|
||||
const aiServiceManager = await import('../ai_service_manager.js');
|
||||
const aiServiceManager = await import('../../services/llm/ai_service_manager.js');
|
||||
await aiServiceManager.default.getOrCreateAnyService();
|
||||
|
||||
// Use the chat pipeline directly for streaming
|
||||
const { ChatPipeline } = await import('../pipeline/chat_pipeline.js');
|
||||
const { ChatPipeline } = await import('../../services/llm/pipeline/chat_pipeline.js');
|
||||
const pipeline = new ChatPipeline({
|
||||
enableStreaming: true,
|
||||
enableMetrics: true,
|
||||
@@ -549,7 +549,7 @@ async function handleStreamingProcess(
|
||||
});
|
||||
|
||||
// Get selected model
|
||||
const { getSelectedModelConfig } = await import('../config/configuration_helpers.js');
|
||||
const { getSelectedModelConfig } = await import('../../services/llm/config/configuration_helpers.js');
|
||||
const modelConfig = await getSelectedModelConfig();
|
||||
|
||||
if (!modelConfig) {
|
||||
|
||||
Reference in New Issue
Block a user