sync with latest

This commit is contained in:
Philipinho
2026-01-26 23:39:09 +00:00
parent 38d0556ac3
commit 4d5e23cad2
3 changed files with 60 additions and 78 deletions
@@ -9,7 +9,7 @@ import type {
export class CollabProxySocket extends EventEmitter { export class CollabProxySocket extends EventEmitter {
private readonly replyTo: string; private readonly replyTo: string;
private readonly pongChannel: string; private readonly serverChannel: string;
private readonly socketId: string; private readonly socketId: string;
private pub: RedisClient; private pub: RedisClient;
private readonly pack: Pack; private readonly pack: Pack;
@@ -19,13 +19,13 @@ export class CollabProxySocket extends EventEmitter {
pub: RedisClient, pub: RedisClient,
pack: Pack, pack: Pack,
replyTo: string, replyTo: string,
pongChannel: string, serverChannel: string,
socketId: string, socketId: string,
) { ) {
super(); super();
this.replyTo = replyTo; this.replyTo = replyTo;
this.pongChannel = pongChannel;
this.socketId = socketId; this.socketId = socketId;
this.serverChannel = serverChannel;
this.pub = pub; this.pub = pub;
this.pack = pack; this.pack = pack;
this.once('close', () => { this.once('close', () => {
@@ -53,7 +53,7 @@ export class CollabProxySocket extends EventEmitter {
const msg: RSAMessagePing = { const msg: RSAMessagePing = {
type: 'ping', type: 'ping',
socketId: this.socketId, socketId: this.socketId,
respondTo: this.pongChannel, replyTo: this.serverChannel,
}; };
this.publish(msg); this.publish(msg);
} }
@@ -1,18 +1,15 @@
// Source https://github.com/ueberdosis/hocuspocus/pull/1008 - MIT // Source https://github.com/ueberdosis/hocuspocus/pull/1008 - MIT
import { IncomingMessage } from 'node:http';
import { import {
Extension, Extension,
Hocuspocus, Hocuspocus,
IncomingMessage as SocketIncomingMessage, IncomingMessage,
afterUnloadDocumentPayload, afterUnloadDocumentPayload,
onConfigurePayload, onConfigurePayload,
onLoadDocumentPayload, onLoadDocumentPayload,
} from '@hocuspocus/server'; } from '@hocuspocus/server';
import RedisClient from 'ioredis'; import RedisClient from 'ioredis';
import { readVarString } from 'lib0/decoding.js'; import { readVarString } from 'lib0/decoding.js';
import { WebSocket } from 'ws';
import { CollabProxySocket } from './collab-proxy-socket'; import { CollabProxySocket } from './collab-proxy-socket';
import { Logger } from '@nestjs/common';
import { import {
BaseWebSocket, BaseWebSocket,
Configuration, Configuration,
@@ -36,7 +33,6 @@ type DocumentName = string;
type SocketId = string; type SocketId = string;
export class RedisSyncExtension<TCE extends CustomEvents> implements Extension { export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
private readonly logger = new Logger('Collab' + RedisSyncExtension.name);
priority = 1000; priority = 1000;
private readonly pub: RedisClient; private readonly pub: RedisClient;
private sub: RedisClient; private sub: RedisClient;
@@ -45,25 +41,20 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
private originSockets: Record<SocketId, BaseWebSocket> = {}; private originSockets: Record<SocketId, BaseWebSocket> = {};
private locks: Record<DocumentName, NodeJS.Timeout> = {}; private locks: Record<DocumentName, NodeJS.Timeout> = {};
private lockPromises: Record<DocumentName, Promise<ServerId | null>> = {}; private lockPromises: Record<DocumentName, Promise<ServerId | null>> = {};
private proxySockets: Record< private proxySockets: Record<SocketId, CollabProxySocket> = {};
SocketId,
{ socket: CollabProxySocket; cleanup: NodeJS.Timeout }
> = {};
private readonly prefix: string; private readonly prefix: string;
private readonly lockPrefix: string; private readonly lockPrefix: string;
private readonly msgChannel: string; private readonly msgChannel: string;
private readonly serverId: ServerId; private readonly serverId: ServerId;
private readonly customEventTTL: number; private readonly customEventTTL: number;
private readonly lockTTL: number; private readonly lockTTL: number;
private readonly proxySocketTTL: number;
private instance!: Hocuspocus; private instance!: Hocuspocus;
private readonly customEvents: TCE; private readonly customEvents: TCE;
private replyIdCounter = 0; private replyIdCounter: number = 0;
private pendingReplies: Record< // @ts-ignore
number, private pendingReplies: Record<number, PromiseWithResolvers<any>['resolve']> =
// @ts-ignore {};
PromiseWithResolvers<unknown>['resolve']
> = {};
constructor(configuration: Configuration<TCE>) { constructor(configuration: Configuration<TCE>) {
const { const {
redis, redis,
@@ -72,7 +63,6 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
serverId, serverId,
lockTTL, lockTTL,
prefix, prefix,
proxySocketTTL,
customEvents, customEvents,
customEventTTL, customEventTTL,
} = configuration; } = configuration;
@@ -82,12 +72,11 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
this.unpack = unpack; this.unpack = unpack;
this.serverId = serverId; this.serverId = serverId;
this.lockTTL = lockTTL ?? 10_000; this.lockTTL = lockTTL ?? 10_000;
this.proxySocketTTL = proxySocketTTL ?? 30_000;
this.customEventTTL = customEventTTL ?? 30_000; this.customEventTTL = customEventTTL ?? 30_000;
this.prefix = prefix ?? 'collab'; this.prefix = prefix ?? 'collab';
this.lockPrefix = `${this.prefix}Lock`; this.lockPrefix = `${this.prefix}Lock`;
this.msgChannel = `${this.prefix}Msg`; this.msgChannel = `${this.prefix}Msg`;
this.customEvents = (customEvents ?? {}) as unknown as TCE; this.customEvents = (customEvents as any) ?? ({} as any as CustomEvents);
this.sub.subscribe(this.msgChannel, `${this.msgChannel}:${this.serverId}`); this.sub.subscribe(this.msgChannel, `${this.msgChannel}:${this.serverId}`);
this.sub.on('messageBuffer', this.handleRedisMessage); this.sub.on('messageBuffer', this.handleRedisMessage);
} }
@@ -96,18 +85,19 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
} }
private closeProxy(socketId: string) { private closeProxy(socketId: string) {
const socketRecord = this.proxySockets[socketId]; const proxySocket = this.proxySockets[socketId];
if (!socketRecord) return; if (proxySocket) {
clearTimeout(socketRecord.cleanup); proxySocket.emit(
socketRecord.socket.emit('close', 1000, 'proxy_cleanup'); 'close',
delete this.proxySockets[socketId]; 1000,
Buffer.from('provider_initiated', 'utf-8'),
);
delete this.proxySockets[socketId];
}
} }
private emitPong(socketId: string) { private pongProxy(socketId: string) {
const socketRecord = this.proxySockets[socketId]; this.proxySockets[socketId]?.emit('pong');
if (socketRecord) {
socketRecord.socket.emit('pong');
}
} }
private handleProxyMessage( private handleProxyMessage(
@@ -115,35 +105,24 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
) { ) {
const { replyTo, message, serializedHTTPRequest } = msg; const { replyTo, message, serializedHTTPRequest } = msg;
const { headers } = serializedHTTPRequest; const { headers } = serializedHTTPRequest;
const socketId = headers['sec-websocket-key']; const socketId = headers['sec-websocket-key']!;
let socketRecord = this.proxySockets[socketId]; let socket = this.proxySockets[socketId];
const cleanup = setTimeout(() => { if (!socket) {
const record = this.proxySockets[socketId]; socket = new CollabProxySocket(
if (record) {
record.socket.emit('close', 1000, 'ttl_expired');
delete this.proxySockets[socketId];
}
}, this.proxySocketTTL);
if (!socketRecord) {
const socket = new CollabProxySocket(
this.pub, this.pub,
this.pack, this.pack,
replyTo, replyTo,
`${this.msgChannel}:${this.serverId}`, `${this.msgChannel}:${this.serverId}`,
socketId, socketId,
); );
socketRecord = { socket, cleanup }; this.proxySockets[socketId] = socket;
this.proxySockets[socketId] = socketRecord;
this.instance.handleConnection( this.instance.handleConnection(
socket as unknown as WebSocket, socket as any,
serializedHTTPRequest as unknown as IncomingMessage, serializedHTTPRequest as any,
{}, {},
); );
} else {
clearTimeout(socketRecord.cleanup);
socketRecord.cleanup = cleanup;
} }
socketRecord.socket.emit('message', message); socket.emit('message', message);
} }
private getOrClaimLock(documentName: string) { private getOrClaimLock(documentName: string) {
@@ -185,6 +164,10 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
this.closeProxy(msg.socketId); this.closeProxy(msg.socketId);
return; return;
} }
if (type === 'pong') {
this.pongProxy(msg.socketId);
return;
}
if (type === 'unload') { if (type === 'unload') {
delete this.lockPromises[msg.documentName]; delete this.lockPromises[msg.documentName];
return; return;
@@ -201,7 +184,7 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
replyId, replyId,
payload: res, payload: res,
}; };
this.pub.publish(`${replyTo}`, this.pack(reply)).catch(() => {}); this.pub.publish(`${replyTo}`, this.pack(reply));
return; return;
} }
if (type === 'customEventComplete') { if (type === 'customEventComplete') {
@@ -212,10 +195,6 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
resolveFn(payload); resolveFn(payload);
return; return;
} }
if (type === 'pong') {
this.emitPong(msg.socketId);
return;
}
const { socketId } = msg; const { socketId } = msg;
const socket = this.originSockets[socketId]; const socket = this.originSockets[socketId];
if (!socket) { if (!socket) {
@@ -225,9 +204,14 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
if (type === 'close') { if (type === 'close') {
socket.close(msg.code, msg.reason); socket.close(msg.code, msg.reason);
} else if (type === 'ping') { } else if (type === 'ping') {
const { respondTo } = msg; // Reply instantly to the proxy socket, without forwarding to client
const pong: RSAMessagePong = { type: 'pong', socketId }; // The origin socket handles heartbeat for itself
this.pub.publish(respondTo, this.pack(pong)).catch(() => {}); const { replyTo, socketId } = msg;
const reply: RSAMessagePong = {
type: 'pong',
socketId,
};
this.pub.publish(`${replyTo}`, this.pack(reply));
} else if (type === 'send') { } else if (type === 'send') {
socket.send(msg.message); socket.send(msg.message);
} }
@@ -253,7 +237,7 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
private async handleEventLocally<TName extends Extract<keyof TCE, string>>( private async handleEventLocally<TName extends Extract<keyof TCE, string>>(
eventName: TName, eventName: TName,
documentName: string, documentName: string,
payload: unknown, payload: any,
) { ) {
const handler = this.customEvents[eventName]; const handler = this.customEvents[eventName];
if (!handler) throw new Error(`Invalid eventName: ${eventName}`); if (!handler) throw new Error(`Invalid eventName: ${eventName}`);
@@ -264,7 +248,7 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
async handleEvent<TName extends Extract<keyof TCE, string>>( async handleEvent<TName extends Extract<keyof TCE, string>>(
eventName: TName, eventName: TName,
documentName: string, documentName: string,
payload: unknown, payload: any,
) { ) {
const isDocLoadedOnInstance = this.instance.documents.has(documentName); const isDocLoadedOnInstance = this.instance.documents.has(documentName);
@@ -286,17 +270,13 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
type: 'customEventStart', type: 'customEventStart',
}; };
const msg = this.pack(proxyMessage); const msg = this.pack(proxyMessage);
this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg).catch(() => {}); this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg);
// @ts-ignore // @ts-ignore
const { promise, resolve, reject } = Promise.withResolvers(); const { promise, resolve, reject } = Promise.withResolvers();
const timeoutId = setTimeout(() => { this.pendingReplies[replyId] = resolve;
delete this.pendingReplies[replyId]; setTimeout(() => {
reject('TIMEOUT'); reject('TIMEOUT');
}, this.customEventTTL); }, this.customEventTTL);
this.pendingReplies[replyId] = (result: unknown) => {
clearTimeout(timeoutId);
resolve(result);
};
return promise as Promise<ReturnType<TCE[TName]>>; return promise as Promise<ReturnType<TCE[TName]>>;
} }
// This server owns the document, but hocuspocus hasn't loaded it yet // This server owns the document, but hocuspocus hasn't loaded it yet
@@ -318,11 +298,11 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
serializedHTTPRequest: SerializedHTTPRequest, serializedHTTPRequest: SerializedHTTPRequest,
context = {}, context = {},
) { ) {
const socketId = serializedHTTPRequest.headers['sec-websocket-key']; const socketId = serializedHTTPRequest.headers['sec-websocket-key']!;
this.originSockets[socketId] = ws; this.originSockets[socketId] = ws;
this.instance.handleConnection( this.instance.handleConnection(
ws as unknown as WebSocket, ws as any,
serializedHTTPRequest as unknown as IncomingMessage, serializedHTTPRequest as any,
context, context,
); );
} }
@@ -332,9 +312,8 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
serializedHTTPRequest: SerializedHTTPRequest, serializedHTTPRequest: SerializedHTTPRequest,
detachableMsg: ArrayBuffer, detachableMsg: ArrayBuffer,
) { ) {
// @ts-ignore
const message = new Uint8Array(detachableMsg.slice()); const message = new Uint8Array(detachableMsg.slice());
const tmpMsg = new SocketIncomingMessage(detachableMsg); const tmpMsg = new IncomingMessage(detachableMsg);
const documentName = readVarString(tmpMsg.decoder); const documentName = readVarString(tmpMsg.decoder);
const isDocLoadedOnInstance = this.instance.documents.has(documentName); const isDocLoadedOnInstance = this.instance.documents.has(documentName);
@@ -353,7 +332,7 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
type: 'proxy', type: 'proxy',
}; };
const msg = this.pack(proxyMessage); const msg = this.pack(proxyMessage);
this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg).catch(() => {}); this.pub.publish(`${this.msgChannel}:${proxyTo}`, msg);
return; return;
} }
// This server owns the document, but hocuspocus hasn't loaded it yet // This server owns the document, but hocuspocus hasn't loaded it yet
@@ -363,8 +342,10 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
onSocketClose(socketId: string, code?: number, reason?: ArrayBuffer) { onSocketClose(socketId: string, code?: number, reason?: ArrayBuffer) {
const socket = this.originSockets[socketId]; const socket = this.originSockets[socketId];
if (!socket) return; if (!socket) return;
// at this point the socket is considered GC'd and we cannot call close
// The origin socket did not set up any connections for the proxy, so none of the hooks will work if we just emit
socket?.emit('close', code, reason);
delete this.originSockets[socketId]; delete this.originSockets[socketId];
socket.emit('close', code, reason);
const msg: RSAMessageCloseProxy = { type: 'closeProxy', socketId }; const msg: RSAMessageCloseProxy = { type: 'closeProxy', socketId };
this.pub.publish(this.msgChannel, this.pack(msg)).catch(() => {}); this.pub.publish(this.msgChannel, this.pack(msg)).catch(() => {});
} }
@@ -385,8 +366,9 @@ export class RedisSyncExtension<TCE extends CustomEvents> implements Extension {
this.releaseLock(documentName); this.releaseLock(documentName);
// Broadcast to cluster to immediately remove the cached redis value // Broadcast to cluster to immediately remove the cached redis value
const msg: RSAMessageUnload = { type: 'unload', documentName }; const msg: RSAMessageUnload = { type: 'unload', documentName };
this.pub.publish(this.msgChannel, this.pack(msg)).catch(() => {}); this.pub.publish(this.msgChannel, this.pack(msg));
} }
async onDestroy() { async onDestroy() {
this.pub.disconnect(false); this.pub.disconnect(false);
this.sub.disconnect(false); this.sub.disconnect(false);
@@ -45,7 +45,7 @@ export type RSAMessageClose = {
export type RSAMessagePing = { export type RSAMessagePing = {
type: 'ping'; type: 'ping';
socketId: string; socketId: string;
respondTo: string; replyTo: string;
}; };
export type RSAMessagePong = { export type RSAMessagePong = {