refactor(WebSocketShard): identify throttling (#8888)

* refactor(WebSocketShard): identify throttling

* chore: add worker handling

* refactor: worker handling

* chore: update tests

* chore: use satisfies where applicable

* chore: add informative comment

* chore: apply suggestions

* refactor(SimpleContextFetchingStrategy): support multiple managers
This commit is contained in:
DD
2022-12-02 15:04:09 +02:00
committed by GitHub
parent 3fca638a84
commit 8f552a0e17
9 changed files with 113 additions and 38 deletions

View File

@@ -27,7 +27,7 @@ const mockConstructor = vi.fn();
const mockSend = vi.fn();
const mockTerminate = vi.fn();
const memberChunkData: GatewayDispatchPayload = {
const memberChunkData = {
op: GatewayOpcodes.Dispatch,
s: 123,
t: GatewayDispatchEvents.GuildMembersChunk,
@@ -35,13 +35,14 @@ const memberChunkData: GatewayDispatchPayload = {
guild_id: '123',
members: [],
},
};
} as unknown as GatewayDispatchPayload;
const sessionInfo: SessionInfo = {
shardId: 0,
shardCount: 2,
sequence: 123,
sessionId: 'abc',
resumeURL: 'wss://ehehe.gg',
};
vi.mock('node:worker_threads', async () => {
@@ -109,6 +110,10 @@ vi.mock('node:worker_threads', async () => {
this.emit('message', session);
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
}
}
@@ -181,7 +186,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
);
const payload: GatewaySendPayload = { op: GatewayOpcodes.RequestGuildMembers, d: { guild_id: '123', limit: 0 } };
const payload = {
op: GatewayOpcodes.RequestGuildMembers,
d: { guild_id: '123', limit: 0, query: '' },
} satisfies GatewaySendPayload;
await manager.send(0, payload);
expect(mockSend).toHaveBeenCalledWith(0, payload);
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {

View File

@@ -18,6 +18,7 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>;
}
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {

View File

@@ -1,8 +1,28 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
// This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();
private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) {
return existing;
}
const throttler = new IdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
return throttler;
}
private readonly throttler: IdentifyThrottler;
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId);
@@ -11,4 +31,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
}
public async waitForIdentify(): Promise<void> {
await this.throttler.waitForIdentify();
}
}

View File

@@ -12,6 +12,8 @@ import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IConte
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
private readonly waitForIdentifyPromises = new Collection<number, () => void>();
public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) {
throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread');
@@ -19,20 +21,24 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
parentPort!.on('message', (payload: WorkerSendPayload) => {
if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) {
const resolve = this.sessionPromises.get(payload.nonce);
resolve?.(payload.session);
this.sessionPromises.get(payload.nonce)?.(payload.session);
this.sessionPromises.delete(payload.nonce);
}
if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
this.waitForIdentifyPromises.get(payload.nonce)?.();
this.waitForIdentifyPromises.delete(payload.nonce);
}
});
}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random();
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId,
nonce,
};
} satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
@@ -40,11 +46,23 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
}
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId,
session: sessionInfo,
};
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
}
public async waitForIdentify(): Promise<void> {
const nonce = Math.random();
const payload = {
op: WorkerRecievePayloadOp.WaitForIdentify,
nonce,
} satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}
}

View File

@@ -1,6 +1,5 @@
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { WebSocketManager } from '../../ws/WebSocketManager';
import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js';
import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy.js';
@@ -15,11 +14,8 @@ export class SimpleShardingStrategy implements IShardingStrategy {
private readonly shards = new Collection<number, WebSocketShard>();
private readonly throttler: IdentifyThrottler;
public constructor(manager: WebSocketManager) {
this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
}
/**
@@ -46,7 +42,6 @@ export class SimpleShardingStrategy implements IShardingStrategy {
const promises = [];
for (const shard of this.shards.values()) {
await this.throttler.waitForIdentify();
promises.push(shard.connect());
}

View File

@@ -18,10 +18,12 @@ export enum WorkerSendPayloadOp {
Destroy,
Send,
SessionInfoResponse,
ShardCanIdentify,
}
export type WorkerSendPayload =
| { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null }
| { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify }
| { op: WorkerSendPayloadOp.Connect; shardId: number }
| { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number }
| { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number };
@@ -32,12 +34,14 @@ export enum WorkerRecievePayloadOp {
Event,
RetrieveSessionInfo,
UpdateSessionInfo,
WaitForIdentify,
}
export type WorkerRecievePayload =
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
| { data: any; event: WebSocketShardEvents; op: WorkerRecievePayloadOp.Event; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
@@ -118,12 +122,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) {
await this.throttler.waitForIdentify();
const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Connect,
shardId,
};
} satisfies WorkerSendPayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
@@ -141,11 +143,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Destroy,
shardId,
options,
};
} satisfies WorkerSendPayload;
promises.push(
// eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then
@@ -169,11 +171,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
throw new Error(`No worker found for shard ${shardId}`);
}
const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Send,
shardId,
payload: data,
};
} satisfies WorkerSendPayload;
worker.postMessage(payload);
}
@@ -213,6 +215,16 @@ export class WorkerShardingStrategy implements IShardingStrategy {
await this.manager.options.updateSessionInfo(payload.shardId, payload.session);
break;
}
case WorkerRecievePayloadOp.WaitForIdentify: {
await this.throttler.waitForIdentify();
const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.ShardCanIdentify,
nonce: payload.nonce,
};
worker.postMessage(response);
break;
}
}
}
}

View File

@@ -40,12 +40,12 @@ for (const shardId of data.shardIds) {
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
};
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
@@ -93,5 +93,9 @@ parentPort!
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
}
});

View File

@@ -1,7 +1,10 @@
import { setTimeout as sleep } from 'node:timers/promises';
import type { WebSocketManager } from '../ws/WebSocketManager';
import { AsyncQueue } from '@sapphire/async-queue';
import type { WebSocketManager } from '../ws/WebSocketManager.js';
export class IdentifyThrottler {
private readonly queue = new AsyncQueue();
private identifyState = {
remaining: 0,
resetsAt: Number.POSITIVE_INFINITY,
@@ -10,20 +13,27 @@ export class IdentifyThrottler {
public constructor(private readonly manager: WebSocketManager) {}
public async waitForIdentify(): Promise<void> {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
const time = diff + Math.random() * 1_500;
await sleep(time);
await this.queue.wait();
try {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
const time = diff + Math.random() * 1_500;
await sleep(time);
}
const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
}
const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
this.identifyState.remaining--;
} finally {
this.queue.shift();
}
this.identifyState.remaining--;
}
}

View File

@@ -295,6 +295,9 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
`intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
]);
await this.strategy.waitForIdentify();
const d: GatewayIdentifyData = {
token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties,