From 8f552a0e17c0eca71063e7a4353b9b351bcdf9fd Mon Sep 17 00:00:00 2001 From: DD Date: Fri, 2 Dec 2022 15:04:09 +0200 Subject: [PATCH] refactor(WebSocketShard): identify throttling (#8888) * refactor(WebSocketShard): identify throttling * chore: add worker handling * refactor: worker handling * chore: update tests * chore: use satisfies where applicable * chore: add informative comment * chore: apply suggestions * refactor(SimpleContextFetchingStrategy): support multiple managers --- .../strategy/WorkerShardingStrategy.test.ts | 14 ++++++-- .../context/IContextFetchingStrategy.ts | 1 + .../context/SimpleContextFetchingStrategy.ts | 26 +++++++++++++- .../context/WorkerContextFetchingStrategy.ts | 30 ++++++++++++---- .../sharding/SimpleShardingStrategy.ts | 5 --- .../sharding/WorkerShardingStrategy.ts | 28 ++++++++++----- packages/ws/src/strategies/sharding/worker.ts | 8 +++-- packages/ws/src/utils/IdentifyThrottler.ts | 36 ++++++++++++------- packages/ws/src/ws/WebSocketShard.ts | 3 ++ 9 files changed, 113 insertions(+), 38 deletions(-) diff --git a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts index 79e402675..141dfa4b4 100644 --- a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts +++ b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts @@ -27,7 +27,7 @@ const mockConstructor = vi.fn(); const mockSend = vi.fn(); const mockTerminate = vi.fn(); -const memberChunkData: GatewayDispatchPayload = { +const memberChunkData = { op: GatewayOpcodes.Dispatch, s: 123, t: GatewayDispatchEvents.GuildMembersChunk, @@ -35,13 +35,14 @@ const memberChunkData: GatewayDispatchPayload = { guild_id: '123', members: [], }, -}; +} as unknown as GatewayDispatchPayload; const sessionInfo: SessionInfo = { shardId: 0, shardCount: 2, sequence: 123, sessionId: 'abc', + resumeURL: 'wss://ehehe.gg', }; vi.mock('node:worker_threads', async () => { @@ -109,6 +110,10 @@ vi.mock('node:worker_threads', async () => { this.emit('message', session); break; } + + case WorkerSendPayloadOp.ShardCanIdentify: { + break; + } } } @@ -181,7 +186,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => { expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }), ); - const payload: GatewaySendPayload = { op: GatewayOpcodes.RequestGuildMembers, d: { guild_id: '123', limit: 0 } }; + const payload = { + op: GatewayOpcodes.RequestGuildMembers, + d: { guild_id: '123', limit: 0, query: '' }, + } satisfies GatewaySendPayload; await manager.send(0, payload); expect(mockSend).toHaveBeenCalledWith(0, payload); expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, { diff --git a/packages/ws/src/strategies/context/IContextFetchingStrategy.ts b/packages/ws/src/strategies/context/IContextFetchingStrategy.ts index 7a51303cf..e9c0ad64e 100644 --- a/packages/ws/src/strategies/context/IContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/IContextFetchingStrategy.ts @@ -18,6 +18,7 @@ export interface IContextFetchingStrategy { readonly options: FetchingStrategyOptions; retrieveSessionInfo(shardId: number): Awaitable; updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable; + waitForIdentify(): Promise; } export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise { diff --git a/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts b/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts index 9a942dd4b..4865f68e1 100644 --- a/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts @@ -1,8 +1,28 @@ +import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { - public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {} + // This strategy assumes every shard is running under the same process - therefore we need a single + // IdentifyThrottler per manager. + private static throttlerCache = new WeakMap(); + + private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler { + const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager); + if (existing) { + return existing; + } + + const throttler = new IdentifyThrottler(manager); + SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler); + return throttler; + } + + private readonly throttler: IdentifyThrottler; + + public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) { + this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager); + } public async retrieveSessionInfo(shardId: number): Promise { return this.manager.options.retrieveSessionInfo(shardId); @@ -11,4 +31,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { return this.manager.options.updateSessionInfo(shardId, sessionInfo); } + + public async waitForIdentify(): Promise { + await this.throttler.waitForIdentify(); + } } diff --git a/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts b/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts index 8962c3140..79c2c17a4 100644 --- a/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts @@ -12,6 +12,8 @@ import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IConte export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { private readonly sessionPromises = new Collection void>(); + private readonly waitForIdentifyPromises = new Collection void>(); + public constructor(public readonly options: FetchingStrategyOptions) { if (isMainThread) { throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread'); @@ -19,20 +21,24 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { parentPort!.on('message', (payload: WorkerSendPayload) => { if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) { - const resolve = this.sessionPromises.get(payload.nonce); - resolve?.(payload.session); + this.sessionPromises.get(payload.nonce)?.(payload.session); this.sessionPromises.delete(payload.nonce); } + + if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) { + this.waitForIdentifyPromises.get(payload.nonce)?.(); + this.waitForIdentifyPromises.delete(payload.nonce); + } }); } public async retrieveSessionInfo(shardId: number): Promise { const nonce = Math.random(); - const payload: WorkerRecievePayload = { + const payload = { op: WorkerRecievePayloadOp.RetrieveSessionInfo, shardId, nonce, - }; + } satisfies WorkerRecievePayload; // eslint-disable-next-line no-promise-executor-return const promise = new Promise((resolve) => this.sessionPromises.set(nonce, resolve)); parentPort!.postMessage(payload); @@ -40,11 +46,23 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { } public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { - const payload: WorkerRecievePayload = { + const payload = { op: WorkerRecievePayloadOp.UpdateSessionInfo, shardId, session: sessionInfo, - }; + } satisfies WorkerRecievePayload; parentPort!.postMessage(payload); } + + public async waitForIdentify(): Promise { + const nonce = Math.random(); + const payload = { + op: WorkerRecievePayloadOp.WaitForIdentify, + nonce, + } satisfies WorkerRecievePayload; + // eslint-disable-next-line no-promise-executor-return + const promise = new Promise((resolve) => this.waitForIdentifyPromises.set(nonce, resolve)); + parentPort!.postMessage(payload); + return promise; + } } diff --git a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts index 01276c175..d2592af69 100644 --- a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts @@ -1,6 +1,5 @@ import { Collection } from '@discordjs/collection'; import type { GatewaySendPayload } from 'discord-api-types/v10'; -import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; import type { WebSocketManager } from '../../ws/WebSocketManager'; import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js'; import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy.js'; @@ -15,11 +14,8 @@ export class SimpleShardingStrategy implements IShardingStrategy { private readonly shards = new Collection(); - private readonly throttler: IdentifyThrottler; - public constructor(manager: WebSocketManager) { this.manager = manager; - this.throttler = new IdentifyThrottler(manager); } /** @@ -46,7 +42,6 @@ export class SimpleShardingStrategy implements IShardingStrategy { const promises = []; for (const shard of this.shards.values()) { - await this.throttler.waitForIdentify(); promises.push(shard.connect()); } diff --git a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts index 3ff4b3802..c2f0df500 100644 --- a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts @@ -18,10 +18,12 @@ export enum WorkerSendPayloadOp { Destroy, Send, SessionInfoResponse, + ShardCanIdentify, } export type WorkerSendPayload = | { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null } + | { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify } | { op: WorkerSendPayloadOp.Connect; shardId: number } | { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number } | { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number }; @@ -32,12 +34,14 @@ export enum WorkerRecievePayloadOp { Event, RetrieveSessionInfo, UpdateSessionInfo, + WaitForIdentify, } export type WorkerRecievePayload = // Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now | { data: any; event: WebSocketShardEvents; op: WorkerRecievePayloadOp.Event; shardId: number } | { nonce: number; op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number } + | { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify } | { op: WorkerRecievePayloadOp.Connected; shardId: number } | { op: WorkerRecievePayloadOp.Destroyed; shardId: number } | { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }; @@ -118,12 +122,10 @@ export class WorkerShardingStrategy implements IShardingStrategy { const promises = []; for (const [shardId, worker] of this.#workerByShardId.entries()) { - await this.throttler.waitForIdentify(); - - const payload: WorkerSendPayload = { + const payload = { op: WorkerSendPayloadOp.Connect, shardId, - }; + } satisfies WorkerSendPayload; // eslint-disable-next-line no-promise-executor-return const promise = new Promise((resolve) => this.connectPromises.set(shardId, resolve)); @@ -141,11 +143,11 @@ export class WorkerShardingStrategy implements IShardingStrategy { const promises = []; for (const [shardId, worker] of this.#workerByShardId.entries()) { - const payload: WorkerSendPayload = { + const payload = { op: WorkerSendPayloadOp.Destroy, shardId, options, - }; + } satisfies WorkerSendPayload; promises.push( // eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then @@ -169,11 +171,11 @@ export class WorkerShardingStrategy implements IShardingStrategy { throw new Error(`No worker found for shard ${shardId}`); } - const payload: WorkerSendPayload = { + const payload = { op: WorkerSendPayloadOp.Send, shardId, payload: data, - }; + } satisfies WorkerSendPayload; worker.postMessage(payload); } @@ -213,6 +215,16 @@ export class WorkerShardingStrategy implements IShardingStrategy { await this.manager.options.updateSessionInfo(payload.shardId, payload.session); break; } + + case WorkerRecievePayloadOp.WaitForIdentify: { + await this.throttler.waitForIdentify(); + const response: WorkerSendPayload = { + op: WorkerSendPayloadOp.ShardCanIdentify, + nonce: payload.nonce, + }; + worker.postMessage(response); + break; + } } } } diff --git a/packages/ws/src/strategies/sharding/worker.ts b/packages/ws/src/strategies/sharding/worker.ts index fc489aa70..4bc3347b5 100644 --- a/packages/ws/src/strategies/sharding/worker.ts +++ b/packages/ws/src/strategies/sharding/worker.ts @@ -40,12 +40,12 @@ for (const shardId of data.shardIds) { for (const event of Object.values(WebSocketShardEvents)) { // @ts-expect-error: Event types incompatible shard.on(event, (data) => { - const payload: WorkerRecievePayload = { + const payload = { op: WorkerRecievePayloadOp.Event, event, data, shardId, - }; + } satisfies WorkerRecievePayload; parentPort!.postMessage(payload); }); } @@ -93,5 +93,9 @@ parentPort! case WorkerSendPayloadOp.SessionInfoResponse: { break; } + + case WorkerSendPayloadOp.ShardCanIdentify: { + break; + } } }); diff --git a/packages/ws/src/utils/IdentifyThrottler.ts b/packages/ws/src/utils/IdentifyThrottler.ts index 90faa3fac..45c35c5e7 100644 --- a/packages/ws/src/utils/IdentifyThrottler.ts +++ b/packages/ws/src/utils/IdentifyThrottler.ts @@ -1,7 +1,10 @@ import { setTimeout as sleep } from 'node:timers/promises'; -import type { WebSocketManager } from '../ws/WebSocketManager'; +import { AsyncQueue } from '@sapphire/async-queue'; +import type { WebSocketManager } from '../ws/WebSocketManager.js'; export class IdentifyThrottler { + private readonly queue = new AsyncQueue(); + private identifyState = { remaining: 0, resetsAt: Number.POSITIVE_INFINITY, @@ -10,20 +13,27 @@ export class IdentifyThrottler { public constructor(private readonly manager: WebSocketManager) {} public async waitForIdentify(): Promise { - if (this.identifyState.remaining <= 0) { - const diff = this.identifyState.resetsAt - Date.now(); - if (diff <= 5_000) { - const time = diff + Math.random() * 1_500; - await sleep(time); + await this.queue.wait(); + + try { + if (this.identifyState.remaining <= 0) { + const diff = this.identifyState.resetsAt - Date.now(); + if (diff <= 5_000) { + // To account for the latency the IDENTIFY payload goes through, we add a bit more wait time + const time = diff + Math.random() * 1_500; + await sleep(time); + } + + const info = await this.manager.fetchGatewayInformation(); + this.identifyState = { + remaining: info.session_start_limit.max_concurrency, + resetsAt: Date.now() + 5_000, + }; } - const info = await this.manager.fetchGatewayInformation(); - this.identifyState = { - remaining: info.session_start_limit.max_concurrency, - resetsAt: Date.now() + 5_000, - }; + this.identifyState.remaining--; + } finally { + this.queue.shift(); } - - this.identifyState.remaining--; } } diff --git a/packages/ws/src/ws/WebSocketShard.ts b/packages/ws/src/ws/WebSocketShard.ts index 04624381f..21f072fff 100644 --- a/packages/ws/src/ws/WebSocketShard.ts +++ b/packages/ws/src/ws/WebSocketShard.ts @@ -295,6 +295,9 @@ export class WebSocketShard extends AsyncEventEmitter { `intents: ${this.strategy.options.intents}`, `compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`, ]); + + await this.strategy.waitForIdentify(); + const d: GatewayIdentifyData = { token: this.strategy.options.token, properties: this.strategy.options.identifyProperties,