From 046d9a31f1d667543b87d1768a67b85d791d4cad Mon Sep 17 00:00:00 2001 From: Philipinho <16838612+Philipinho@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:58:09 -0800 Subject: [PATCH] Sidebar tree socket permissions --- .../page/services/page-permission.service.ts | 26 +++ .../src/core/page/services/page.service.ts | 32 +++- .../repos/page/page-permission.repo.ts | 41 +++++ apps/server/src/ws/ws-tree.service.ts | 47 ++++++ apps/server/src/ws/ws.gateway.ts | 38 +++-- apps/server/src/ws/ws.module.ts | 8 +- apps/server/src/ws/ws.service.ts | 157 ++++++++++++++++++ apps/server/src/ws/ws.utils.ts | 17 ++ 8 files changed, 338 insertions(+), 28 deletions(-) create mode 100644 apps/server/src/ws/ws-tree.service.ts create mode 100644 apps/server/src/ws/ws.service.ts create mode 100644 apps/server/src/ws/ws.utils.ts diff --git a/apps/server/src/core/page/services/page-permission.service.ts b/apps/server/src/core/page/services/page-permission.service.ts index 03722c29..3b35a2fb 100644 --- a/apps/server/src/core/page/services/page-permission.service.ts +++ b/apps/server/src/core/page/services/page-permission.service.ts @@ -32,6 +32,8 @@ import { CursorPaginationResult, emptyCursorPaginationResult, } from '@docmost/db/pagination/cursor-pagination'; +import { WsService } from '../../../ws/ws.service'; +import { WsTreeService } from '../../../ws/ws-tree.service'; export type PageRestrictionInfo = { id: string; @@ -51,6 +53,8 @@ export class PagePermissionService { private pagePermissionRepo: PagePermissionRepo, private pageRepo: PageRepo, private spaceAbility: SpaceAbilityFactory, + private wsService: WsService, + private wsTreeService: WsTreeService, @InjectKysely() private readonly db: KyselyDB, ) {} @@ -95,6 +99,9 @@ export class PagePermissionService { trx, ); }); + + await this.wsService.invalidateSpaceRestrictionCache(page.spaceId); + await this.wsTreeService.notifyPageRestricted(page, authUser.id); } async addPagePermissions( @@ -181,6 +188,23 @@ export class PagePermissionService { if (permissionsToAdd.length > 0) { await this.pagePermissionRepo.insertPagePermissions(permissionsToAdd); + + const notifyUserIds = validUsers.map((u) => u.id); + + if (validGroups.length > 0) { + const groupMembers = await this.db + .selectFrom('groupUsers') + .select('userId') + .where( + 'groupId', + 'in', + validGroups.map((g) => g.id), + ) + .execute(); + notifyUserIds.push(...groupMembers.map((m) => m.userId)); + } + + await this.wsTreeService.notifyPermissionGranted(page, notifyUserIds); } } @@ -314,6 +338,8 @@ export class PagePermissionService { } await this.pagePermissionRepo.deletePageAccess(pageId); + + await this.wsService.invalidateSpaceRestrictionCache(page.spaceId); } async getPagePermissions( diff --git a/apps/server/src/core/page/services/page.service.ts b/apps/server/src/core/page/services/page.service.ts index 85e3ae7f..f97dc63e 100644 --- a/apps/server/src/core/page/services/page.service.ts +++ b/apps/server/src/core/page/services/page.service.ts @@ -217,7 +217,11 @@ export class PageService { cursor: pagination.cursor, beforeCursor: pagination.beforeCursor, fields: [ - { expression: 'position', direction: 'asc', orderModifier: (ob) => ob.collate('C').asc() }, + { + expression: 'position', + direction: 'asc', + orderModifier: (ob) => ob.collate('C').asc(), + }, { expression: 'id', direction: 'asc' }, ], parseCursor: (cursor) => ({ @@ -296,13 +300,19 @@ export class PageService { // Find inaccessible pages whose parent is being moved - these need to be orphaned const pagesToOrphan = allPages.filter( - (p) => !accessibleIds.has(p.id) && p.parentPageId && accessibleIds.has(p.parentPageId), + (p) => + !accessibleIds.has(p.id) && + p.parentPageId && + accessibleIds.has(p.parentPageId), ); await executeTx(this.db, async (trx) => { // Orphan inaccessible child pages (make them root pages in original space) for (const page of pagesToOrphan) { - const orphanPosition = await this.nextPagePosition(rootPage.spaceId, null); + const orphanPosition = await this.nextPagePosition( + rootPage.spaceId, + null, + ); await this.pageRepo.updatePage( { parentPageId: null, position: orphanPosition }, page.id, @@ -689,7 +699,10 @@ export class PageService { userId: string, pagination: PaginationOptions, ): Promise> { - const result = await this.pageRepo.getRecentPagesInSpace(spaceId, pagination); + const result = await this.pageRepo.getRecentPagesInSpace( + spaceId, + pagination, + ); if (result.items.length > 0) { const pageIds = result.items.map((p) => p.id); @@ -814,7 +827,9 @@ export class PageService { * 2. Its parent is also included (or it's the root page) * This ensures that if a middle page is inaccessible, its entire subtree is excluded. */ - private async filterAccessibleTreePages( + private async filterAccessibleTreePages< + T extends { id: string; parentPageId: string | null }, + >( pages: T[], rootPageId: string, userId: string, @@ -823,12 +838,13 @@ export class PageService { if (pages.length === 0) return []; const pageIds = pages.map((p) => p.id); - const accessibleIds = - await this.pagePermissionRepo.filterAccessiblePageIds({ + const accessibleIds = await this.pagePermissionRepo.filterAccessiblePageIds( + { pageIds, userId, spaceId, - }); + }, + ); const accessibleSet = new Set(accessibleIds); // Prune: include a page only if it's accessible AND its parent chain to root is included diff --git a/apps/server/src/database/repos/page/page-permission.repo.ts b/apps/server/src/database/repos/page/page-permission.repo.ts index 6d5f308c..f4fa5ecb 100644 --- a/apps/server/src/database/repos/page/page-permission.repo.ts +++ b/apps/server/src/database/repos/page/page-permission.repo.ts @@ -984,6 +984,47 @@ export class PagePermissionRepo { return results.map((r) => r.descendantId); } + /** + * Given a pageId and a set of candidate userIds, return the subset who can + * access the page (have permission on ALL restricted ancestors). + * Returns all userIds if the page has no restricted ancestors. + */ + async getUserIdsWithPageAccess( + pageId: string, + userIds: string[], + ): Promise { + if (userIds.length === 0) return []; + + const results = await sql<{ userId: string }>` + WITH RECURSIVE ancestors AS ( + SELECT id AS ancestor_id, parent_page_id + FROM pages + WHERE id = ${pageId}::uuid + UNION ALL + SELECT p.id, p.parent_page_id + FROM pages p + JOIN ancestors a ON a.parent_page_id = p.id + ) + SELECT cu.user_id AS "userId" + FROM unnest(${userIds}::uuid[]) AS cu(user_id) + WHERE NOT EXISTS ( + SELECT 1 + FROM ancestors a + JOIN page_access pa ON pa.page_id = a.ancestor_id + LEFT JOIN page_permissions pp ON pp.page_access_id = pa.id + AND ( + pp.user_id = cu.user_id + OR pp.group_id IN ( + SELECT gu.group_id FROM group_users gu WHERE gu.user_id = cu.user_id + ) + ) + WHERE pp.id IS NULL + ) + `.execute(this.db); + + return results.rows.map((r) => r.userId); + } + private userGroupIdsSubquery( eb: ExpressionBuilder, userId: string, diff --git a/apps/server/src/ws/ws-tree.service.ts b/apps/server/src/ws/ws-tree.service.ts new file mode 100644 index 00000000..8aadfa99 --- /dev/null +++ b/apps/server/src/ws/ws-tree.service.ts @@ -0,0 +1,47 @@ +import { Injectable } from '@nestjs/common'; +import { Page } from '@docmost/db/types/entity.types'; +import { WsService } from './ws.service'; + +@Injectable() +export class WsTreeService { + constructor(private readonly wsService: WsService) {} + + async notifyPageRestricted(page: Page, excludeUserId: string): Promise { + await this.wsService.emitToSpaceExceptUsers(page.spaceId, [excludeUserId], { + operation: 'deleteTreeNode', + spaceId: page.spaceId, + payload: { + node: { + id: page.id, + slugId: page.slugId, + }, + }, + }); + } + + async notifyPermissionGranted(page: Page, userIds: string[]): Promise { + if (userIds.length === 0) return; + + await this.wsService.emitToUsers(userIds, { + operation: 'addTreeNode', + spaceId: page.spaceId, + payload: { + parentId: page.parentPageId ?? null, + index: 0, + data: { + id: page.id, + slugId: page.slugId, + name: page.title ?? '', + title: page.title, + icon: page.icon, + position: page.position, + spaceId: page.spaceId, + parentPageId: page.parentPageId, + creatorId: page.creatorId, + hasChildren: false, + children: [], + }, + }, + }); + } +} diff --git a/apps/server/src/ws/ws.gateway.ts b/apps/server/src/ws/ws.gateway.ts index eeaec897..231dfade 100644 --- a/apps/server/src/ws/ws.gateway.ts +++ b/apps/server/src/ws/ws.gateway.ts @@ -1,6 +1,7 @@ import { MessageBody, OnGatewayConnection, + OnGatewayInit, SubscribeMessage, WebSocketGateway, WebSocketServer, @@ -10,20 +11,30 @@ import { TokenService } from '../core/auth/services/token.service'; import { JwtPayload, JwtType } from '../core/auth/dto/jwt-payload'; import { OnModuleDestroy } from '@nestjs/common'; import { SpaceMemberRepo } from '@docmost/db/repos/space/space-member.repo'; +import { WsService } from './ws.service'; +import { getSpaceRoomName, getUserRoomName } from './ws.utils'; import * as cookie from 'cookie'; @WebSocketGateway({ cors: { origin: '*' }, transports: ['websocket'], }) -export class WsGateway implements OnGatewayConnection, OnModuleDestroy { +export class WsGateway + implements OnGatewayConnection, OnGatewayInit, OnModuleDestroy +{ @WebSocketServer() server: Server; + constructor( private tokenService: TokenService, private spaceMemberRepo: SpaceMemberRepo, + private wsService: WsService, ) {} + afterInit(server: Server): void { + this.wsService.setServer(server); + } + async handleConnection(client: Socket, ...args: any[]): Promise { try { const cookies = cookie.parse(client.handshake.headers.cookie); @@ -35,12 +46,15 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { const userId = token.sub; const workspaceId = token.workspaceId; + client.data.userId = userId; + const userSpaceIds = await this.spaceMemberRepo.getUserSpaceIds(userId); + const userRoom = getUserRoomName(userId); const workspaceRoom = `workspace-${workspaceId}`; - const spaceRooms = userSpaceIds.map((id) => this.getSpaceRoomName(id)); + const spaceRooms = userSpaceIds.map((id) => getSpaceRoomName(id)); - client.join([workspaceRoom, ...spaceRooms]); + client.join([userRoom, workspaceRoom, ...spaceRooms]); } catch (err) { client.emit('Unauthorized'); client.disconnect(); @@ -48,17 +62,9 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { } @SubscribeMessage('message') - handleMessage(client: Socket, data: any): void { - const spaceEvents = [ - 'updateOne', - 'addTreeNode', - 'moveTreeNode', - 'deleteTreeNode', - ]; - - if (spaceEvents.includes(data?.operation) && data?.spaceId) { - const room = this.getSpaceRoomName(data.spaceId); - client.broadcast.to(room).emit('message', data); + async handleMessage(client: Socket, data: any): Promise { + if (this.wsService.isTreeEvent(data)) { + await this.wsService.handleTreeEvent(client, data); return; } @@ -81,8 +87,4 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { this.server.close(); } } - - getSpaceRoomName(spaceId: string): string { - return `space-${spaceId}`; - } } diff --git a/apps/server/src/ws/ws.module.ts b/apps/server/src/ws/ws.module.ts index aa2d9b7c..b6fcd66a 100644 --- a/apps/server/src/ws/ws.module.ts +++ b/apps/server/src/ws/ws.module.ts @@ -1,9 +1,13 @@ -import { Module } from '@nestjs/common'; +import { Global, Module } from '@nestjs/common'; import { WsGateway } from './ws.gateway'; +import { WsService } from './ws.service'; +import { WsTreeService } from './ws-tree.service'; import { TokenModule } from '../core/auth/token.module'; +@Global() @Module({ imports: [TokenModule], - providers: [WsGateway], + providers: [WsGateway, WsService, WsTreeService], + exports: [WsService, WsTreeService], }) export class WsModule {} diff --git a/apps/server/src/ws/ws.service.ts b/apps/server/src/ws/ws.service.ts new file mode 100644 index 00000000..e2bf4807 --- /dev/null +++ b/apps/server/src/ws/ws.service.ts @@ -0,0 +1,157 @@ +import { Inject, Injectable } from '@nestjs/common'; +import { CACHE_MANAGER } from '@nestjs/cache-manager'; +import { Cache } from 'cache-manager'; +import { Server, Socket } from 'socket.io'; +import { PagePermissionRepo } from '@docmost/db/repos/page/page-permission.repo'; +import { + TREE_EVENTS, + WS_SPACE_RESTRICTION_CACHE_PREFIX, + WS_CACHE_TTL_MS, + getSpaceRoomName, + getUserRoomName, +} from './ws.utils'; + +@Injectable() +export class WsService { + private server: Server; + + constructor( + private readonly pagePermissionRepo: PagePermissionRepo, + @Inject(CACHE_MANAGER) private readonly cacheManager: Cache, + ) {} + + setServer(server: Server): void { + this.server = server; + } + + async handleTreeEvent(client: Socket, data: any): Promise { + const room = getSpaceRoomName(data.spaceId); + + const hasRestrictions = await this.spaceHasRestrictions(data.spaceId); + if (!hasRestrictions) { + client.broadcast.to(room).emit('message', data); + return; + } + + const pageId = this.extractPageId(data); + if (!pageId) { + client.broadcast.to(room).emit('message', data); + return; + } + + const isRestricted = + await this.pagePermissionRepo.hasRestrictedAncestor(pageId); + if (!isRestricted) { + client.broadcast.to(room).emit('message', data); + return; + } + + await this.broadcastToAuthorizedUsers(client, room, pageId, data); + } + + async invalidateSpaceRestrictionCache(spaceId: string): Promise { + await this.cacheManager.del( + `${WS_SPACE_RESTRICTION_CACHE_PREFIX}${spaceId}`, + ); + } + + async emitToUsers(userIds: string[], data: any): Promise { + if (userIds.length === 0) return; + const rooms = userIds.map((id) => getUserRoomName(id)); + this.server.to(rooms).emit('message', data); + } + + async emitToSpaceExceptUsers( + spaceId: string, + excludeUserIds: string[], + data: any, + ): Promise { + const room = getSpaceRoomName(spaceId); + const sockets = await this.server.in(room).fetchSockets(); + const excludeSet = new Set(excludeUserIds); + + for (const socket of sockets) { + const userId = socket.data.userId as string; + if (userId && !excludeSet.has(userId)) { + socket.emit('message', data); + } + } + } + + isTreeEvent(data: any): boolean { + return TREE_EVENTS.has(data?.operation) && !!data?.spaceId; + } + + private async broadcastToAuthorizedUsers( + sender: Socket, + room: string, + pageId: string, + data: any, + ): Promise { + const sockets = await this.server.in(room).fetchSockets(); + + const otherSockets = sockets.filter((s) => s.id !== sender.id); + if (otherSockets.length === 0) return; + + const userSocketMap = new Map(); + for (const socket of otherSockets) { + const userId = socket.data.userId as string; + if (!userId) continue; + const existing = userSocketMap.get(userId); + if (existing) { + existing.push(socket); + } else { + userSocketMap.set(userId, [socket]); + } + } + + const candidateUserIds = Array.from(userSocketMap.keys()); + if (candidateUserIds.length === 0) return; + + const authorizedUserIds = + await this.pagePermissionRepo.getUserIdsWithPageAccess( + pageId, + candidateUserIds, + ); + + const authorizedSet = new Set(authorizedUserIds); + for (const [userId, userSockets] of userSocketMap) { + if (authorizedSet.has(userId)) { + for (const socket of userSockets) { + socket.emit('message', data); + } + } + } + } + + private async spaceHasRestrictions(spaceId: string): Promise { + const cacheKey = `${WS_SPACE_RESTRICTION_CACHE_PREFIX}${spaceId}`; + + const cached = await this.cacheManager.get(cacheKey); + if (cached !== undefined && cached !== null) { + return cached; + } + + const hasRestrictions = + await this.pagePermissionRepo.hasRestrictedPagesInSpace(spaceId); + + await this.cacheManager.set(cacheKey, hasRestrictions, WS_CACHE_TTL_MS); + + return hasRestrictions; + } + + private extractPageId(data: any): string | null { + switch (data.operation) { + case 'addTreeNode': + return data.payload?.data?.id ?? null; + case 'moveTreeNode': + return data.payload?.id ?? null; + case 'deleteTreeNode': + return data.payload?.node?.id ?? null; + case 'updateOne': + return data.id ?? null; + default: + return null; + } + } +} diff --git a/apps/server/src/ws/ws.utils.ts b/apps/server/src/ws/ws.utils.ts new file mode 100644 index 00000000..0cf460f1 --- /dev/null +++ b/apps/server/src/ws/ws.utils.ts @@ -0,0 +1,17 @@ +export const WS_CACHE_TTL_MS = 30_000; +export const WS_SPACE_RESTRICTION_CACHE_PREFIX = 'ws:space-restrictions:'; + +export function getSpaceRoomName(spaceId: string): string { + return `space-${spaceId}`; +} + +export function getUserRoomName(userId: string): string { + return `user-${userId}`; +} + +export const TREE_EVENTS = new Set([ + 'updateOne', + 'addTreeNode', + 'moveTreeNode', + 'deleteTreeNode', +]);