feat(ee): ai chat (#2098)

* feat: ai chat

* feat: ai chat

* sync

* cleanup

* view space button
This commit is contained in:
Philip Okugbe
2026-04-10 19:23:47 +01:00
committed by GitHub
parent da9b43681e
commit 57efb91bd3
63 changed files with 4149 additions and 48 deletions
+2
View File
@@ -75,10 +75,12 @@
"class-transformer": "^0.5.1",
"class-validator": "^0.15.1",
"cookie": "^1.1.1",
"fast-bm25": "0.0.5",
"fastify-ip": "^2.0.0",
"fs-extra": "^11.3.4",
"happy-dom": "20.8.9",
"ioredis": "^5.10.1",
"js-tiktoken": "^1.0.21",
"jsonwebtoken": "^9.0.3",
"kysely": "^0.28.14",
"kysely-migration-cli": "^0.4.2",
@@ -3,6 +3,7 @@ export enum AttachmentType {
WorkspaceIcon = 'workspace-icon',
SpaceIcon = 'space-icon',
File = 'file',
Chat = 'chat',
}
export const validImageExtensions = ['.jpg', '.png', '.jpeg'];
@@ -178,21 +178,29 @@ export class AttachmentController {
}
const attachment = await this.attachmentRepo.findById(fileId);
if (
!attachment ||
attachment.workspaceId !== workspace.id ||
!attachment.pageId ||
!attachment.spaceId
) {
if (!attachment || attachment.workspaceId !== workspace.id) {
throw new NotFoundException();
}
const page = await this.pageRepo.findById(attachment.pageId);
if (!page) {
throw new NotFoundException();
}
if (attachment.aiChatId) {
// Chat-owned attachment: only the user who uploaded (and therefore
// owns the chat, per AttachmentRepo.claimAttachmentsForChat) can
// read it back.
if (attachment.creatorId !== user.id) {
throw new NotFoundException();
}
} else {
if (!attachment.pageId || !attachment.spaceId) {
throw new NotFoundException();
}
await this.pageAccessService.validateCanView(page, user);
const page = await this.pageRepo.findById(attachment.pageId);
if (!page) {
throw new NotFoundException();
}
await this.pageAccessService.validateCanView(page, user);
}
try {
return await this.sendFileResponse(req, res, attachment, 'private');
@@ -71,6 +71,8 @@ export function getAttachmentFolderPath(
return `${workspaceId}/space-logos`;
case AttachmentType.File:
return `${workspaceId}/files`;
case AttachmentType.Chat:
return `${workspaceId}/chat-files`;
default:
return `${workspaceId}/files`;
}
@@ -28,6 +28,11 @@ export class AttachmentProcessor extends WorkerHost implements OnModuleDestroy {
job.data.pageId,
);
}
if (job.name === QueueJob.DELETE_AI_CHAT_ATTACHMENTS) {
await this.attachmentService.handleDeleteAiChatAttachments(
job.data.aiChatId,
);
}
if (
job.name === QueueJob.ATTACHMENT_INDEX_CONTENT ||
job.name === QueueJob.ATTACHMENT_INDEXING
@@ -289,6 +289,31 @@ export class AttachmentService {
);
}
async handleDeleteAiChatAttachments(aiChatId: string) {
try {
const attachments = await this.attachmentRepo.findByAiChatId(aiChatId);
if (!attachments || attachments.length === 0) {
return;
}
await Promise.all(
attachments.map(async (attachment) => {
try {
await this.storageService.delete(attachment.filePath);
await this.attachmentRepo.deleteAttachmentById(attachment.id);
} catch (err) {
this.logger.log(
`DeleteAiChatAttachments: failed to delete attachment ${attachment.id}:`,
err,
);
}
}),
);
} catch (err) {
throw err;
}
}
async handleDeleteSpaceAttachments(spaceId: string) {
try {
const attachments = await this.attachmentRepo.findBySpaceId(spaceId);
+8 -3
View File
@@ -11,6 +11,10 @@ import {
Logger,
} from '@nestjs/common';
import { SkipThrottle, ThrottlerGuard } from '@nestjs/throttler';
import {
AI_CHAT_THROTTLER,
AUTH_THROTTLER,
} from '../../integrations/throttle/throttler-names';
import { LoginDto } from './dto/login.dto';
import { AuthService } from './services/auth.service';
import { SessionService } from '../session/session.service';
@@ -34,6 +38,7 @@ import {
IAuditService,
} from '../../integrations/audit/audit.service';
@SkipThrottle({ [AI_CHAT_THROTTLER]: true })
@UseGuards(ThrottlerGuard)
@Controller('auth')
export class AuthController {
@@ -113,7 +118,7 @@ export class AuthController {
return workspace;
}
@SkipThrottle()
@SkipThrottle({ [AUTH_THROTTLER]: true })
@UseGuards(JwtAuthGuard)
@HttpCode(HttpStatus.OK)
@Post('change-password')
@@ -176,7 +181,7 @@ export class AuthController {
return this.authService.verifyUserToken(verifyUserTokenDto, workspace.id);
}
@SkipThrottle()
@SkipThrottle({ [AUTH_THROTTLER]: true })
@UseGuards(JwtAuthGuard)
@HttpCode(HttpStatus.OK)
@Post('collab-token')
@@ -187,7 +192,7 @@ export class AuthController {
return this.authService.getCollabToken(user, workspace.id);
}
@SkipThrottle()
@SkipThrottle({ [AUTH_THROTTLER]: true })
@UseGuards(JwtAuthGuard)
@HttpCode(HttpStatus.OK)
@Post('logout')
@@ -46,6 +46,10 @@ export class UpdateWorkspaceDto extends PartialType(CreateWorkspaceDto) {
@IsBoolean()
mcpEnabled: boolean;
@IsOptional()
@IsBoolean()
aiChat: boolean;
@IsOptional()
@IsInt()
@Min(1)
@@ -142,7 +142,7 @@ export class WorkspaceService {
status = WorkspaceStatus.Active;
plan = 'standard';
billingEmail = user.email;
settings = { ai: { generative: true } };
settings = { ai: { generative: true, chat: true } };
}
// create workspace
@@ -458,11 +458,26 @@ export class WorkspaceService {
);
}
if (typeof updateWorkspaceDto.aiChat !== 'undefined') {
const prev = settingsBefore?.ai?.chat ?? false;
if (prev !== updateWorkspaceDto.aiChat) {
before.aiChat = prev;
after.aiChat = updateWorkspaceDto.aiChat;
}
await this.workspaceRepo.updateAiSettings(
workspaceId,
'chat',
updateWorkspaceDto.aiChat,
trx,
);
}
delete updateWorkspaceDto.restrictApiToAdmins;
delete updateWorkspaceDto.aiSearch;
delete updateWorkspaceDto.generativeAi;
delete updateWorkspaceDto.disablePublicSharing;
delete updateWorkspaceDto.mcpEnabled;
delete updateWorkspaceDto.aiChat;
await this.workspaceRepo.updateWorkspace(
updateWorkspaceDto,
@@ -0,0 +1,118 @@
import { type Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.createTable('ai_chats')
.ifNotExists()
.addColumn('id', 'uuid', (col) =>
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
)
.addColumn('workspace_id', 'uuid', (col) =>
col.references('workspaces.id').onDelete('cascade').notNull(),
)
.addColumn('creator_id', 'uuid', (col) =>
col.references('users.id').notNull(),
)
.addColumn('title', 'varchar', (col) => col)
.addColumn('created_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('deleted_at', 'timestamptz', (col) => col)
.execute();
await db.schema
.createIndex('idx_ai_chats_workspace_creator')
.ifNotExists()
.on('ai_chats')
.columns(['workspace_id', 'creator_id', 'id'])
.execute();
await db.schema
.createTable('ai_chat_messages')
.ifNotExists()
.addColumn('id', 'uuid', (col) =>
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
)
.addColumn('chat_id', 'uuid', (col) =>
col.references('ai_chats.id').onDelete('cascade').notNull(),
)
.addColumn('workspace_id', 'uuid', (col) =>
col.references('workspaces.id').onDelete('cascade').notNull(),
)
.addColumn('user_id', 'uuid', (col) =>
col.references('users.id').onDelete('set null'),
)
.addColumn('role', 'varchar', (col) => col.notNull())
.addColumn('content', 'text', (col) => col)
.addColumn('tool_calls', 'jsonb', (col) => col)
.addColumn('metadata', 'jsonb', (col) => col)
.addColumn('tsv', sql`tsvector`, (col) => col)
.addColumn('created_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('deleted_at', 'timestamptz', (col) => col)
.execute();
await db.schema
.createIndex('idx_ai_chat_messages_chat_id')
.ifNotExists()
.on('ai_chat_messages')
.columns(['chat_id', 'id'])
.execute();
await db.schema
.createIndex('idx_ai_chat_messages_tsv')
.ifNotExists()
.on('ai_chat_messages')
.using('GIN')
.column('tsv')
.execute();
//ts-vector
await sql`
CREATE OR REPLACE FUNCTION ai_chat_messages_tsvector_trigger() RETURNS trigger AS $$
BEGIN
NEW.tsv := to_tsvector('english', f_unaccent(substring(coalesce(NEW.content, ''), 1, 100000)));
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
`.execute(db);
await sql`
CREATE OR REPLACE TRIGGER ai_chat_messages_tsvector_update
BEFORE INSERT OR UPDATE ON ai_chat_messages
FOR EACH ROW EXECUTE FUNCTION ai_chat_messages_tsvector_trigger();
`.execute(db);
await db.schema
.alterTable('attachments')
.addColumn('ai_chat_id', 'uuid', (col) => col)
.execute();
await db.schema
.createIndex('idx_attachments_ai_chat_id')
.ifNotExists()
.on('attachments')
.column('ai_chat_id')
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema.dropIndex('idx_attachments_ai_chat_id').execute();
await db.schema.alterTable('attachments').dropColumn('ai_chat_id').execute();
await sql`DROP TRIGGER IF EXISTS ai_chat_messages_tsvector_update ON ai_chat_messages`.execute(
db,
);
await sql`DROP FUNCTION IF EXISTS ai_chat_messages_tsvector_trigger`.execute(
db,
);
await db.schema.dropTable('ai_chat_messages').execute();
await db.schema.dropTable('ai_chats').execute();
}
@@ -7,6 +7,7 @@ import {
InsertableAttachment,
UpdatableAttachment,
} from '@docmost/db/types/entity.types';
import { AttachmentType } from '../../../core/attachment/attachment.constants';
@Injectable()
export class AttachmentRepo {
@@ -23,6 +24,7 @@ export class AttachmentRepo {
'creatorId',
'pageId',
'spaceId',
'aiChatId',
'workspaceId',
'createdAt',
'updatedAt',
@@ -44,6 +46,21 @@ export class AttachmentRepo {
.executeTakeFirst();
}
async findByIdWithContent(
attachmentId: string,
opts?: {
trx?: KyselyTransaction;
},
): Promise<Attachment> {
const db = dbOrTx(this.db, opts?.trx);
return db
.selectFrom('attachments')
.select([...this.baseFields, 'textContent'])
.where('id', '=', attachmentId)
.executeTakeFirst();
}
async insertAttachment(
insertableAttachment: InsertableAttachment,
trx?: KyselyTransaction,
@@ -72,6 +89,21 @@ export class AttachmentRepo {
.execute();
}
async findByAiChatId(
aiChatId: string,
opts?: {
trx?: KyselyTransaction;
},
): Promise<Attachment[]> {
const db = dbOrTx(this.db, opts?.trx);
return db
.selectFrom('attachments')
.select(this.baseFields)
.where('aiChatId', '=', aiChatId)
.execute();
}
updateAttachmentsByPageId(
updatableAttachment: UpdatableAttachment,
pageIds: string[],
@@ -97,6 +129,25 @@ export class AttachmentRepo {
.executeTakeFirst();
}
async claimAttachmentsForChat(
attachmentIds: string[],
aiChatId: string,
creatorId: string,
workspaceId: string,
): Promise<void> {
if (attachmentIds.length === 0) return;
await this.db
.updateTable('attachments')
.set({ aiChatId })
.where('id', 'in', attachmentIds)
.where('creatorId', '=', creatorId)
.where('workspaceId', '=', workspaceId)
.where('type', '=', AttachmentType.Chat)
.where('aiChatId', 'is', null)
.execute();
}
async deleteAttachmentById(attachmentId: string): Promise<void> {
await this.db
.deleteFrom('attachments')
+28
View File
@@ -43,6 +43,7 @@ export interface ApiKeys {
}
export interface Attachments {
aiChatId: string | null;
createdAt: Generated<Timestamp>;
creatorId: string;
deletedAt: Timestamp | null;
@@ -429,6 +430,31 @@ export interface PagePermissions {
updatedAt: Generated<Timestamp>;
}
export interface AiChats {
id: Generated<string>;
workspaceId: string;
creatorId: string;
title: string | null;
createdAt: Generated<Timestamp>;
updatedAt: Generated<Timestamp>;
deletedAt: Timestamp | null;
}
export interface AiChatMessages {
id: Generated<string>;
chatId: string;
workspaceId: string;
userId: string | null;
role: string;
content: string | null;
toolCalls: Json | null;
metadata: Json | null;
tsv: string | null;
createdAt: Generated<Timestamp>;
updatedAt: Generated<Timestamp>;
deletedAt: Timestamp | null;
}
export interface UserSessions {
id: Generated<string>;
userId: string;
@@ -445,6 +471,8 @@ export interface UserSessions {
}
export interface DB {
aiChats: AiChats;
aiChatMessages: AiChatMessages;
apiKeys: ApiKeys;
attachments: Attachments;
audit: Audit;
@@ -1,5 +1,7 @@
import { Insertable, Selectable, Updateable } from 'kysely';
import {
AiChats,
AiChatMessages,
Attachments,
Comments,
Groups,
@@ -29,6 +31,21 @@ import {
} from './db';
import { PageEmbeddings } from '@docmost/db/types/embeddings.types';
// AI Chat
export type AiChat = Selectable<AiChats>;
export type InsertableAiChat = Insertable<AiChats>;
export type UpdatableAiChat = Updateable<Omit<AiChats, 'id'>>;
// AI Chat Message
// `tsv` is an internal tsvector column maintained by a trigger for
// full-text search. It is omitted from the public type so it never leaks
// into HTTP responses or the chat history fed to the language model.
export type AiChatMessage = Omit<Selectable<AiChatMessages>, 'tsv'>;
export type InsertableAiChatMessage = Omit<
Insertable<AiChatMessages>,
'tsv'
>;
// Workspace
export type Workspace = Selectable<Workspaces>;
export type InsertableWorkspace = Insertable<Workspaces>;
@@ -252,6 +252,13 @@ export class EnvironmentService {
return this.configService.get<string>('AI_COMPLETION_MODEL');
}
getAiChatModel(): string {
return (
this.configService.get<string>('AI_CHAT_MODEL') ||
this.configService.get<string>('AI_COMPLETION_MODEL')
);
}
getAiEmbeddingDimension(): number {
return parseInt(
this.configService.get<string>('AI_EMBEDDING_DIMENSION'),
@@ -353,7 +353,7 @@ export class ExportService {
if (attachmentIds.length > 0) {
const attachments = await this.db
.selectFrom('attachments')
.selectAll()
.select(['id', 'fileName', 'filePath'])
.where('id', 'in', attachmentIds)
.where('spaceId', '=', spaceId)
.execute();
@@ -17,6 +17,7 @@ export enum QueueJob {
ATTACHMENT_INDEX_CONTENT = 'attachment-index-content',
ATTACHMENT_INDEXING = 'attachment-indexing',
DELETE_PAGE_ATTACHMENTS = 'delete-page-attachments',
DELETE_AI_CHAT_ATTACHMENTS = 'delete-ai-chat-attachments',
DELETE_USER_AVATARS = 'delete-user-avatars',
@@ -4,6 +4,7 @@ import { ThrottlerStorageRedisService } from '@nest-lab/throttler-storage-redis'
import { EnvironmentService } from '../environment/environment.service';
import { EnvironmentModule } from '../environment/environment.module';
import { parseRedisUrl } from '../../common/helpers';
import { AUTH_THROTTLER, AI_CHAT_THROTTLER } from './throttler-names';
import Redis from 'ioredis';
@Module({
@@ -14,7 +15,10 @@ import Redis from 'ioredis';
const redisConfig = parseRedisUrl(environmentService.getRedisUrl());
return {
throttlers: [{ name: 'auth', ttl: 60_000, limit: 10 }],
throttlers: [
{ name: AUTH_THROTTLER, ttl: 60_000, limit: 10 },
{ name: AI_CHAT_THROTTLER, ttl: 60_000, limit: 25 },
],
errorMessage: 'Too many requests',
storage: new ThrottlerStorageRedisService(
new Redis({
@@ -0,0 +1,2 @@
export const AUTH_THROTTLER = 'auth';
export const AI_CHAT_THROTTLER = 'ai-chat';
@@ -0,0 +1,13 @@
import { Injectable } from '@nestjs/common';
import { ThrottlerGuard } from '@nestjs/throttler';
type AuthedRequest = { user?: { id?: string } };
@Injectable()
export class UserThrottlerGuard extends ThrottlerGuard {
protected async getTracker(req: AuthedRequest): Promise<string> {
const userId = req.user?.id;
if (userId) return `user:${userId}`;
return super.getTracker(req as Parameters<ThrottlerGuard['getTracker']>[0]);
}
}