feat(EE): AI vector search (#1691)

* WIP

* AI module - init

* WIP

* sync

* WIP

* refactor naming

* new columns

* sync

* sync

* fix search bug

* stream response

* WIP

* feat embeddings sync

* refine

* Add workspaceId to page events

* refine

* WIP

* add translation string

* sync

* reset ai answer on query change

* hide AI search in cloud

* capture streaming error

* sync
This commit is contained in:
Philip Okugbe
2025-12-01 11:50:25 +00:00
committed by GitHub
parent c3b350d943
commit 9fb16bc842
46 changed files with 1608 additions and 68 deletions
@@ -0,0 +1,113 @@
import React, { useMemo } from "react";
import { Paper, Text, Group, Stack, Loader, Box } from "@mantine/core";
import { IconSparkles, IconFileText } from "@tabler/icons-react";
import { Link } from "react-router-dom";
import { IAiSearchResponse } from "../services/ai-search-service.ts";
import { buildPageUrl } from "@/features/page/page.utils.ts";
import { markdownToHtml } from "@docmost/editor-ext";
import DOMPurify from "dompurify";
import { useTranslation } from "react-i18next";
interface AiSearchResultProps {
result?: IAiSearchResponse;
isLoading?: boolean;
streamingAnswer?: string;
streamingSources?: any[];
}
export function AiSearchResult({
result,
isLoading,
streamingAnswer = "",
streamingSources = [],
}: AiSearchResultProps) {
const { t } = useTranslation();
// Use streaming data if available, otherwise fall back to result
const answer = streamingAnswer || result?.answer || "";
const sources =
streamingSources.length > 0 ? streamingSources : result?.sources || [];
// Deduplicate sources by pageId, keeping the one with highest similarity
const deduplicatedSources = useMemo(() => {
if (!sources || sources.length === 0) return [];
const pageMap = new Map();
sources.forEach((source) => {
const existing = pageMap.get(source.pageId);
if (!existing || source.similarity > existing.similarity) {
pageMap.set(source.pageId, source);
}
});
return Array.from(pageMap.values());
}, [sources]);
if (isLoading && !answer) {
return (
<Paper p="md" radius="md" withBorder>
<Group>
<Loader size="sm" />
<Text size="sm">{t("AI is thinking...")}</Text>
</Group>
</Paper>
);
}
if (!answer && !isLoading) {
return null;
}
return (
<Stack gap="md" p="md">
<Paper p="md" radius="md" withBorder>
<Group gap="xs" mb="sm">
<IconSparkles size={20} color="var(--mantine-color-blue-6)" />
<Text fw={600} size="sm">
{t("AI Answer")}
</Text>
{isLoading && <Loader size="xs" />}
</Group>
<div
dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(markdownToHtml(answer) as string),
}}
/>
</Paper>
{deduplicatedSources.length > 0 && (
<Stack gap="xs">
<Text size="xs" fw={600} c="dimmed">
{t("Sources")}
</Text>
{deduplicatedSources.map((source) => (
<Box
key={source.pageId}
component={Link}
to={buildPageUrl(source.spaceSlug, source.slugId, source.title)}
style={{
textDecoration: "none",
color: "inherit",
display: "block",
}}
>
<Paper
p="xs"
radius="sm"
withBorder
style={{ cursor: "pointer" }}
>
<Group gap="xs">
<IconFileText size={16} />
<Text size="sm" truncate>
{source.title}
</Text>
</Group>
</Paper>
</Box>
))}
</Stack>
)}
</Stack>
);
}
@@ -0,0 +1,69 @@
import { Group, Text, Switch, MantineSize, Title } from "@mantine/core";
import { useAtom } from "jotai";
import { workspaceAtom } from "@/features/user/atoms/current-user-atom.ts";
import React, { useState } from "react";
import { useTranslation } from "react-i18next";
import { updateWorkspace } from "@/features/workspace/services/workspace-service.ts";
import { notifications } from "@mantine/notifications";
import { isCloud } from "@/lib/config.ts";
import useLicense from "@/ee/hooks/use-license.tsx";
export default function EnableAiSearch() {
const { t } = useTranslation();
return (
<>
<Group justify="space-between" wrap="nowrap" gap="xl">
<div>
<Text size="md">{t("AI-powered search (Ask AI)")}</Text>
<Text size="sm" c="dimmed">
{t(
"AI search uses vector embeddings to provide semantic search capabilities across your workspace content.",
)}
</Text>
</div>
<AiSearchToggle />
</Group>
</>
);
}
interface AiSearchToggleProps {
size?: MantineSize;
label?: string;
}
export function AiSearchToggle({ size, label }: AiSearchToggleProps) {
const { t } = useTranslation();
const [workspace, setWorkspace] = useAtom(workspaceAtom);
const [checked, setChecked] = useState(workspace?.settings?.ai?.search);
const { hasLicenseKey } = useLicense();
const hasAccess = isCloud() || (!isCloud() && hasLicenseKey);
const handleChange = async (event: React.ChangeEvent<HTMLInputElement>) => {
const value = event.currentTarget.checked;
try {
const updatedWorkspace = await updateWorkspace({ aiSearch: value });
setChecked(value);
setWorkspace(updatedWorkspace);
} catch (err) {
notifications.show({
message: err?.response?.data?.message,
color: "red",
});
}
};
return (
<Switch
size={size}
label={label}
labelPosition="left"
defaultChecked={checked}
onChange={handleChange}
disabled={!hasAccess}
aria-label={t("Toggle AI search")}
/>
);
}
@@ -0,0 +1,46 @@
import { useMutation, UseMutationResult } from "@tanstack/react-query";
import { useState, useCallback } from "react";
import { askAi, IAiSearchResponse } from "@/ee/ai/services/ai-search-service.ts";
import { IPageSearchParams } from "@/features/search/types/search.types.ts";
// @ts-ignore
interface UseAiSearchResult extends UseMutationResult<IAiSearchResponse, Error, IPageSearchParams> {
streamingAnswer: string;
streamingSources: any[];
clearStreaming: () => void;
}
export function useAiSearch(): UseAiSearchResult {
const [streamingAnswer, setStreamingAnswer] = useState("");
const [streamingSources, setStreamingSources] = useState<any[]>([]);
const clearStreaming = useCallback(() => {
setStreamingAnswer("");
setStreamingSources([]);
}, []);
const mutation = useMutation({
mutationFn: async (params: IPageSearchParams & { contentType?: string }) => {
setStreamingAnswer("");
setStreamingSources([]);
const { contentType, ...apiParams } = params;
return await askAi(apiParams, (chunk) => {
if (chunk.content) {
setStreamingAnswer((prev) => prev + chunk.content);
}
if (chunk.sources) {
setStreamingSources(chunk.sources);
}
});
},
});
return {
...mutation,
streamingAnswer,
streamingSources,
clearStreaming,
};
}
+61
View File
@@ -0,0 +1,61 @@
import { useState, useCallback, useRef } from "react";
import { useAiGenerateStreamMutation } from "@/ee/ai/queries/ai-query.ts";
import { AiGenerateDto } from "@/ee/ai/types/ai.types.ts";
export function useAiStream() {
const [content, setContent] = useState("");
const [isStreaming, setIsStreaming] = useState(false);
const abortControllerRef = useRef<AbortController | null>(null);
const mutation = useAiGenerateStreamMutation();
const startStream = useCallback(
async (data: AiGenerateDto) => {
setContent("");
setIsStreaming(true);
try {
const controller = await mutation.mutateAsync({
...data,
onChunk: (chunk) => {
setContent((prev) => prev + chunk.content);
},
onError: (error) => {
console.error("AI stream error:", error);
setIsStreaming(false);
},
onComplete: () => {
setIsStreaming(false);
},
});
abortControllerRef.current = controller;
} catch (error) {
console.error("Failed to start stream:", error);
setIsStreaming(false);
}
},
[mutation]
);
const stopStream = useCallback(() => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
setIsStreaming(false);
}
}, []);
const resetContent = useCallback(() => {
setContent("");
}, []);
return {
content,
isStreaming,
startStream,
stopStream,
resetContent,
isLoading: mutation.isPending,
error: mutation.error,
};
}
@@ -0,0 +1,46 @@
import { Helmet } from "react-helmet-async";
import { getAppName, isCloud } from "@/lib/config.ts";
import SettingsTitle from "@/components/settings/settings-title.tsx";
import React from "react";
import useUserRole from "@/hooks/use-user-role.tsx";
import { useTranslation } from "react-i18next";
import useLicense from "@/ee/hooks/use-license.tsx";
import EnableAiSearch from "@/ee/ai/components/enable-ai-search.tsx";
import { Alert } from "@mantine/core";
import { IconInfoCircle } from "@tabler/icons-react";
export default function AiSettings() {
const { t } = useTranslation();
const { isAdmin } = useUserRole();
const { hasLicenseKey } = useLicense();
if (!isAdmin) {
return null;
}
const hasAccess = isCloud() || (!isCloud() && hasLicenseKey);
return (
<>
<Helmet>
<title>AI - {getAppName()}</title>
</Helmet>
<SettingsTitle title={t("AI settings")} />
{!hasAccess && (
<Alert
icon={<IconInfoCircle />}
title={t("Enterprise feature")}
color="blue"
mb="lg"
>
{t(
"AI is only available in the Docmost enterprise edition. Contact sales@docmost.com.",
)}
</Alert>
)}
<EnableAiSearch />
</>
);
}
+44
View File
@@ -0,0 +1,44 @@
import {
useMutation,
UseMutationResult,
useQuery,
UseQueryResult,
} from "@tanstack/react-query";
import {
generateAiContent,
generateAiContentStream,
} from "@/ee/ai/services/ai-service.ts";
import {
AiConfigResponse,
AiContentResponse,
AiGenerateDto,
AiStreamChunk,
AiStreamError,
} from "@/ee/ai/types/ai.types.ts";
export function useAiGenerateMutation(): UseMutationResult<
AiContentResponse,
Error,
AiGenerateDto
> {
return useMutation({
mutationFn: (data: AiGenerateDto) => generateAiContent(data),
});
}
interface StreamCallbacks {
onChunk: (chunk: AiStreamChunk) => void;
onError?: (error: AiStreamError) => void;
onComplete?: () => void;
}
export function useAiGenerateStreamMutation(): UseMutationResult<
AbortController,
Error,
AiGenerateDto & StreamCallbacks
> {
return useMutation({
mutationFn: ({ onChunk, onError, onComplete, ...data }) =>
generateAiContentStream(data, onChunk, onError, onComplete),
});
}
@@ -0,0 +1,79 @@
import api from "@/lib/api-client.ts";
import { IPageSearchParams } from "@/features/search/types/search.types.ts";
export interface IAiSearchResponse {
answer: string;
sources?: Array<{
pageId: string;
title: string;
slugId: string;
spaceSlug: string;
similarity: number;
distance: number;
chunkIndex: number;
excerpt: string;
}>;
}
export async function askAi(
params: IPageSearchParams,
onChunk?: (chunk: { content?: string; sources?: any[] }) => void,
): Promise<IAiSearchResponse> {
const response = await fetch("/api/ai/ask", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
body: JSON.stringify(params),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body?.getReader();
const decoder = new TextDecoder();
let answer = "";
let sources: any[] = [];
if (reader) {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split("\n");
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6);
if (data === "[DONE]") break;
try {
const parsed = JSON.parse(data);
if (parsed.error) {
throw new Error(parsed.error);
}
if (parsed.content) {
answer += parsed.content;
onChunk?.({ content: parsed.content });
}
if (parsed.sources) {
sources = parsed.sources;
onChunk?.({ sources: parsed.sources });
}
} catch (e) {
if (e instanceof Error) {
throw e;
}
// Skip invalid JSON
}
}
}
}
}
return { answer, sources };
}
@@ -0,0 +1,89 @@
import api from "@/lib/api-client.ts";
import {
AiGenerateDto,
AiContentResponse,
AiStreamChunk,
AiStreamError,
} from "@/ee/ai/types/ai.types.ts";
export async function generateAiContent(
data: AiGenerateDto,
): Promise<AiContentResponse> {
const req = await api.post<AiContentResponse>("/ai/generate", data);
return req.data;
}
export async function generateAiContentStream(
data: AiGenerateDto,
onChunk: (chunk: AiStreamChunk) => void,
onError?: (error: AiStreamError) => void,
onComplete?: () => void,
): Promise<AbortController> {
const abortController = new AbortController();
try {
const response = await fetch("/api/ai/generate/stream", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(data),
signal: abortController.signal,
credentials: "include", // This ensures cookies are sent, matching axios withCredentials
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body?.getReader();
const decoder = new TextDecoder();
if (!reader) {
throw new Error("Response body is not readable");
}
const processStream = async () => {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
const lines = chunk.split("\n");
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6);
if (data === "[DONE]") {
onComplete?.();
return;
}
try {
const parsed = JSON.parse(data);
if (parsed.error) {
onError?.(parsed);
} else {
onChunk(parsed);
}
} catch (e) {
// Ignore parse errors for incomplete chunks
}
}
}
}
} catch (error) {
if (error.name !== "AbortError") {
onError?.({ error: error.message });
}
} finally {
reader.releaseLock();
}
};
processStream();
} catch (error) {
onError?.({ error: error.message });
}
return abortController;
}
+40
View File
@@ -0,0 +1,40 @@
export enum AiAction {
IMPROVE_WRITING = "improve_writing",
FIX_SPELLING_GRAMMAR = "fix_spelling_grammar",
MAKE_SHORTER = "make_shorter",
MAKE_LONGER = "make_longer",
SIMPLIFY = "simplify",
CHANGE_TONE = "change_tone",
SUMMARIZE = "summarize",
CONTINUE_WRITING = "continue_writing",
TRANSLATE = "translate",
CUSTOM = "custom",
}
export interface AiGenerateDto {
action?: AiAction;
content: string;
prompt?: string;
}
export interface AiContentResponse {
content: string;
usage?: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
}
export interface AiConfigResponse {
configured: boolean;
availableActions: AiAction[];
}
export interface AiStreamChunk {
content: string;
}
export interface AiStreamError {
error: string;
}
@@ -11,7 +11,7 @@ export default function OssDetails() {
withTableBorder
>
<Table.Caption>
To unlock enterprise features like SSO, MFA, Resolve comments, contact sales@docmost.com.
To unlock enterprise features like AI, SSO, MFA, Resolve comments, contact sales@docmost.com.
</Table.Caption>
<Table.Tbody>
<Table.Tr>