diff --git a/apps/server/src/ws/services/excalidraw-collab.service.ts b/apps/server/src/ws/services/excalidraw-collab.service.ts new file mode 100644 index 00000000..12924296 --- /dev/null +++ b/apps/server/src/ws/services/excalidraw-collab.service.ts @@ -0,0 +1,91 @@ +import { Injectable } from '@nestjs/common'; +import { Server, Socket } from 'socket.io'; +import { ExcalidrawFollowPayload } from '../types/excalidraw.types'; + +@Injectable() +export class ExcalidrawCollabService { + async handleJoinRoom( + client: Socket, + server: Server, + roomId: string, + ): Promise { + await client.join(roomId); + + const sockets = await server.in(roomId).fetchSockets(); + + if (sockets.length <= 1) { + server.to(client.id).emit('first-in-room'); + } else { + client.broadcast.to(roomId).emit('new-user', client.id); + } + + server.in(roomId).emit( + 'room-user-change', + sockets.map((socket) => socket.id), + ); + } + + handleServerBroadcast( + client: Socket, + roomId: string, + encryptedData: ArrayBuffer, + iv: Uint8Array, + ): void { + client.broadcast.to(roomId).emit('client-broadcast', encryptedData, iv); + } + + handleServerVolatileBroadcast( + client: Socket, + roomId: string, + encryptedData: ArrayBuffer, + iv: Uint8Array, + ): void { + client.volatile.broadcast + .to(roomId) + .emit('client-broadcast', encryptedData, iv); + } + + async handleUserFollow( + client: Socket, + server: Server, + payload: ExcalidrawFollowPayload, + ): Promise { + const roomId = `follow@${payload.userToFollow.socketId}`; + + if (payload.action === 'FOLLOW') { + await client.join(roomId); + } else { + await client.leave(roomId); + } + + const sockets = await server.in(roomId).fetchSockets(); + const followedBy = sockets.map((socket) => socket.id); + + server.to(payload.userToFollow.socketId).emit( + 'user-follow-room-change', + followedBy, + ); + } + + async handleDisconnecting(client: Socket, server: Server): Promise { + for (const roomId of Array.from(client.rooms)) { + const otherClients = (await server.in(roomId).fetchSockets()).filter( + (socket) => socket.id !== client.id, + ); + + const isFollowRoom = roomId.startsWith('follow@'); + + if (!isFollowRoom && otherClients.length > 0) { + client.broadcast.to(roomId).emit( + 'room-user-change', + otherClients.map((socket) => socket.id), + ); + } + + if (isFollowRoom && otherClients.length === 0) { + const socketId = roomId.replace('follow@', ''); + server.to(socketId).emit('broadcast-unfollow'); + } + } + } +} diff --git a/apps/server/src/ws/types/excalidraw.types.ts b/apps/server/src/ws/types/excalidraw.types.ts new file mode 100644 index 00000000..75a6d5bf --- /dev/null +++ b/apps/server/src/ws/types/excalidraw.types.ts @@ -0,0 +1,9 @@ +export type ExcalidrawUserToFollow = { + socketId: string; + username: string; +}; + +export type ExcalidrawFollowPayload = { + userToFollow: ExcalidrawUserToFollow; + action: 'FOLLOW' | 'UNFOLLOW'; +}; diff --git a/apps/server/src/ws/ws.gateway.ts b/apps/server/src/ws/ws.gateway.ts index eeaec897..d08d8cee 100644 --- a/apps/server/src/ws/ws.gateway.ts +++ b/apps/server/src/ws/ws.gateway.ts @@ -1,6 +1,7 @@ import { MessageBody, OnGatewayConnection, + OnGatewayDisconnect, SubscribeMessage, WebSocketGateway, WebSocketServer, @@ -11,17 +12,23 @@ import { JwtPayload, JwtType } from '../core/auth/dto/jwt-payload'; import { OnModuleDestroy } from '@nestjs/common'; import { SpaceMemberRepo } from '@docmost/db/repos/space/space-member.repo'; import * as cookie from 'cookie'; +import { ExcalidrawCollabService } from './services/excalidraw-collab.service'; +import { ExcalidrawFollowPayload } from './types/excalidraw.types'; @WebSocketGateway({ cors: { origin: '*' }, transports: ['websocket'], }) -export class WsGateway implements OnGatewayConnection, OnModuleDestroy { +export class WsGateway + implements OnGatewayConnection, OnGatewayDisconnect, OnModuleDestroy +{ @WebSocketServer() server: Server; + constructor( private tokenService: TokenService, private spaceMemberRepo: SpaceMemberRepo, + private excalidrawCollabService: ExcalidrawCollabService, ) {} async handleConnection(client: Socket, ...args: any[]): Promise { @@ -41,6 +48,8 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { const spaceRooms = userSpaceIds.map((id) => this.getSpaceRoomName(id)); client.join([workspaceRoom, ...spaceRooms]); + + this.server.to(client.id).emit('init-room'); } catch (err) { client.emit('Unauthorized'); client.disconnect(); @@ -66,9 +75,15 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { } @SubscribeMessage('join-room') - handleJoinRoom(client: Socket, @MessageBody() roomName: string): void { - // if room is a space, check if user has permissions - //client.join(roomName); + async handleJoinRoom( + client: Socket, + @MessageBody() roomId: string, + ): Promise { + await this.excalidrawCollabService.handleJoinRoom( + client, + this.server, + roomId, + ); } @SubscribeMessage('leave-room') @@ -76,6 +91,48 @@ export class WsGateway implements OnGatewayConnection, OnModuleDestroy { client.leave(roomName); } + @SubscribeMessage('server-broadcast') + handleServerBroadcast( + client: Socket, + [roomId, encryptedData, iv]: [string, ArrayBuffer, Uint8Array], + ): void { + this.excalidrawCollabService.handleServerBroadcast( + client, + roomId, + encryptedData, + iv, + ); + } + + @SubscribeMessage('server-volatile-broadcast') + handleServerVolatileBroadcast( + client: Socket, + [roomId, encryptedData, iv]: [string, ArrayBuffer, Uint8Array], + ): void { + this.excalidrawCollabService.handleServerVolatileBroadcast( + client, + roomId, + encryptedData, + iv, + ); + } + + @SubscribeMessage('user-follow') + async handleUserFollow( + client: Socket, + @MessageBody() payload: ExcalidrawFollowPayload, + ): Promise { + await this.excalidrawCollabService.handleUserFollow( + client, + this.server, + payload, + ); + } + + async handleDisconnect(client: Socket): Promise { + await this.excalidrawCollabService.handleDisconnecting(client, this.server); + } + onModuleDestroy() { if (this.server) { this.server.close(); diff --git a/apps/server/src/ws/ws.module.ts b/apps/server/src/ws/ws.module.ts index aa2d9b7c..d57fc95f 100644 --- a/apps/server/src/ws/ws.module.ts +++ b/apps/server/src/ws/ws.module.ts @@ -1,9 +1,10 @@ import { Module } from '@nestjs/common'; import { WsGateway } from './ws.gateway'; import { TokenModule } from '../core/auth/token.module'; +import { ExcalidrawCollabService } from './services/excalidraw-collab.service'; @Module({ imports: [TokenModule], - providers: [WsGateway], + providers: [WsGateway, ExcalidrawCollabService], }) export class WsModule {}