Merge branch 'main' into feat/scim

This commit is contained in:
Philipinho
2026-02-02 16:43:03 +00:00
484 changed files with 27163 additions and 6019 deletions
+50 -28
View File
@@ -1,6 +1,6 @@
{
"name": "server",
"version": "0.20.4",
"version": "0.25.0-beta.1",
"description": "",
"author": "",
"private": true,
@@ -30,61 +30,84 @@
"test:e2e": "jest --config test/jest-e2e.json"
},
"dependencies": {
"@ai-sdk/google": "^3.0.9",
"@ai-sdk/openai": "^3.0.11",
"@ai-sdk/openai-compatible": "^2.0.12",
"@aws-sdk/client-s3": "3.701.0",
"@aws-sdk/lib-storage": "3.701.0",
"@aws-sdk/s3-request-presigner": "3.701.0",
"@casl/ability": "^6.7.3",
"@fastify/cookie": "^11.0.2",
"@fastify/multipart": "^9.0.3",
"@fastify/static": "^8.2.0",
"@nestjs/bullmq": "^11.0.2",
"@nestjs/common": "^11.1.3",
"@fastify/multipart": "^9.3.0",
"@fastify/static": "^8.3.0",
"@langchain/core": "1.1.13",
"@langchain/textsplitters": "1.0.1",
"@nestjs-labs/nestjs-ioredis": "^11.0.4",
"@nestjs/bullmq": "^11.0.4",
"@nestjs/common": "^11.1.11",
"@nestjs/config": "^4.0.2",
"@nestjs/core": "^11.1.3",
"@nestjs/core": "^11.1.11",
"@nestjs/event-emitter": "^3.0.1",
"@nestjs/jwt": "^11.0.0",
"@nestjs/jwt": "11.0.0",
"@nestjs/mapped-types": "^2.1.0",
"@nestjs/passport": "^11.0.5",
"@nestjs/platform-fastify": "^11.1.3",
"@nestjs/platform-socket.io": "^11.1.3",
"@nestjs/schedule": "^6.0.0",
"@nestjs/platform-fastify": "^11.1.11",
"@nestjs/platform-socket.io": "^11.1.11",
"@nestjs/schedule": "^6.1.0",
"@nestjs/terminus": "^11.0.0",
"@nestjs/websockets": "^11.1.3",
"@node-saml/passport-saml": "^5.0.1",
"@nestjs/websockets": "^11.1.11",
"@node-saml/passport-saml": "^5.1.0",
"@react-email/components": "0.0.28",
"@react-email/render": "1.0.2",
"@socket.io/redis-adapter": "^8.3.0",
"bcrypt": "^5.1.1",
"bullmq": "^5.53.2",
"ai": "^6.0.37",
"ai-sdk-ollama": "^3.1.1",
"bcrypt": "^6.0.0",
"bullmq": "^5.65.0",
"cache-manager": "^6.4.3",
"cheerio": "^1.1.0",
"cheerio": "^1.1.2",
"class-transformer": "^0.5.1",
"class-validator": "^0.14.1",
"cookie": "^1.0.2",
"fs-extra": "^11.3.0",
"happy-dom": "^15.11.6",
"jsonwebtoken": "^9.0.2",
"class-validator": "^0.14.3",
"cookie": "^1.1.1",
"fs-extra": "^11.3.3",
"happy-dom": "20.1.0",
"ioredis": "^5.4.1",
"jsonwebtoken": "^9.0.3",
"kysely": "^0.28.2",
"kysely-migration-cli": "^0.4.2",
"kysely-postgres-js": "^3.0.0",
"ldapts": "^7.4.0",
"lib0": "^0.2.117",
"mammoth": "^1.11.0",
"mime-types": "^2.1.35",
"msgpackr": "^1.11.8",
"nanoid": "3.3.11",
"nestjs-kysely": "^1.2.0",
"nodemailer": "^7.0.3",
"nestjs-pino": "^4.5.0",
"nodemailer": "^7.0.12",
"openid-client": "^5.7.1",
"otpauth": "^9.4.1",
"p-limit": "^6.2.0",
"passport-google-oauth20": "^2.0.0",
"passport-jwt": "^4.0.1",
"pg": "^8.16.0",
"pdfjs-dist": "^5.4.394",
"pg-tsquery": "^8.4.2",
"pgvector": "^0.2.1",
"postgres": "^3.4.8",
"pino-http": "^11.0.0",
"pino-pretty": "^13.1.3",
"postmark": "^4.0.5",
"react": "^18.3.1",
"reflect-metadata": "^0.2.2",
"rxjs": "^7.8.2",
"sanitize-filename-ts": "^1.0.2",
"sanitize-filename-ts": "1.0.2",
"scimmy": "1.3.5",
"socket.io": "^4.8.1",
"sharp": "0.34.3",
"socket.io": "^4.8.3",
"stripe": "^17.5.0",
"tmp-promise": "^3.0.3",
"ws": "^8.18.2",
"tseep": "^1.3.1",
"typesense": "^2.1.0",
"ws": "^8.19.0",
"yauzl": "^3.2.0"
},
"devDependencies": {
@@ -101,7 +124,6 @@
"@types/nodemailer": "^6.4.17",
"@types/passport-google-oauth20": "^2.0.16",
"@types/passport-jwt": "^4.0.1",
"@types/pg": "^8.11.11",
"@types/supertest": "^6.0.2",
"@types/ws": "^8.5.14",
"@types/yauzl": "^2.10.3",
@@ -109,7 +131,7 @@
"eslint-config-prettier": "^10.0.1",
"globals": "^15.15.0",
"jest": "^29.7.0",
"kysely-codegen": "^0.17.0",
"kysely-codegen": "^0.19.0",
"prettier": "^3.5.1",
"react-email": "3.0.2",
"source-map-support": "^0.5.21",
+7
View File
@@ -16,6 +16,9 @@ import { ExportModule } from './integrations/export/export.module';
import { ImportModule } from './integrations/import/import.module';
import { SecurityModule } from './integrations/security/security.module';
import { TelemetryModule } from './integrations/telemetry/telemetry.module';
import { RedisModule } from '@nestjs-labs/nestjs-ioredis';
import { RedisConfigService } from './integrations/redis/redis-config.service';
import { LoggerModule } from './common/logger/logger.module';
const enterpriseModules = [];
try {
@@ -33,9 +36,13 @@ try {
@Module({
imports: [
LoggerModule,
CoreModule,
DatabaseModule,
EnvironmentModule,
RedisModule.forRootAsync({
useClass: RedisConfigService,
}),
CollaborationModule,
WsModule,
QueueModule,
@@ -30,14 +30,22 @@ export class CollabWsAdapter {
return this.wss;
}
public destroy() {
public close() {
try {
this.wss.clients.forEach((client) => {
client.terminate();
});
this.wss.close();
} catch (err) {
console.error(err);
}
}
public destroy() {
try {
this.wss.close();
this.wss.clients.forEach((client) => {
client.terminate();
});
} catch (err) {
console.error(err);
}
}
}
@@ -1,10 +1,9 @@
import { Hocuspocus, Server as HocuspocusServer } from '@hocuspocus/server';
import { Hocuspocus } from '@hocuspocus/server';
import { IncomingMessage } from 'http';
import WebSocket from 'ws';
import { AuthenticationExtension } from './extensions/authentication.extension';
import { PersistenceExtension } from './extensions/persistence.extension';
import { Injectable } from '@nestjs/common';
import { Redis } from '@hocuspocus/extension-redis';
import { EnvironmentService } from '../integrations/environment/environment.service';
import {
createRetryStrategy,
@@ -12,21 +11,41 @@ import {
RedisConfig,
} from '../common/helpers';
import { LoggerExtension } from './extensions/logger.extension';
import {
RedisSyncExtension,
SerializedHTTPRequest,
} from './extensions/redis-sync';
import { WsSocketWrapper } from './extensions/redis-sync/ws-socket-wrapper';
import RedisClient from 'ioredis';
import { pack, unpack } from 'msgpackr';
import { nanoid } from 'nanoid';
import * as os from 'node:os';
import { CollabWsAdapter } from './adapter/collab-ws.adapter';
import {
CollaborationHandler,
CollabEventHandlers,
} from './collaboration.handler';
@Injectable()
export class CollaborationGateway {
private hocuspocus: Hocuspocus;
private readonly hocuspocus: Hocuspocus;
private redisConfig: RedisConfig;
// @ts-ignore
private readonly redisSync: RedisSyncExtension<CollabEventHandlers> | null =
null;
private readonly withRedis: boolean;
constructor(
private authenticationExtension: AuthenticationExtension,
private persistenceExtension: PersistenceExtension,
private loggerExtension: LoggerExtension,
private environmentService: EnvironmentService,
private collabEventsService: CollaborationHandler,
) {
this.redisConfig = parseRedisUrl(this.environmentService.getRedisUrl());
this.withRedis = !this.environmentService.isCollabDisableRedis();
this.hocuspocus = HocuspocusServer.configure({
this.hocuspocus = new Hocuspocus({
debounce: 10000,
maxDebounce: 45000,
unloadImmediately: false,
@@ -34,26 +53,80 @@ export class CollaborationGateway {
this.authenticationExtension,
this.persistenceExtension,
this.loggerExtension,
...(this.environmentService.isCollabDisableRedis()
? []
: [
new Redis({
host: this.redisConfig.host,
port: this.redisConfig.port,
options: {
password: this.redisConfig.password,
db: this.redisConfig.db,
family: this.redisConfig.family,
retryStrategy: createRetryStrategy(),
},
}),
]),
],
});
if (this.withRedis) {
// @ts-ignore
this.redisSync = new RedisSyncExtension({
redis: new RedisClient({
host: this.redisConfig.host,
port: this.redisConfig.port,
password: this.redisConfig.password,
db: this.redisConfig.db,
family: this.redisConfig.family,
retryStrategy: createRetryStrategy(),
}),
serverId: `collab-${os?.hostname()}-${nanoid(10)}`,
prefix: 'collab',
pack,
unpack,
// @ts-ignore
customEvents: this.collabEventsService.getHandlers(this.hocuspocus),
});
this.hocuspocus.configuration.extensions.push(this.redisSync);
// @ts-ignore
this.redisSync.onConfigure({ instance: this.hocuspocus });
}
}
private serializeRequest(request: IncomingMessage): SerializedHTTPRequest {
return {
method: request.method ?? 'GET',
url: request.url ?? '/',
headers: {
'sec-websocket-key': request.headers['sec-websocket-key'] ?? '',
'sec-websocket-protocol':
request.headers['sec-websocket-protocol'] ?? '',
},
socket: { remoteAddress: request.socket?.remoteAddress ?? '' },
};
}
handleConnection(client: WebSocket, request: IncomingMessage): any {
this.hocuspocus.handleConnection(client, request);
if (this.redisSync) {
const serializedHTTPRequest = this.serializeRequest(request);
const socketId = serializedHTTPRequest.headers['sec-websocket-key'];
// Create wrapper socket that only receives events via emit()
// This prevents double-handling since Hocuspocus won't listen to raw WebSocket events
const wrappedSocket = new WsSocketWrapper(client);
// Route through RedisSync extension (this calls handleConnection internally)
this.redisSync.onSocketOpen(wrappedSocket as any, serializedHTTPRequest);
// Forward raw WebSocket messages to the extension
client.on('message', (data: ArrayBuffer) => {
this.redisSync!.onSocketMessage(
wrappedSocket as any,
serializedHTTPRequest,
data,
);
});
// Forward close events
client.on('close', (code: number, reason: Buffer) => {
this.redisSync!.onSocketClose(socketId, code, reason);
});
// Forward pong events for keepalive
client.on('pong', (data: Buffer) => {
wrappedSocket.emit('pong', data);
});
} else {
// Fallback to direct Hocuspocus connection
this.hocuspocus.handleConnection(client, request);
}
}
getConnectionCount() {
@@ -64,7 +137,52 @@ export class CollaborationGateway {
return this.hocuspocus.getDocumentsCount();
}
async destroy(): Promise<void> {
await this.hocuspocus.destroy();
handleYjsEvent<TName extends keyof CollabEventHandlers>(
eventName: TName,
documentName: string,
payload: Parameters<CollabEventHandlers[TName]>[1],
) {
return this.redisSync?.handleEvent(eventName, documentName, payload);
}
openDirectConnection(documentName: string, context?: any) {
return this.hocuspocus.openDirectConnection(documentName, context);
}
/*
*Can be used before calling openDirectConnection directly
*/
async lockDocument(documentName: string) {
return this.redisSync.lockDocument(documentName);
}
/*
*Releases a document lock and stops the interval that maintains it.
*/
async releaseLock(documentName: string) {
return this.redisSync.releaseLock(documentName);
}
async destroy(collabWsAdapter: CollabWsAdapter): Promise<void> {
// eslint-disable-next-line no-async-promise-executor
await new Promise(async (resolve) => {
try {
// Wait for all documents to unload
this.hocuspocus.configuration.extensions.push({
async afterUnloadDocument({ instance }) {
if (instance.getDocumentsCount() === 0) resolve('');
},
});
collabWsAdapter?.close();
if (this.hocuspocus.getDocumentsCount() === 0) resolve('');
this.hocuspocus.closeConnections();
} catch (error) {
console.error(error);
}
});
await this.hocuspocus.hooks('onDestroy', { instance: this.hocuspocus });
}
}
@@ -0,0 +1,42 @@
import { Injectable, Logger } from '@nestjs/common';
import { Hocuspocus, Document } from '@hocuspocus/server';
export type CollabEventHandlers = ReturnType<
CollaborationHandler['getHandlers']
>;
@Injectable()
export class CollaborationHandler {
private readonly logger = new Logger(CollaborationHandler.name);
constructor() {}
getHandlers(hocuspocus: Hocuspocus) {
return {
alterState: async (documentName: string, payload: { pageId: string }) => {
// dummy
// this.logger.log('Processing', documentName, payload);
// await this.withYdocConnection(hocuspocus, documentName, {}, (doc) => {
// const fragment = doc.getXmlFragment('default');
//});
},
};
}
async withYdocConnection(
hocuspocus: Hocuspocus,
documentName: string,
context: any = {},
fn: (doc: Document) => void,
): Promise<void> {
const connection = await hocuspocus.openDirectConnection(
documentName,
context,
);
try {
await connection.transact(fn);
} finally {
await connection.disconnect();
}
}
}
@@ -9,6 +9,7 @@ import { WebSocket } from 'ws';
import { TokenModule } from '../core/auth/token.module';
import { HistoryListener } from './listeners/history.listener';
import { LoggerExtension } from './extensions/logger.extension';
import { CollaborationHandler } from './collaboration.handler';
@Module({
providers: [
@@ -17,6 +18,7 @@ import { LoggerExtension } from './extensions/logger.extension';
PersistenceExtension,
LoggerExtension,
HistoryListener,
CollaborationHandler,
],
exports: [CollaborationGateway],
imports: [TokenModule],
@@ -46,16 +48,12 @@ export class CollaborationModule implements OnModuleInit, OnModuleDestroy {
});
wss.on('error', (error) =>
this.logger.log('WebSocket server error:', error),
this.logger.error('WebSocket server error:', error),
);
}
async onModuleDestroy(): Promise<void> {
if (this.collaborationGateway) {
await this.collaborationGateway.destroy();
}
if (this.collabWsAdapter) {
this.collabWsAdapter.destroy();
}
await this.collaborationGateway?.destroy(this.collabWsAdapter);
this.collabWsAdapter?.destroy();
}
}
@@ -1,18 +1,14 @@
import { StarterKit } from '@tiptap/starter-kit';
import { TextAlign } from '@tiptap/extension-text-align';
import { TaskList } from '@tiptap/extension-task-list';
import { TaskItem } from '@tiptap/extension-task-item';
import { Underline } from '@tiptap/extension-underline';
import { Superscript } from '@tiptap/extension-superscript';
import SubScript from '@tiptap/extension-subscript';
import { Highlight } from '@tiptap/extension-highlight';
import { Typography } from '@tiptap/extension-typography';
import { TextStyle } from '@tiptap/extension-text-style';
import { Color } from '@tiptap/extension-color';
import { Youtube } from '@tiptap/extension-youtube';
import Table from '@tiptap/extension-table';
import TableHeader from '@tiptap/extension-table-header';
import { TaskList, TaskItem } from '@tiptap/extension-list';
import {
Heading,
Callout,
Comment,
CustomCodeBlock,
@@ -22,8 +18,10 @@ import {
LinkExtension,
MathBlock,
MathInline,
TableHeader,
TableCell,
TableRow,
CustomTable,
TiptapImage,
TiptapVideo,
TrailingNode,
@@ -31,25 +29,38 @@ import {
Drawio,
Excalidraw,
Embed,
Mention
Mention,
Subpages,
Highlight,
UniqueID,
addUniqueIdsToDoc,
} from '@docmost/editor-ext';
import { generateText, getSchema, JSONContent } from '@tiptap/core';
import { generateHTML } from '../common/helpers/prosemirror/html';
import { generateHTML, generateJSON } from '../common/helpers/prosemirror/html';
// @tiptap/html library works best for generating prosemirror json state but not HTML
// see: https://github.com/ueberdosis/tiptap/issues/5352
// see:https://github.com/ueberdosis/tiptap/issues/4089
import { generateJSON } from '@tiptap/html';
import { Node } from '@tiptap/pm/model';
//import { generateJSON } from '@tiptap/html';
import { Node, Schema } from '@tiptap/pm/model';
import { Logger } from '@nestjs/common';
export const tiptapExtensions = [
StarterKit.configure({
codeBlock: false,
link: false,
trailingNode: false,
heading: false,
}),
Heading,
UniqueID.configure({
types: ['heading', 'paragraph'],
}),
Comment,
TextAlign.configure({ types: ["heading", "paragraph"] }),
TextAlign.configure({ types: ['heading', 'paragraph'] }),
TaskList,
TaskItem,
Underline,
TaskItem.configure({
nested: true,
}),
LinkExtension,
Superscript,
SubScript,
@@ -63,10 +74,10 @@ export const tiptapExtensions = [
Details,
DetailsContent,
DetailsSummary,
Table,
TableHeader,
TableRow,
CustomTable,
TableCell,
TableRow,
TableHeader,
Youtube,
TiptapImage,
TiptapVideo,
@@ -76,7 +87,8 @@ export const tiptapExtensions = [
Drawio,
Excalidraw,
Embed,
Mention
Mention,
Subpages,
] as any;
export function jsonToHtml(tiptapJson: any) {
@@ -84,7 +96,14 @@ export function jsonToHtml(tiptapJson: any) {
}
export function htmlToJson(html: string) {
return generateJSON(html, tiptapExtensions);
const pmJson = generateJSON(html, tiptapExtensions);
try {
return addUniqueIdsToDoc(pmJson, tiptapExtensions);
} catch (error) {
console.warn('failed to add unique ids to doc', error);
return pmJson;
}
}
export function jsonToText(tiptapJson: JSONContent) {
@@ -92,9 +111,53 @@ export function jsonToText(tiptapJson: JSONContent) {
}
export function jsonToNode(tiptapJson: JSONContent) {
return Node.fromJSON(getSchema(tiptapExtensions), tiptapJson);
const schema = getSchema(tiptapExtensions);
try {
return Node.fromJSON(schema, tiptapJson);
} catch (error) {
if (
error instanceof RangeError &&
error.message.includes('Unknown node type')
) {
Logger.warn('Stripping unknown node types from document:', error.message);
const cleanedJson = stripUnknownNodes(tiptapJson, schema);
return Node.fromJSON(schema, cleanedJson);
}
throw error;
}
}
export function getPageId(documentName: string) {
return documentName.split('.')[1];
}
function stripUnknownNodes(
json: JSONContent,
schema: Schema,
): JSONContent | null {
if (!json || typeof json !== 'object') return json;
// Recursively clean children first, flattening any unwrapped content
if (json.content && Array.isArray(json.content)) {
const newContent: JSONContent[] = [];
for (const child of json.content) {
const cleaned = stripUnknownNodes(child, schema);
if (Array.isArray(cleaned)) {
newContent.push(...cleaned);
} else if (cleaned) {
newContent.push(cleaned);
}
}
json.content = newContent;
}
// Check if this node is unknown AFTER processing children
if (json.type && !schema.nodes[json.type]) {
// Unwrap: return cleaned children directly instead of wrapping
return (
json.content && json.content.length > 0 ? json.content : null
) as any;
}
return json;
}
@@ -46,6 +46,10 @@ export class AuthenticationExtension implements Extension {
throw new UnauthorizedException();
}
if (user.deactivatedAt || user.deletedAt) {
throw new UnauthorizedException();
}
const page = await this.pageRepo.findById(pageId);
if (!page) {
this.logger.warn(`Page not found: ${pageId}`);
@@ -65,7 +69,7 @@ export class AuthenticationExtension implements Extension {
}
if (userSpaceRole === SpaceRole.READER) {
data.connection.readOnly = true;
data.connectionConfig.readOnly = true;
this.logger.debug(`User granted readonly access to page: ${pageId}`);
}
@@ -9,11 +9,11 @@ import { Injectable, Logger } from '@nestjs/common';
export class LoggerExtension implements Extension {
private readonly logger = new Logger('Collab' + LoggerExtension.name);
async onDisconnect(data: onDisconnectPayload) {
this.logger.debug(`User disconnected from "${data.documentName}".`);
}
async afterUnloadDocument(data: onLoadDocumentPayload) {
this.logger.debug('Unloaded ' + data.documentName + ' from memory');
}
async onDisconnect(data: onDisconnectPayload) {
this.logger.debug('User disconnected from ' + data.documentName);
}
}
@@ -35,6 +35,7 @@ export class PersistenceExtension implements Extension {
@InjectKysely() private readonly db: KyselyDB,
private eventEmitter: EventEmitter2,
@InjectQueue(QueueName.GENERAL_QUEUE) private generalQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
) {}
async onLoadDocument(data: onLoadDocumentPayload) {
@@ -168,6 +169,11 @@ export class PersistenceExtension implements Extension {
workspaceId: page.workspaceId,
mentions: pageMentions,
} as IPageBacklinkJob);
await this.aiQueue.add(QueueJob.PAGE_CONTENT_UPDATED, {
pageIds: [pageId],
workspaceId: page.workspaceId,
});
}
}
@@ -0,0 +1,70 @@
import type RedisClient from 'ioredis';
import { EventEmitter } from 'tseep';
import type {
Pack,
RSAMessageClose,
RSAMessagePing,
RSAMessageSend,
} from './redis-sync.types';
export class CollabProxySocket extends EventEmitter {
private readonly replyTo: string;
private readonly serverChannel: string;
private readonly socketId: string;
private pub: RedisClient;
private readonly pack: Pack;
readyState = 1;
constructor(
pub: RedisClient,
pack: Pack,
replyTo: string,
serverChannel: string,
socketId: string,
) {
super();
this.replyTo = replyTo;
this.socketId = socketId;
this.serverChannel = serverChannel;
this.pub = pub;
this.pack = pack;
this.once('close', () => {
this.readyState = 3;
});
}
private publish(msg: RSAMessageClose | RSAMessagePing | RSAMessageSend) {
this.pub.publish(this.replyTo, this.pack(msg));
}
close(code?: number, reason?: string) {
if (this.readyState !== 1) return;
const msg: RSAMessageClose = {
type: 'close',
code,
reason,
socketId: this.socketId,
};
this.publish(msg);
}
ping() {
if (this.readyState !== 1) return;
const msg: RSAMessagePing = {
type: 'ping',
socketId: this.socketId,
replyTo: this.serverChannel,
};
this.publish(msg);
}
send(message: Uint8Array) {
if (this.readyState !== 1) return;
const msg: RSAMessageSend = {
type: 'send',
socketId: this.socketId,
message,
};
this.publish(msg);
}
}
@@ -0,0 +1,2 @@
export * from './redis-sync.extension';
export type { SerializedHTTPRequest } from './redis-sync.extension';
@@ -0,0 +1,378 @@
// Source https://github.com/ueberdosis/hocuspocus/pull/1008 - MIT
import {
Extension,
Hocuspocus,
IncomingMessage,
afterUnloadDocumentPayload,
onConfigurePayload,
onLoadDocumentPayload,
} from '@hocuspocus/server';
import RedisClient from 'ioredis';
import { readVarString } from 'lib0/decoding.js';
import { CollabProxySocket } from './collab-proxy-socket';
import {
BaseWebSocket,
Configuration,
CustomEvents,
Pack,
RSAMessage,
RSAMessageCloseProxy,
RSAMessageCustomEventComplete,
RSAMessageCustomEventStart,
RSAMessagePong,
RSAMessageProxy,
RSAMessageUnload,
SerializedHTTPRequest,
Unpack,
} from './redis-sync.types';
export type { Pack, SerializedHTTPRequest } from './redis-sync.types';
type ServerId = string;
type DocumentName = string;
type SocketId = string;
export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
priority = 1000;
private readonly pub: RedisClient;
private sub: RedisClient;
private readonly pack: Pack;
private readonly unpack: Unpack;
private originSockets: Record<SocketId, BaseWebSocket> = {};
private locks: Record<DocumentName, NodeJS.Timeout> = {};
private lockPromises: Record<DocumentName, Promise<ServerId | null>> = {};
private proxySockets: Record<SocketId, CollabProxySocket> = {};
private readonly prefix: string;
private readonly lockPrefix: string;
private readonly msgChannel: string;
private readonly serverId: ServerId;
private readonly customEventTTL: number;
private readonly lockTTL: number;
private instance!: Hocuspocus;
private readonly customEvents: TCE;
private replyIdCounter: number = 0;
// @ts-ignore
private pendingReplies: Record<number, PromiseWithResolvers<any>['resolve']> =
{};
constructor(configuration: Configuration<TCE>) {
const {
redis,
pack,
unpack,
serverId,
lockTTL,
prefix,
customEvents,
customEventTTL,
} = configuration;
this.pub = redis.duplicate();
this.sub = redis.duplicate();
this.pack = pack;
this.unpack = unpack;
this.serverId = serverId;
this.lockTTL = lockTTL ?? 10_000;
this.customEventTTL = customEventTTL ?? 30_000;
this.prefix = prefix ?? 'collab';
this.lockPrefix = `${this.prefix}Lock`;
this.msgChannel = `${this.prefix}Msg`;
this.customEvents = (customEvents as any) ?? ({} as any as CustomEvents);
this.sub.subscribe(this.msgChannel, `${this.msgChannel}:${this.serverId}`);
this.sub.on('messageBuffer', this.handleRedisMessage);
this.pub.on('error', () => {});
this.sub.on('error', () => {});
}
private getKey(documentName: string) {
return `${this.lockPrefix}:${documentName}`;
}
private closeProxy(socketId: string) {
const proxySocket = this.proxySockets[socketId];
if (proxySocket) {
proxySocket.emit(
'close',
1000,
Buffer.from('provider_initiated', 'utf-8'),
);
delete this.proxySockets[socketId];
}
}
private pongProxy(socketId: string) {
this.proxySockets[socketId]?.emit('pong');
}
private handleProxyMessage(
msg: Pick<RSAMessageProxy, 'replyTo' | 'message' | 'serializedHTTPRequest'>,
) {
const { replyTo, message, serializedHTTPRequest } = msg;
const { headers } = serializedHTTPRequest;
const socketId = headers['sec-websocket-key']!;
let socket = this.proxySockets[socketId];
if (!socket) {
socket = new CollabProxySocket(
this.pub,
this.pack,
replyTo,
`${this.msgChannel}:${this.serverId}`,
socketId,
);
this.proxySockets[socketId] = socket;
this.instance.handleConnection(
socket as any,
serializedHTTPRequest as any,
{},
);
}
socket.emit('message', message);
}
private getOrClaimLock(documentName: string) {
const lockPromise = this.pub.set(
this.getKey(documentName),
this.serverId,
'PX',
this.lockTTL,
'NX',
'GET',
);
this.lockPromises[documentName] = lockPromise;
// Briefly cache the serverId that claimed the doc to reduce load on redis
// When the claimant unloads the doc, it will send an unload message to immediately clear this
// a lockTTL / 2 guarantees stale reads < lockTTL upon server crash
setTimeout(() => {
delete this.lockPromises[documentName];
}, this.lockTTL / 2);
return lockPromise;
}
private getOrClaimLockThrottled(documentName: string) {
const existingWorkerIdPromise = this.lockPromises[documentName];
if (existingWorkerIdPromise) return existingWorkerIdPromise;
return this.getOrClaimLock(documentName);
}
private handleRedisMessage = async (
_channel: Buffer,
packedMessage: Buffer,
) => {
const msg = this.unpack(packedMessage) as RSAMessage;
const { type } = msg;
if (type === 'proxy') {
this.handleProxyMessage(msg);
return;
}
if (type === 'closeProxy') {
this.closeProxy(msg.socketId);
return;
}
if (type === 'pong') {
this.pongProxy(msg.socketId);
return;
}
if (type === 'unload') {
delete this.lockPromises[msg.documentName];
return;
}
if (type === 'customEventStart') {
const { documentName, eventName, payload, replyTo, replyId } = msg;
const res = await this.handleEventLocally(
eventName as Extract<keyof TCE, string>,
documentName,
payload,
);
const reply: RSAMessageCustomEventComplete = {
type: 'customEventComplete',
replyId,
payload: res,
};
this.pub.publish(`${replyTo}`, this.pack(reply));
return;
}
if (type === 'customEventComplete') {
const { replyId, payload } = msg;
const resolveFn = this.pendingReplies[replyId];
if (!resolveFn) return;
delete this.pendingReplies[replyId];
resolveFn(payload);
return;
}
const { socketId } = msg;
const socket = this.originSockets[socketId];
if (!socket) {
// origin socket already cleaned up
return;
}
if (type === 'close') {
socket.close(msg.code, msg.reason);
} else if (type === 'ping') {
// Reply instantly to the proxy socket, without forwarding to client
// The origin socket handles heartbeat for itself
const { replyTo, socketId } = msg;
const reply: RSAMessagePong = {
type: 'pong',
socketId,
};
this.pub.publish(`${replyTo}`, this.pack(reply));
} else if (type === 'send') {
socket.send(msg.message);
}
};
async maintainLock(documentName: string) {
this.locks[documentName] = setInterval(() => {
this.pub.set(
this.getKey(documentName),
this.serverId,
'PX',
this.lockTTL,
);
}, this.lockTTL / 2);
}
async releaseLock(documentName: string) {
clearInterval(this.locks[documentName]);
delete this.locks[documentName];
return this.pub.del(this.getKey(documentName));
}
private async handleEventLocally<TName extends Extract<keyof TCE, string>>(
eventName: TName,
documentName: string,
payload: any,
) {
const handler = this.customEvents[eventName];
if (!handler) throw new Error(`Invalid eventName: ${eventName}`);
const result = await handler(documentName, payload);
return result as Promise<ReturnType<TCE[TName]>>;
}
async handleEvent<TName extends Extract<keyof TCE, string>>(
eventName: TName,
documentName: string,
payload: any,
) {
const isDocLoadedOnInstance = this.instance.documents.has(documentName);
if (isDocLoadedOnInstance) {
return this.handleEventLocally(eventName, documentName, payload);
}
const proxyTo = await this.getOrClaimLockThrottled(documentName);
if (proxyTo && proxyTo !== this.serverId) {
++this.replyIdCounter; // bug in biome thinks this.replyIdCounter is not used if written on the line below
const replyId = this.replyIdCounter;
// another server owns the doc
const proxyMessage: RSAMessageCustomEventStart = {
eventName,
documentName,
payload,
replyTo: `${this.msgChannel}:${this.serverId}`,
replyId,
type: 'customEventStart',
};
const msg = this.pack(proxyMessage);
this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg);
// @ts-ignore
const { promise, resolve, reject } = Promise.withResolvers();
this.pendingReplies[replyId] = resolve;
setTimeout(() => {
reject('TIMEOUT');
}, this.customEventTTL);
return promise as Promise<ReturnType<TCE[TName]>>;
}
// This server owns the document, but hocuspocus hasn't loaded it yet
return this.handleEventLocally(eventName, documentName, payload);
}
async lockDocument(documentName: string) {
const proxyTo = await this.getOrClaimLockThrottled(documentName);
if (proxyTo && proxyTo !== this.serverId) {
throw new Error(`Could not lock document: ${documentName}`);
}
this.maintainLock(documentName);
return () => this.releaseLock(documentName);
}
/* WebSocket Server Hooks */
onSocketOpen(
ws: BaseWebSocket,
serializedHTTPRequest: SerializedHTTPRequest,
context = {},
) {
const socketId = serializedHTTPRequest.headers['sec-websocket-key']!;
this.originSockets[socketId] = ws;
this.instance.handleConnection(
ws as any,
serializedHTTPRequest as any,
context,
);
}
async onSocketMessage(
ws: BaseWebSocket,
serializedHTTPRequest: SerializedHTTPRequest,
detachableMsg: ArrayBuffer,
) {
const message = new Uint8Array(detachableMsg.slice());
const tmpMsg = new IncomingMessage(detachableMsg);
const documentName = readVarString(tmpMsg.decoder);
const isDocLoadedOnInstance = this.instance.documents.has(documentName);
if (isDocLoadedOnInstance) {
ws.emit('message', message);
return;
}
const proxyTo = await this.getOrClaimLockThrottled(documentName);
if (proxyTo && proxyTo !== this.serverId) {
// another server owns the doc
const proxyMessage: RSAMessageProxy = {
serializedHTTPRequest: serializedHTTPRequest,
replyTo: `${this.msgChannel}:${this.serverId}`,
message,
type: 'proxy',
};
const msg = this.pack(proxyMessage);
this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg);
return;
}
// This server owns the document, but hocuspocus hasn't loaded it yet
ws.emit('message', message);
}
onSocketClose(socketId: string, code?: number, reason?: ArrayBuffer) {
const socket = this.originSockets[socketId];
if (!socket) return;
// at this point the socket is considered GC'd and we cannot call close
// The origin socket did not set up any connections for the proxy, so none of the hooks will work if we just emit
socket?.emit('close', code, reason);
delete this.originSockets[socketId];
const msg: RSAMessageCloseProxy = { type: 'closeProxy', socketId };
this.pub.publish(this.msgChannel, this.pack(msg)).catch(() => {});
}
/* Hocuspocus hooks */
async onConfigure({ instance }: onConfigurePayload) {
this.instance = instance;
}
async onLoadDocument(data: onLoadDocumentPayload) {
const { documentName } = data;
// Refresh the lock TTL
this.maintainLock(documentName);
}
async afterUnloadDocument(data: afterUnloadDocumentPayload) {
const { documentName } = data;
this.releaseLock(documentName);
// Broadcast to cluster to immediately remove the cached redis value
const msg: RSAMessageUnload = { type: 'unload', documentName };
this.pub.publish(this.msgChannel, this.pack(msg));
}
async onDestroy() {
this.pub.disconnect(false);
this.sub.disconnect(false);
}
}
@@ -0,0 +1,121 @@
import EventEmitter from 'node:events';
import { IncomingHttpHeaders } from 'node:http2';
import RedisClient from 'ioredis';
export type SecondParam<T> = T extends (
arg1: unknown,
arg2: infer A,
...args: unknown[]
) => unknown
? A
: never;
export type SerializedHTTPRequest = {
method: string;
url: string;
headers: IncomingHttpHeaders;
socket: { remoteAddress: string };
};
export type RSAMessageProxy = {
type: 'proxy';
replyTo: string;
message: Uint8Array<ArrayBufferLike>;
serializedHTTPRequest: SerializedHTTPRequest;
};
export type RSAMessageCloseProxy = {
type: 'closeProxy';
socketId: string;
};
export type RSAMessageUnload = {
type: 'unload';
documentName: string;
};
export type RSAMessageClose = {
type: 'close';
code?: number;
reason?: string;
socketId: string;
};
export type RSAMessagePing = {
type: 'ping';
socketId: string;
replyTo: string;
};
export type RSAMessagePong = {
type: 'pong';
socketId: string;
};
export type RSAMessageSend = {
type: 'send';
// @ts-ignore
message: Uint8Array<ArrayBufferLike>;
socketId: string;
};
export type RSAMessageCustomEventStart<TName = string, TPayload = unknown> = {
type: 'customEventStart';
documentName: string;
eventName: TName;
payload: TPayload;
replyTo: string;
replyId: number;
};
export type RSAMessageCustomEventComplete = {
type: 'customEventComplete';
replyId: number;
payload: unknown;
};
export type RSAMessage =
| RSAMessageProxy
| RSAMessageCloseProxy
| RSAMessageUnload
| RSAMessageClose
| RSAMessagePing
| RSAMessagePong
| RSAMessageSend
| RSAMessageCustomEventStart
| RSAMessageCustomEventComplete;
// @ts-ignore
export type Pack = (msg: RSAMessage) => string | Buffer<ArrayBufferLike>;
export type Unpack = (
// @ts-ignore
packedMessage: Uint8Array | Buffer<ArrayBufferLike>,
) => RSAMessage;
type ServerId = string;
type DocumentName = string;
type CustomEventName = string;
export type CustomEvents = Record<
CustomEventName,
(documentName: string, payload: unknown) => Promise<unknown>
>;
export interface Configuration<TCE> {
redis: RedisClient;
pack: Pack;
unpack: Unpack;
serverId: ServerId;
lockTTL?: number;
customEventTTL?: number;
prefix?: string;
customEvents?: TCE;
}
export type BaseWebSocket = EventEmitter & {
readyState: number;
close(code?: number, reason?: string): void;
ping(): void;
send(message: Uint8Array): void;
};
@@ -0,0 +1,47 @@
import { EventEmitter } from 'events';
import type WebSocket from 'ws';
/**
* Wrapper around ws WebSocket that only receives events via emit().
* This prevents double-handling when used with RedisSyncExtension.
*/
export class WsSocketWrapper extends EventEmitter {
private ws: WebSocket;
readyState = 1;
constructor(ws: WebSocket) {
super();
this.ws = ws;
this.once('close', () => {
this.readyState = 3;
});
}
close(code?: number, reason?: string) {
if (this.readyState !== 1) return;
this.readyState = 3;
try {
this.ws.close(code, reason);
} catch (e) {
// Socket already closed
}
}
ping() {
if (this.readyState !== 1) return;
try {
this.ws.ping();
} catch (e) {
// Socket already closed
}
}
send(message: Uint8Array) {
if (this.readyState !== 1) return;
try {
this.ws.send(message);
} catch (e) {
// Socket already closed
}
}
}
@@ -8,9 +8,11 @@ import { QueueModule } from '../../integrations/queue/queue.module';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { HealthModule } from '../../integrations/health/health.module';
import { CollaborationController } from './collaboration.controller';
import { LoggerModule } from '../../common/logger/logger.module';
@Module({
imports: [
LoggerModule,
DatabaseModule,
EnvironmentModule,
CollaborationModule,
@@ -5,22 +5,27 @@ import {
NestFastifyApplication,
} from '@nestjs/platform-fastify';
import { TransformHttpResponseInterceptor } from '../../common/interceptors/http-response.interceptor';
import { InternalLogFilter } from '../../common/logger/internal-log-filter';
import { Logger } from '@nestjs/common';
import { Logger as PinoLogger } from 'nestjs-pino';
async function bootstrap() {
const app = await NestFactory.create<NestFastifyApplication>(
CollabAppModule,
new FastifyAdapter({
ignoreTrailingSlash: true,
ignoreDuplicateSlashes: true,
maxParamLength: 500,
routerOptions: {
maxParamLength: 1000,
ignoreTrailingSlash: true,
ignoreDuplicateSlashes: true,
},
}),
{
logger: new InternalLogFilter(),
logger: false,
bufferLogs: false,
},
);
app.useLogger(app.get(PinoLogger));
app.setGlobalPrefix('api', { exclude: ['/'] });
app.enableCors();
@@ -32,7 +37,8 @@ async function bootstrap() {
const logger = new Logger('CollabServer');
const port = process.env.COLLAB_PORT || 3001;
await app.listen(port, '0.0.0.0', () => {
const host = process.env.HOST || '0.0.0.0';
await app.listen(port, host, () => {
logger.log(`Listening on http://127.0.0.1:${port}`);
});
}
@@ -1,3 +1,18 @@
export enum EventName {
COLLAB_PAGE_UPDATED = 'collab.page.updated',
}
PAGE_CREATED = 'page.created',
PAGE_UPDATED = 'page.updated',
PAGE_CONTENT_UPDATED = 'page-content-updated',
PAGE_MOVED_TO_SPACE = 'page-moved-to-space',
PAGE_DELETED = 'page.deleted',
PAGE_SOFT_DELETED = 'page.soft_deleted',
PAGE_RESTORED = 'page.restored',
SPACE_CREATED = 'space.created',
SPACE_UPDATED = 'space.updated',
SPACE_DELETED = 'space.deleted',
WORKSPACE_CREATED = 'workspace.created',
WORKSPACE_UPDATED = 'workspace.updated',
WORKSPACE_DELETED = 'workspace.deleted',
}
@@ -0,0 +1,71 @@
// https://github.com/WebReflection/html-escaper
/**
* Copyright (C) 2017-present by Andrea Giammarchi - @WebReflection
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
const { replace } = '';
// escape
const es = /&(?:amp|#38|lt|#60|gt|#62|apos|#39|quot|#34);/g;
const ca = /[&<>'"]/g;
const esca = {
'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
"'": '&#39;',
'"': '&quot;',
};
const pe = (m) => esca[m];
/**
* Safely escape HTML entities such as `&`, `<`, `>`, `"`, and `'`.
* @param {string} es the input to safely escape
* @returns {string} the escaped input, and it **throws** an error if
* the input type is unexpected, except for boolean and numbers,
* converted as string.
*/
export const htmlEscape = (es) => replace.call(es, ca, pe);
// unescape
const unes = {
'&amp;': '&',
'&#38;': '&',
'&lt;': '<',
'&#60;': '<',
'&gt;': '>',
'&#62;': '>',
'&apos;': "'",
'&#39;': "'",
'&quot;': '"',
'&#34;': '"',
};
const cape = (m) => unes[m];
/**
* Safely unescape previously escaped entities such as `&`, `<`, `>`, `"`,
* and `'`.
* @param {string} un a previously escaped string
* @returns {string} the unescaped input, and it **throws** an error if
* the input type is unexpected, except for boolean and numbers,
* converted as string.
*/
export const htmlUnescape = (un) => replace.call(un, es, cape);
@@ -1,21 +1,29 @@
import { Extensions, getSchema, JSONContent } from '@tiptap/core';
import { DOMSerializer, Node } from '@tiptap/pm/model';
import { Window } from 'happy-dom';
import { type Extensions, type JSONContent, getSchema } from '@tiptap/core';
import { Node } from '@tiptap/pm/model';
import { getHTMLFromFragment } from './getHTMLFromFragment';
/**
* This function generates HTML from a ProseMirror JSON content object.
*
* @remarks **Important**: This function requires `happy-dom` to be installed in your project.
* @param doc - The ProseMirror JSON content object.
* @param extensions - The Tiptap extensions used to build the schema.
* @returns The generated HTML string.
* @example
* ```js
* const html = generateHTML(doc, extensions)
* console.log(html)
* ```
*/
export function generateHTML(doc: JSONContent, extensions: Extensions): string {
if (typeof window !== 'undefined') {
throw new Error(
'generateHTML can only be used in a Node environment\nIf you want to use this in a browser environment, use the `@tiptap/html` import instead.',
);
}
const schema = getSchema(extensions);
const contentNode = Node.fromJSON(schema, doc);
const window = new Window();
const fragment = DOMSerializer.fromSchema(schema).serializeFragment(
contentNode.content,
{
document: window.document as unknown as Document,
},
);
const serializer = new window.XMLSerializer();
// @ts-ignore
return serializer.serializeToString(fragment as unknown as Node);
return getHTMLFromFragment(contentNode, schema);
}
@@ -1,21 +1,55 @@
import { Extensions, getSchema } from '@tiptap/core';
import { DOMParser, ParseOptions } from '@tiptap/pm/model';
import type { Extensions } from '@tiptap/core';
import { getSchema } from '@tiptap/core';
import { type ParseOptions, DOMParser as PMDOMParser } from '@tiptap/pm/model';
import { Window } from 'happy-dom';
// this function does not work as intended
// it has issues with closing tags
/**
* Generates a JSON object from the given HTML string and converts it into a Prosemirror node with content.
* @remarks **Important**: This function requires `happy-dom` to be installed in your project.
* @param {string} html - The HTML string to be converted into a Prosemirror node.
* @param {Extensions} extensions - The extensions to be used for generating the schema.
* @param {ParseOptions} options - The options to be supplied to the parser.
* @returns {Promise<Record<string, any>>} - A promise with the generated JSON object.
* @example
* const html = '<p>Hello, world!</p>'
* const extensions = [...]
* const json = generateJSON(html, extensions)
* console.log(json) // { type: 'doc', content: [{ type: 'paragraph', content: [{ type: 'text', text: 'Hello, world!' }] }] }
*/
export function generateJSON(
html: string,
extensions: Extensions,
options?: ParseOptions,
): Record<string, any> {
const schema = getSchema(extensions);
if (typeof window !== 'undefined') {
throw new Error(
'generateJSON can only be used in a Node environment\nIf you want to use this in a browser environment, use the `@tiptap/html` import instead.',
);
}
const window = new Window();
const document = window.document;
document.body.innerHTML = html;
const localWindow = new Window();
const localDOMParser = new localWindow.DOMParser();
let result: Record<string, any>;
return DOMParser.fromSchema(schema)
.parse(document as never, options)
.toJSON();
try {
const schema = getSchema(extensions);
let doc: ReturnType<typeof localDOMParser.parseFromString> | null = null;
const htmlString = `<!DOCTYPE html><html><body>${html}</body></html>`;
doc = localDOMParser.parseFromString(htmlString, 'text/html');
if (!doc) {
throw new Error('Failed to parse HTML string');
}
result = PMDOMParser.fromSchema(schema)
.parse(doc.body as unknown as Node, options)
.toJSON();
} finally {
// clean up happy-dom to avoid memory leaks
localWindow.happyDOM.abort();
localWindow.happyDOM.close();
}
return result;
}
@@ -0,0 +1,54 @@
import type { Node, Schema } from '@tiptap/pm/model';
import { DOMSerializer } from '@tiptap/pm/model';
import { Window } from 'happy-dom';
/**
* Returns the HTML string representation of a given document node.
*
* @remarks **Important**: This function requires `happy-dom` to be installed in your project.
* @param doc - The document node to serialize.
* @param schema - The Prosemirror schema to use for serialization.
* @returns A promise containing the HTML string representation of the document fragment.
*
* @example
* ```typescript
* const html = getHTMLFromFragment(doc, schema)
* ```
*/
export function getHTMLFromFragment(
doc: Node,
schema: Schema,
options?: { document?: Document },
): string {
if (options?.document) {
const wrap = options.document.createElement('div');
DOMSerializer.fromSchema(schema).serializeFragment(
doc.content,
{ document: options.document },
wrap,
);
return wrap.innerHTML;
}
const localWindow = new Window();
let result: string;
try {
const fragment = DOMSerializer.fromSchema(schema).serializeFragment(
doc.content,
{
document: localWindow.document as unknown as Document,
},
);
const serializer = new localWindow.XMLSerializer();
result = serializer.serializeToString(fragment as any);
} finally {
// clean up happy-dom to avoid memory leaks
localWindow.happyDOM.abort();
localWindow.happyDOM.close();
}
return result;
}
@@ -0,0 +1,16 @@
export type ExportPageMetadata = {
pageId: string;
slugId: string;
icon: string | null;
position: string;
parentPath: string | null;
createdAt: string;
updatedAt: string;
};
export type ExportMetadata = {
exportedAt: string;
source: 'docmost';
version: string;
pages: Record<string, ExportPageMetadata>;
};
+67 -1
View File
@@ -1,6 +1,8 @@
import * as path from 'path';
import * as bcrypt from 'bcrypt';
import { sanitize } from 'sanitize-filename-ts';
import { FastifyRequest } from 'fastify';
import { Readable, Transform } from 'stream';
export const envPath = path.resolve(process.cwd(), '..', '..', '.env');
@@ -16,6 +18,12 @@ export async function comparePasswordHash(
return bcrypt.compare(plainPassword, passwordHash);
}
export function generateRandomSuffixNumbers(length: number) {
return Math.random()
.toFixed(length)
.substring(2, 2 + length);
}
export type RedisConfig = {
host: string;
port: number;
@@ -65,6 +73,64 @@ export function extractDateFromUuid7(uuid7: string) {
}
export function sanitizeFileName(fileName: string): string {
const sanitizedFilename = sanitize(fileName).replace(/ /g, '_');
const sanitizedFilename = sanitize(fileName)
.replace(/ /g, '_')
.replace(/#/g, '_');
return sanitizedFilename.slice(0, 255);
}
export function removeAccent(str: string): string {
if (!str) return str;
return str.normalize('NFD').replace(/[\u0300-\u036f]/g, '');
}
export function extractBearerTokenFromHeader(
request: FastifyRequest,
): string | undefined {
const [type, token] = request.headers.authorization?.split(' ') ?? [];
return type === 'Bearer' ? token : undefined;
}
export function hasLicenseOrEE(opts: {
licenseKey: string;
plan: string;
isCloud: boolean;
}): boolean {
const { licenseKey, plan, isCloud } = opts;
return Boolean(licenseKey) || (isCloud && plan === 'business');
}
/**
* Normalizes a database URL for postgres.js compatibility.
* - Removes `sslmode=no-verify` (not supported by postgres.js), keeps other sslmode values
* - Removes `schema` parameter (has no effect via connection string)
* Note: If we don't strip them, the connection will fail
*/
export function normalizePostgresUrl(url: string): string {
const parsed = new URL(url);
const newParams = new URLSearchParams();
for (const [key, value] of parsed.searchParams) {
if (key === 'sslmode' && value === 'no-verify') continue;
if (key === 'schema') continue;
newParams.append(key, value);
}
parsed.search = newParams.toString();
return parsed.toString();
}
export function createByteCountingStream(source: Readable) {
let bytesRead = 0;
const stream = new Transform({
transform(chunk, encoding, callback) {
bytesRead += chunk.length;
callback(null, chunk);
},
});
source.pipe(stream);
source.on('error', (err) => stream.emit('error', err));
return { stream, getBytesRead: () => bytesRead };
}
@@ -12,10 +12,14 @@ export class InternalLogFilter extends ConsoleLogger {
constructor() {
super();
this.allowedLogLevels =
process.env.NODE_ENV === 'production'
? ['log', 'error', 'fatal']
: ['log', 'debug', 'verbose', 'warn', 'error', 'fatal'];
const isProduction = process.env.NODE_ENV === 'production';
const isDebugMode = process.env.DEBUG_MODE === 'true';
if (isProduction && !isDebugMode) {
this.allowedLogLevels = ['log', 'error', 'fatal'];
} else {
this.allowedLogLevels = ['log', 'debug', 'verbose', 'warn', 'error', 'fatal'];
}
}
private isLogLevelAllowed(level: string): boolean {
@@ -0,0 +1,9 @@
import { Module } from '@nestjs/common';
import { LoggerModule as PinoLoggerModule } from 'nestjs-pino';
import { createPinoConfig } from './pino.config';
@Module({
imports: [PinoLoggerModule.forRoot(createPinoConfig())],
exports: [PinoLoggerModule],
})
export class LoggerModule {}
@@ -0,0 +1,84 @@
import { Params } from 'nestjs-pino';
import { stdTimeFunctions } from 'pino';
const CONTEXTS_TO_IGNORE = [
'InstanceLoader',
'RoutesResolver',
'RouterExplorer',
'LegacyRouteConverter',
'WebSocketsController',
];
export function createPinoConfig(): Params {
const isProduction = process.env.NODE_ENV?.toLowerCase() === 'production';
const isDebugMode = process.env.DEBUG_MODE?.toLowerCase() === 'true';
const logHttp = process.env.LOG_HTTP?.toLowerCase() === 'true';
const level = isProduction && !isDebugMode ? 'info' : 'debug';
return {
pinoHttp: {
level,
timestamp: stdTimeFunctions.isoTime,
transport: !isProduction
? {
target: 'pino-pretty',
options: {
colorize: true,
singleLine: true,
translateTime: 'SYS:standard',
ignore: 'pid,hostname',
},
}
: undefined,
formatters: {
level: (label) => ({ level: label }),
},
hooks: {
logMethod(inputArgs, method) {
if (isProduction && !isDebugMode) {
for (const arg of inputArgs) {
if (typeof arg === 'object' && arg !== null && 'context' in arg) {
const context = (arg as Record<string, unknown>)['context'];
if (typeof context === 'string' && CONTEXTS_TO_IGNORE.includes(context)) {
return;
}
}
}
}
return method.apply(this, inputArgs);
},
},
serializers: {
req: (req) => {
const forwardedFor = req.headers?.['x-forwarded-for'];
const ip =
req.headers?.['cf-connecting-ip'] ||
(typeof forwardedFor === 'string' ? forwardedFor.split(',')[0]?.trim() : undefined) ||
req.remoteAddress;
return {
method: req.method,
url: req.url,
ip,
userAgent: req.headers?.['user-agent'],
};
},
res: (res) => ({
statusCode: res.statusCode,
}),
},
customLogLevel: (_req, res, err) => {
if (res.statusCode >= 500 || err) return 'error';
if (res.statusCode >= 400) return 'warn';
return 'info';
},
autoLogging: logHttp
? {
ignore: (req) =>
req.url === '/api/health' || req.url === '/api/health/live',
}
: false,
},
};
}
@@ -0,0 +1,34 @@
// MIT - https://github.com/typestack/class-validator/pull/2626
import isISO6391Validator from 'validator/lib/isISO6391';
import { buildMessage, ValidateBy, ValidationOptions } from 'class-validator';
export const IS_ISO6391 = 'isISO6391';
/**
* Check if the string is a valid [ISO 639-1](https://en.wikipedia.org/wiki/ISO_639-1) officially assigned language code.
*/
export function isISO6391(value: unknown): boolean {
return typeof value === 'string' && isISO6391Validator(value);
}
/**
* Check if the string is a valid [ISO 639-1](https://en.wikipedia.org/wiki/ISO_639-1) officially assigned language code.
*/
export function IsISO6391(
validationOptions?: ValidationOptions,
): PropertyDecorator {
return ValidateBy(
{
name: IS_ISO6391,
validator: {
validate: (value, args): boolean => isISO6391(value),
defaultMessage: buildMessage(
(eachPrefix) =>
eachPrefix + '$property must be a valid ISO 639-1 language code',
validationOptions,
),
},
},
validationOptions,
);
}
@@ -1,12 +1,12 @@
export enum AttachmentType {
Avatar = 'avatar',
WorkspaceLogo = 'workspace-logo',
SpaceLogo = 'space-logo',
WorkspaceIcon = 'workspace-icon',
SpaceIcon = 'space-icon',
File = 'file',
}
export const validImageExtensions = ['.jpg', '.png', '.jpeg'];
export const MAX_AVATAR_SIZE = '5MB';
export const MAX_AVATAR_SIZE = '10MB';
export const inlineFileExtensions = [
'.jpg',
@@ -1,5 +1,6 @@
import {
BadRequestException,
Body,
Controller,
ForbiddenException,
Get,
@@ -50,6 +51,8 @@ import { validate as isValidUUID } from 'uuid';
import { EnvironmentService } from '../../integrations/environment/environment.service';
import { TokenService } from '../auth/services/token.service';
import { JwtAttachmentPayload, JwtType } from '../auth/dto/jwt-payload';
import * as path from 'path';
import { RemoveIconDto } from './dto/attachment.dto';
@Controller()
export class AttachmentController {
@@ -178,7 +181,9 @@ export class AttachmentController {
}
try {
const fileStream = await this.storageService.read(attachment.filePath);
const fileStream = await this.storageService.readStream(
attachment.filePath,
);
res.headers({
'Content-Type': attachment.mimeType,
'Cache-Control': 'private, max-age=3600',
@@ -238,7 +243,9 @@ export class AttachmentController {
}
try {
const fileStream = await this.storageService.read(attachment.filePath);
const fileStream = await this.storageService.readStream(
attachment.filePath,
);
res.headers({
'Content-Type': attachment.mimeType,
'Cache-Control': 'public, max-age=3600',
@@ -301,7 +308,7 @@ export class AttachmentController {
throw new BadRequestException('Invalid image attachment type');
}
if (attachmentType === AttachmentType.WorkspaceLogo) {
if (attachmentType === AttachmentType.WorkspaceIcon) {
const ability = this.workspaceAbility.createForUser(user, workspace);
if (
ability.cannot(
@@ -313,7 +320,7 @@ export class AttachmentController {
}
}
if (attachmentType === AttachmentType.SpaceLogo) {
if (attachmentType === AttachmentType.SpaceIcon) {
if (!spaceId) {
throw new BadRequestException('spaceId is required');
}
@@ -356,18 +363,74 @@ export class AttachmentController {
throw new BadRequestException('Invalid image attachment type');
}
const filenameWithoutExt = path.basename(fileName, path.extname(fileName));
if (!isValidUUID(filenameWithoutExt)) {
throw new BadRequestException('Invalid file id');
}
const filePath = `${getAttachmentFolderPath(attachmentType, workspace.id)}/${fileName}`;
try {
const fileStream = await this.storageService.read(filePath);
const fileStream = await this.storageService.readStream(filePath);
res.headers({
'Content-Type': getMimeType(filePath),
'Cache-Control': 'private, max-age=86400',
});
return res.send(fileStream);
} catch (err) {
this.logger.error(err);
// this.logger.error(err);
throw new NotFoundException('File not found');
}
}
@UseGuards(JwtAuthGuard)
@HttpCode(HttpStatus.OK)
@Post('attachments/remove-icon')
async removeIcon(
@Body() dto: RemoveIconDto,
@AuthUser() user: User,
@AuthWorkspace() workspace: Workspace,
) {
const { type, spaceId } = dto;
// remove current user avatar
if (type === AttachmentType.Avatar) {
await this.attachmentService.removeUserAvatar(user);
return;
}
// remove space icon
if (type === AttachmentType.SpaceIcon) {
if (!spaceId) {
throw new BadRequestException(
'spaceId is required to change space icons',
);
}
const spaceAbility = await this.spaceAbility.createForUser(user, spaceId);
if (
spaceAbility.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Settings)
) {
throw new ForbiddenException();
}
await this.attachmentService.removeSpaceIcon(spaceId, workspace.id);
return;
}
// remove workspace icon
if (type === AttachmentType.WorkspaceIcon) {
const ability = this.workspaceAbility.createForUser(user, workspace);
if (
ability.cannot(
WorkspaceCaslAction.Manage,
WorkspaceCaslSubject.Settings,
)
) {
throw new ForbiddenException();
}
await this.attachmentService.removeWorkspaceIcon(workspace);
return;
}
}
}
@@ -1,19 +1,21 @@
import { MultipartFile } from '@fastify/multipart';
import { randomBytes } from 'crypto';
import { sanitize } from 'sanitize-filename-ts';
import * as path from 'path';
import { AttachmentType } from './attachment.constants';
import { sanitizeFileName } from '../../common/helpers';
import * as sharp from 'sharp';
export interface PreparedFile {
buffer: Buffer;
buffer?: Buffer;
fileName: string;
fileSize: number;
fileExtension: string;
mimeType: string;
multiPartFile?: MultipartFile;
}
export async function prepareFile(
filePromise: Promise<MultipartFile>,
options: { skipBuffer?: boolean } = {},
): Promise<PreparedFile> {
const file = await filePromise;
@@ -22,12 +24,16 @@ export async function prepareFile(
}
try {
const rand = randomBytes(8).toString('hex');
let buffer: Buffer | undefined;
let fileSize = 0;
const buffer = await file.toBuffer();
const sanitizedFilename = sanitize(file.filename).replace(/ /g, '_');
if (!options.skipBuffer) {
buffer = await file.toBuffer();
fileSize = buffer.length;
}
const sanitizedFilename = sanitizeFileName(file.filename);
const fileName = sanitizedFilename.slice(0, 255);
const fileSize = buffer.length;
const fileExtension = path.extname(file.filename).toLowerCase();
return {
@@ -36,6 +42,7 @@ export async function prepareFile(
fileSize,
fileExtension,
mimeType: file.mimetype,
multiPartFile: file,
};
} catch (error) {
throw error;
@@ -58,9 +65,9 @@ export function getAttachmentFolderPath(
switch (type) {
case AttachmentType.Avatar:
return `${workspaceId}/avatars`;
case AttachmentType.WorkspaceLogo:
return `${workspaceId}/workspace-logo`;
case AttachmentType.SpaceLogo:
case AttachmentType.WorkspaceIcon:
return `${workspaceId}/workspace-logos`;
case AttachmentType.SpaceIcon:
return `${workspaceId}/space-logos`;
case AttachmentType.File:
return `${workspaceId}/files`;
@@ -70,3 +77,51 @@ export function getAttachmentFolderPath(
}
export const validAttachmentTypes = Object.values(AttachmentType);
export async function compressAndResizeIcon(
buffer: Buffer,
attachmentType?: AttachmentType,
): Promise<Buffer> {
try {
let sharpInstance = sharp(buffer);
const metadata = await sharpInstance.metadata();
const targetWidth = 300;
const targetHeight = 300;
// Only resize if image is larger than target dimensions
if (metadata.width > targetWidth || metadata.height > targetHeight) {
sharpInstance = sharpInstance.resize(targetWidth, targetHeight, {
fit: 'inside',
withoutEnlargement: true,
});
}
// Handle based on original format
if (metadata.format === 'png') {
// Only flatten avatars to remove transparency
if (attachmentType === AttachmentType.Avatar) {
sharpInstance = sharpInstance.flatten({
background: { r: 255, g: 255, b: 255 },
});
}
return await sharpInstance
.png({
quality: 85,
compressionLevel: 6,
})
.toBuffer();
} else {
return await sharpInstance
.jpeg({
quality: 85,
progressive: true,
mozjpeg: true,
})
.toBuffer();
}
} catch (err) {
throw err;
}
}
@@ -0,0 +1,17 @@
import { IsEnum, IsIn, IsNotEmpty, IsOptional, IsUUID } from 'class-validator';
import { AttachmentType } from '../attachment.constants';
export class RemoveIconDto {
@IsEnum(AttachmentType)
@IsIn([
AttachmentType.Avatar,
AttachmentType.SpaceIcon,
AttachmentType.WorkspaceIcon,
])
@IsNotEmpty()
type: AttachmentType;
@IsOptional()
@IsUUID()
spaceId: string;
}
@@ -1,3 +0,0 @@
import { IsOptional, IsString, IsUUID } from 'class-validator';
export class AvatarUploadDto {}
@@ -1,20 +0,0 @@
import {
IsDefined,
IsNotEmpty,
IsOptional,
IsString,
IsUUID,
} from 'class-validator';
export class UploadFileDto {
@IsString()
@IsNotEmpty()
attachmentType: string;
@IsOptional()
@IsUUID()
pageId: string;
@IsDefined()
file: any;
}
@@ -3,16 +3,19 @@ import { OnWorkerEvent, Processor, WorkerHost } from '@nestjs/bullmq';
import { Job } from 'bullmq';
import { AttachmentService } from '../services/attachment.service';
import { QueueJob, QueueName } from 'src/integrations/queue/constants';
import { Space } from '@docmost/db/types/entity.types';
import { ModuleRef } from '@nestjs/core';
@Processor(QueueName.ATTACHMENT_QUEUE)
export class AttachmentProcessor extends WorkerHost implements OnModuleDestroy {
private readonly logger = new Logger(AttachmentProcessor.name);
constructor(private readonly attachmentService: AttachmentService) {
constructor(
private readonly attachmentService: AttachmentService,
private moduleRef: ModuleRef,
) {
super();
}
async process(job: Job<Space, void>): Promise<void> {
async process(job: Job<any, void>): Promise<void> {
try {
if (job.name === QueueJob.DELETE_SPACE_ATTACHMENTS) {
await this.attachmentService.handleDeleteSpaceAttachments(job.data.id);
@@ -20,6 +23,38 @@ export class AttachmentProcessor extends WorkerHost implements OnModuleDestroy {
if (job.name === QueueJob.DELETE_USER_AVATARS) {
await this.attachmentService.handleDeleteUserAvatars(job.data.id);
}
if (job.name === QueueJob.DELETE_PAGE_ATTACHMENTS) {
await this.attachmentService.handleDeletePageAttachments(
job.data.pageId,
);
}
if (
job.name === QueueJob.ATTACHMENT_INDEX_CONTENT ||
job.name === QueueJob.ATTACHMENT_INDEXING
) {
let AttachmentEeModule: any;
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
AttachmentEeModule = require('./../../../ee/attachments-ee/attachment-ee.service');
} catch (err) {
this.logger.debug(
'Attachment enterprise module requested but EE module not bundled in this build',
);
return;
}
const attachmentEeService = this.moduleRef.get(
AttachmentEeModule.AttachmentEeService,
{ strict: false },
);
if (job.name === QueueJob.ATTACHMENT_INDEX_CONTENT) {
await attachmentEeService.indexAttachment(job.data.attachmentId);
} else if (job.name === QueueJob.ATTACHMENT_INDEXING) {
await attachmentEeService.indexAttachments(
job.data.workspaceId,
);
}
}
} catch (err) {
throw err;
}
@@ -32,9 +67,15 @@ export class AttachmentProcessor extends WorkerHost implements OnModuleDestroy {
@OnWorkerEvent('failed')
onError(job: Job) {
this.logger.error(
`Error processing ${job.name} job. Reason: ${job.failedReason}`,
);
if (job.name === QueueJob.ATTACHMENT_INDEX_CONTENT) {
this.logger.debug(
`Error processing ${job.name} job for attachment ${job.data?.attachmentId}. Reason: ${job.failedReason}`,
);
} else {
this.logger.error(
`Error processing ${job.name} job. Reason: ${job.failedReason}`,
);
}
}
@OnWorkerEvent('completed')
@@ -4,9 +4,11 @@ import {
Logger,
NotFoundException,
} from '@nestjs/common';
import { Readable } from 'stream';
import { StorageService } from '../../../integrations/storage/storage.service';
import { MultipartFile } from '@fastify/multipart';
import {
compressAndResizeIcon,
getAttachmentFolderPath,
PreparedFile,
prepareFile,
@@ -16,12 +18,16 @@ import { v4 as uuid4, v7 as uuid7 } from 'uuid';
import { AttachmentRepo } from '@docmost/db/repos/attachment/attachment.repo';
import { AttachmentType, validImageExtensions } from '../attachment.constants';
import { KyselyDB, KyselyTransaction } from '@docmost/db/types/kysely.types';
import { Attachment } from '@docmost/db/types/entity.types';
import { Attachment, User, Workspace } from '@docmost/db/types/entity.types';
import { InjectKysely } from 'nestjs-kysely';
import { executeTx } from '@docmost/db/utils';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo';
import { SpaceRepo } from '@docmost/db/repos/space/space.repo';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { createByteCountingStream } from '../../../common/helpers/utils';
@Injectable()
export class AttachmentService {
@@ -33,6 +39,7 @@ export class AttachmentService {
private readonly workspaceRepo: WorkspaceRepo,
private readonly spaceRepo: SpaceRepo,
@InjectKysely() private readonly db: KyselyDB,
@InjectQueue(QueueName.ATTACHMENT_QUEUE) private attachmentQueue: Queue,
) {}
async uploadFile(opts: {
@@ -44,7 +51,9 @@ export class AttachmentService {
attachmentId?: string;
}) {
const { filePromise, pageId, spaceId, userId, workspaceId } = opts;
const preparedFile: PreparedFile = await prepareFile(filePromise);
const preparedFile: PreparedFile = await prepareFile(filePromise, {
skipBuffer: true,
});
let isUpdate = false;
let attachmentId = null;
@@ -76,7 +85,14 @@ export class AttachmentService {
const filePath = `${getAttachmentFolderPath(AttachmentType.File, workspaceId)}/${attachmentId}/${preparedFile.fileName}`;
await this.uploadToDrive(filePath, preparedFile.buffer);
const { stream, getBytesRead } = createByteCountingStream(
preparedFile.multiPartFile.file,
);
await this.uploadToDrive(filePath, stream);
// Update fileSize from the consumed stream
preparedFile.fileSize = getBytesRead();
let attachment: Attachment = null;
try {
@@ -99,6 +115,23 @@ export class AttachmentService {
pageId,
});
}
// Only index PDFs and DOCX files
if (['.pdf', '.docx'].includes(attachment.fileExt.toLowerCase())) {
await this.attachmentQueue.add(
QueueJob.ATTACHMENT_INDEX_CONTENT,
{
attachmentId: attachmentId,
},
{
attempts: 2,
backoff: {
type: 'exponential',
delay: 10000,
},
},
);
}
} catch (err) {
// delete uploaded file on error
this.logger.error(err);
@@ -111,8 +144,8 @@ export class AttachmentService {
filePromise: Promise<MultipartFile>,
type:
| AttachmentType.Avatar
| AttachmentType.WorkspaceLogo
| AttachmentType.SpaceLogo,
| AttachmentType.WorkspaceIcon
| AttachmentType.SpaceIcon,
userId: string,
workspaceId: string,
spaceId?: string,
@@ -120,6 +153,12 @@ export class AttachmentService {
const preparedFile: PreparedFile = await prepareFile(filePromise);
validateFileType(preparedFile.fileExtension, validImageExtensions);
const processedBuffer = await compressAndResizeIcon(
preparedFile.buffer,
type,
);
preparedFile.buffer = processedBuffer;
preparedFile.fileSize = processedBuffer.length;
preparedFile.fileName = uuid4() + preparedFile.fileExtension;
const filePath = `${getAttachmentFolderPath(type, workspaceId)}/${preparedFile.fileName}`;
@@ -153,7 +192,7 @@ export class AttachmentService {
workspaceId,
trx,
);
} else if (type === AttachmentType.WorkspaceLogo) {
} else if (type === AttachmentType.WorkspaceIcon) {
const workspace = await this.workspaceRepo.findById(workspaceId, {
trx,
});
@@ -165,7 +204,7 @@ export class AttachmentService {
workspaceId,
trx,
);
} else if (type === AttachmentType.SpaceLogo && spaceId) {
} else if (type === AttachmentType.SpaceIcon && spaceId) {
const space = await this.spaceRepo.findById(spaceId, workspaceId, {
trx,
});
@@ -184,7 +223,6 @@ export class AttachmentService {
});
} catch (err) {
// delete uploaded file on db update failure
this.logger.error('Image upload error:', err);
await this.deleteRedundantFile(filePath);
throw new BadRequestException('Failed to upload image');
}
@@ -208,9 +246,9 @@ export class AttachmentService {
}
}
async uploadToDrive(filePath: string, fileBuffer: any) {
async uploadToDrive(filePath: string, fileContent: Buffer | Readable) {
try {
await this.storageService.upload(filePath, fileBuffer);
await this.storageService.upload(filePath, fileContent);
} catch (err) {
this.logger.error('Error uploading file to drive:', err);
throw new BadRequestException('Error uploading file to drive');
@@ -321,4 +359,87 @@ export class AttachmentService {
throw err;
}
}
async handleDeletePageAttachments(pageId: string) {
try {
// Fetch attachments for this page from database
const attachments = await this.db
.selectFrom('attachments')
.select(['id', 'filePath'])
.where('pageId', '=', pageId)
.execute();
if (!attachments || attachments.length === 0) {
return;
}
const failedDeletions = [];
await Promise.all(
attachments.map(async (attachment) => {
try {
// Delete from storage
await this.storageService.delete(attachment.filePath);
// Delete from database
await this.attachmentRepo.deleteAttachmentById(attachment.id);
} catch (err) {
failedDeletions.push(attachment.id);
this.logger.error(
`Failed to delete attachment ${attachment.id} for page ${pageId}:`,
err,
);
}
}),
);
if (failedDeletions.length > 0) {
this.logger.warn(
`Failed to delete ${failedDeletions.length} attachments for page ${pageId}`,
);
}
} catch (err) {
this.logger.error(
`Error in handleDeletePageAttachments for page ${pageId}:`,
err,
);
throw err;
}
}
async removeUserAvatar(user: User) {
if (user.avatarUrl && !user.avatarUrl.toLowerCase().startsWith('http')) {
const filePath = `${getAttachmentFolderPath(AttachmentType.Avatar, user.workspaceId)}/${user.avatarUrl}`;
await this.deleteRedundantFile(filePath);
}
await this.userRepo.updateUser(
{ avatarUrl: null },
user.id,
user.workspaceId,
);
}
async removeSpaceIcon(spaceId: string, workspaceId: string) {
const space = await this.spaceRepo.findById(spaceId, workspaceId);
if (!space) {
throw new NotFoundException('Space not found');
}
if (space.logo && !space.logo.toLowerCase().startsWith('http')) {
const filePath = `${getAttachmentFolderPath(AttachmentType.SpaceIcon, workspaceId)}/${space.logo}`;
await this.deleteRedundantFile(filePath);
}
await this.spaceRepo.updateSpace({ logo: null }, spaceId, workspaceId);
}
async removeWorkspaceIcon(workspace: Workspace) {
if (workspace.logo && !workspace.logo.toLowerCase().startsWith('http')) {
const filePath = `${getAttachmentFolderPath(AttachmentType.WorkspaceIcon, workspace.id)}/${workspace.logo}`;
await this.deleteRedundantFile(filePath);
}
await this.workspaceRepo.updateWorkspace({ logo: null }, workspace.id);
}
}
+60 -8
View File
@@ -1,13 +1,12 @@
import {
BadRequestException,
Body,
Controller,
HttpCode,
HttpStatus,
Post,
Req,
Res,
UseGuards,
Logger,
} from '@nestjs/common';
import { LoginDto } from './dto/login.dto';
import { AuthService } from './services/auth.service';
@@ -23,14 +22,17 @@ import { ForgotPasswordDto } from './dto/forgot-password.dto';
import { PasswordResetDto } from './dto/password-reset.dto';
import { VerifyUserTokenDto } from './dto/verify-user-token.dto';
import { FastifyReply } from 'fastify';
import { addDays } from 'date-fns';
import { validateSsoEnforcement } from './auth.util';
import { ModuleRef } from '@nestjs/core';
@Controller('auth')
export class AuthController {
private readonly logger = new Logger(AuthController.name);
constructor(
private authService: AuthService,
private environmentService: EnvironmentService,
private moduleRef: ModuleRef,
) {}
@HttpCode(HttpStatus.OK)
@@ -42,6 +44,45 @@ export class AuthController {
) {
validateSsoEnforcement(workspace);
let MfaModule: any;
let isMfaModuleReady = false;
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
MfaModule = require('./../../ee/mfa/services/mfa.service');
isMfaModuleReady = true;
} catch (err) {
this.logger.debug(
'MFA module requested but EE module not bundled in this build',
);
isMfaModuleReady = false;
}
if (isMfaModuleReady) {
const mfaService = this.moduleRef.get(MfaModule.MfaService, {
strict: false,
});
const mfaResult = await mfaService.checkMfaRequirements(
loginInput,
workspace,
res,
);
if (mfaResult) {
// If user has MFA enabled OR workspace enforces MFA, require MFA verification
if (mfaResult.userHasMfa || mfaResult.requiresMfaSetup) {
return {
userHasMfa: mfaResult.userHasMfa,
requiresMfaSetup: mfaResult.requiresMfaSetup,
isMfaEnforced: mfaResult.isMfaEnforced,
};
} else if (mfaResult.authToken) {
// User doesn't have MFA and workspace doesn't require it
this.setAuthCookie(res, mfaResult.authToken);
return;
}
}
}
const authToken = await this.authService.login(loginInput, workspace.id);
this.setAuthCookie(res, authToken);
}
@@ -88,11 +129,22 @@ export class AuthController {
@Body() passwordResetDto: PasswordResetDto,
@AuthWorkspace() workspace: Workspace,
) {
const authToken = await this.authService.passwordReset(
const result = await this.authService.passwordReset(
passwordResetDto,
workspace.id,
workspace,
);
this.setAuthCookie(res, authToken);
if (result.requiresLogin) {
return {
requiresLogin: true,
};
}
// Set auth cookie if no MFA is required
this.setAuthCookie(res, result.authToken);
return {
requiresLogin: false,
};
}
@HttpCode(HttpStatus.OK)
@@ -111,7 +163,7 @@ export class AuthController {
@AuthUser() user: User,
@AuthWorkspace() workspace: Workspace,
) {
return this.authService.getCollabToken(user.id, workspace.id);
return this.authService.getCollabToken(user, workspace.id);
}
@UseGuards(JwtAuthGuard)
@@ -125,7 +177,7 @@ export class AuthController {
res.setCookie('authToken', token, {
httpOnly: true,
path: '/',
expires: addDays(new Date(), 30),
expires: this.environmentService.getCookieExpiresIn(),
secure: this.environmentService.isHttps(),
});
}
+13
View File
@@ -6,3 +6,16 @@ export function validateSsoEnforcement(workspace: Workspace) {
throw new BadRequestException('This workspace has enforced SSO login.');
}
}
export function validateAllowedEmail(userEmail: string, workspace: Workspace) {
const emailParts = userEmail.split('@');
const emailDomain = emailParts[1].toLowerCase();
if (
workspace.emailDomains?.length > 0 &&
!workspace.emailDomains.includes(emailDomain)
) {
throw new BadRequestException(
`The email domain "${emailDomain}" is not approved for this workspace.`,
);
}
}
@@ -1,6 +1,12 @@
import { IsNotEmpty, IsString, MaxLength, MinLength } from 'class-validator';
import {
IsNotEmpty,
IsOptional,
IsString,
MaxLength,
MinLength,
} from 'class-validator';
import { CreateUserDto } from './create-user.dto';
import {Transform, TransformFnParams} from "class-transformer";
import { Transform, TransformFnParams } from 'class-transformer';
export class CreateAdminUserDto extends CreateUserDto {
@IsNotEmpty()
@@ -9,10 +15,17 @@ export class CreateAdminUserDto extends CreateUserDto {
@Transform(({ value }: TransformFnParams) => value?.trim())
name: string;
@IsNotEmpty()
@MinLength(3)
@IsOptional()
@MinLength(1)
@MaxLength(50)
@IsString()
@Transform(({ value }: TransformFnParams) => value?.trim())
workspaceName: string;
@IsOptional()
@MinLength(4)
@MaxLength(50)
@IsString()
@Transform(({ value }: TransformFnParams) => value?.trim())
hostname?: string;
}
@@ -3,6 +3,8 @@ export enum JwtType {
COLLAB = 'collab',
EXCHANGE = 'exchange',
ATTACHMENT = 'attachment',
MFA_TOKEN = 'mfa_token',
API_KEY = 'api_key',
}
export type JwtPayload = {
sub: string;
@@ -30,3 +32,15 @@ export type JwtAttachmentPayload = {
type: 'attachment';
};
export interface JwtMfaTokenPayload {
sub: string;
workspaceId: string;
type: 'mfa_token';
}
export type JwtApiKeyPayload = {
sub: string;
workspaceId: string;
apiKeyId: string;
type: 'api_key';
};
@@ -22,7 +22,7 @@ import { ForgotPasswordDto } from '../dto/forgot-password.dto';
import ForgotPasswordEmail from '@docmost/transactional/emails/forgot-password-email';
import { UserTokenRepo } from '@docmost/db/repos/user-token/user-token.repo';
import { PasswordResetDto } from '../dto/password-reset.dto';
import { UserToken, Workspace } from '@docmost/db/types/entity.types';
import { User, UserToken, Workspace } from '@docmost/db/types/entity.types';
import { UserTokenType } from '../auth.constants';
import { KyselyDB } from '@docmost/db/types/kysely.types';
import { InjectKysely } from 'nestjs-kysely';
@@ -47,7 +47,7 @@ export class AuthService {
includePassword: true,
});
const errorMessage = 'email or password does not match';
const errorMessage = 'Email or password does not match';
if (!user || user?.deletedAt) {
throw new UnauthorizedException(errorMessage);
}
@@ -106,6 +106,7 @@ export class AuthService {
await this.userRepo.updateUser(
{
password: newPasswordHash,
hasGeneratedPassword: false,
},
userId,
workspaceId,
@@ -156,10 +157,13 @@ export class AuthService {
});
}
async passwordReset(passwordResetDto: PasswordResetDto, workspaceId: string) {
async passwordReset(
passwordResetDto: PasswordResetDto,
workspace: Workspace,
) {
const userToken = await this.userTokenRepo.findById(
passwordResetDto.token,
workspaceId,
workspace.id,
);
if (
@@ -170,7 +174,9 @@ export class AuthService {
throw new BadRequestException('Invalid or expired token');
}
const user = await this.userRepo.findById(userToken.userId, workspaceId);
const user = await this.userRepo.findById(userToken.userId, workspace.id, {
includeUserMfa: true,
});
if (!user || user.deletedAt) {
throw new NotFoundException('User not found');
}
@@ -181,9 +187,10 @@ export class AuthService {
await this.userRepo.updateUser(
{
password: newPasswordHash,
hasGeneratedPassword: false,
},
user.id,
workspaceId,
workspace.id,
trx,
);
@@ -201,7 +208,18 @@ export class AuthService {
template: emailTemplate,
});
return this.tokenService.generateAccessToken(user);
// Check if user has MFA enabled or workspace enforces MFA
const userHasMfa = user?.['mfa']?.isEnabled || false;
const workspaceEnforcesMfa = workspace.enforceMfa || false;
if (userHasMfa || workspaceEnforcesMfa) {
return {
requiresLogin: true,
};
}
const authToken = await this.tokenService.generateAccessToken(user);
return { authToken };
}
async verifyUserToken(
@@ -222,9 +240,9 @@ export class AuthService {
}
}
async getCollabToken(userId: string, workspaceId: string) {
async getCollabToken(user: User, workspaceId: string) {
const token = await this.tokenService.generateCollabToken(
userId,
user,
workspaceId,
);
return { token };
@@ -92,7 +92,8 @@ export class SignupService {
// create workspace with full setup
const workspaceData: CreateWorkspaceDto = {
name: createAdminUserDto.workspaceName,
name: createAdminUserDto.workspaceName || 'My workspace',
hostname: createAdminUserDto.hostname,
};
workspace = await this.workspaceService.create(
@@ -6,9 +6,11 @@ import {
import { JwtService } from '@nestjs/jwt';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
import {
JwtApiKeyPayload,
JwtAttachmentPayload,
JwtCollabPayload,
JwtExchangePayload,
JwtMfaTokenPayload,
JwtPayload,
JwtType,
} from '../dto/jwt-payload';
@@ -22,7 +24,7 @@ export class TokenService {
) {}
async generateAccessToken(user: User): Promise<string> {
if (user.deletedAt) {
if (user.deactivatedAt || user.deletedAt) {
throw new ForbiddenException();
}
@@ -35,12 +37,13 @@ export class TokenService {
return this.jwtService.sign(payload);
}
async generateCollabToken(
userId: string,
workspaceId: string,
): Promise<string> {
async generateCollabToken(user: User, workspaceId: string): Promise<string> {
if (user.deactivatedAt || user.deletedAt) {
throw new ForbiddenException();
}
const payload: JwtCollabPayload = {
sub: userId,
sub: user.id,
workspaceId,
type: JwtType.COLLAB,
};
@@ -75,6 +78,40 @@ export class TokenService {
return this.jwtService.sign(payload, { expiresIn: '1h' });
}
async generateMfaToken(user: User, workspaceId: string): Promise<string> {
if (user.deactivatedAt || user.deletedAt) {
throw new ForbiddenException();
}
const payload: JwtMfaTokenPayload = {
sub: user.id,
workspaceId,
type: JwtType.MFA_TOKEN,
};
return this.jwtService.sign(payload, { expiresIn: '5m' });
}
async generateApiToken(opts: {
apiKeyId: string;
user: User;
workspaceId: string;
expiresIn?: string | number;
}): Promise<string> {
const { apiKeyId, user, workspaceId, expiresIn } = opts;
if (user.deactivatedAt || user.deletedAt) {
throw new ForbiddenException();
}
const payload: JwtApiKeyPayload = {
sub: user.id,
apiKeyId: apiKeyId,
workspaceId,
type: JwtType.API_KEY,
};
return this.jwtService.sign(payload, expiresIn ? { expiresIn } : {});
}
async verifyJwt(token: string, tokenType: string) {
const payload = await this.jwtService.verifyAsync(token, {
secret: this.environmentService.getAppSecret(),
@@ -2,10 +2,12 @@ import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport';
import { Strategy } from 'passport-jwt';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
import { JwtPayload, JwtType } from '../dto/jwt-payload';
import { JwtApiKeyPayload, JwtPayload, JwtType } from '../dto/jwt-payload';
import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
import { FastifyRequest } from 'fastify';
import { extractBearerTokenFromHeader } from '../../../common/helpers';
import { ModuleRef } from '@nestjs/core';
@Injectable()
export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
@@ -15,10 +17,11 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
private userRepo: UserRepo,
private workspaceRepo: WorkspaceRepo,
private readonly environmentService: EnvironmentService,
private moduleRef: ModuleRef,
) {
super({
jwtFromRequest: (req: FastifyRequest) => {
return req.cookies?.authToken || this.extractTokenFromHeader(req);
return req.cookies?.authToken || extractBearerTokenFromHeader(req);
},
ignoreExpiration: false,
secretOrKey: environmentService.getAppSecret(),
@@ -26,8 +29,8 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
});
}
async validate(req: any, payload: JwtPayload) {
if (!payload.workspaceId || payload.type !== JwtType.ACCESS) {
async validate(req: any, payload: JwtPayload | JwtApiKeyPayload) {
if (!payload.workspaceId) {
throw new UnauthorizedException();
}
@@ -35,6 +38,14 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
throw new UnauthorizedException('Workspace does not match');
}
if (payload.type === JwtType.API_KEY) {
return this.validateApiKey(req, payload as JwtApiKeyPayload);
}
if (payload.type !== JwtType.ACCESS) {
throw new UnauthorizedException();
}
const workspace = await this.workspaceRepo.findById(payload.workspaceId);
if (!workspace) {
@@ -42,15 +53,36 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
}
const user = await this.userRepo.findById(payload.sub, payload.workspaceId);
if (!user || user.deletedAt) {
if (!user || user.deactivatedAt || user.deletedAt) {
throw new UnauthorizedException();
}
return { user, workspace };
}
private extractTokenFromHeader(request: FastifyRequest): string | undefined {
const [type, token] = request.headers.authorization?.split(' ') ?? [];
return type === 'Bearer' ? token : undefined;
private async validateApiKey(req: any, payload: JwtApiKeyPayload) {
let ApiKeyModule: any;
let isApiKeyModuleReady = false;
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
ApiKeyModule = require('./../../../ee/api-key/api-key.service');
isApiKeyModuleReady = true;
} catch (err) {
this.logger.debug(
'API Key module requested but enterprise module not bundled in this build',
);
isApiKeyModuleReady = false;
}
if (isApiKeyModuleReady) {
const ApiKeyService = this.moduleRef.get(ApiKeyModule.ApiKeyService, {
strict: false,
});
return ApiKeyService.validateApiKey(payload);
}
throw new UnauthorizedException('Enterprise API Key module missing');
}
}
@@ -40,6 +40,7 @@ function buildWorkspaceOwnerAbility() {
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Group);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Member);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Attachment);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.API);
return build();
}
@@ -55,6 +56,7 @@ function buildWorkspaceAdminAbility() {
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Group);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Member);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Attachment);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.API);
return build();
}
@@ -68,6 +70,7 @@ function buildWorkspaceMemberAbility() {
can(WorkspaceCaslAction.Read, WorkspaceCaslSubject.Space);
can(WorkspaceCaslAction.Read, WorkspaceCaslSubject.Group);
can(WorkspaceCaslAction.Manage, WorkspaceCaslSubject.Attachment);
can(WorkspaceCaslAction.Create, WorkspaceCaslSubject.API);
return build();
}
@@ -11,6 +11,7 @@ export enum WorkspaceCaslSubject {
Space = 'space',
Group = 'group',
Attachment = 'attachment',
API = 'api_key',
}
export type IWorkspaceAbility =
@@ -18,4 +19,5 @@ export type IWorkspaceAbility =
| [WorkspaceCaslAction, WorkspaceCaslSubject.Member]
| [WorkspaceCaslAction, WorkspaceCaslSubject.Space]
| [WorkspaceCaslAction, WorkspaceCaslSubject.Group]
| [WorkspaceCaslAction, WorkspaceCaslSubject.Attachment];
| [WorkspaceCaslAction, WorkspaceCaslSubject.Attachment]
| [WorkspaceCaslAction, WorkspaceCaslSubject.API];
@@ -43,7 +43,7 @@ export class CommentController {
@AuthWorkspace() workspace: Workspace,
) {
const page = await this.pageRepo.findById(createCommentDto.pageId);
if (!page) {
if (!page || page.deletedAt) {
throw new NotFoundException('Page not found');
}
@@ -53,9 +53,11 @@ export class CommentController {
}
return this.commentService.create(
user.id,
page.id,
workspace.id,
{
userId: user.id,
page,
workspaceId: workspace.id,
},
createCommentDto,
);
}
@@ -67,7 +69,6 @@ export class CommentController {
@Body()
pagination: PaginationOptions,
@AuthUser() user: User,
// @AuthWorkspace() workspace: Workspace,
) {
const page = await this.pageRepo.findById(input.pageId);
if (!page) {
@@ -89,12 +90,10 @@ export class CommentController {
throw new NotFoundException('Comment not found');
}
const page = await this.pageRepo.findById(comment.pageId);
if (!page) {
throw new NotFoundException('Page not found');
}
const ability = await this.spaceAbility.createForUser(user, page.spaceId);
const ability = await this.spaceAbility.createForUser(
user,
comment.spaceId,
);
if (ability.cannot(SpaceCaslAction.Read, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
@@ -103,19 +102,76 @@ export class CommentController {
@HttpCode(HttpStatus.OK)
@Post('update')
update(@Body() updateCommentDto: UpdateCommentDto, @AuthUser() user: User) {
//TODO: only comment creators can update their comments
return this.commentService.update(
updateCommentDto.commentId,
updateCommentDto,
async update(@Body() dto: UpdateCommentDto, @AuthUser() user: User) {
const comment = await this.commentRepo.findById(dto.commentId);
if (!comment) {
throw new NotFoundException('Comment not found');
}
const ability = await this.spaceAbility.createForUser(
user,
comment.spaceId,
);
// must be a space member with edit permission
if (ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Page)) {
throw new ForbiddenException(
'You must have space edit permission to edit comments',
);
}
return this.commentService.update(comment, dto, user);
}
@HttpCode(HttpStatus.OK)
@Post('delete')
remove(@Body() input: CommentIdDto, @AuthUser() user: User) {
// TODO: only comment creators and admins can delete their comments
return this.commentService.remove(input.commentId, user);
async delete(@Body() input: CommentIdDto, @AuthUser() user: User) {
const comment = await this.commentRepo.findById(input.commentId);
if (!comment) {
throw new NotFoundException('Comment not found');
}
const ability = await this.spaceAbility.createForUser(
user,
comment.spaceId,
);
// must be a space member with edit permission
if (ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
// Check if user is the comment owner
const isOwner = comment.creatorId === user.id;
if (isOwner) {
/*
// Check if comment has children from other users
const hasChildrenFromOthers =
await this.commentRepo.hasChildrenFromOtherUsers(comment.id, user.id);
// Owner can delete if no children from other users
if (!hasChildrenFromOthers) {
await this.commentRepo.deleteComment(comment.id);
return;
}
// If has children from others, only space admin can delete
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Settings)) {
throw new ForbiddenException(
'Only space admins can delete comments with replies from other users',
);
}*/
await this.commentRepo.deleteComment(comment.id);
return;
}
// Space admin can delete any comment
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Settings)) {
throw new ForbiddenException(
'You can only delete your own comments or must be a space admin',
);
}
await this.commentRepo.deleteComment(comment.id);
}
}
+15 -38
View File
@@ -7,10 +7,10 @@ import {
import { CreateCommentDto } from './dto/create-comment.dto';
import { UpdateCommentDto } from './dto/update-comment.dto';
import { CommentRepo } from '@docmost/db/repos/comment/comment.repo';
import { Comment, User } from '@docmost/db/types/entity.types';
import { Comment, Page, User } from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { PageRepo } from '@docmost/db/repos/page/page.repo';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class CommentService {
@@ -22,6 +22,7 @@ export class CommentService {
async findById(commentId: string) {
const comment = await this.commentRepo.findById(commentId, {
includeCreator: true,
includeResolvedBy: true,
});
if (!comment) {
throw new NotFoundException('Comment not found');
@@ -30,11 +31,10 @@ export class CommentService {
}
async create(
userId: string,
pageId: string,
workspaceId: string,
opts: { userId: string; page: Page; workspaceId: string },
createCommentDto: CreateCommentDto,
) {
const { userId, page, workspaceId } = opts;
const commentContent = JSON.parse(createCommentDto.content);
if (createCommentDto.parentCommentId) {
@@ -42,7 +42,7 @@ export class CommentService {
createCommentDto.parentCommentId,
);
if (!parentComment || parentComment.pageId !== pageId) {
if (!parentComment || parentComment.pageId !== page.id) {
throw new BadRequestException('Parent comment not found');
}
@@ -51,49 +51,38 @@ export class CommentService {
}
}
const createdComment = await this.commentRepo.insertComment({
pageId: pageId,
return await this.commentRepo.insertComment({
pageId: page.id,
content: commentContent,
selection: createCommentDto?.selection?.substring(0, 250),
type: 'inline',
parentCommentId: createCommentDto?.parentCommentId,
creatorId: userId,
workspaceId: workspaceId,
spaceId: page.spaceId,
});
return createdComment;
}
async findByPageId(
pageId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<Comment>> {
): Promise<CursorPaginationResult<Comment>> {
const page = await this.pageRepo.findById(pageId);
if (!page) {
throw new BadRequestException('Page not found');
}
const pageComments = await this.commentRepo.findPageComments(
pageId,
pagination,
);
return pageComments;
return this.commentRepo.findPageComments(pageId, pagination);
}
async update(
commentId: string,
comment: Comment,
updateCommentDto: UpdateCommentDto,
authUser: User,
): Promise<Comment> {
const commentContent = JSON.parse(updateCommentDto.content);
const comment = await this.commentRepo.findById(commentId);
if (!comment) {
throw new NotFoundException('Comment not found');
}
if (comment.creatorId !== authUser.id) {
throw new ForbiddenException('You can only edit your own comments');
}
@@ -104,26 +93,14 @@ export class CommentService {
{
content: commentContent,
editedAt: editedAt,
updatedAt: editedAt,
},
commentId,
comment.id,
);
comment.content = commentContent;
comment.editedAt = editedAt;
comment.updatedAt = editedAt;
return comment;
}
async remove(commentId: string, authUser: User): Promise<void> {
const comment = await this.commentRepo.findById(commentId);
if (!comment) {
throw new NotFoundException('Comment not found');
}
if (comment.creatorId !== authUser.id) {
throw new ForbiddenException('You can only delete your own comments');
}
await this.commentRepo.deleteComment(commentId);
}
}
@@ -11,7 +11,7 @@ import {Transform, TransformFnParams} from "class-transformer";
export class CreateGroupDto {
@MinLength(2)
@MaxLength(50)
@MaxLength(100)
@IsString()
@Transform(({ value }: TransformFnParams) => value?.trim())
name: string;
@@ -11,7 +11,7 @@ import { UpdateGroupDto } from '../dto/update-group.dto';
import { KyselyTransaction } from '@docmost/db/types/kysely.types';
import { GroupRepo } from '@docmost/db/repos/group/group.repo';
import { Group, InsertableGroup, User } from '@docmost/db/types/entity.types';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
import { GroupUserService } from './group-user.service';
@Injectable()
@@ -132,12 +132,8 @@ export class GroupService {
async getWorkspaceGroups(
workspaceId: string,
paginationOptions: PaginationOptions,
): Promise<PaginationResult<Group>> {
const groups = await this.groupRepo.getGroupsPaginated(
workspaceId,
paginationOptions,
);
return groups;
): Promise<CursorPaginationResult<Group>> {
return this.groupRepo.getGroupsPaginated(workspaceId, paginationOptions);
}
async deleteGroup(groupId: string, workspaceId: string): Promise<void> {
@@ -1,7 +1,7 @@
import { IsNotEmpty, IsString } from 'class-validator';
export class GetFileDto {
@IsString()
export class DeletedPageDto {
@IsNotEmpty()
attachmentId: string;
}
@IsString()
spaceId: string;
}
@@ -1,13 +1,13 @@
import { IsString, IsNotEmpty } from 'class-validator';
import { IsString, IsNotEmpty, IsOptional } from 'class-validator';
export class CopyPageToSpaceDto {
export class DuplicatePageDto {
@IsNotEmpty()
@IsString()
pageId: string;
@IsNotEmpty()
@IsOptional()
@IsString()
spaceId: string;
spaceId?: string;
}
export type CopyPageMapEntry = {
@@ -31,3 +31,9 @@ export class PageInfoDto extends PageIdDto {
@IsBoolean()
includeContent: boolean;
}
export class DeletePageDto extends PageIdDto {
@IsOptional()
@IsBoolean()
permanentlyDelete?: boolean;
}
@@ -1,7 +1,11 @@
import { IsOptional, IsString } from 'class-validator';
import { IsOptional, IsString, IsUUID } from 'class-validator';
import { SpaceIdDto } from './page.dto';
export class SidebarPageDto extends SpaceIdDto {
export class SidebarPageDto {
@IsOptional()
@IsUUID()
spaceId: string;
@IsOptional()
@IsString()
pageId: string;
+135 -47
View File
@@ -1,19 +1,24 @@
import {
Controller,
Post,
BadRequestException,
Body,
Controller,
ForbiddenException,
HttpCode,
HttpStatus,
UseGuards,
ForbiddenException,
NotFoundException,
BadRequestException,
Post,
UseGuards,
} from '@nestjs/common';
import { PageService } from './services/page.service';
import { CreatePageDto } from './dto/create-page.dto';
import { UpdatePageDto } from './dto/update-page.dto';
import { MovePageDto, MovePageToSpaceDto } from './dto/move-page.dto';
import { PageHistoryIdDto, PageIdDto, PageInfoDto } from './dto/page.dto';
import {
DeletePageDto,
PageHistoryIdDto,
PageIdDto,
PageInfoDto,
} from './dto/page.dto';
import { PageHistoryService } from './services/page-history.service';
import { AuthUser } from '../../common/decorators/auth-user.decorator';
import { AuthWorkspace } from '../../common/decorators/auth-workspace.decorator';
@@ -28,7 +33,8 @@ import {
import SpaceAbilityFactory from '../casl/abilities/space-ability.factory';
import { PageRepo } from '@docmost/db/repos/page/page.repo';
import { RecentPageDto } from './dto/recent-page.dto';
import { CopyPageToSpaceDto } from './dto/copy-page.dto';
import { DuplicatePageDto } from './dto/duplicate-page.dto';
import { DeletedPageDto } from './dto/deleted-page.dto';
@UseGuards(JwtAuthGuard)
@Controller('pages')
@@ -100,7 +106,47 @@ export class PageController {
@HttpCode(HttpStatus.OK)
@Post('delete')
async delete(@Body() pageIdDto: PageIdDto, @AuthUser() user: User) {
async delete(
@Body() deletePageDto: DeletePageDto,
@AuthUser() user: User,
@AuthWorkspace() workspace: Workspace,
) {
const page = await this.pageRepo.findById(deletePageDto.pageId);
if (!page) {
throw new NotFoundException('Page not found');
}
const ability = await this.spaceAbility.createForUser(user, page.spaceId);
if (deletePageDto.permanentlyDelete) {
// Permanent deletion requires space admin permissions
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Settings)) {
throw new ForbiddenException(
'Only space admins can permanently delete pages',
);
}
await this.pageService.forceDelete(deletePageDto.pageId, workspace.id);
} else {
// Soft delete requires page manage permissions
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
await this.pageService.removePage(
deletePageDto.pageId,
user.id,
workspace.id,
);
}
}
@HttpCode(HttpStatus.OK)
@Post('restore')
async restore(
@Body() pageIdDto: PageIdDto,
@AuthUser() user: User,
@AuthWorkspace() workspace: Workspace,
) {
const page = await this.pageRepo.findById(pageIdDto.pageId);
if (!page) {
@@ -111,13 +157,12 @@ export class PageController {
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
await this.pageService.forceDelete(pageIdDto.pageId);
}
@HttpCode(HttpStatus.OK)
@Post('restore')
async restore(@Body() pageIdDto: PageIdDto) {
// await this.pageService.restore(deletePageDto.id);
await this.pageRepo.restorePage(pageIdDto.pageId, workspace.id);
return this.pageRepo.findById(pageIdDto.pageId, {
includeHasChildren: true,
});
}
@HttpCode(HttpStatus.OK)
@@ -146,6 +191,30 @@ export class PageController {
return this.pageService.getRecentPages(user.id, pagination);
}
@HttpCode(HttpStatus.OK)
@Post('trash')
async getDeletedPages(
@Body() deletedPageDto: DeletedPageDto,
@Body() pagination: PaginationOptions,
@AuthUser() user: User,
) {
if (deletedPageDto.spaceId) {
const ability = await this.spaceAbility.createForUser(
user,
deletedPageDto.spaceId,
);
if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
return this.pageService.getDeletedSpacePages(
deletedPageDto.spaceId,
pagination,
);
}
}
// TODO: scope to workspaces
@HttpCode(HttpStatus.OK)
@Post('/history')
@@ -155,6 +224,10 @@ export class PageController {
@AuthUser() user: User,
) {
const page = await this.pageRepo.findById(dto.pageId);
if (!page) {
throw new NotFoundException('Page not found');
}
const ability = await this.spaceAbility.createForUser(user, page.spaceId);
if (ability.cannot(SpaceCaslAction.Read, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
@@ -191,21 +264,28 @@ export class PageController {
@Body() pagination: PaginationOptions,
@AuthUser() user: User,
) {
const ability = await this.spaceAbility.createForUser(user, dto.spaceId);
if (!dto.spaceId && !dto.pageId) {
throw new BadRequestException(
'Either spaceId or pageId must be provided',
);
}
let spaceId = dto.spaceId;
if (dto.pageId) {
const page = await this.pageRepo.findById(dto.pageId);
if (!page) {
throw new ForbiddenException();
}
spaceId = page.spaceId;
}
const ability = await this.spaceAbility.createForUser(user, spaceId);
if (ability.cannot(SpaceCaslAction.Read, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
let pageId = null;
if (dto.pageId) {
const page = await this.pageRepo.findById(dto.pageId);
if (page.spaceId !== dto.spaceId) {
throw new ForbiddenException();
}
pageId = page.id;
}
return this.pageService.getSidebarPages(dto.spaceId, pagination, pageId);
return this.pageService.getSidebarPages(spaceId, pagination, dto.pageId);
}
@HttpCode(HttpStatus.OK)
@@ -239,33 +319,41 @@ export class PageController {
}
@HttpCode(HttpStatus.OK)
@Post('copy-to-space')
async copyPageToSpace(
@Body() dto: CopyPageToSpaceDto,
@AuthUser() user: User,
) {
@Post('duplicate')
async duplicatePage(@Body() dto: DuplicatePageDto, @AuthUser() user: User) {
const copiedPage = await this.pageRepo.findById(dto.pageId);
if (!copiedPage) {
throw new NotFoundException('Page to copy not found');
}
if (copiedPage.spaceId === dto.spaceId) {
throw new BadRequestException('Page is already in this space');
// If spaceId is provided, it's a copy to different space
if (dto.spaceId) {
const abilities = await Promise.all([
this.spaceAbility.createForUser(user, copiedPage.spaceId),
this.spaceAbility.createForUser(user, dto.spaceId),
]);
if (
abilities.some((ability) =>
ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Page),
)
) {
throw new ForbiddenException();
}
return this.pageService.duplicatePage(copiedPage, dto.spaceId, user);
} else {
// If no spaceId, it's a duplicate in same space
const ability = await this.spaceAbility.createForUser(
user,
copiedPage.spaceId,
);
if (ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Page)) {
throw new ForbiddenException();
}
return this.pageService.duplicatePage(copiedPage, undefined, user);
}
const abilities = await Promise.all([
this.spaceAbility.createForUser(user, copiedPage.spaceId),
this.spaceAbility.createForUser(user, dto.spaceId),
]);
if (
abilities.some((ability) =>
ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Page),
)
) {
throw new ForbiddenException();
}
return this.pageService.copyPageToSpace(copiedPage, dto.spaceId, user);
}
@HttpCode(HttpStatus.OK)
+3 -2
View File
@@ -2,12 +2,13 @@ import { Module } from '@nestjs/common';
import { PageService } from './services/page.service';
import { PageController } from './page.controller';
import { PageHistoryService } from './services/page-history.service';
import { TrashCleanupService } from './services/trash-cleanup.service';
import { StorageModule } from '../../integrations/storage/storage.module';
@Module({
controllers: [PageController],
providers: [PageService, PageHistoryService],
providers: [PageService, PageHistoryService, TrashCleanupService],
exports: [PageService, PageHistoryService],
imports: [StorageModule]
imports: [StorageModule],
})
export class PageModule {}
@@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common';
import { PageHistoryRepo } from '@docmost/db/repos/page/page-history.repo';
import { PageHistory } from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class PageHistoryService {
@@ -15,12 +15,10 @@ export class PageHistoryService {
async findHistoryByPageId(
pageId: string,
paginationOptions: PaginationOptions,
): Promise<PaginationResult<any>> {
const pageHistory = await this.pageHistoryRepo.findPageHistoryByPageId(
): Promise<CursorPaginationResult<PageHistory>> {
return this.pageHistoryRepo.findPageHistoryByPageId(
pageId,
paginationOptions,
);
return pageHistory;
}
}
+218 -165
View File
@@ -10,15 +10,13 @@ import { PageRepo } from '@docmost/db/repos/page/page.repo';
import { InsertablePage, Page, User } from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import {
executeWithPagination,
PaginationResult,
} from '@docmost/db/pagination/pagination';
CursorPaginationResult,
executeWithCursorPagination,
} from '@docmost/db/pagination/cursor-pagination';
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB } from '@docmost/db/types/kysely.types';
import { generateJitteredKeyBetween } from 'fractional-indexing-jittered';
import { MovePageDto } from '../dto/move-page.dto';
import { ExpressionBuilder } from 'kysely';
import { DB } from '@docmost/db/types/db';
import { generateSlugId } from '../../../common/helpers';
import { executeTx } from '@docmost/db/utils';
import { AttachmentRepo } from '@docmost/db/repos/attachment/attachment.repo';
@@ -31,9 +29,17 @@ import {
removeMarkTypeFromDoc,
} from '../../../common/helpers/prosemirror/utils';
import { jsonToNode, jsonToText } from 'src/collaboration/collaboration.util';
import { CopyPageMapEntry, ICopyPageAttachment } from '../dto/copy-page.dto';
import {
CopyPageMapEntry,
ICopyPageAttachment,
} from '../dto/duplicate-page.dto';
import { Node as PMNode } from '@tiptap/pm/model';
import { StorageService } from '../../../integrations/storage/storage.service';
import { InjectQueue } from '@nestjs/bullmq';
import { Queue } from 'bullmq';
import { QueueJob, QueueName } from '../../../integrations/queue/constants';
import { EventName } from '../../../common/events/event.contants';
import { EventEmitter2 } from '@nestjs/event-emitter';
@Injectable()
export class PageService {
@@ -44,6 +50,9 @@ export class PageService {
private attachmentRepo: AttachmentRepo,
@InjectKysely() private readonly db: KyselyDB,
private readonly storageService: StorageService,
@InjectQueue(QueueName.ATTACHMENT_QUEUE) private attachmentQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
private eventEmitter: EventEmitter2,
) {}
async findById(
@@ -104,7 +113,8 @@ export class PageService {
.selectFrom('pages')
.select(['position'])
.where('spaceId', '=', spaceId)
.orderBy('position', 'desc')
.where('deletedAt', 'is', null)
.orderBy('position', (ob) => ob.collate('C').desc())
.limit(1);
if (parentPageId) {
@@ -166,28 +176,11 @@ export class PageService {
});
}
withHasChildren(eb: ExpressionBuilder<DB, 'pages'>) {
return eb
.selectFrom('pages as child')
.select((eb) =>
eb
.case()
.when(eb.fn.countAll(), '>', 0)
.then(true)
.else(false)
.end()
.as('count'),
)
.whereRef('child.parentPageId', '=', 'pages.id')
.limit(1)
.as('hasChildren');
}
async getSidebarPages(
spaceId: string,
pagination: PaginationOptions,
pageId?: string,
): Promise<any> {
): Promise<CursorPaginationResult<Partial<Page> & { hasChildren: boolean }>> {
let query = this.db
.selectFrom('pages')
.select([
@@ -199,9 +192,10 @@ export class PageService {
'parentPageId',
'spaceId',
'creatorId',
'deletedAt',
])
.select((eb) => this.withHasChildren(eb))
.orderBy('position', 'asc')
.select((eb) => this.pageRepo.withHasChildren(eb))
.where('deletedAt', 'is', null)
.where('spaceId', '=', spaceId);
if (pageId) {
@@ -210,12 +204,19 @@ export class PageService {
query = query.where('parentPageId', 'is', null);
}
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: 250,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'position', direction: 'asc', orderModifier: (ob) => ob.collate('C').asc() },
{ expression: 'id', direction: 'asc' },
],
parseCursor: (cursor) => ({
position: cursor.position,
id: cursor.id,
}),
});
return result;
}
async movePageToSpace(rootPage: Page, spaceId: string) {
@@ -240,29 +241,54 @@ export class PageService {
);
}
// update spaceId in shares
if (pageIds.length > 0) {
// update spaceId in shares
await trx
.updateTable('shares')
.set({ spaceId: spaceId })
.where('pageId', 'in', pageIds)
.execute();
}
// Update attachments
await this.attachmentRepo.updateAttachmentsByPageId(
{ spaceId },
pageIds,
trx,
);
// Update comments
await trx
.updateTable('comments')
.set({ spaceId: spaceId })
.where('pageId', 'in', pageIds)
.execute();
// Update attachments
await this.attachmentRepo.updateAttachmentsByPageId(
{ spaceId },
pageIds,
trx,
);
await this.aiQueue.add(QueueJob.PAGE_MOVED_TO_SPACE, {
pageId: pageIds,
workspaceId: rootPage.workspaceId,
});
}
});
}
async copyPageToSpace(rootPage: Page, spaceId: string, authUser: User) {
//TODO:
// i. maintain internal links within copied pages
async duplicatePage(
rootPage: Page,
targetSpaceId: string | undefined,
authUser: User,
) {
const spaceId = targetSpaceId || rootPage.spaceId;
const isDuplicateInSameSpace =
!targetSpaceId || targetSpaceId === rootPage.spaceId;
const nextPosition = await this.nextPagePosition(spaceId);
let nextPosition: string;
if (isDuplicateInSameSpace) {
// For duplicate in same space, position right after the original page
nextPosition = generateJitteredKeyBetween(rootPage.position, null);
} else {
// For copy to different space, position at the end
nextPosition = await this.nextPagePosition(spaceId);
}
const pages = await this.pageRepo.getPageAndDescendants(rootPage.id, {
includeContent: true,
@@ -326,12 +352,38 @@ export class PageService {
});
}
// Update internal page links in mention nodes
prosemirrorDoc.descendants((node: PMNode) => {
if (
node.type.name === 'mention' &&
node.attrs.entityType === 'page'
) {
const referencedPageId = node.attrs.entityId;
// Check if the referenced page is within the pages being copied
if (referencedPageId && pageMap.has(referencedPageId)) {
const mappedPage = pageMap.get(referencedPageId);
//@ts-ignore
node.attrs.entityId = mappedPage.newPageId;
//@ts-ignore
node.attrs.slugId = mappedPage.newSlugId;
}
}
});
const prosemirrorJson = prosemirrorDoc.toJSON();
// Add "Copy of " prefix to the root page title only for duplicates in same space
let title = page.title;
if (isDuplicateInSameSpace && page.id === rootPage.id) {
const originalTitle = page.title || 'Untitled';
title = `Copy of ${originalTitle}`;
}
return {
id: pageFromMap.newPageId,
slugId: pageFromMap.newSlugId,
title: page.title,
title: title,
icon: page.icon,
content: prosemirrorJson,
textContent: jsonToText(prosemirrorJson),
@@ -341,15 +393,26 @@ export class PageService {
workspaceId: page.workspaceId,
creatorId: authUser.id,
lastUpdatedById: authUser.id,
parentPageId: page.parentPageId
? pageMap.get(page.parentPageId)?.newPageId
: null,
parentPageId:
page.id === rootPage.id
? isDuplicateInSameSpace
? rootPage.parentPageId
: null
: page.parentPageId
? pageMap.get(page.parentPageId)?.newPageId
: null,
};
}),
);
await this.db.insertInto('pages').values(insertablePages).execute();
const insertedPageIds = insertablePages.map((page) => page.id);
this.eventEmitter.emit(EventName.PAGE_CREATED, {
pageIds: insertedPageIds,
workspaceId: authUser.workspaceId,
});
//TODO: best to handle this in a queue
const attachmentsIds = Array.from(attachmentMap.keys());
if (attachmentsIds.length > 0) {
@@ -377,33 +440,50 @@ export class PageService {
attachment.id,
newAttachmentId,
);
await this.storageService.copy(attachment.filePath, newPathFile);
await this.db
.insertInto('attachments')
.values({
id: newAttachmentId,
type: attachment.type,
filePath: newPathFile,
fileName: attachment.fileName,
fileSize: attachment.fileSize,
mimeType: attachment.mimeType,
fileExt: attachment.fileExt,
creatorId: attachment.creatorId,
workspaceId: attachment.workspaceId,
pageId: newPageId,
spaceId: spaceId,
})
.execute();
try {
await this.storageService.copy(attachment.filePath, newPathFile);
await this.db
.insertInto('attachments')
.values({
id: newAttachmentId,
type: attachment.type,
filePath: newPathFile,
fileName: attachment.fileName,
fileSize: attachment.fileSize,
mimeType: attachment.mimeType,
fileExt: attachment.fileExt,
creatorId: attachment.creatorId,
workspaceId: attachment.workspaceId,
pageId: newPageId,
spaceId: spaceId,
})
.execute();
} catch (err) {
this.logger.error(
`Duplicate page: failed to copy attachment ${attachment.id}`,
err,
);
// Continue with other attachments even if one fails
}
} catch (err) {
this.logger.log(err);
this.logger.error(err);
}
}
}
const newPageId = pageMap.get(rootPage.id).newPageId;
return await this.pageRepo.findById(newPageId, {
const duplicatedPage = await this.pageRepo.findById(newPageId, {
includeSpace: true,
});
const hasChildren = pages.length > 1;
return {
...duplicatedPage,
hasChildren,
};
}
async movePage(dto: MovePageDto, movedPage: Page) {
@@ -450,9 +530,11 @@ export class PageService {
'position',
'parentPageId',
'spaceId',
'deletedAt',
])
.select((eb) => this.withHasChildren(eb))
.select((eb) => this.pageRepo.withHasChildren(eb))
.where('id', '=', childPageId)
.where('deletedAt', 'is', null)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
@@ -464,6 +546,7 @@ export class PageService {
'p.position',
'p.parentPageId',
'p.spaceId',
'p.deletedAt',
])
.select(
exp
@@ -478,11 +561,13 @@ export class PageService {
.as('count'),
)
.whereRef('child.parentPageId', '=', 'id')
.where('child.deletedAt', 'is', null)
.limit(1)
.as('hasChildren'),
)
//.select((eb) => this.withHasChildren(eb))
.innerJoin('page_ancestors as pa', 'pa.parentPageId', 'p.id'),
.innerJoin('page_ancestors as pa', 'pa.parentPageId', 'p.id')
.where('p.deletedAt', 'is', null),
),
)
.selectFrom('page_ancestors')
@@ -495,109 +580,77 @@ export class PageService {
async getRecentSpacePages(
spaceId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<Page>> {
return await this.pageRepo.getRecentPagesInSpace(spaceId, pagination);
): Promise<CursorPaginationResult<Page>> {
return this.pageRepo.getRecentPagesInSpace(spaceId, pagination);
}
async getRecentPages(
userId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<Page>> {
return await this.pageRepo.getRecentPages(userId, pagination);
): Promise<CursorPaginationResult<Page>> {
return this.pageRepo.getRecentPages(userId, pagination);
}
async forceDelete(pageId: string): Promise<void> {
await this.pageRepo.deletePage(pageId);
async getDeletedSpacePages(
spaceId: string,
pagination: PaginationOptions,
): Promise<CursorPaginationResult<Page>> {
return this.pageRepo.getDeletedPagesInSpace(spaceId, pagination);
}
async forceDelete(pageId: string, workspaceId: string): Promise<void> {
// Get all descendant IDs (including the page itself) using recursive CTE
const descendants = await this.db
.withRecursive('page_descendants', (db) =>
db
.selectFrom('pages')
.select(['id'])
.where('id', '=', pageId)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
.select(['p.id'])
.innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId'),
),
)
.selectFrom('page_descendants')
.selectAll()
.execute();
const pageIds = descendants.map((d) => d.id);
// Queue attachment deletion for all pages with unique job IDs to prevent duplicates
for (const id of pageIds) {
await this.attachmentQueue.add(
QueueJob.DELETE_PAGE_ATTACHMENTS,
{
pageId: id,
},
{
jobId: `delete-page-attachments-${id}`,
attempts: 3,
backoff: {
type: 'exponential',
delay: 5000,
},
},
);
}
if (pageIds.length > 0) {
await this.db.deleteFrom('pages').where('id', 'in', pageIds).execute();
this.eventEmitter.emit(EventName.PAGE_DELETED, {
pageIds: pageIds,
workspaceId,
});
}
}
async removePage(
pageId: string,
userId: string,
workspaceId: string,
): Promise<void> {
await this.pageRepo.removePage(pageId, userId, workspaceId);
}
}
/*
// TODO: page deletion and restoration
async delete(pageId: string): Promise<void> {
await this.dataSource.transaction(async (manager: EntityManager) => {
const page = await manager
.createQueryBuilder(Page, 'page')
.where('page.id = :pageId', { pageId })
.select(['page.id', 'page.workspaceId'])
.getOne();
if (!page) {
throw new NotFoundException(`Page not found`);
}
await this.softDeleteChildrenRecursive(page.id, manager);
await this.pageOrderingService.removePageFromHierarchy(page, manager);
await manager.softDelete(Page, pageId);
});
}
private async softDeleteChildrenRecursive(
parentId: string,
manager: EntityManager,
): Promise<void> {
const childrenPage = await manager
.createQueryBuilder(Page, 'page')
.where('page.parentPageId = :parentId', { parentId })
.select(['page.id', 'page.title', 'page.parentPageId'])
.getMany();
for (const child of childrenPage) {
await this.softDeleteChildrenRecursive(child.id, manager);
await manager.softDelete(Page, child.id);
}
}
async restore(pageId: string): Promise<void> {
await this.dataSource.transaction(async (manager: EntityManager) => {
const isDeleted = await manager
.createQueryBuilder(Page, 'page')
.where('page.id = :pageId', { pageId })
.withDeleted()
.getCount();
if (!isDeleted) {
return;
}
await manager.recover(Page, { id: pageId });
await this.restoreChildrenRecursive(pageId, manager);
// Fetch the page details to find out its parent and workspace
const restoredPage = await manager
.createQueryBuilder(Page, 'page')
.where('page.id = :pageId', { pageId })
.select(['page.id', 'page.title', 'page.spaceId', 'page.parentPageId'])
.getOne();
if (!restoredPage) {
throw new NotFoundException(`Restored page not found.`);
}
// add page back to its hierarchy
await this.pageOrderingService.addPageToOrder(
restoredPage.spaceId,
pageId,
restoredPage.parentPageId,
);
});
}
private async restoreChildrenRecursive(
parentId: string,
manager: EntityManager,
): Promise<void> {
const childrenPage = await manager
.createQueryBuilder(Page, 'page')
.setLock('pessimistic_write')
.where('page.parentPageId = :parentId', { parentId })
.select(['page.id', 'page.title', 'page.parentPageId'])
.withDeleted()
.getMany();
for (const child of childrenPage) {
await this.restoreChildrenRecursive(child.id, manager);
await manager.recover(Page, { id: child.id });
}
}
*/
@@ -0,0 +1,116 @@
import { Injectable, Logger } from '@nestjs/common';
import { Interval } from '@nestjs/schedule';
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB } from '@docmost/db/types/kysely.types';
import { InjectQueue } from '@nestjs/bullmq';
import { Queue } from 'bullmq';
import { QueueJob, QueueName } from '../../../integrations/queue/constants';
@Injectable()
export class TrashCleanupService {
private readonly logger = new Logger(TrashCleanupService.name);
private readonly RETENTION_DAYS = 30;
constructor(
@InjectKysely() private readonly db: KyselyDB,
@InjectQueue(QueueName.ATTACHMENT_QUEUE) private attachmentQueue: Queue,
) {}
@Interval('trash-cleanup', 24 * 60 * 60 * 1000) // every 24 hours
async cleanupOldTrash() {
try {
this.logger.debug('Starting trash cleanup job');
const retentionDate = new Date();
retentionDate.setDate(retentionDate.getDate() - this.RETENTION_DAYS);
// Get all pages that were deleted more than 30 days ago
const oldDeletedPages = await this.db
.selectFrom('pages')
.select(['id', 'spaceId', 'workspaceId'])
.where('deletedAt', '<', retentionDate)
.execute();
if (oldDeletedPages.length === 0) {
this.logger.debug('No old trash items to clean up');
return;
}
this.logger.debug(`Found ${oldDeletedPages.length} pages to clean up`);
// Process each page
for (const page of oldDeletedPages) {
try {
await this.cleanupPage(page.id);
} catch (error) {
this.logger.error(
`Failed to cleanup page ${page.id}: ${error instanceof Error ? error.message : 'Unknown error'}`,
error instanceof Error ? error.stack : undefined,
);
}
}
this.logger.debug('Trash cleanup job completed');
} catch (error) {
this.logger.error(
'Trash cleanup job failed',
error instanceof Error ? error.stack : undefined,
);
}
}
private async cleanupPage(pageId: string) {
// Get all descendants using recursive CTE (including the page itself)
const descendants = await this.db
.withRecursive('page_descendants', (db) =>
db
.selectFrom('pages')
.select(['id'])
.where('id', '=', pageId)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
.select(['p.id'])
.innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId'),
),
)
.selectFrom('page_descendants')
.selectAll()
.execute();
const pageIds = descendants.map((d) => d.id);
this.logger.debug(
`Cleaning up page ${pageId} with ${pageIds.length - 1} descendants`,
);
// Queue attachment deletion for all pages with unique job IDs to prevent duplicates
for (const id of pageIds) {
await this.attachmentQueue.add(
QueueJob.DELETE_PAGE_ATTACHMENTS,
{
pageId: id,
},
{
jobId: `delete-page-attachments-${id}`,
attempts: 3,
backoff: {
type: 'exponential',
delay: 5000,
},
},
);
}
try {
if (pageIds.length > 0) {
await this.db.deleteFrom('pages').where('id', 'in', pageIds).execute();
}
} catch (error) {
// Log but don't throw - pages might have been deleted by another node
this.logger.warn(
`Error deleting pages, they may have been already deleted: ${error instanceof Error ? error.message : 'Unknown error'}`,
);
}
}
}
@@ -1,3 +1,5 @@
import { Space } from '@docmost/db/types/entity.types';
export class SearchResponseDto {
id: string;
title: string;
@@ -8,4 +10,5 @@ export class SearchResponseDto {
highlight: string;
createdAt: Date;
updatedAt: Date;
space: Partial<Space>;
}
@@ -5,15 +5,13 @@ import {
IsOptional,
IsString,
} from 'class-validator';
import { PartialType } from '@nestjs/mapped-types';
import { CreateWorkspaceDto } from '../../workspace/dto/create-workspace.dto';
export class SearchDTO {
@IsNotEmpty()
@IsString()
query: string;
@IsNotEmpty()
@IsOptional()
@IsString()
spaceId: string;
@@ -5,6 +5,7 @@ import {
ForbiddenException,
HttpCode,
HttpStatus,
Logger,
Post,
UseGuards,
} from '@nestjs/common';
@@ -24,13 +25,19 @@ import {
} from '../casl/interfaces/space-ability.type';
import { AuthUser } from '../../common/decorators/auth-user.decorator';
import { Public } from 'src/common/decorators/public.decorator';
import { EnvironmentService } from '../../integrations/environment/environment.service';
import { ModuleRef } from '@nestjs/core';
@UseGuards(JwtAuthGuard)
@Controller('search')
export class SearchController {
private readonly logger = new Logger(SearchController.name);
constructor(
private readonly searchService: SearchService,
private readonly spaceAbility: SpaceAbilityFactory,
private readonly environmentService: EnvironmentService,
private moduleRef: ModuleRef,
) {}
@HttpCode(HttpStatus.OK)
@@ -53,7 +60,14 @@ export class SearchController {
}
}
return this.searchService.searchPage(searchDto.query, searchDto, {
if (this.environmentService.getSearchDriver() === 'typesense') {
return this.searchTypesense(searchDto, {
userId: user.id,
workspaceId: workspace.id,
});
}
return this.searchService.searchPage(searchDto, {
userId: user.id,
workspaceId: workspace.id,
});
@@ -81,8 +95,47 @@ export class SearchController {
throw new BadRequestException('shareId is required');
}
return this.searchService.searchPage(searchDto.query, searchDto, {
if (this.environmentService.getSearchDriver() === 'typesense') {
return this.searchTypesense(searchDto, {
workspaceId: workspace.id,
});
}
return this.searchService.searchPage(searchDto, {
workspaceId: workspace.id,
});
}
async searchTypesense(
searchParams: SearchDTO,
opts: {
userId?: string;
workspaceId: string;
},
) {
const { userId, workspaceId } = opts;
let TypesenseModule: any;
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
TypesenseModule = require('./../../ee/typesense/services/page-search.service');
const PageSearchService = this.moduleRef.get(
TypesenseModule.PageSearchService,
{
strict: false,
},
);
return PageSearchService.searchPage(searchParams, {
userId: userId,
workspaceId,
});
} catch (err) {
this.logger.debug(
'Typesense module requested but enterprise module not bundled in this build',
);
}
throw new BadRequestException('Enterprise Typesense search module missing');
}
}
+56 -28
View File
@@ -21,15 +21,16 @@ export class SearchService {
) {}
async searchPage(
query: string,
searchParams: SearchDTO,
opts: {
userId?: string;
workspaceId: string;
},
): Promise<SearchResponseDto[]> {
): Promise<{ items: SearchResponseDto[] }> {
const { query } = searchParams;
if (query.length < 1) {
return;
return { items: [] };
}
const searchQuery = tsquery(query.trim() + '*');
@@ -44,17 +45,24 @@ export class SearchService {
'creatorId',
'createdAt',
'updatedAt',
sql<number>`ts_rank(tsv, to_tsquery(${searchQuery}))`.as('rank'),
sql<string>`ts_headline('english', text_content, to_tsquery(${searchQuery}),'MinWords=9, MaxWords=10, MaxFragments=3')`.as(
sql<number>`ts_rank(tsv, to_tsquery('english', f_unaccent(${searchQuery})))`.as(
'rank',
),
sql<string>`ts_headline('english', text_content, to_tsquery('english', f_unaccent(${searchQuery})),'MinWords=9, MaxWords=10, MaxFragments=3')`.as(
'highlight',
),
])
.where('tsv', '@@', sql<string>`to_tsquery(${searchQuery})`)
.where(
'tsv',
'@@',
sql<string>`to_tsquery('english', f_unaccent(${searchQuery}))`,
)
.$if(Boolean(searchParams.creatorId), (qb) =>
qb.where('creatorId', '=', searchParams.creatorId),
)
.where('deletedAt', 'is', null)
.orderBy('rank', 'desc')
.limit(searchParams.limit | 20)
.limit(searchParams.limit || 25)
.offset(searchParams.offset || 0);
if (!searchParams.shareId) {
@@ -66,22 +74,19 @@ export class SearchService {
queryResults = queryResults.where('spaceId', '=', searchParams.spaceId);
} else if (opts.userId && !searchParams.spaceId) {
// only search spaces the user is a member of
const userSpaceIds = await this.spaceMemberRepo.getUserSpaceIds(
opts.userId,
);
if (userSpaceIds.length > 0) {
queryResults = queryResults
.where('spaceId', 'in', userSpaceIds)
.where('workspaceId', '=', opts.workspaceId);
} else {
return [];
}
queryResults = queryResults
.where(
'spaceId',
'in',
this.spaceMemberRepo.getUserSpaceIdsQuery(opts.userId),
)
.where('workspaceId', '=', opts.workspaceId);
} else if (searchParams.shareId && !searchParams.spaceId && !opts.userId) {
// search in shares
const shareId = searchParams.shareId;
const share = await this.shareRepo.findById(shareId);
if (!share || share.workspaceId !== opts.workspaceId) {
return [];
return { items: [] };
}
const pageIdsToSearch = [];
@@ -103,10 +108,10 @@ export class SearchService {
.where('id', 'in', pageIdsToSearch)
.where('workspaceId', '=', opts.workspaceId);
} else {
return [];
return { items: [] };
}
} else {
return [];
return { items: [] };
}
//@ts-ignore
@@ -122,7 +127,7 @@ export class SearchService {
return result;
});
return searchResults;
return { items: searchResults };
}
async searchSuggestions(
@@ -138,21 +143,37 @@ export class SearchService {
const query = suggestion.query.toLowerCase().trim();
if (suggestion.includeUsers) {
users = await this.db
const userQuery = this.db
.selectFrom('users')
.select(['id', 'name', 'avatarUrl'])
.where((eb) => eb(sql`LOWER(users.name)`, 'like', `%${query}%`))
.select(['id', 'name', 'email', 'avatarUrl'])
.where('workspaceId', '=', workspaceId)
.where('deletedAt', 'is', null)
.limit(limit)
.execute();
.where((eb) =>
eb.or([
eb(
sql`LOWER(f_unaccent(users.name))`,
'like',
sql`LOWER(f_unaccent(${`%${query}%`}))`,
),
eb(sql`users.email`, 'ilike', sql`f_unaccent(${`%${query}%`})`),
]),
)
.limit(limit);
users = await userQuery.execute();
}
if (suggestion.includeGroups) {
groups = await this.db
.selectFrom('groups')
.select(['id', 'name', 'description'])
.where((eb) => eb(sql`LOWER(groups.name)`, 'like', `%${query}%`))
.where((eb) =>
eb(
sql`LOWER(f_unaccent(groups.name))`,
'like',
sql`LOWER(f_unaccent(${`%${query}%`}))`,
),
)
.where('workspaceId', '=', workspaceId)
.limit(limit)
.execute();
@@ -162,7 +183,14 @@ export class SearchService {
let pageSearch = this.db
.selectFrom('pages')
.select(['id', 'slugId', 'title', 'icon', 'spaceId'])
.where((eb) => eb(sql`LOWER(pages.title)`, 'like', `%${query}%`))
.where((eb) =>
eb(
sql`LOWER(f_unaccent(pages.title))`,
'like',
sql`LOWER(f_unaccent(${`%${query}%`}))`,
),
)
.where('deletedAt', 'is', null)
.where('workspaceId', '=', workspaceId)
.limit(limit);
@@ -7,6 +7,7 @@ import { validate as isValidUUID } from 'uuid';
import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo';
import { EnvironmentService } from '../../integrations/environment/environment.service';
import { Workspace } from '@docmost/db/types/entity.types';
import { htmlEscape } from '../../common/helpers/html-escaper';
@Controller('share')
export class ShareSeoController {
@@ -68,7 +69,7 @@ export class ShareSeoController {
return this.sendIndex(indexFilePath, res);
}
const rawTitle = share.sharedPage.title ?? 'untitled';
const rawTitle = htmlEscape(share?.sharedPage.title ?? 'untitled');
const metaTitle =
rawTitle.length > 80 ? `${rawTitle.slice(0, 77)}` : rawTitle;
+11 -6
View File
@@ -31,6 +31,7 @@ import { Public } from '../../common/decorators/public.decorator';
import { ShareRepo } from '@docmost/db/repos/share/share.repo';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { EnvironmentService } from '../../integrations/environment/environment.service';
import { hasLicenseOrEE } from '../../common/helpers';
@UseGuards(JwtAuthGuard)
@Controller('shares')
@@ -65,9 +66,11 @@ export class ShareController {
return {
...(await this.shareService.getSharedPage(dto, workspace.id)),
hasLicenseKey:
Boolean(workspace.licenseKey) ||
(this.environmentService.isCloud() && workspace.plan === 'business'),
hasLicenseKey: hasLicenseOrEE({
licenseKey: workspace.licenseKey,
isCloud: this.environmentService.isCloud(),
plan: workspace.plan,
}),
};
}
@@ -175,9 +178,11 @@ export class ShareController {
) {
return {
...(await this.shareService.getShareTree(dto.shareId, workspace.id)),
hasLicenseKey:
Boolean(workspace.licenseKey) ||
(this.environmentService.isCloud() && workspace.plan === 'business'),
hasLicenseKey: hasLicenseOrEE({
licenseKey: workspace.licenseKey,
isCloud: this.environmentService.isCloud(),
plan: workspace.plan,
}),
};
}
}
+56 -52
View File
@@ -69,8 +69,8 @@ export class ShareService {
return await this.shareRepo.insertShare({
key: nanoIdGen().toLowerCase(),
pageId: page.id,
includeSubPages: createShareDto.includeSubPages || true,
searchIndexing: createShareDto.searchIndexing || true,
includeSubPages: createShareDto.includeSubPages ?? false,
searchIndexing: createShareDto.searchIndexing ?? false,
creatorId: authUserId,
spaceId: page.spaceId,
workspaceId,
@@ -108,12 +108,12 @@ export class ShareService {
includeCreator: true,
});
page.content = await this.updatePublicAttachments(page);
if (!page) {
if (!page || page.deletedAt) {
throw new NotFoundException('Shared page not found');
}
page.content = await this.updatePublicAttachments(page);
return { page, share };
}
@@ -123,78 +123,82 @@ export class ShareService {
.withRecursive('page_hierarchy', (cte) =>
cte
.selectFrom('pages')
.leftJoin('shares', 'shares.pageId', 'pages.id')
.select([
'id',
'slugId',
'pages.id',
'pages.slugId',
'pages.title',
'pages.icon',
'parentPageId',
'pages.parentPageId',
sql`0`.as('level'),
'shares.id as shareId',
'shares.key as shareKey',
'shares.includeSubPages',
'shares.searchIndexing',
'shares.creatorId',
'shares.spaceId',
'shares.workspaceId',
'shares.createdAt',
])
.where(isValidUUID(pageId) ? 'id' : 'slugId', '=', pageId)
.unionAll((union) =>
union
.selectFrom('pages as p')
.select([
'p.id',
'p.slugId',
'p.title',
'p.icon',
'p.parentPageId',
// Increase the level by 1 for each ancestor.
sql`ph.level + 1`.as('level'),
])
.innerJoin('page_hierarchy as ph', 'ph.parentPageId', 'p.id'),
.where(isValidUUID(pageId) ? 'pages.id' : 'pages.slugId', '=', pageId)
.where('pages.deletedAt', 'is', null)
.unionAll(
(union) =>
union
.selectFrom('pages as p')
.innerJoin('page_hierarchy as ph', 'ph.parentPageId', 'p.id')
.leftJoin('shares as s', 's.pageId', 'p.id')
.select([
'p.id',
'p.slugId',
'p.title',
'p.icon',
'p.parentPageId',
sql`ph.level + 1`.as('level'),
's.id as shareId',
's.key as shareKey',
's.includeSubPages',
's.searchIndexing',
's.creatorId',
's.spaceId',
's.workspaceId',
's.createdAt',
])
.where('p.deletedAt', 'is', null)
.where(sql`ph.share_id`, 'is', null) // stop if share found
.where(sql`ph.level`, '<', sql`25`), // prevent loop
),
)
.selectFrom('page_hierarchy')
.leftJoin('shares', 'shares.pageId', 'page_hierarchy.id')
.select([
'page_hierarchy.id as sharedPageId',
'page_hierarchy.slugId as sharedPageSlugId',
'page_hierarchy.title as sharedPageTitle',
'page_hierarchy.icon as sharedPageIcon',
'page_hierarchy.level as level',
'shares.id',
'shares.key',
'shares.pageId',
'shares.includeSubPages',
'shares.searchIndexing',
'shares.creatorId',
'shares.spaceId',
'shares.workspaceId',
'shares.createdAt',
'shares.updatedAt',
])
.where('shares.id', 'is not', null)
.orderBy('page_hierarchy.level', 'asc')
.selectAll()
.where('shareId', 'is not', null)
.limit(1)
.executeTakeFirst();
if (!share || share.workspaceId != workspaceId) {
if (!share || share.workspaceId !== workspaceId) {
return undefined;
}
if (share.level === 1 && !share.includeSubPages) {
// we can only show a page if its shared ancestor permits it
if ((share.level as number) > 0 && !share.includeSubPages) {
return undefined;
}
return {
id: share.id,
key: share.key,
id: share.shareId,
key: share.shareKey,
includeSubPages: share.includeSubPages,
searchIndexing: share.searchIndexing,
pageId: share.pageId,
pageId: share.id,
creatorId: share.creatorId,
spaceId: share.spaceId,
workspaceId: share.workspaceId,
createdAt: share.createdAt,
level: share.level,
sharedPage: {
id: share.sharedPageId,
slugId: share.sharedPageSlugId,
title: share.sharedPageTitle,
icon: share.sharedPageIcon,
id: share.id,
slugId: share.slugId,
title: share.title,
icon: share.icon,
},
};
}
@@ -9,7 +9,7 @@ import {Transform, TransformFnParams} from "class-transformer";
export class CreateSpaceDto {
@MinLength(2)
@MaxLength(50)
@MaxLength(100)
@IsString()
@Transform(({ value }: TransformFnParams) => value?.trim())
name: string;
@@ -19,7 +19,7 @@ export class CreateSpaceDto {
description?: string;
@MinLength(2)
@MaxLength(50)
@MaxLength(100)
@IsAlphanumeric()
slug: string;
}
@@ -13,7 +13,7 @@ import { SpaceRepo } from '@docmost/db/repos/space/space.repo';
import { RemoveSpaceMemberDto } from '../dto/remove-space-member.dto';
import { UpdateSpaceMemberRoleDto } from '../dto/update-space-member-role.dto';
import { SpaceRole } from '../../../common/helpers/types/permission';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class SpaceMemberService {
@@ -68,18 +68,16 @@ export class SpaceMemberService {
spaceId: string,
workspaceId: string,
pagination: PaginationOptions,
) {
): Promise<CursorPaginationResult<any>> {
const space = await this.spaceRepo.findById(spaceId, workspaceId);
if (!space) {
throw new NotFoundException('Space not found');
}
const members = await this.spaceMemberRepo.getSpaceMembersPaginated(
return await this.spaceMemberRepo.getSpaceMembersPaginated(
spaceId,
pagination,
);
return members;
}
async addMembersToSpaceBatch(
@@ -276,7 +274,7 @@ export class SpaceMemberService {
async getUserSpaces(
userId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<Space>> {
return await this.spaceMemberRepo.getUserSpaces(userId, pagination);
): Promise<CursorPaginationResult<Space>> {
return this.spaceMemberRepo.getUserSpaces(userId, pagination);
}
}
@@ -8,7 +8,6 @@ import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { SpaceRepo } from '@docmost/db/repos/space/space.repo';
import { KyselyDB, KyselyTransaction } from '@docmost/db/types/kysely.types';
import { Space, User } from '@docmost/db/types/entity.types';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { UpdateSpaceDto } from '../dto/update-space.dto';
import { executeTx } from '@docmost/db/utils';
import { InjectKysely } from 'nestjs-kysely';
@@ -17,6 +16,7 @@ import { SpaceRole } from '../../../common/helpers/types/permission';
import { QueueJob, QueueName } from 'src/integrations/queue/constants';
import { Queue } from 'bullmq';
import { InjectQueue } from '@nestjs/bullmq';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class SpaceService {
@@ -130,13 +130,8 @@ export class SpaceService {
async getWorkspaceSpaces(
workspaceId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<Space>> {
const spaces = await this.spaceRepo.getSpacesInWorkspace(
workspaceId,
pagination,
);
return spaces;
): Promise<CursorPaginationResult<Space>> {
return this.spaceRepo.getSpacesInWorkspace(workspaceId, pagination);
}
async deleteSpace(spaceId: string, workspaceId: string): Promise<void> {
@@ -1,5 +1,13 @@
import { OmitType, PartialType } from '@nestjs/mapped-types';
import { IsBoolean, IsOptional, IsString } from 'class-validator';
import {
IsBoolean,
IsIn,
IsNotEmpty,
IsOptional,
IsString,
MaxLength,
MinLength,
} from 'class-validator';
import { CreateUserDto } from '../../auth/dto/create-user.dto';
export class UpdateUserDto extends PartialType(
@@ -13,7 +21,18 @@ export class UpdateUserDto extends PartialType(
@IsBoolean()
fullPageWidth: boolean;
@IsOptional()
@IsString()
@IsIn(['read', 'edit'])
pageEditMode: string;
@IsOptional()
@IsString()
locale: string;
@IsOptional()
@MinLength(8)
@MaxLength(70)
@IsString()
confirmPassword: string;
}
+1 -1
View File
@@ -50,6 +50,6 @@ export class UserController {
@AuthUser() user: User,
@AuthWorkspace() workspace: Workspace,
) {
return this.userService.update(updateUserDto, user.id, workspace.id);
return this.userService.update(updateUserDto, user.id, workspace);
}
}
+41 -4
View File
@@ -3,8 +3,12 @@ import {
BadRequestException,
Injectable,
NotFoundException,
UnauthorizedException,
} from '@nestjs/common';
import { UpdateUserDto } from './dto/update-user.dto';
import { comparePasswordHash } from 'src/common/helpers/utils';
import { Workspace } from '@docmost/db/types/entity.types';
import { validateSsoEnforcement } from '../auth/auth.util';
@Injectable()
export class UserService {
@@ -17,9 +21,14 @@ export class UserService {
async update(
updateUserDto: UpdateUserDto,
userId: string,
workspaceId: string,
workspace: Workspace,
) {
const user = await this.userRepo.findById(userId, workspaceId);
const includePassword =
updateUserDto.email != null && updateUserDto.confirmPassword != null;
const user = await this.userRepo.findById(userId, workspace.id, {
includePassword,
});
if (!user) {
throw new NotFoundException('User not found');
@@ -34,14 +43,40 @@ export class UserService {
);
}
if (typeof updateUserDto.pageEditMode !== 'undefined') {
return this.userRepo.updatePreference(
userId,
'pageEditMode',
updateUserDto.pageEditMode.toLowerCase(),
);
}
if (updateUserDto.name) {
user.name = updateUserDto.name;
}
if (updateUserDto.email && user.email != updateUserDto.email) {
if (await this.userRepo.findByEmail(updateUserDto.email, workspaceId)) {
validateSsoEnforcement(workspace);
if (!updateUserDto.confirmPassword) {
throw new BadRequestException(
'You must provide a password to change your email',
);
}
const isPasswordMatch = await comparePasswordHash(
updateUserDto.confirmPassword,
user.password,
);
if (!isPasswordMatch) {
throw new BadRequestException('You must provide the correct password to change your email');
}
if (await this.userRepo.findByEmail(updateUserDto.email, workspace.id)) {
throw new BadRequestException('A user with this email already exists');
}
user.email = updateUserDto.email;
}
@@ -53,7 +88,9 @@ export class UserService {
user.locale = updateUserDto.locale;
}
await this.userRepo.updateUser(updateUserDto, userId, workspaceId);
delete updateUserDto.confirmPassword;
await this.userRepo.updateUser(updateUserDto, userId, workspace.id);
return user;
}
}
@@ -30,7 +30,6 @@ import {
WorkspaceCaslAction,
WorkspaceCaslSubject,
} from '../../casl/interfaces/workspace-ability.type';
import { addDays } from 'date-fns';
import { FastifyReply } from 'fastify';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
import { CheckHostnameDto } from '../dto/check-hostname.dto';
@@ -180,10 +179,13 @@ export class WorkspaceController {
@Public()
@HttpCode(HttpStatus.OK)
@Post('invites/info')
async getInvitationById(@Body() dto: InvitationIdDto, @Req() req: any) {
async getInvitationById(
@Body() dto: InvitationIdDto,
@AuthWorkspace() workspace: Workspace,
) {
return this.workspaceInvitationService.getInvitationById(
dto.invitationId,
req.raw.workspaceId,
workspace,
);
}
@@ -253,20 +255,30 @@ export class WorkspaceController {
@Post('invites/accept')
async acceptInvite(
@Body() acceptInviteDto: AcceptInviteDto,
@Req() req: any,
@AuthWorkspace() workspace: Workspace,
@Res({ passthrough: true }) res: FastifyReply,
) {
const authToken = await this.workspaceInvitationService.acceptInvitation(
const result = await this.workspaceInvitationService.acceptInvitation(
acceptInviteDto,
req.raw.workspaceId,
workspace,
);
res.setCookie('authToken', authToken, {
if (result.requiresLogin) {
return {
requiresLogin: true,
};
}
res.setCookie('authToken', result.authToken, {
httpOnly: true,
path: '/',
expires: addDays(new Date(), 30),
expires: this.environmentService.getCookieExpiresIn(),
secure: this.environmentService.isHttps(),
});
return {
requiresLogin: false,
};
}
@Public()
@@ -14,4 +14,20 @@ export class UpdateWorkspaceDto extends PartialType(CreateWorkspaceDto) {
@IsOptional()
@IsBoolean()
enforceSso: boolean;
@IsOptional()
@IsBoolean()
enforceMfa: boolean;
@IsOptional()
@IsBoolean()
restrictApiToAdmins: boolean;
@IsOptional()
@IsBoolean()
aiSearch: boolean;
@IsOptional()
@IsBoolean()
generativeAi: boolean;
}
@@ -8,6 +8,7 @@ import { AcceptInviteDto, InviteUserDto } from '../dto/invitation.dto';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB } from '@docmost/db/types/kysely.types';
import { sql } from 'kysely';
import { executeTx } from '@docmost/db/utils';
import {
Group,
@@ -22,12 +23,16 @@ import InvitationAcceptedEmail from '@docmost/transactional/emails/invitation-ac
import { TokenService } from '../../auth/services/token.service';
import { nanoIdGen } from '../../../common/helpers';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { DomainService } from 'src/integrations/environment/domain.service';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
import {
validateAllowedEmail,
validateSsoEnforcement,
} from '../../auth/auth.util';
@Injectable()
export class WorkspaceInvitationService {
@@ -51,31 +56,36 @@ export class WorkspaceInvitationService {
if (pagination.query) {
query = query.where((eb) =>
eb('email', 'ilike', `%${pagination.query}%`),
eb(
sql`email`,
'ilike',
sql`f_unaccent(${'%' + pagination.query + '%'})`,
),
);
}
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [{ expression: 'id', direction: 'asc' }],
parseCursor: (cursor) => ({ id: cursor.id }),
});
return result;
}
async getInvitationById(invitationId: string, workspaceId: string) {
async getInvitationById(invitationId: string, workspace: Workspace) {
const invitation = await this.db
.selectFrom('workspaceInvitations')
.select(['id', 'email', 'createdAt'])
.where('id', '=', invitationId)
.where('workspaceId', '=', workspaceId)
.where('workspaceId', '=', workspace.id)
.executeTakeFirst();
if (!invitation) {
throw new NotFoundException('Invitation not found');
}
return invitation;
return { ...invitation, enforceSso: workspace.enforceSso };
}
async getInvitationTokenById(invitationId: string, workspaceId: string) {
@@ -141,6 +151,10 @@ export class WorkspaceInvitationService {
groupIds: validGroups?.map((group: Partial<Group>) => group.id),
}));
if (invitesToInsert.length < 1) {
return;
}
invites = await trx
.insertInto('workspaceInvitations')
.values(invitesToInsert)
@@ -169,12 +183,19 @@ export class WorkspaceInvitationService {
}
}
async acceptInvitation(dto: AcceptInviteDto, workspaceId: string) {
async acceptInvitation(
dto: AcceptInviteDto,
workspace: Workspace,
): Promise<{
authToken?: string;
requiresLogin?: boolean;
message?: string;
}> {
const invitation = await this.db
.selectFrom('workspaceInvitations')
.selectAll()
.where('id', '=', dto.invitationId)
.where('workspaceId', '=', workspaceId)
.where('workspaceId', '=', workspace.id)
.executeTakeFirst();
if (!invitation) {
@@ -185,6 +206,9 @@ export class WorkspaceInvitationService {
throw new BadRequestException('Invalid invitation token');
}
validateSsoEnforcement(workspace);
validateAllowedEmail(invitation.email, workspace);
let newUser: User;
try {
@@ -197,7 +221,7 @@ export class WorkspaceInvitationService {
password: dto.password,
role: invitation.role,
invitedById: invitation.invitedById,
workspaceId: workspaceId,
workspaceId: workspace.id,
},
trx,
);
@@ -205,7 +229,7 @@ export class WorkspaceInvitationService {
// add user to default group
await this.groupUserRepo.addUserToDefaultGroup(
newUser.id,
workspaceId,
workspace.id,
trx,
);
@@ -215,7 +239,7 @@ export class WorkspaceInvitationService {
.selectFrom('groups')
.select(['id', 'name'])
.where('groups.id', 'in', invitation.groupIds)
.where('groups.workspaceId', '=', workspaceId)
.where('groups.workspaceId', '=', workspace.id)
.execute();
if (validGroups && validGroups.length > 0) {
@@ -256,7 +280,7 @@ export class WorkspaceInvitationService {
// notify the inviter
const invitedByUser = await this.userRepo.findById(
invitation.invitedById,
workspaceId,
workspace.id,
);
if (invitedByUser) {
@@ -273,10 +297,19 @@ export class WorkspaceInvitationService {
}
if (this.environmentService.isCloud()) {
await this.billingQueue.add(QueueJob.STRIPE_SEATS_SYNC, { workspaceId });
await this.billingQueue.add(QueueJob.STRIPE_SEATS_SYNC, {
workspaceId: workspace.id,
});
}
return this.tokenService.generateAccessToken(newUser);
if (workspace.enforceMfa) {
return {
requiresLogin: true,
};
}
const authToken = await this.tokenService.generateAccessToken(newUser);
return { authToken };
}
async resendInvitation(
@@ -19,7 +19,6 @@ import { User } from '@docmost/db/types/entity.types';
import { GroupUserRepo } from '@docmost/db/repos/group/group-user.repo';
import { GroupRepo } from '@docmost/db/repos/group/group.repo';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { PaginationResult } from '@docmost/db/pagination/pagination';
import { UpdateWorkspaceUserRoleDto } from '../dto/update-workspace-user-role.dto';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
@@ -28,10 +27,12 @@ import { jsonArrayFrom } from 'kysely/helpers/postgres';
import { addDays } from 'date-fns';
import { DISALLOWED_HOSTNAMES, WorkspaceStatus } from '../workspace.constants';
import { v4 } from 'uuid';
import { AttachmentType } from 'src/core/attachment/attachment.constants';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { generateRandomSuffixNumbers } from '../../../common/helpers';
import { isPageEmbeddingsTableExists } from '@docmost/db/helpers/helpers';
import { CursorPaginationResult } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class WorkspaceService {
@@ -49,6 +50,7 @@ export class WorkspaceService {
@InjectKysely() private readonly db: KyselyDB,
@InjectQueue(QueueName.ATTACHMENT_QUEUE) private attachmentQueue: Queue,
@InjectQueue(QueueName.BILLING_QUEUE) private billingQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
) {}
async findById(workspaceId: string) {
@@ -302,6 +304,60 @@ export class WorkspaceService {
}
}
if (typeof updateWorkspaceDto.restrictApiToAdmins !== 'undefined') {
await this.workspaceRepo.updateApiSettings(
workspaceId,
'restrictToAdmins',
updateWorkspaceDto.restrictApiToAdmins,
);
delete updateWorkspaceDto.restrictApiToAdmins;
}
if (typeof updateWorkspaceDto.aiSearch !== 'undefined') {
await this.workspaceRepo.updateAiSettings(
workspaceId,
'search',
updateWorkspaceDto.aiSearch,
);
if (updateWorkspaceDto.aiSearch) {
const tableExists = await isPageEmbeddingsTableExists(this.db);
if (!tableExists) {
throw new BadRequestException(
'Failed to activate. Make sure pgvector postgres extension is installed.',
);
}
await this.aiQueue.add(QueueJob.WORKSPACE_CREATE_EMBEDDINGS, {
workspaceId,
});
} else {
// Schedule deletion after 24 hours
const deleteJobId = `ai-search-disabled-${workspaceId}`;
await this.aiQueue.add(
QueueJob.WORKSPACE_DELETE_EMBEDDINGS,
{ workspaceId },
{
jobId: deleteJobId,
delay: 24 * 60 * 60 * 1000,
removeOnComplete: true,
removeOnFail: true,
},
);
}
delete updateWorkspaceDto.aiSearch;
}
if (typeof updateWorkspaceDto.generativeAi !== 'undefined') {
await this.workspaceRepo.updateAiSettings(
workspaceId,
'generative',
updateWorkspaceDto.generativeAi,
);
delete updateWorkspaceDto.generativeAi;
}
await this.workspaceRepo.updateWorkspace(updateWorkspaceDto, workspaceId);
const workspace = await this.workspaceRepo.findById(workspaceId, {
@@ -319,13 +375,8 @@ export class WorkspaceService {
async getWorkspaceUsers(
workspaceId: string,
pagination: PaginationOptions,
): Promise<PaginationResult<User>> {
const users = await this.userRepo.getUsersPaginated(
workspaceId,
pagination,
);
return users;
): Promise<CursorPaginationResult<User>> {
return this.userRepo.getUsersPaginated(workspaceId, pagination);
}
async updateWorkspaceUserRole(
@@ -377,24 +428,20 @@ export class WorkspaceService {
name: string,
trx?: KyselyTransaction,
): Promise<string> {
const generateRandomSuffix = (length: number) =>
Math.random()
.toFixed(length)
.substring(2, 2 + length);
let subdomain = name
.toLowerCase()
.replace(/[^a-z0-9]/g, '')
.substring(0, 20);
.replace(/[^a-z0-9-]/g, '')
.substring(0, 20)
.replace(/^-+|-+$/g, ''); //remove any hyphen at the start or end
// Ensure we leave room for a random suffix.
const maxSuffixLength = 6;
if (subdomain.length < 4) {
subdomain = `${subdomain}-${generateRandomSuffix(maxSuffixLength)}`;
subdomain = `${subdomain}-${generateRandomSuffixNumbers(maxSuffixLength)}`;
}
if (DISALLOWED_HOSTNAMES.includes(subdomain)) {
subdomain = `workspace-${generateRandomSuffix(maxSuffixLength)}`;
subdomain = `workspace-${generateRandomSuffixNumbers(maxSuffixLength)}`;
}
let uniqueHostname = subdomain;
@@ -408,7 +455,7 @@ export class WorkspaceService {
break;
}
// Append a random suffix and retry.
const randomSuffix = generateRandomSuffix(maxSuffixLength);
const randomSuffix = generateRandomSuffixNumbers(maxSuffixLength);
uniqueHostname = `${subdomain}-${randomSuffix}`.substring(0, 25);
}
+27 -22
View File
@@ -7,8 +7,7 @@ import {
} from '@nestjs/common';
import { InjectKysely, KyselyModule } from 'nestjs-kysely';
import { EnvironmentService } from '../integrations/environment/environment.service';
import { CamelCasePlugin, LogEvent, PostgresDialect, sql } from 'kysely';
import { Pool, types } from 'pg';
import { CamelCasePlugin, LogEvent, sql } from 'kysely';
import { GroupRepo } from '@docmost/db/repos/group/group.repo';
import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
@@ -25,9 +24,10 @@ import { MigrationService } from '@docmost/db/services/migration.service';
import { UserTokenRepo } from './repos/user-token/user-token.repo';
import { BacklinkRepo } from '@docmost/db/repos/backlink/backlink.repo';
import { ShareRepo } from '@docmost/db/repos/share/share.repo';
// https://github.com/brianc/node-postgres/issues/811
types.setTypeParser(types.builtins.INT8, (val) => Number(val));
import { PageListener } from '@docmost/db/listeners/page.listener';
import { PostgresJSDialect } from 'kysely-postgres-js';
import * as postgres from 'postgres';
import { normalizePostgresUrl } from '../common/helpers';
@Global()
@Module({
@@ -36,26 +36,30 @@ types.setTypeParser(types.builtins.INT8, (val) => Number(val));
imports: [],
inject: [EnvironmentService],
useFactory: (environmentService: EnvironmentService) => ({
dialect: new PostgresDialect({
pool: new Pool({
connectionString: environmentService.getDatabaseURL(),
max: environmentService.getDatabaseMaxPool(),
}).on('error', (err) => {
console.error('Database error:', err.message);
}),
dialect: new PostgresJSDialect({
postgres: postgres(
normalizePostgresUrl(environmentService.getDatabaseURL()),
{
max: environmentService.getDatabaseMaxPool(),
onnotice: () => {},
types: {
bigint: {
to: 20,
from: [20, 1700],
serialize: (value: number) => value.toString(),
parse: (value: string) => Number.parseInt(value),
},
},
},
),
}),
plugins: [new CamelCasePlugin()],
log: (event: LogEvent) => {
if (environmentService.getNodeEnv() !== 'development') return;
const logger = new Logger(DatabaseModule.name);
if (event.level) {
if (process.env.DEBUG_DB?.toLowerCase() === 'true') {
logger.debug(event.query.sql);
logger.debug('query time: ' + event.queryDurationMillis + ' ms');
//if (event.query.parameters.length > 0) {
// logger.debug('parameters: ' + event.query.parameters);
//}
}
if (process.env.DEBUG_DB?.toLowerCase() === 'true') {
logger.debug(event.query.sql);
logger.debug('query time: ' + event.queryDurationMillis + ' ms');
}
},
}),
@@ -75,7 +79,8 @@ types.setTypeParser(types.builtins.INT8, (val) => Number(val));
AttachmentRepo,
UserTokenRepo,
BacklinkRepo,
ShareRepo
ShareRepo,
PageListener,
],
exports: [
WorkspaceRepo,
@@ -90,7 +95,7 @@ types.setTypeParser(types.builtins.INT8, (val) => Number(val));
AttachmentRepo,
UserTokenRepo,
BacklinkRepo,
ShareRepo
ShareRepo,
],
})
export class DatabaseModule
@@ -0,0 +1,22 @@
import { sql } from 'kysely';
import { KyselyDB } from '@docmost/db/types/kysely.types';
export async function isPageEmbeddingsTableExists(db: KyselyDB) {
return tableExists({ db, tableName: 'page_embeddings' });
}
export async function tableExists(opts: {
db: KyselyDB;
tableName: string;
}): Promise<boolean> {
const { db, tableName } = opts;
const result = await sql<{ exists: boolean }>`
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = COALESCE(current_schema(), 'public')
AND table_name = ${tableName}
) as exists
`.execute(db);
return result.rows[0]?.exists ?? false;
}
@@ -0,0 +1,80 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { EventName } from '../../common/events/event.contants';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { EnvironmentService } from '../../integrations/environment/environment.service';
export class PageEvent {
pageIds: string[];
workspaceId: string;
}
@Injectable()
export class PageListener {
private readonly logger = new Logger(PageListener.name);
constructor(
private readonly environmentService: EnvironmentService,
@InjectQueue(QueueName.SEARCH_QUEUE) private searchQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
) {}
@OnEvent(EventName.PAGE_CREATED)
async handlePageCreated(event: PageEvent) {
const { pageIds, workspaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.PAGE_CREATED, {
pageIds,
});
}
await this.aiQueue.add(QueueJob.PAGE_CREATED, { pageIds, workspaceId });
}
@OnEvent(EventName.PAGE_UPDATED)
async handlePageUpdated(event: PageEvent) {
const { pageIds } = event;
await this.searchQueue.add(QueueJob.PAGE_UPDATED, { pageIds });
}
@OnEvent(EventName.PAGE_DELETED)
async handlePageDeleted(event: PageEvent) {
const { pageIds, workspaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.PAGE_DELETED, { pageIds });
}
await this.aiQueue.add(QueueJob.PAGE_DELETED, { pageIds, workspaceId });
}
@OnEvent(EventName.PAGE_SOFT_DELETED)
async handlePageSoftDeleted(event: PageEvent) {
const { pageIds, workspaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.PAGE_SOFT_DELETED, { pageIds });
}
await this.aiQueue.add(QueueJob.PAGE_SOFT_DELETED, {
pageIds,
workspaceId,
});
}
@OnEvent(EventName.PAGE_RESTORED)
async handlePageRestored(event: PageEvent) {
const { pageIds, workspaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.PAGE_RESTORED, { pageIds });
}
await this.aiQueue.add(QueueJob.PAGE_RESTORED, { pageIds, workspaceId });
}
isTypesense(): boolean {
return this.environmentService.getSearchDriver() === 'typesense';
}
}
@@ -0,0 +1,36 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { EventName } from '../../common/events/event.contants';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { EnvironmentService } from '../../integrations/environment/environment.service';
export class SpaceEvent {
spaceId: string;
}
@Injectable()
export class SpaceListener {
private readonly logger = new Logger(SpaceListener.name);
constructor(
private readonly environmentService: EnvironmentService,
@InjectQueue(QueueName.SEARCH_QUEUE) private searchQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
) {}
@OnEvent(EventName.SPACE_DELETED)
async handleSpaceDeleted(event: SpaceEvent) {
const { spaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.SPACE_DELETED, { spaceId });
}
await this.aiQueue.add(QueueJob.SPACE_DELETED, { spaceId });
}
isTypesense(): boolean {
return this.environmentService.getSearchDriver() === 'typesense';
}
}
@@ -0,0 +1,36 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { EventName } from '../../common/events/event.contants';
import { InjectQueue } from '@nestjs/bullmq';
import { QueueJob, QueueName } from '../../integrations/queue/constants';
import { Queue } from 'bullmq';
import { EnvironmentService } from '../../integrations/environment/environment.service';
export class WorkspaceEvent {
workspaceId: string;
}
@Injectable()
export class WorkspaceListener {
private readonly logger = new Logger(WorkspaceListener.name);
constructor(
private readonly environmentService: EnvironmentService,
@InjectQueue(QueueName.SEARCH_QUEUE) private searchQueue: Queue,
@InjectQueue(QueueName.AI_QUEUE) private aiQueue: Queue,
) {}
@OnEvent(EventName.WORKSPACE_DELETED)
async handlePageDeleted(event: WorkspaceEvent) {
const { workspaceId } = event;
if (this.isTypesense()) {
await this.searchQueue.add(QueueJob.WORKSPACE_DELETED, { workspaceId });
}
await this.aiQueue.add(QueueJob.WORKSPACE_DELETED, { workspaceId });
}
isTypesense(): boolean {
return this.environmentService.getSearchDriver() === 'typesense';
}
}
+6 -12
View File
@@ -1,25 +1,19 @@
import * as path from 'path';
import { promises as fs } from 'fs';
import pg from 'pg';
import {
Kysely,
Migrator,
PostgresDialect,
FileMigrationProvider,
} from 'kysely';
import { Kysely, Migrator, FileMigrationProvider } from 'kysely';
import { run } from 'kysely-migration-cli';
import * as dotenv from 'dotenv';
import { envPath } from '../common/helpers/utils';
import { envPath, normalizePostgresUrl } from '../common/helpers';
import { PostgresJSDialect } from 'kysely-postgres-js';
import postgres from 'postgres';
dotenv.config({ path: envPath });
const migrationFolder = path.join(__dirname, './migrations');
const db = new Kysely<any>({
dialect: new PostgresDialect({
pool: new pg.Pool({
connectionString: process.env.DATABASE_URL,
}) as any,
dialect: new PostgresJSDialect({
postgres: postgres(normalizePostgresUrl(process.env.DATABASE_URL)),
}),
});
@@ -0,0 +1,23 @@
import { type Kysely } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('billing')
.addColumn('billing_scheme', 'varchar', (col) => col)
.addColumn('tiered_up_to', 'varchar', (col) => col)
.addColumn('tiered_flat_amount', 'int8', (col) => col)
.addColumn('tiered_unit_amount', 'int8', (col) => col)
.addColumn('plan_name', 'varchar', (col) => col)
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('billing')
.dropColumn('billing_scheme')
.dropColumn('tiered_up_to')
.dropColumn('tiered_flat_amount')
.dropColumn('tiered_unit_amount')
.dropColumn('plan_name')
.execute();
}
@@ -0,0 +1,39 @@
import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.createTable('user_mfa')
.addColumn('id', 'uuid', (col) =>
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
)
.addColumn('user_id', 'uuid', (col) =>
col.references('users.id').onDelete('cascade').notNull(),
)
.addColumn('method', 'varchar', (col) => col.notNull().defaultTo('totp'))
.addColumn('secret', 'text', (col) => col)
.addColumn('is_enabled', 'boolean', (col) => col.defaultTo(false))
.addColumn('backup_codes', sql`text[]`, (col) => col)
.addColumn('workspace_id', 'uuid', (col) =>
col.references('workspaces.id').onDelete('cascade').notNull(),
)
.addColumn('created_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addUniqueConstraint('user_mfa_user_id_unique', ['user_id'])
.execute();
// Add MFA policy columns to workspaces
await db.schema
.alterTable('workspaces')
.addColumn('enforce_mfa', 'boolean', (col) => col.defaultTo(false))
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema.alterTable('workspaces').dropColumn('enforce_mfa').execute();
await db.schema.dropTable('user_mfa').execute();
}
@@ -0,0 +1,61 @@
import { type Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
// Add last_edited_by_id column to comments table
await db.schema
.alterTable('comments')
.addColumn('last_edited_by_id', 'uuid', (col) =>
col.references('users.id').onDelete('set null'),
)
.execute();
// Add resolved_by_id column to comments table
await db.schema
.alterTable('comments')
.addColumn('resolved_by_id', 'uuid', (col) =>
col.references('users.id').onDelete('set null'),
)
.execute();
// Add updated_at timestamp column to comments table
await db.schema
.alterTable('comments')
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.execute();
// Add space_id column to comments table
await db.schema
.alterTable('comments')
.addColumn('space_id', 'uuid', (col) =>
col.references('spaces.id').onDelete('cascade'),
)
.execute();
// Backfill space_id from the related pages
await db
.updateTable('comments as c')
.set((eb) => ({
space_id: eb.ref('p.space_id'),
}))
.from('pages as p')
.whereRef('c.page_id', '=', 'p.id')
.execute();
// Make space_id NOT NULL after populating data
await db.schema
.alterTable('comments')
.alterColumn('space_id', (col) => col.setNotNull())
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('comments')
.dropColumn('last_edited_by_id')
.execute();
await db.schema.alterTable('comments').dropColumn('resolved_by_id').execute();
await db.schema.alterTable('comments').dropColumn('updated_at').execute();
await db.schema.alterTable('comments').dropColumn('space_id').execute();
}
@@ -0,0 +1,50 @@
import { type Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
// Create unaccent extension
await sql`CREATE EXTENSION IF NOT EXISTS unaccent`.execute(db);
// Create pg_trgm extension
await sql`CREATE EXTENSION IF NOT EXISTS pg_trgm`.execute(db);
// Create IMMUTABLE wrapper function for unaccent
// This allows us to create indexes on unaccented columns for better performance
// https://stackoverflow.com/a/11007216/8299075
await sql`
CREATE OR REPLACE FUNCTION f_unaccent(text) RETURNS text
AS $$
SELECT unaccent('unaccent', $1);
$$ LANGUAGE sql IMMUTABLE PARALLEL SAFE STRICT;
`.execute(db);
// Update the pages tsvector trigger to use the immutable function
await sql`
CREATE OR REPLACE FUNCTION pages_tsvector_trigger() RETURNS trigger AS $$
begin
new.tsv :=
setweight(to_tsvector('english', f_unaccent(coalesce(new.title, ''))), 'A') ||
setweight(to_tsvector('english', f_unaccent(substring(coalesce(new.text_content, ''), 1, 1000000))), 'B');
return new;
end;
$$ LANGUAGE plpgsql;
`.execute(db);
}
export async function down(db: Kysely<any>): Promise<void> {
await sql`
CREATE OR REPLACE FUNCTION pages_tsvector_trigger() RETURNS trigger AS $$
begin
new.tsv :=
setweight(to_tsvector('english', coalesce(new.title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(new.text_content, '')), 'B');
return new;
end;
$$ LANGUAGE plpgsql;
`.execute(db);
await sql`DROP FUNCTION IF EXISTS f_unaccent(text)`.execute(db);
await sql`DROP EXTENSION IF EXISTS pg_trgm`.execute(db);
await sql`DROP EXTENSION IF EXISTS unaccent`.execute(db);
}
@@ -0,0 +1,15 @@
import { type Kysely } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('auth_providers')
.addColumn('group_sync', 'boolean', (col) => col.defaultTo(false).notNull())
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('auth_providers')
.dropColumn('group_sync')
.execute();
}
@@ -0,0 +1,68 @@
import { type Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
// switch type to text column since you can't add value to PG types in a transaction
await db.schema
.alterTable('auth_providers')
.alterColumn('type', (col) => col.setDataType('text'))
.execute();
await db.schema.dropType('auth_provider_type').ifExists().execute();
await db.schema
.alterTable('users')
.addColumn('has_generated_password', 'boolean', (col) =>
col.notNull().defaultTo(false).ifNotExists(),
)
.execute();
await db.schema
.alterTable('auth_providers')
.addColumn('ldap_url', 'varchar', (col) => col)
.addColumn('ldap_bind_dn', 'varchar', (col) => col)
.addColumn('ldap_bind_password', 'varchar', (col) => col)
.addColumn('ldap_base_dn', 'varchar', (col) => col)
.addColumn('ldap_user_search_filter', 'varchar', (col) => col)
.addColumn('ldap_user_attributes', 'jsonb', (col) =>
col.defaultTo(sql`'{}'::jsonb`),
)
.addColumn('ldap_tls_enabled', 'boolean', (col) => col.defaultTo(false))
.addColumn('ldap_tls_ca_cert', 'text', (col) => col)
.addColumn('ldap_config', 'jsonb', (col) => col.defaultTo(sql`'{}'::jsonb`))
.addColumn('settings', 'jsonb', (col) => col.defaultTo(sql`'{}'::jsonb`))
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('users')
.dropColumn('has_generated_password')
.execute();
await db.schema
.alterTable('auth_providers')
.dropColumn('ldap_url')
.dropColumn('ldap_bind_dn')
.dropColumn('ldap_bind_password')
.dropColumn('ldap_base_dn')
.dropColumn('ldap_user_search_filter')
.dropColumn('ldap_user_attributes')
.dropColumn('ldap_tls_enabled')
.dropColumn('ldap_tls_ca_cert')
.dropColumn('ldap_config')
.dropColumn('settings')
.execute();
await db.schema
.createType('auth_provider_type')
.asEnum(['saml', 'oidc', 'google'])
.execute();
await db.deleteFrom('auth_providers').where('type', '=', 'ldap').execute();
await sql`
ALTER TABLE auth_providers
ALTER COLUMN type TYPE auth_provider_type
USING type::auth_provider_type
`.execute(db);
}
@@ -0,0 +1,29 @@
import { type Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('attachments')
.addColumn('text_content', 'text', (col) => col)
.addColumn('tsv', sql`tsvector`, (col) => col)
.execute();
await db.schema
.createIndex('attachments_tsv_idx')
.on('attachments')
.using('GIN')
.column('tsv')
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema
.alterTable('attachments')
.dropIndex('attachments_tsv_idx')
.execute();
await db.schema
.alterTable('attachments')
.dropColumn('text_content')
.dropColumn('tsv')
.execute();
}
@@ -0,0 +1,30 @@
import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.createTable('api_keys')
.addColumn('id', 'uuid', (col) =>
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
)
.addColumn('name', 'text', (col) => col)
.addColumn('creator_id', 'uuid', (col) =>
col.notNull().references('users.id').onDelete('cascade'),
)
.addColumn('workspace_id', 'uuid', (col) =>
col.notNull().references('workspaces.id').onDelete('cascade'),
)
.addColumn('expires_at', 'timestamptz')
.addColumn('last_used_at', 'timestamptz')
.addColumn('created_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('deleted_at', 'timestamptz', (col) => col)
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema.dropTable('api_keys').execute();
}
@@ -0,0 +1,348 @@
// adapted from https://github.com/charlie-hadden/kysely-paginate/blob/main/src/cursor.ts - MIT
import {
OrderByDirection,
OrderByModifiers,
ReferenceExpression,
SelectQueryBuilder,
StringReference,
} from 'kysely';
type SortField<DB, TB extends keyof DB, O> =
| {
expression:
| (StringReference<DB, TB> & keyof O & string)
| (StringReference<DB, TB> & `${string}.${keyof O & string}`);
direction: OrderByDirection;
orderModifier?: OrderByModifiers;
key?: keyof O & string;
}
| {
expression: ReferenceExpression<DB, TB>;
direction: OrderByDirection;
orderModifier?: OrderByModifiers;
key: keyof O & string;
};
type ExtractSortFieldKey<
DB,
TB extends keyof DB,
O,
T extends SortField<DB, TB, O>,
> = T['key'] extends keyof O & string
? T['key']
: T['expression'] extends keyof O & string
? T['expression']
: T['expression'] extends `${string}.${infer K}`
? K extends keyof O & string
? K
: never
: never;
type Fields<DB, TB extends keyof DB, O> = ReadonlyArray<
Readonly<SortField<DB, TB, O>>
>;
type FieldNames<DB, TB extends keyof DB, O, T extends Fields<DB, TB, O>> = {
[TIndex in keyof T]: ExtractSortFieldKey<DB, TB, O, T[TIndex]>;
};
type EncodeCursorValues<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
> = {
[TIndex in keyof T]: [
ExtractSortFieldKey<DB, TB, O, T[TIndex]>,
O[ExtractSortFieldKey<DB, TB, O, T[TIndex]>],
];
};
export type CursorEncoder<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
> = (values: EncodeCursorValues<DB, TB, O, T>) => string;
type DecodedCursor<DB, TB extends keyof DB, O, T extends Fields<DB, TB, O>> = {
[TField in ExtractSortFieldKey<DB, TB, O, T[number]>]: string;
};
export type CursorDecoder<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
> = (
cursor: string,
fields: FieldNames<DB, TB, O, T>,
) => DecodedCursor<DB, TB, O, T>;
type ParsedCursorValues<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
> = {
[TField in ExtractSortFieldKey<DB, TB, O, T[number]>]: O[TField];
};
export type CursorParser<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
> = (cursor: DecodedCursor<DB, TB, O, T>) => ParsedCursorValues<DB, TB, O, T>;
type CursorPaginationResultRow<
TRow,
TCursorKey extends string | boolean | undefined,
> = TRow & {
[K in TCursorKey extends undefined
? never
: TCursorKey extends false
? never
: TCursorKey extends true
? '$cursor'
: TCursorKey]: string;
};
type CursorPaginationMeta = {
limit: number;
hasNextPage: boolean;
hasPrevPage: boolean;
nextCursor: string | null;
prevCursor: string | null;
};
export type CursorPaginationResult<
TRow,
TCursorKey extends string | boolean | undefined = undefined,
> = {
meta: CursorPaginationMeta;
items: CursorPaginationResultRow<TRow, TCursorKey>[];
};
export async function executeWithCursorPagination<
DB,
TB extends keyof DB,
O,
const TFields extends Fields<DB, TB, O>,
TCursorKey extends string | boolean | undefined = undefined,
>(
qb: SelectQueryBuilder<DB, TB, O>,
opts: {
perPage: number;
cursor?: string;
beforeCursor?: string;
cursorPerRow?: TCursorKey;
fields: TFields;
encodeCursor?: CursorEncoder<DB, TB, O, TFields>;
decodeCursor?: CursorDecoder<DB, TB, O, TFields>;
parseCursor:
| CursorParser<DB, TB, O, TFields>
| { parse: CursorParser<DB, TB, O, TFields> };
},
): Promise<CursorPaginationResult<O, TCursorKey>> {
const encodeCursor = opts.encodeCursor ?? defaultEncodeCursor;
const decodeCursor = opts.decodeCursor ?? defaultDecodeCursor;
const parseCursor =
typeof opts.parseCursor === 'function'
? opts.parseCursor
: opts.parseCursor.parse;
const fields = opts.fields.map((field) => {
let key = field.key;
if (!key && typeof field.expression === 'string') {
const expressionParts = field.expression.split('.');
key = (expressionParts[1] ?? expressionParts[0]) as
| (keyof O & string)
| undefined;
}
if (!key) throw new Error('missing key');
return { ...field, key };
});
function generateCursor(row: O): string {
const cursorFieldValues = fields.map(({ key }) => [
key,
row[key],
]) as EncodeCursorValues<DB, TB, O, TFields>;
return encodeCursor(cursorFieldValues);
}
const fieldNames = fields.map((field) => field.key) as FieldNames<
DB,
TB,
O,
TFields
>;
function applyCursor(
qb: SelectQueryBuilder<DB, TB, O>,
encoded: string,
defaultDirection: OrderByDirection,
) {
const decoded = decodeCursor(encoded, fieldNames);
const cursor = parseCursor(decoded);
return qb.where(({ and, or, eb }) => {
let expression;
for (let i = fields.length - 1; i >= 0; --i) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const field = fields[i]!;
const comparison = field.direction === defaultDirection ? '>' : '<';
const value = cursor[field.key as keyof typeof cursor];
const conditions = [eb(field.expression, comparison, value)];
if (expression) {
conditions.push(and([eb(field.expression, '=', value), expression]));
}
expression = or(conditions);
}
if (!expression) {
throw new Error('Error building cursor expression');
}
return expression;
});
}
if (opts.cursor) qb = applyCursor(qb, opts.cursor, 'asc');
if (opts.beforeCursor) qb = applyCursor(qb, opts.beforeCursor, 'desc');
const reversed = !!opts.beforeCursor && !opts.cursor;
for (const { expression, direction, orderModifier } of fields) {
qb = qb.orderBy(
expression,
orderModifier ??
(reversed ? (direction === 'asc' ? 'desc' : 'asc') : direction),
);
}
const rows = await qb.limit(opts.perPage + 1).execute();
const hasNextPage = rows.length > opts.perPage;
// If we fetched an extra row to determine if we have a next page, that
// shouldn't be in the returned results
if (rows.length > opts.perPage) rows.pop();
if (reversed) rows.reverse();
const startRow = rows[0];
const endRow = rows[rows.length - 1];
const hasPrevPage = !!opts.cursor;
const prevCursor = hasPrevPage && startRow ? generateCursor(startRow) : null;
const nextCursor = hasNextPage && endRow ? generateCursor(endRow) : null;
return {
items: rows.map((row) => {
if (opts.cursorPerRow) {
const cursorKey =
typeof opts.cursorPerRow === 'string' ? opts.cursorPerRow : '$cursor';
(row as any)[cursorKey] = generateCursor(row);
}
return row as CursorPaginationResultRow<O, TCursorKey>;
}),
meta: {
limit: opts.perPage,
hasNextPage,
hasPrevPage,
nextCursor,
prevCursor,
},
};
}
export function defaultEncodeCursor<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
>(values: EncodeCursorValues<DB, TB, O, T>) {
const cursor = new URLSearchParams();
for (const [key, value] of values) {
switch (typeof value) {
case 'string':
cursor.set(key, value);
break;
case 'number':
case 'bigint':
cursor.set(key, value.toString(10));
break;
case 'object': {
if (value instanceof Date) {
cursor.set(key, value.toISOString());
break;
}
}
// eslint-disable-next-line no-fallthrough
default:
throw new Error(`Unable to encode '${key.toString()}'`);
}
}
return Buffer.from(cursor.toString(), 'utf8').toString('base64url');
}
export function defaultDecodeCursor<
DB,
TB extends keyof DB,
O,
T extends Fields<DB, TB, O>,
>(
cursor: string,
fields: FieldNames<DB, TB, O, T>,
): DecodedCursor<DB, TB, O, T> {
let parsed;
try {
parsed = [
...new URLSearchParams(
Buffer.from(cursor, 'base64url').toString('utf8'),
).entries(),
];
} catch {
throw new Error('Unparsable cursor');
}
if (parsed.length !== fields.length) {
throw new Error('Unexpected number of fields');
}
for (let i = 0; i < fields.length; i++) {
const field = parsed[i];
const expectedName = fields[i];
if (!field) {
throw new Error('Unable to find field');
}
if (field[0] !== expectedName) {
throw new Error('Unexpected field name');
}
}
return Object.fromEntries(parsed) as DecodedCursor<DB, TB, O, T>;
}
@@ -1,4 +1,5 @@
import {
IsBoolean,
IsNumber,
IsOptional,
IsPositive,
@@ -8,11 +9,6 @@ import {
} from 'class-validator';
export class PaginationOptions {
@IsOptional()
@IsNumber()
@Min(1)
page = 1;
@IsOptional()
@IsNumber()
@IsPositive()
@@ -20,7 +16,19 @@ export class PaginationOptions {
@Max(100)
limit = 20;
@IsOptional()
@IsString()
cursor?: string;
@IsOptional()
@IsString()
beforeCursor?: string;
@IsOptional()
@IsString()
query: string;
@IsOptional()
@IsBoolean()
adminView: boolean;
}
@@ -12,6 +12,23 @@ import {
export class AttachmentRepo {
constructor(@InjectKysely() private readonly db: KyselyDB) {}
private baseFields: Array<keyof Attachment> = [
'id',
'fileName',
'filePath',
'fileSize',
'fileExt',
'mimeType',
'type',
'creatorId',
'pageId',
'spaceId',
'workspaceId',
'createdAt',
'updatedAt',
'deletedAt',
];
async findById(
attachmentId: string,
opts?: {
@@ -22,7 +39,7 @@ export class AttachmentRepo {
return db
.selectFrom('attachments')
.selectAll()
.select(this.baseFields)
.where('id', '=', attachmentId)
.executeTakeFirst();
}
@@ -36,7 +53,7 @@ export class AttachmentRepo {
return db
.insertInto('attachments')
.values(insertableAttachment)
.returningAll()
.returning(this.baseFields)
.executeTakeFirst();
}
@@ -50,7 +67,7 @@ export class AttachmentRepo {
return db
.selectFrom('attachments')
.selectAll()
.select(this.baseFields)
.where('spaceId', '=', spaceId)
.execute();
}
@@ -64,6 +81,7 @@ export class AttachmentRepo {
.updateTable('attachments')
.set(updatableAttachment)
.where('pageId', 'in', pageIds)
.returning(this.baseFields)
.executeTakeFirst();
}
@@ -75,7 +93,7 @@ export class AttachmentRepo {
.updateTable('attachments')
.set(updatableAttachment)
.where('id', '=', attachmentId)
.returningAll()
.returning(this.baseFields)
.executeTakeFirst();
}
@@ -8,7 +8,7 @@ import {
UpdatableComment,
} from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { ExpressionBuilder } from 'kysely';
import { DB } from '@docmost/db/types/db';
import { jsonObjectFrom } from 'kysely/helpers/postgres';
@@ -20,12 +20,13 @@ export class CommentRepo {
// todo, add workspaceId
async findById(
commentId: string,
opts?: { includeCreator: boolean },
opts?: { includeCreator: boolean; includeResolvedBy: boolean },
): Promise<Comment> {
return await this.db
.selectFrom('comments')
.selectAll('comments')
.$if(opts?.includeCreator, (qb) => qb.select(this.withCreator))
.$if(opts?.includeResolvedBy, (qb) => qb.select(this.withResolvedBy))
.where('id', '=', commentId)
.executeTakeFirst();
}
@@ -35,15 +36,16 @@ export class CommentRepo {
.selectFrom('comments')
.selectAll('comments')
.select((eb) => this.withCreator(eb))
.where('pageId', '=', pageId)
.orderBy('createdAt', 'asc');
.select((eb) => this.withResolvedBy(eb))
.where('pageId', '=', pageId);
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [{ expression: 'id', direction: 'asc' }],
parseCursor: (cursor) => ({ id: cursor.id }),
});
return result;
}
async updateComment(
@@ -80,7 +82,37 @@ export class CommentRepo {
).as('creator');
}
withResolvedBy(eb: ExpressionBuilder<DB, 'comments'>) {
return jsonObjectFrom(
eb
.selectFrom('users')
.select(['users.id', 'users.name', 'users.avatarUrl'])
.whereRef('users.id', '=', 'comments.resolvedById'),
).as('resolvedBy');
}
async deleteComment(commentId: string): Promise<void> {
await this.db.deleteFrom('comments').where('id', '=', commentId).execute();
}
async hasChildren(commentId: string): Promise<boolean> {
const result = await this.db
.selectFrom('comments')
.select((eb) => eb.fn.count('id').as('count'))
.where('parentCommentId', '=', commentId)
.executeTakeFirst();
return Number(result?.count) > 0;
}
async hasChildrenFromOtherUsers(commentId: string, userId: string): Promise<boolean> {
const result = await this.db
.selectFrom('comments')
.select((eb) => eb.fn.count('id').as('count'))
.where('parentCommentId', '=', commentId)
.where('creatorId', '!=', userId)
.executeTakeFirst();
return Number(result?.count) > 0;
}
}
@@ -6,9 +6,10 @@ import {
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB, KyselyTransaction } from '@docmost/db/types/kysely.types';
import { dbOrTx, executeTx } from '@docmost/db/utils';
import { sql } from 'kysely';
import { GroupUser, InsertableGroupUser } from '@docmost/db/types/entity.types';
import { PaginationOptions } from '../../pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { GroupRepo } from '@docmost/db/repos/group/group.repo';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
@@ -51,18 +52,20 @@ export class GroupUserRepo {
.selectFrom('groupUsers')
.innerJoin('users', 'users.id', 'groupUsers.userId')
.selectAll('users')
.where('groupId', '=', groupId)
.orderBy('createdAt', 'asc');
.where('groupId', '=', groupId);
if (pagination.query) {
query = query.where((eb) =>
eb('users.name', 'ilike', `%${pagination.query}%`),
eb(sql`f_unaccent(users.name)`, 'ilike', sql`f_unaccent(${'%' + pagination.query + '%'})`),
);
}
const result = await executeWithPagination(query, {
page: pagination.page,
const result = await executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [{ expression: 'users.id', direction: 'asc', key: 'id' }],
parseCursor: (cursor) => ({ id: cursor.id }),
});
result.items.map((user) => {
@@ -10,8 +10,8 @@ import {
import { ExpressionBuilder, sql } from 'kysely';
import { PaginationOptions } from '../../pagination/pagination-options';
import { DB } from '@docmost/db/types/db';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { DefaultGroup } from '../../../core/group/dto/create-group.dto';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
@Injectable()
export class GroupRepo {
@@ -107,30 +107,44 @@ export class GroupRepo {
}
async getGroupsPaginated(workspaceId: string, pagination: PaginationOptions) {
let query = this.db
let baseQuery = this.db
.selectFrom('groups')
.selectAll('groups')
.select((eb) => this.withMemberCount(eb))
.where('workspaceId', '=', workspaceId)
.orderBy('memberCount', 'desc')
.orderBy('createdAt', 'asc');
.where('workspaceId', '=', workspaceId);
if (pagination.query) {
query = query.where((eb) =>
eb('name', 'ilike', `%${pagination.query}%`).or(
'description',
baseQuery = baseQuery.where((eb) =>
eb(
sql`f_unaccent(name)`,
'ilike',
`%${pagination.query}%`,
sql`f_unaccent(${'%' + pagination.query + '%'})`,
).or(
sql`f_unaccent(description)`,
'ilike',
sql`f_unaccent(${'%' + pagination.query + '%'})`,
),
);
}
const result = executeWithPagination(query, {
page: pagination.page,
const query = this.db.selectFrom(baseQuery.as('sub')).selectAll('sub');
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{
expression: 'sub.memberCount',
direction: 'desc',
key: 'memberCount',
},
{ expression: 'sub.id', direction: 'asc', key: 'id' },
],
parseCursor: (cursor) => ({
memberCount: parseInt(cursor.memberCount, 10),
id: cursor.id,
}),
});
return result;
}
withMemberCount(eb: ExpressionBuilder<DB, 'groups'>) {
@@ -8,7 +8,7 @@ import {
PageHistory,
} from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { jsonObjectFrom } from 'kysely/helpers/postgres';
import { ExpressionBuilder } from 'kysely';
import { DB } from '@docmost/db/types/db';
@@ -65,15 +65,15 @@ export class PageHistoryRepo {
.selectFrom('pageHistory')
.selectAll()
.select((eb) => this.withLastUpdatedBy(eb))
.where('pageId', '=', pageId)
.orderBy('createdAt', 'desc');
.where('pageId', '=', pageId);
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [{ expression: 'id', direction: 'desc' }],
parseCursor: (cursor) => ({ id: cursor.id }),
});
return result;
}
async findPageLastHistory(pageId: string, trx?: KyselyTransaction) {
+240 -20
View File
@@ -1,25 +1,28 @@
import { Injectable } from '@nestjs/common';
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB, KyselyTransaction } from '../../types/kysely.types';
import { dbOrTx } from '../../utils';
import { dbOrTx, executeTx } from '../../utils';
import {
InsertablePage,
Page,
UpdatablePage,
} from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { validate as isValidUUID } from 'uuid';
import { ExpressionBuilder, sql } from 'kysely';
import { DB } from '@docmost/db/types/db';
import { jsonArrayFrom, jsonObjectFrom } from 'kysely/helpers/postgres';
import { SpaceMemberRepo } from '@docmost/db/repos/space/space-member.repo';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { EventName } from '../../../common/events/event.contants';
@Injectable()
export class PageRepo {
constructor(
@InjectKysely() private readonly db: KyselyDB,
private spaceMemberRepo: SpaceMemberRepo,
private eventEmitter: EventEmitter2,
) {}
private baseFields: Array<keyof Page> = [
@@ -45,11 +48,13 @@ export class PageRepo {
pageId: string,
opts?: {
includeContent?: boolean;
includeTextContent?: boolean;
includeYdoc?: boolean;
includeSpace?: boolean;
includeCreator?: boolean;
includeLastUpdatedBy?: boolean;
includeContributors?: boolean;
includeHasChildren?: boolean;
withLock?: boolean;
trx?: KyselyTransaction;
},
@@ -60,7 +65,11 @@ export class PageRepo {
.selectFrom('pages')
.select(this.baseFields)
.$if(opts?.includeContent, (qb) => qb.select('content'))
.$if(opts?.includeYdoc, (qb) => qb.select('ydoc'));
.$if(opts?.includeYdoc, (qb) => qb.select('ydoc'))
.$if(opts?.includeTextContent, (qb) => qb.select('textContent'))
.$if(opts?.includeHasChildren, (qb) =>
qb.select((eb) => this.withHasChildren(eb)),
);
if (opts?.includeCreator) {
query = query.select((eb) => this.withCreator(eb));
@@ -104,7 +113,7 @@ export class PageRepo {
pageIds: string[],
trx?: KyselyTransaction,
) {
return dbOrTx(this.db, trx)
const result = await dbOrTx(this.db, trx)
.updateTable('pages')
.set({ ...updatePageData, updatedAt: new Date() })
.where(
@@ -113,6 +122,13 @@ export class PageRepo {
pageIds,
)
.executeTakeFirst();
this.eventEmitter.emit(EventName.PAGE_UPDATED, {
pageIds: pageIds,
workspaceId: updatePageData.workspaceId,
});
return result;
}
async insertPage(
@@ -120,11 +136,18 @@ export class PageRepo {
trx?: KyselyTransaction,
): Promise<Page> {
const db = dbOrTx(this.db, trx);
return db
const result = await db
.insertInto('pages')
.values(insertablePage)
.returning(this.baseFields)
.executeTakeFirst();
this.eventEmitter.emit(EventName.PAGE_CREATED, {
pageIds: [result.id],
workspaceId: result.workspaceId,
});
return result;
}
async deletePage(pageId: string): Promise<void> {
@@ -139,40 +162,204 @@ export class PageRepo {
await query.execute();
}
async removePage(
pageId: string,
deletedById: string,
workspaceId: string,
): Promise<void> {
const currentDate = new Date();
const descendants = await this.db
.withRecursive('page_descendants', (db) =>
db
.selectFrom('pages')
.select(['id'])
.where('id', '=', pageId)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
.select(['p.id'])
.innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId'),
),
)
.selectFrom('page_descendants')
.selectAll()
.execute();
const pageIds = descendants.map((d) => d.id);
if (pageIds.length > 0) {
await executeTx(this.db, async (trx) => {
await trx
.updateTable('pages')
.set({
deletedById: deletedById,
deletedAt: currentDate,
})
.where('id', 'in', pageIds)
.execute();
await trx.deleteFrom('shares').where('pageId', 'in', pageIds).execute();
});
this.eventEmitter.emit(EventName.PAGE_SOFT_DELETED, {
pageIds: pageIds,
workspaceId,
});
}
}
async restorePage(pageId: string, workspaceId: string): Promise<void> {
// First, check if the page being restored has a deleted parent
const pageToRestore = await this.db
.selectFrom('pages')
.select(['id', 'parentPageId'])
.where('id', '=', pageId)
.executeTakeFirst();
if (!pageToRestore) {
return;
}
// Check if the parent is also deleted
let shouldDetachFromParent = false;
if (pageToRestore.parentPageId) {
const parent = await this.db
.selectFrom('pages')
.select(['id', 'deletedAt'])
.where('id', '=', pageToRestore.parentPageId)
.executeTakeFirst();
// If parent is deleted, we should detach this page from it
shouldDetachFromParent = parent?.deletedAt !== null;
}
// Find all descendants to restore
const pages = await this.db
.withRecursive('page_descendants', (db) =>
db
.selectFrom('pages')
.select(['id'])
.where('id', '=', pageId)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
.select(['p.id'])
.innerJoin('page_descendants as pd', 'pd.id', 'p.parentPageId'),
),
)
.selectFrom('page_descendants')
.selectAll()
.execute();
const pageIds = pages.map((p) => p.id);
// Restore all pages, but only detach the root page if its parent is deleted
await this.db
.updateTable('pages')
.set({ deletedById: null, deletedAt: null })
.where('id', 'in', pageIds)
.execute();
// If we need to detach the restored page from its deleted parent
if (shouldDetachFromParent) {
await this.db
.updateTable('pages')
.set({ parentPageId: null })
.where('id', '=', pageId)
.execute();
}
this.eventEmitter.emit(EventName.PAGE_RESTORED, {
pageIds: pageIds,
workspaceId: workspaceId,
});
}
async getRecentPagesInSpace(spaceId: string, pagination: PaginationOptions) {
const query = this.db
.selectFrom('pages')
.select(this.baseFields)
.select((eb) => this.withSpace(eb))
.where('spaceId', '=', spaceId)
.orderBy('updatedAt', 'desc');
.where('deletedAt', 'is', null);
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'updatedAt', direction: 'desc' },
{ expression: 'id', direction: 'desc' },
],
parseCursor: (cursor) => ({
updatedAt: new Date(cursor.updatedAt),
id: cursor.id,
}),
});
return result;
}
async getRecentPages(userId: string, pagination: PaginationOptions) {
const userSpaceIds = await this.spaceMemberRepo.getUserSpaceIds(userId);
const query = this.db
.selectFrom('pages')
.select(this.baseFields)
.select((eb) => this.withSpace(eb))
.where('spaceId', 'in', userSpaceIds)
.orderBy('updatedAt', 'desc');
.where('spaceId', 'in', this.spaceMemberRepo.getUserSpaceIdsQuery(userId))
.where('deletedAt', 'is', null);
const hasEmptyIds = userSpaceIds.length === 0;
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
hasEmptyIds,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'updatedAt', direction: 'desc' },
{ expression: 'id', direction: 'desc' },
],
parseCursor: (cursor) => ({
updatedAt: new Date(cursor.updatedAt),
id: cursor.id,
}),
});
}
return result;
async getDeletedPagesInSpace(spaceId: string, pagination: PaginationOptions) {
const query = this.db
.selectFrom('pages')
.select(this.baseFields)
.select('content')
.select((eb) => this.withSpace(eb))
.select((eb) => this.withDeletedBy(eb))
.where('spaceId', '=', spaceId)
.where('deletedAt', 'is not', null)
// Only include pages that are either root pages (no parent) or whose parent is not deleted
// This prevents showing orphaned pages when their parent has been soft-deleted
.where((eb) =>
eb.or([
eb('parentPageId', 'is', null),
eb.not(
eb.exists(
eb
.selectFrom('pages as parent')
.select('parent.id')
.where('parent.id', '=', eb.ref('pages.parentPageId'))
.where('parent.deletedAt', 'is not', null),
),
),
]),
);
return executeWithCursorPagination(query, {
perPage: pagination.limit,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'deletedAt', direction: 'desc' },
{ expression: 'id', direction: 'desc' },
],
parseCursor: (cursor) => ({
deletedAt: new Date(cursor.deletedAt),
id: cursor.id,
}),
});
}
withSpace(eb: ExpressionBuilder<DB, 'pages'>) {
@@ -202,6 +389,15 @@ export class PageRepo {
).as('lastUpdatedBy');
}
withDeletedBy(eb: ExpressionBuilder<DB, 'pages'>) {
return jsonObjectFrom(
eb
.selectFrom('users')
.select(['users.id', 'users.name', 'users.avatarUrl'])
.whereRef('users.id', '=', 'pages.deletedById'),
).as('deletedBy');
}
withContributors(eb: ExpressionBuilder<DB, 'pages'>) {
return jsonArrayFrom(
eb
@@ -211,6 +407,24 @@ export class PageRepo {
).as('contributors');
}
withHasChildren(eb: ExpressionBuilder<DB, 'pages'>) {
return eb
.selectFrom('pages as child')
.select((eb) =>
eb
.case()
.when(eb.fn.countAll(), '>', 0)
.then(true)
.else(false)
.end()
.as('count'),
)
.whereRef('child.parentPageId', '=', 'pages.id')
.where('child.deletedAt', 'is', null)
.limit(1)
.as('hasChildren');
}
async getPageAndDescendants(
parentPageId: string,
opts: { includeContent: boolean },
@@ -228,9 +442,12 @@ export class PageRepo {
'parentPageId',
'spaceId',
'workspaceId',
'createdAt',
'updatedAt',
])
.$if(opts?.includeContent, (qb) => qb.select('content'))
.where('id', '=', parentPageId)
.where('deletedAt', 'is', null)
.unionAll((exp) =>
exp
.selectFrom('pages as p')
@@ -243,9 +460,12 @@ export class PageRepo {
'p.parentPageId',
'p.spaceId',
'p.workspaceId',
'p.createdAt',
'p.updatedAt',
])
.$if(opts?.includeContent, (qb) => qb.select('p.content'))
.innerJoin('page_hierarchy as ph', 'p.parentPageId', 'ph.id'),
.innerJoin('page_hierarchy as ph', 'p.parentPageId', 'ph.id')
.where('p.deletedAt', 'is', null),
),
)
.selectFrom('page_hierarchy')
@@ -8,7 +8,7 @@ import {
UpdatableShare,
} from '@docmost/db/types/entity.types';
import { PaginationOptions } from '@docmost/db/pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { executeWithCursorPagination } from '@docmost/db/pagination/cursor-pagination';
import { validate as isValidUUID } from 'uuid';
import { ExpressionBuilder, sql } from 'kysely';
import { DB } from '@docmost/db/types/db';
@@ -137,25 +137,27 @@ export class ShareRepo {
}
async getShares(userId: string, pagination: PaginationOptions) {
const userSpaceIds = await this.spaceMemberRepo.getUserSpaceIds(userId);
const query = this.db
.selectFrom('shares')
.select(this.baseFields)
.select((eb) => this.withPage(eb))
.select((eb) => this.withSpace(eb, userId))
.select((eb) => this.withCreator(eb))
.where('spaceId', 'in', userSpaceIds)
.orderBy('updatedAt', 'desc');
.where('spaceId', 'in', this.spaceMemberRepo.getUserSpaceIdsQuery(userId));
const hasEmptyIds = userSpaceIds.length === 0;
const result = executeWithPagination(query, {
page: pagination.page,
return executeWithCursorPagination(query, {
perPage: pagination.limit,
hasEmptyIds,
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'updatedAt', direction: 'desc' },
{ expression: 'id', direction: 'desc' },
],
parseCursor: (cursor) => ({
updatedAt: new Date(cursor.updatedAt),
id: cursor.id,
}),
});
return result;
}
withPage(eb: ExpressionBuilder<DB, 'shares'>) {

Some files were not shown because too many files have changed in this diff Show More