mirror of
https://github.com/docmost/docmost.git
synced 2026-05-08 23:33:09 +08:00
feat: ai chat
This commit is contained in:
@@ -3,6 +3,7 @@ export enum AttachmentType {
|
||||
WorkspaceIcon = 'workspace-icon',
|
||||
SpaceIcon = 'space-icon',
|
||||
File = 'file',
|
||||
Chat = 'chat',
|
||||
}
|
||||
|
||||
export const validImageExtensions = ['.jpg', '.png', '.jpeg'];
|
||||
|
||||
@@ -178,21 +178,29 @@ export class AttachmentController {
|
||||
}
|
||||
|
||||
const attachment = await this.attachmentRepo.findById(fileId);
|
||||
if (
|
||||
!attachment ||
|
||||
attachment.workspaceId !== workspace.id ||
|
||||
!attachment.pageId ||
|
||||
!attachment.spaceId
|
||||
) {
|
||||
if (!attachment || attachment.workspaceId !== workspace.id) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
const page = await this.pageRepo.findById(attachment.pageId);
|
||||
if (!page) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
if (attachment.aiChatId) {
|
||||
// Chat-owned attachment: only the user who uploaded (and therefore
|
||||
// owns the chat, per AttachmentRepo.claimAttachmentsForChat) can
|
||||
// read it back.
|
||||
if (attachment.creatorId !== user.id) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
} else {
|
||||
if (!attachment.pageId || !attachment.spaceId) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await this.pageAccessService.validateCanView(page, user);
|
||||
const page = await this.pageRepo.findById(attachment.pageId);
|
||||
if (!page) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await this.pageAccessService.validateCanView(page, user);
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.sendFileResponse(req, res, attachment, 'private');
|
||||
|
||||
@@ -71,6 +71,8 @@ export function getAttachmentFolderPath(
|
||||
return `${workspaceId}/space-logos`;
|
||||
case AttachmentType.File:
|
||||
return `${workspaceId}/files`;
|
||||
case AttachmentType.Chat:
|
||||
return `${workspaceId}/chat-files`;
|
||||
default:
|
||||
return `${workspaceId}/files`;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,11 @@ export class AttachmentProcessor extends WorkerHost implements OnModuleDestroy {
|
||||
job.data.pageId,
|
||||
);
|
||||
}
|
||||
if (job.name === QueueJob.DELETE_AI_CHAT_ATTACHMENTS) {
|
||||
await this.attachmentService.handleDeleteAiChatAttachments(
|
||||
job.data.aiChatId,
|
||||
);
|
||||
}
|
||||
if (
|
||||
job.name === QueueJob.ATTACHMENT_INDEX_CONTENT ||
|
||||
job.name === QueueJob.ATTACHMENT_INDEXING
|
||||
|
||||
@@ -289,6 +289,31 @@ export class AttachmentService {
|
||||
);
|
||||
}
|
||||
|
||||
async handleDeleteAiChatAttachments(aiChatId: string) {
|
||||
try {
|
||||
const attachments = await this.attachmentRepo.findByAiChatId(aiChatId);
|
||||
if (!attachments || attachments.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
attachments.map(async (attachment) => {
|
||||
try {
|
||||
await this.storageService.delete(attachment.filePath);
|
||||
await this.attachmentRepo.deleteAttachmentById(attachment.id);
|
||||
} catch (err) {
|
||||
this.logger.log(
|
||||
`DeleteAiChatAttachments: failed to delete attachment ${attachment.id}:`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async handleDeleteSpaceAttachments(spaceId: string) {
|
||||
try {
|
||||
const attachments = await this.attachmentRepo.findBySpaceId(spaceId);
|
||||
|
||||
@@ -11,6 +11,10 @@ import {
|
||||
Logger,
|
||||
} from '@nestjs/common';
|
||||
import { SkipThrottle, ThrottlerGuard } from '@nestjs/throttler';
|
||||
import {
|
||||
AI_CHAT_THROTTLER,
|
||||
AUTH_THROTTLER,
|
||||
} from '../../integrations/throttle/throttler-names';
|
||||
import { LoginDto } from './dto/login.dto';
|
||||
import { AuthService } from './services/auth.service';
|
||||
import { SessionService } from '../session/session.service';
|
||||
@@ -34,6 +38,7 @@ import {
|
||||
IAuditService,
|
||||
} from '../../integrations/audit/audit.service';
|
||||
|
||||
@SkipThrottle({ [AI_CHAT_THROTTLER]: true })
|
||||
@UseGuards(ThrottlerGuard)
|
||||
@Controller('auth')
|
||||
export class AuthController {
|
||||
@@ -113,7 +118,7 @@ export class AuthController {
|
||||
return workspace;
|
||||
}
|
||||
|
||||
@SkipThrottle()
|
||||
@SkipThrottle({ [AUTH_THROTTLER]: true })
|
||||
@UseGuards(JwtAuthGuard)
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@Post('change-password')
|
||||
@@ -176,7 +181,7 @@ export class AuthController {
|
||||
return this.authService.verifyUserToken(verifyUserTokenDto, workspace.id);
|
||||
}
|
||||
|
||||
@SkipThrottle()
|
||||
@SkipThrottle({ [AUTH_THROTTLER]: true })
|
||||
@UseGuards(JwtAuthGuard)
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@Post('collab-token')
|
||||
@@ -187,7 +192,7 @@ export class AuthController {
|
||||
return this.authService.getCollabToken(user, workspace.id);
|
||||
}
|
||||
|
||||
@SkipThrottle()
|
||||
@SkipThrottle({ [AUTH_THROTTLER]: true })
|
||||
@UseGuards(JwtAuthGuard)
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@Post('logout')
|
||||
|
||||
@@ -142,7 +142,7 @@ export class WorkspaceService {
|
||||
status = WorkspaceStatus.Active;
|
||||
plan = 'standard';
|
||||
billingEmail = user.email;
|
||||
settings = { ai: { generative: true } };
|
||||
settings = { ai: { generative: true, chat: true } };
|
||||
}
|
||||
|
||||
// create workspace
|
||||
|
||||
+58
@@ -3,6 +3,7 @@ import { type Kysely, sql } from 'kysely';
|
||||
export async function up(db: Kysely<any>): Promise<void> {
|
||||
await db.schema
|
||||
.createTable('ai_chats')
|
||||
.ifNotExists()
|
||||
.addColumn('id', 'uuid', (col) =>
|
||||
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
|
||||
)
|
||||
@@ -19,16 +20,19 @@ export async function up(db: Kysely<any>): Promise<void> {
|
||||
.addColumn('updated_at', 'timestamptz', (col) =>
|
||||
col.notNull().defaultTo(sql`now()`),
|
||||
)
|
||||
.addColumn('deleted_at', 'timestamptz', (col) => col)
|
||||
.execute();
|
||||
|
||||
await db.schema
|
||||
.createIndex('idx_ai_chats_workspace_creator')
|
||||
.ifNotExists()
|
||||
.on('ai_chats')
|
||||
.columns(['workspace_id', 'creator_id', 'id'])
|
||||
.execute();
|
||||
|
||||
await db.schema
|
||||
.createTable('ai_chat_messages')
|
||||
.ifNotExists()
|
||||
.addColumn('id', 'uuid', (col) =>
|
||||
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
|
||||
)
|
||||
@@ -38,23 +42,77 @@ export async function up(db: Kysely<any>): Promise<void> {
|
||||
.addColumn('workspace_id', 'uuid', (col) =>
|
||||
col.references('workspaces.id').onDelete('cascade').notNull(),
|
||||
)
|
||||
.addColumn('user_id', 'uuid', (col) =>
|
||||
col.references('users.id').onDelete('set null'),
|
||||
)
|
||||
.addColumn('role', 'varchar', (col) => col.notNull())
|
||||
.addColumn('content', 'text', (col) => col)
|
||||
.addColumn('tool_calls', 'jsonb', (col) => col)
|
||||
.addColumn('metadata', 'jsonb', (col) => col)
|
||||
.addColumn('tsv', sql`tsvector`, (col) => col)
|
||||
.addColumn('created_at', 'timestamptz', (col) =>
|
||||
col.notNull().defaultTo(sql`now()`),
|
||||
)
|
||||
.addColumn('updated_at', 'timestamptz', (col) =>
|
||||
col.notNull().defaultTo(sql`now()`),
|
||||
)
|
||||
.addColumn('deleted_at', 'timestamptz', (col) => col)
|
||||
.execute();
|
||||
|
||||
await db.schema
|
||||
.createIndex('idx_ai_chat_messages_chat_id')
|
||||
.ifNotExists()
|
||||
.on('ai_chat_messages')
|
||||
.columns(['chat_id', 'id'])
|
||||
.execute();
|
||||
|
||||
await db.schema
|
||||
.createIndex('idx_ai_chat_messages_tsv')
|
||||
.ifNotExists()
|
||||
.on('ai_chat_messages')
|
||||
.using('GIN')
|
||||
.column('tsv')
|
||||
.execute();
|
||||
|
||||
//ts-vector
|
||||
await sql`
|
||||
CREATE OR REPLACE FUNCTION ai_chat_messages_tsvector_trigger() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
NEW.tsv := to_tsvector('english', f_unaccent(substring(coalesce(NEW.content, ''), 1, 100000)));
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
`.execute(db);
|
||||
|
||||
await sql`
|
||||
CREATE OR REPLACE TRIGGER ai_chat_messages_tsvector_update
|
||||
BEFORE INSERT OR UPDATE ON ai_chat_messages
|
||||
FOR EACH ROW EXECUTE FUNCTION ai_chat_messages_tsvector_trigger();
|
||||
`.execute(db);
|
||||
|
||||
await db.schema
|
||||
.alterTable('attachments')
|
||||
.addColumn('ai_chat_id', 'uuid', (col) => col)
|
||||
.execute();
|
||||
|
||||
await db.schema
|
||||
.createIndex('idx_attachments_ai_chat_id')
|
||||
.ifNotExists()
|
||||
.on('attachments')
|
||||
.column('ai_chat_id')
|
||||
.execute();
|
||||
}
|
||||
|
||||
export async function down(db: Kysely<any>): Promise<void> {
|
||||
await db.schema.dropIndex('idx_attachments_ai_chat_id').execute();
|
||||
await db.schema.alterTable('attachments').dropColumn('ai_chat_id').execute();
|
||||
|
||||
await sql`DROP TRIGGER IF EXISTS ai_chat_messages_tsvector_update ON ai_chat_messages`.execute(
|
||||
db,
|
||||
);
|
||||
await sql`DROP FUNCTION IF EXISTS ai_chat_messages_tsvector_trigger`.execute(
|
||||
db,
|
||||
);
|
||||
await db.schema.dropTable('ai_chat_messages').execute();
|
||||
await db.schema.dropTable('ai_chats').execute();
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
InsertableAttachment,
|
||||
UpdatableAttachment,
|
||||
} from '@docmost/db/types/entity.types';
|
||||
import { AttachmentType } from '../../../core/attachment/attachment.constants';
|
||||
|
||||
@Injectable()
|
||||
export class AttachmentRepo {
|
||||
@@ -23,6 +24,7 @@ export class AttachmentRepo {
|
||||
'creatorId',
|
||||
'pageId',
|
||||
'spaceId',
|
||||
'aiChatId',
|
||||
'workspaceId',
|
||||
'createdAt',
|
||||
'updatedAt',
|
||||
@@ -87,6 +89,21 @@ export class AttachmentRepo {
|
||||
.execute();
|
||||
}
|
||||
|
||||
async findByAiChatId(
|
||||
aiChatId: string,
|
||||
opts?: {
|
||||
trx?: KyselyTransaction;
|
||||
},
|
||||
): Promise<Attachment[]> {
|
||||
const db = dbOrTx(this.db, opts?.trx);
|
||||
|
||||
return db
|
||||
.selectFrom('attachments')
|
||||
.select(this.baseFields)
|
||||
.where('aiChatId', '=', aiChatId)
|
||||
.execute();
|
||||
}
|
||||
|
||||
updateAttachmentsByPageId(
|
||||
updatableAttachment: UpdatableAttachment,
|
||||
pageIds: string[],
|
||||
@@ -112,6 +129,25 @@ export class AttachmentRepo {
|
||||
.executeTakeFirst();
|
||||
}
|
||||
|
||||
async claimAttachmentsForChat(
|
||||
attachmentIds: string[],
|
||||
aiChatId: string,
|
||||
creatorId: string,
|
||||
workspaceId: string,
|
||||
): Promise<void> {
|
||||
if (attachmentIds.length === 0) return;
|
||||
|
||||
await this.db
|
||||
.updateTable('attachments')
|
||||
.set({ aiChatId })
|
||||
.where('id', 'in', attachmentIds)
|
||||
.where('creatorId', '=', creatorId)
|
||||
.where('workspaceId', '=', workspaceId)
|
||||
.where('type', '=', AttachmentType.Chat)
|
||||
.where('aiChatId', 'is', null)
|
||||
.execute();
|
||||
}
|
||||
|
||||
async deleteAttachmentById(attachmentId: string): Promise<void> {
|
||||
await this.db
|
||||
.deleteFrom('attachments')
|
||||
|
||||
+6
@@ -43,6 +43,7 @@ export interface ApiKeys {
|
||||
}
|
||||
|
||||
export interface Attachments {
|
||||
aiChatId: string | null;
|
||||
createdAt: Generated<Timestamp>;
|
||||
creatorId: string;
|
||||
deletedAt: Timestamp | null;
|
||||
@@ -436,17 +437,22 @@ export interface AiChats {
|
||||
title: string | null;
|
||||
createdAt: Generated<Timestamp>;
|
||||
updatedAt: Generated<Timestamp>;
|
||||
deletedAt: Timestamp | null;
|
||||
}
|
||||
|
||||
export interface AiChatMessages {
|
||||
id: Generated<string>;
|
||||
chatId: string;
|
||||
workspaceId: string;
|
||||
userId: string | null;
|
||||
role: string;
|
||||
content: string | null;
|
||||
toolCalls: Json | null;
|
||||
metadata: Json | null;
|
||||
tsv: string | null;
|
||||
createdAt: Generated<Timestamp>;
|
||||
updatedAt: Generated<Timestamp>;
|
||||
deletedAt: Timestamp | null;
|
||||
}
|
||||
|
||||
export interface UserSessions {
|
||||
|
||||
@@ -37,8 +37,14 @@ export type InsertableAiChat = Insertable<AiChats>;
|
||||
export type UpdatableAiChat = Updateable<Omit<AiChats, 'id'>>;
|
||||
|
||||
// AI Chat Message
|
||||
export type AiChatMessage = Selectable<AiChatMessages>;
|
||||
export type InsertableAiChatMessage = Insertable<AiChatMessages>;
|
||||
// `tsv` is an internal tsvector column maintained by a trigger for
|
||||
// full-text search. It is omitted from the public type so it never leaks
|
||||
// into HTTP responses or the chat history fed to the language model.
|
||||
export type AiChatMessage = Omit<Selectable<AiChatMessages>, 'tsv'>;
|
||||
export type InsertableAiChatMessage = Omit<
|
||||
Insertable<AiChatMessages>,
|
||||
'tsv'
|
||||
>;
|
||||
|
||||
// Workspace
|
||||
export type Workspace = Selectable<Workspaces>;
|
||||
|
||||
+1
-1
Submodule apps/server/src/ee updated: a9d3d46869...a3e4e9c72c
@@ -353,7 +353,7 @@ export class ExportService {
|
||||
if (attachmentIds.length > 0) {
|
||||
const attachments = await this.db
|
||||
.selectFrom('attachments')
|
||||
.selectAll()
|
||||
.select(['id', 'fileName', 'filePath'])
|
||||
.where('id', 'in', attachmentIds)
|
||||
.where('spaceId', '=', spaceId)
|
||||
.execute();
|
||||
|
||||
@@ -17,6 +17,7 @@ export enum QueueJob {
|
||||
ATTACHMENT_INDEX_CONTENT = 'attachment-index-content',
|
||||
ATTACHMENT_INDEXING = 'attachment-indexing',
|
||||
DELETE_PAGE_ATTACHMENTS = 'delete-page-attachments',
|
||||
DELETE_AI_CHAT_ATTACHMENTS = 'delete-ai-chat-attachments',
|
||||
|
||||
DELETE_USER_AVATARS = 'delete-user-avatars',
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import { ThrottlerStorageRedisService } from '@nest-lab/throttler-storage-redis'
|
||||
import { EnvironmentService } from '../environment/environment.service';
|
||||
import { EnvironmentModule } from '../environment/environment.module';
|
||||
import { parseRedisUrl } from '../../common/helpers';
|
||||
import { AUTH_THROTTLER, AI_CHAT_THROTTLER } from './throttler-names';
|
||||
import Redis from 'ioredis';
|
||||
|
||||
@Module({
|
||||
@@ -14,7 +15,10 @@ import Redis from 'ioredis';
|
||||
const redisConfig = parseRedisUrl(environmentService.getRedisUrl());
|
||||
|
||||
return {
|
||||
throttlers: [{ name: 'auth', ttl: 60_000, limit: 10 }],
|
||||
throttlers: [
|
||||
{ name: AUTH_THROTTLER, ttl: 60_000, limit: 10 },
|
||||
{ name: AI_CHAT_THROTTLER, ttl: 60_000, limit: 25 },
|
||||
],
|
||||
errorMessage: 'Too many requests',
|
||||
storage: new ThrottlerStorageRedisService(
|
||||
new Redis({
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
export const AUTH_THROTTLER = 'auth';
|
||||
export const AI_CHAT_THROTTLER = 'ai-chat';
|
||||
@@ -0,0 +1,13 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { ThrottlerGuard } from '@nestjs/throttler';
|
||||
|
||||
type AuthedRequest = { user?: { id?: string } };
|
||||
|
||||
@Injectable()
|
||||
export class UserThrottlerGuard extends ThrottlerGuard {
|
||||
protected async getTracker(req: AuthedRequest): Promise<string> {
|
||||
const userId = req.user?.id;
|
||||
if (userId) return `user:${userId}`;
|
||||
return super.getTracker(req as Parameters<ThrottlerGuard['getTracker']>[0]);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user