refactor(WebSocketShard): identify throttling (#8888)

* refactor(WebSocketShard): identify throttling

* chore: add worker handling

* refactor: worker handling

* chore: update tests

* chore: use satisfies where applicable

* chore: add informative comment

* chore: apply suggestions

* refactor(SimpleContextFetchingStrategy): support multiple managers
This commit is contained in:
DD
2022-12-02 15:04:09 +02:00
committed by GitHub
parent 3fca638a84
commit 8f552a0e17
9 changed files with 113 additions and 38 deletions

View File

@@ -27,7 +27,7 @@ const mockConstructor = vi.fn();
const mockSend = vi.fn(); const mockSend = vi.fn();
const mockTerminate = vi.fn(); const mockTerminate = vi.fn();
const memberChunkData: GatewayDispatchPayload = { const memberChunkData = {
op: GatewayOpcodes.Dispatch, op: GatewayOpcodes.Dispatch,
s: 123, s: 123,
t: GatewayDispatchEvents.GuildMembersChunk, t: GatewayDispatchEvents.GuildMembersChunk,
@@ -35,13 +35,14 @@ const memberChunkData: GatewayDispatchPayload = {
guild_id: '123', guild_id: '123',
members: [], members: [],
}, },
}; } as unknown as GatewayDispatchPayload;
const sessionInfo: SessionInfo = { const sessionInfo: SessionInfo = {
shardId: 0, shardId: 0,
shardCount: 2, shardCount: 2,
sequence: 123, sequence: 123,
sessionId: 'abc', sessionId: 'abc',
resumeURL: 'wss://ehehe.gg',
}; };
vi.mock('node:worker_threads', async () => { vi.mock('node:worker_threads', async () => {
@@ -109,6 +110,10 @@ vi.mock('node:worker_threads', async () => {
this.emit('message', session); this.emit('message', session);
break; break;
} }
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
} }
} }
@@ -181,7 +186,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }), expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
); );
const payload: GatewaySendPayload = { op: GatewayOpcodes.RequestGuildMembers, d: { guild_id: '123', limit: 0 } }; const payload = {
op: GatewayOpcodes.RequestGuildMembers,
d: { guild_id: '123', limit: 0, query: '' },
} satisfies GatewaySendPayload;
await manager.send(0, payload); await manager.send(0, payload);
expect(mockSend).toHaveBeenCalledWith(0, payload); expect(mockSend).toHaveBeenCalledWith(0, payload);
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, { expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {

View File

@@ -18,6 +18,7 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions; readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>; retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>; updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>;
} }
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> { export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {

View File

@@ -1,8 +1,28 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {} // This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();
private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) {
return existing;
}
const throttler = new IdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
return throttler;
}
private readonly throttler: IdentifyThrottler;
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> { public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId); return this.manager.options.retrieveSessionInfo(shardId);
@@ -11,4 +31,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
return this.manager.options.updateSessionInfo(shardId, sessionInfo); return this.manager.options.updateSessionInfo(shardId, sessionInfo);
} }
public async waitForIdentify(): Promise<void> {
await this.throttler.waitForIdentify();
}
} }

View File

@@ -12,6 +12,8 @@ import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IConte
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>(); private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
private readonly waitForIdentifyPromises = new Collection<number, () => void>();
public constructor(public readonly options: FetchingStrategyOptions) { public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) { if (isMainThread) {
throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread'); throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread');
@@ -19,20 +21,24 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
parentPort!.on('message', (payload: WorkerSendPayload) => { parentPort!.on('message', (payload: WorkerSendPayload) => {
if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) { if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) {
const resolve = this.sessionPromises.get(payload.nonce); this.sessionPromises.get(payload.nonce)?.(payload.session);
resolve?.(payload.session);
this.sessionPromises.delete(payload.nonce); this.sessionPromises.delete(payload.nonce);
} }
if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
this.waitForIdentifyPromises.get(payload.nonce)?.();
this.waitForIdentifyPromises.delete(payload.nonce);
}
}); });
} }
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> { public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random(); const nonce = Math.random();
const payload: WorkerRecievePayload = { const payload = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo, op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId, shardId,
nonce, nonce,
}; } satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return // eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve)); const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
@@ -40,11 +46,23 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
} }
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload: WorkerRecievePayload = { const payload = {
op: WorkerRecievePayloadOp.UpdateSessionInfo, op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId, shardId,
session: sessionInfo, session: sessionInfo,
}; } satisfies WorkerRecievePayload;
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
} }
public async waitForIdentify(): Promise<void> {
const nonce = Math.random();
const payload = {
op: WorkerRecievePayloadOp.WaitForIdentify,
nonce,
} satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}
} }

View File

@@ -1,6 +1,5 @@
import { Collection } from '@discordjs/collection'; import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10'; import type { GatewaySendPayload } from 'discord-api-types/v10';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { WebSocketManager } from '../../ws/WebSocketManager'; import type { WebSocketManager } from '../../ws/WebSocketManager';
import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js'; import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js';
import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy.js'; import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy.js';
@@ -15,11 +14,8 @@ export class SimpleShardingStrategy implements IShardingStrategy {
private readonly shards = new Collection<number, WebSocketShard>(); private readonly shards = new Collection<number, WebSocketShard>();
private readonly throttler: IdentifyThrottler;
public constructor(manager: WebSocketManager) { public constructor(manager: WebSocketManager) {
this.manager = manager; this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
} }
/** /**
@@ -46,7 +42,6 @@ export class SimpleShardingStrategy implements IShardingStrategy {
const promises = []; const promises = [];
for (const shard of this.shards.values()) { for (const shard of this.shards.values()) {
await this.throttler.waitForIdentify();
promises.push(shard.connect()); promises.push(shard.connect());
} }

View File

@@ -18,10 +18,12 @@ export enum WorkerSendPayloadOp {
Destroy, Destroy,
Send, Send,
SessionInfoResponse, SessionInfoResponse,
ShardCanIdentify,
} }
export type WorkerSendPayload = export type WorkerSendPayload =
| { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null } | { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null }
| { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify }
| { op: WorkerSendPayloadOp.Connect; shardId: number } | { op: WorkerSendPayloadOp.Connect; shardId: number }
| { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number } | { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number }
| { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number }; | { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number };
@@ -32,12 +34,14 @@ export enum WorkerRecievePayloadOp {
Event, Event,
RetrieveSessionInfo, RetrieveSessionInfo,
UpdateSessionInfo, UpdateSessionInfo,
WaitForIdentify,
} }
export type WorkerRecievePayload = export type WorkerRecievePayload =
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now // Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
| { data: any; event: WebSocketShardEvents; op: WorkerRecievePayloadOp.Event; shardId: number } | { data: any; event: WebSocketShardEvents; op: WorkerRecievePayloadOp.Event; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number } | { nonce: number; op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number } | { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number } | { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }; | { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
@@ -118,12 +122,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = []; const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) { for (const [shardId, worker] of this.#workerByShardId.entries()) {
await this.throttler.waitForIdentify(); const payload = {
const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Connect, op: WorkerSendPayloadOp.Connect,
shardId, shardId,
}; } satisfies WorkerSendPayload;
// eslint-disable-next-line no-promise-executor-return // eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve)); const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
@@ -141,11 +143,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = []; const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) { for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload: WorkerSendPayload = { const payload = {
op: WorkerSendPayloadOp.Destroy, op: WorkerSendPayloadOp.Destroy,
shardId, shardId,
options, options,
}; } satisfies WorkerSendPayload;
promises.push( promises.push(
// eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then // eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then
@@ -169,11 +171,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
throw new Error(`No worker found for shard ${shardId}`); throw new Error(`No worker found for shard ${shardId}`);
} }
const payload: WorkerSendPayload = { const payload = {
op: WorkerSendPayloadOp.Send, op: WorkerSendPayloadOp.Send,
shardId, shardId,
payload: data, payload: data,
}; } satisfies WorkerSendPayload;
worker.postMessage(payload); worker.postMessage(payload);
} }
@@ -213,6 +215,16 @@ export class WorkerShardingStrategy implements IShardingStrategy {
await this.manager.options.updateSessionInfo(payload.shardId, payload.session); await this.manager.options.updateSessionInfo(payload.shardId, payload.session);
break; break;
} }
case WorkerRecievePayloadOp.WaitForIdentify: {
await this.throttler.waitForIdentify();
const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.ShardCanIdentify,
nonce: payload.nonce,
};
worker.postMessage(response);
break;
}
} }
} }
} }

View File

@@ -40,12 +40,12 @@ for (const shardId of data.shardIds) {
for (const event of Object.values(WebSocketShardEvents)) { for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible // @ts-expect-error: Event types incompatible
shard.on(event, (data) => { shard.on(event, (data) => {
const payload: WorkerRecievePayload = { const payload = {
op: WorkerRecievePayloadOp.Event, op: WorkerRecievePayloadOp.Event,
event, event,
data, data,
shardId, shardId,
}; } satisfies WorkerRecievePayload;
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
}); });
} }
@@ -93,5 +93,9 @@ parentPort!
case WorkerSendPayloadOp.SessionInfoResponse: { case WorkerSendPayloadOp.SessionInfoResponse: {
break; break;
} }
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
} }
}); });

View File

@@ -1,7 +1,10 @@
import { setTimeout as sleep } from 'node:timers/promises'; import { setTimeout as sleep } from 'node:timers/promises';
import type { WebSocketManager } from '../ws/WebSocketManager'; import { AsyncQueue } from '@sapphire/async-queue';
import type { WebSocketManager } from '../ws/WebSocketManager.js';
export class IdentifyThrottler { export class IdentifyThrottler {
private readonly queue = new AsyncQueue();
private identifyState = { private identifyState = {
remaining: 0, remaining: 0,
resetsAt: Number.POSITIVE_INFINITY, resetsAt: Number.POSITIVE_INFINITY,
@@ -10,20 +13,27 @@ export class IdentifyThrottler {
public constructor(private readonly manager: WebSocketManager) {} public constructor(private readonly manager: WebSocketManager) {}
public async waitForIdentify(): Promise<void> { public async waitForIdentify(): Promise<void> {
if (this.identifyState.remaining <= 0) { await this.queue.wait();
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) { try {
const time = diff + Math.random() * 1_500; if (this.identifyState.remaining <= 0) {
await sleep(time); const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
const time = diff + Math.random() * 1_500;
await sleep(time);
}
const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
} }
const info = await this.manager.fetchGatewayInformation(); this.identifyState.remaining--;
this.identifyState = { } finally {
remaining: info.session_start_limit.max_concurrency, this.queue.shift();
resetsAt: Date.now() + 5_000,
};
} }
this.identifyState.remaining--;
} }
} }

View File

@@ -295,6 +295,9 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
`intents: ${this.strategy.options.intents}`, `intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`, `compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
]); ]);
await this.strategy.waitForIdentify();
const d: GatewayIdentifyData = { const d: GatewayIdentifyData = {
token: this.strategy.options.token, token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties, properties: this.strategy.options.identifyProperties,