Sidebar tree socket permissions

This commit is contained in:
Philipinho
2026-02-11 12:58:09 -08:00
parent dacc26dea7
commit 046d9a31f1
8 changed files with 338 additions and 28 deletions
@@ -32,6 +32,8 @@ import {
CursorPaginationResult,
emptyCursorPaginationResult,
} from '@docmost/db/pagination/cursor-pagination';
import { WsService } from '../../../ws/ws.service';
import { WsTreeService } from '../../../ws/ws-tree.service';
export type PageRestrictionInfo = {
id: string;
@@ -51,6 +53,8 @@ export class PagePermissionService {
private pagePermissionRepo: PagePermissionRepo,
private pageRepo: PageRepo,
private spaceAbility: SpaceAbilityFactory,
private wsService: WsService,
private wsTreeService: WsTreeService,
@InjectKysely() private readonly db: KyselyDB,
) {}
@@ -95,6 +99,9 @@ export class PagePermissionService {
trx,
);
});
await this.wsService.invalidateSpaceRestrictionCache(page.spaceId);
await this.wsTreeService.notifyPageRestricted(page, authUser.id);
}
async addPagePermissions(
@@ -181,6 +188,23 @@ export class PagePermissionService {
if (permissionsToAdd.length > 0) {
await this.pagePermissionRepo.insertPagePermissions(permissionsToAdd);
const notifyUserIds = validUsers.map((u) => u.id);
if (validGroups.length > 0) {
const groupMembers = await this.db
.selectFrom('groupUsers')
.select('userId')
.where(
'groupId',
'in',
validGroups.map((g) => g.id),
)
.execute();
notifyUserIds.push(...groupMembers.map((m) => m.userId));
}
await this.wsTreeService.notifyPermissionGranted(page, notifyUserIds);
}
}
@@ -314,6 +338,8 @@ export class PagePermissionService {
}
await this.pagePermissionRepo.deletePageAccess(pageId);
await this.wsService.invalidateSpaceRestrictionCache(page.spaceId);
}
async getPagePermissions(
@@ -217,7 +217,11 @@ export class PageService {
cursor: pagination.cursor,
beforeCursor: pagination.beforeCursor,
fields: [
{ expression: 'position', direction: 'asc', orderModifier: (ob) => ob.collate('C').asc() },
{
expression: 'position',
direction: 'asc',
orderModifier: (ob) => ob.collate('C').asc(),
},
{ expression: 'id', direction: 'asc' },
],
parseCursor: (cursor) => ({
@@ -296,13 +300,19 @@ export class PageService {
// Find inaccessible pages whose parent is being moved - these need to be orphaned
const pagesToOrphan = allPages.filter(
(p) => !accessibleIds.has(p.id) && p.parentPageId && accessibleIds.has(p.parentPageId),
(p) =>
!accessibleIds.has(p.id) &&
p.parentPageId &&
accessibleIds.has(p.parentPageId),
);
await executeTx(this.db, async (trx) => {
// Orphan inaccessible child pages (make them root pages in original space)
for (const page of pagesToOrphan) {
const orphanPosition = await this.nextPagePosition(rootPage.spaceId, null);
const orphanPosition = await this.nextPagePosition(
rootPage.spaceId,
null,
);
await this.pageRepo.updatePage(
{ parentPageId: null, position: orphanPosition },
page.id,
@@ -689,7 +699,10 @@ export class PageService {
userId: string,
pagination: PaginationOptions,
): Promise<CursorPaginationResult<Page>> {
const result = await this.pageRepo.getRecentPagesInSpace(spaceId, pagination);
const result = await this.pageRepo.getRecentPagesInSpace(
spaceId,
pagination,
);
if (result.items.length > 0) {
const pageIds = result.items.map((p) => p.id);
@@ -814,7 +827,9 @@ export class PageService {
* 2. Its parent is also included (or it's the root page)
* This ensures that if a middle page is inaccessible, its entire subtree is excluded.
*/
private async filterAccessibleTreePages<T extends { id: string; parentPageId: string | null }>(
private async filterAccessibleTreePages<
T extends { id: string; parentPageId: string | null },
>(
pages: T[],
rootPageId: string,
userId: string,
@@ -823,12 +838,13 @@ export class PageService {
if (pages.length === 0) return [];
const pageIds = pages.map((p) => p.id);
const accessibleIds =
await this.pagePermissionRepo.filterAccessiblePageIds({
const accessibleIds = await this.pagePermissionRepo.filterAccessiblePageIds(
{
pageIds,
userId,
spaceId,
});
},
);
const accessibleSet = new Set(accessibleIds);
// Prune: include a page only if it's accessible AND its parent chain to root is included
@@ -984,6 +984,47 @@ export class PagePermissionRepo {
return results.map((r) => r.descendantId);
}
/**
* Given a pageId and a set of candidate userIds, return the subset who can
* access the page (have permission on ALL restricted ancestors).
* Returns all userIds if the page has no restricted ancestors.
*/
async getUserIdsWithPageAccess(
pageId: string,
userIds: string[],
): Promise<string[]> {
if (userIds.length === 0) return [];
const results = await sql<{ userId: string }>`
WITH RECURSIVE ancestors AS (
SELECT id AS ancestor_id, parent_page_id
FROM pages
WHERE id = ${pageId}::uuid
UNION ALL
SELECT p.id, p.parent_page_id
FROM pages p
JOIN ancestors a ON a.parent_page_id = p.id
)
SELECT cu.user_id AS "userId"
FROM unnest(${userIds}::uuid[]) AS cu(user_id)
WHERE NOT EXISTS (
SELECT 1
FROM ancestors a
JOIN page_access pa ON pa.page_id = a.ancestor_id
LEFT JOIN page_permissions pp ON pp.page_access_id = pa.id
AND (
pp.user_id = cu.user_id
OR pp.group_id IN (
SELECT gu.group_id FROM group_users gu WHERE gu.user_id = cu.user_id
)
)
WHERE pp.id IS NULL
)
`.execute(this.db);
return results.rows.map((r) => r.userId);
}
private userGroupIdsSubquery(
eb: ExpressionBuilder<any, keyof DB>,
userId: string,
+47
View File
@@ -0,0 +1,47 @@
import { Injectable } from '@nestjs/common';
import { Page } from '@docmost/db/types/entity.types';
import { WsService } from './ws.service';
@Injectable()
export class WsTreeService {
constructor(private readonly wsService: WsService) {}
async notifyPageRestricted(page: Page, excludeUserId: string): Promise<void> {
await this.wsService.emitToSpaceExceptUsers(page.spaceId, [excludeUserId], {
operation: 'deleteTreeNode',
spaceId: page.spaceId,
payload: {
node: {
id: page.id,
slugId: page.slugId,
},
},
});
}
async notifyPermissionGranted(page: Page, userIds: string[]): Promise<void> {
if (userIds.length === 0) return;
await this.wsService.emitToUsers(userIds, {
operation: 'addTreeNode',
spaceId: page.spaceId,
payload: {
parentId: page.parentPageId ?? null,
index: 0,
data: {
id: page.id,
slugId: page.slugId,
name: page.title ?? '',
title: page.title,
icon: page.icon,
position: page.position,
spaceId: page.spaceId,
parentPageId: page.parentPageId,
creatorId: page.creatorId,
hasChildren: false,
children: [],
},
},
});
}
}
+20 -18
View File
@@ -1,6 +1,7 @@
import {
MessageBody,
OnGatewayConnection,
OnGatewayInit,
SubscribeMessage,
WebSocketGateway,
WebSocketServer,
@@ -10,20 +11,30 @@ import { TokenService } from '../core/auth/services/token.service';
import { JwtPayload, JwtType } from '../core/auth/dto/jwt-payload';
import { OnModuleDestroy } from '@nestjs/common';
import { SpaceMemberRepo } from '@docmost/db/repos/space/space-member.repo';
import { WsService } from './ws.service';
import { getSpaceRoomName, getUserRoomName } from './ws.utils';
import * as cookie from 'cookie';
@WebSocketGateway({
cors: { origin: '*' },
transports: ['websocket'],
})
export class WsGateway implements OnGatewayConnection, OnModuleDestroy {
export class WsGateway
implements OnGatewayConnection, OnGatewayInit, OnModuleDestroy
{
@WebSocketServer()
server: Server;
constructor(
private tokenService: TokenService,
private spaceMemberRepo: SpaceMemberRepo,
private wsService: WsService,
) {}
afterInit(server: Server): void {
this.wsService.setServer(server);
}
async handleConnection(client: Socket, ...args: any[]): Promise<void> {
try {
const cookies = cookie.parse(client.handshake.headers.cookie);
@@ -35,12 +46,15 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy {
const userId = token.sub;
const workspaceId = token.workspaceId;
client.data.userId = userId;
const userSpaceIds = await this.spaceMemberRepo.getUserSpaceIds(userId);
const userRoom = getUserRoomName(userId);
const workspaceRoom = `workspace-${workspaceId}`;
const spaceRooms = userSpaceIds.map((id) => this.getSpaceRoomName(id));
const spaceRooms = userSpaceIds.map((id) => getSpaceRoomName(id));
client.join([workspaceRoom, ...spaceRooms]);
client.join([userRoom, workspaceRoom, ...spaceRooms]);
} catch (err) {
client.emit('Unauthorized');
client.disconnect();
@@ -48,17 +62,9 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy {
}
@SubscribeMessage('message')
handleMessage(client: Socket, data: any): void {
const spaceEvents = [
'updateOne',
'addTreeNode',
'moveTreeNode',
'deleteTreeNode',
];
if (spaceEvents.includes(data?.operation) && data?.spaceId) {
const room = this.getSpaceRoomName(data.spaceId);
client.broadcast.to(room).emit('message', data);
async handleMessage(client: Socket, data: any): Promise<void> {
if (this.wsService.isTreeEvent(data)) {
await this.wsService.handleTreeEvent(client, data);
return;
}
@@ -81,8 +87,4 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy {
this.server.close();
}
}
getSpaceRoomName(spaceId: string): string {
return `space-${spaceId}`;
}
}
+6 -2
View File
@@ -1,9 +1,13 @@
import { Module } from '@nestjs/common';
import { Global, Module } from '@nestjs/common';
import { WsGateway } from './ws.gateway';
import { WsService } from './ws.service';
import { WsTreeService } from './ws-tree.service';
import { TokenModule } from '../core/auth/token.module';
@Global()
@Module({
imports: [TokenModule],
providers: [WsGateway],
providers: [WsGateway, WsService, WsTreeService],
exports: [WsService, WsTreeService],
})
export class WsModule {}
+157
View File
@@ -0,0 +1,157 @@
import { Inject, Injectable } from '@nestjs/common';
import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { Cache } from 'cache-manager';
import { Server, Socket } from 'socket.io';
import { PagePermissionRepo } from '@docmost/db/repos/page/page-permission.repo';
import {
TREE_EVENTS,
WS_SPACE_RESTRICTION_CACHE_PREFIX,
WS_CACHE_TTL_MS,
getSpaceRoomName,
getUserRoomName,
} from './ws.utils';
@Injectable()
export class WsService {
private server: Server;
constructor(
private readonly pagePermissionRepo: PagePermissionRepo,
@Inject(CACHE_MANAGER) private readonly cacheManager: Cache,
) {}
setServer(server: Server): void {
this.server = server;
}
async handleTreeEvent(client: Socket, data: any): Promise<void> {
const room = getSpaceRoomName(data.spaceId);
const hasRestrictions = await this.spaceHasRestrictions(data.spaceId);
if (!hasRestrictions) {
client.broadcast.to(room).emit('message', data);
return;
}
const pageId = this.extractPageId(data);
if (!pageId) {
client.broadcast.to(room).emit('message', data);
return;
}
const isRestricted =
await this.pagePermissionRepo.hasRestrictedAncestor(pageId);
if (!isRestricted) {
client.broadcast.to(room).emit('message', data);
return;
}
await this.broadcastToAuthorizedUsers(client, room, pageId, data);
}
async invalidateSpaceRestrictionCache(spaceId: string): Promise<void> {
await this.cacheManager.del(
`${WS_SPACE_RESTRICTION_CACHE_PREFIX}${spaceId}`,
);
}
async emitToUsers(userIds: string[], data: any): Promise<void> {
if (userIds.length === 0) return;
const rooms = userIds.map((id) => getUserRoomName(id));
this.server.to(rooms).emit('message', data);
}
async emitToSpaceExceptUsers(
spaceId: string,
excludeUserIds: string[],
data: any,
): Promise<void> {
const room = getSpaceRoomName(spaceId);
const sockets = await this.server.in(room).fetchSockets();
const excludeSet = new Set(excludeUserIds);
for (const socket of sockets) {
const userId = socket.data.userId as string;
if (userId && !excludeSet.has(userId)) {
socket.emit('message', data);
}
}
}
isTreeEvent(data: any): boolean {
return TREE_EVENTS.has(data?.operation) && !!data?.spaceId;
}
private async broadcastToAuthorizedUsers(
sender: Socket,
room: string,
pageId: string,
data: any,
): Promise<void> {
const sockets = await this.server.in(room).fetchSockets();
const otherSockets = sockets.filter((s) => s.id !== sender.id);
if (otherSockets.length === 0) return;
const userSocketMap = new Map<string, typeof otherSockets>();
for (const socket of otherSockets) {
const userId = socket.data.userId as string;
if (!userId) continue;
const existing = userSocketMap.get(userId);
if (existing) {
existing.push(socket);
} else {
userSocketMap.set(userId, [socket]);
}
}
const candidateUserIds = Array.from(userSocketMap.keys());
if (candidateUserIds.length === 0) return;
const authorizedUserIds =
await this.pagePermissionRepo.getUserIdsWithPageAccess(
pageId,
candidateUserIds,
);
const authorizedSet = new Set(authorizedUserIds);
for (const [userId, userSockets] of userSocketMap) {
if (authorizedSet.has(userId)) {
for (const socket of userSockets) {
socket.emit('message', data);
}
}
}
}
private async spaceHasRestrictions(spaceId: string): Promise<boolean> {
const cacheKey = `${WS_SPACE_RESTRICTION_CACHE_PREFIX}${spaceId}`;
const cached = await this.cacheManager.get<boolean>(cacheKey);
if (cached !== undefined && cached !== null) {
return cached;
}
const hasRestrictions =
await this.pagePermissionRepo.hasRestrictedPagesInSpace(spaceId);
await this.cacheManager.set(cacheKey, hasRestrictions, WS_CACHE_TTL_MS);
return hasRestrictions;
}
private extractPageId(data: any): string | null {
switch (data.operation) {
case 'addTreeNode':
return data.payload?.data?.id ?? null;
case 'moveTreeNode':
return data.payload?.id ?? null;
case 'deleteTreeNode':
return data.payload?.node?.id ?? null;
case 'updateOne':
return data.id ?? null;
default:
return null;
}
}
}
+17
View File
@@ -0,0 +1,17 @@
export const WS_CACHE_TTL_MS = 30_000;
export const WS_SPACE_RESTRICTION_CACHE_PREFIX = 'ws:space-restrictions:';
export function getSpaceRoomName(spaceId: string): string {
return `space-${spaceId}`;
}
export function getUserRoomName(userId: string): string {
return `user-${userId}`;
}
export const TREE_EVENTS = new Set([
'updateOne',
'addTreeNode',
'moveTreeNode',
'deleteTreeNode',
]);