refactor: abstract identify throttling and correct max_concurrency handling (#9375)

* refactor: properly support max_concurrency ratelimit keys

* fix: properly block for same key

* chore: export session state

* chore: throttler no longer requires manager

* refactor: abstract throttlers

* chore: proper member order

* chore: remove leftover debug log

* chore: use @link tag in doc comment

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>

* chore: suggested changes

* fix(WebSocketShard): cancel identify if the shard closed in the meantime

* refactor(throttlers): support abort signals

* fix: memory leak

* chore: remove leftover

---------

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
This commit is contained in:
DD
2023-04-14 23:26:37 +03:00
committed by GitHub
parent cac3c07729
commit 02dfaf1aa2
16 changed files with 279 additions and 161 deletions

View File

@@ -57,9 +57,9 @@ vi.mock('node:worker_threads', async () => {
this.emit('online'); this.emit('online');
// same deal here // same deal here
setImmediate(() => { setImmediate(() => {
const message = { const message: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WorkerReady, op: WorkerReceivePayloadOp.WorkerReady,
} satisfies WorkerReceivePayload; };
this.emit('message', message); this.emit('message', message);
}); });
}); });
@@ -68,39 +68,39 @@ vi.mock('node:worker_threads', async () => {
public postMessage(message: WorkerSendPayload) { public postMessage(message: WorkerSendPayload) {
switch (message.op) { switch (message.op) {
case WorkerSendPayloadOp.Connect: { case WorkerSendPayloadOp.Connect: {
const response = { const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Connected, op: WorkerReceivePayloadOp.Connected,
shardId: message.shardId, shardId: message.shardId,
} satisfies WorkerReceivePayload; };
this.emit('message', response); this.emit('message', response);
break; break;
} }
case WorkerSendPayloadOp.Destroy: { case WorkerSendPayloadOp.Destroy: {
const response = { const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Destroyed, op: WorkerReceivePayloadOp.Destroyed,
shardId: message.shardId, shardId: message.shardId,
} satisfies WorkerReceivePayload; };
this.emit('message', response); this.emit('message', response);
break; break;
} }
case WorkerSendPayloadOp.Send: { case WorkerSendPayloadOp.Send: {
if (message.payload.op === GatewayOpcodes.RequestGuildMembers) { if (message.payload.op === GatewayOpcodes.RequestGuildMembers) {
const response = { const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Event, op: WorkerReceivePayloadOp.Event,
shardId: message.shardId, shardId: message.shardId,
event: WebSocketShardEvents.Dispatch, event: WebSocketShardEvents.Dispatch,
data: memberChunkData, data: memberChunkData,
} satisfies WorkerReceivePayload; };
this.emit('message', response); this.emit('message', response);
// Fetch session info // Fetch session info
const sessionFetch = { const sessionFetch: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.RetrieveSessionInfo, op: WorkerReceivePayloadOp.RetrieveSessionInfo,
shardId: message.shardId, shardId: message.shardId,
nonce: Math.random(), nonce: Math.random(),
} satisfies WorkerReceivePayload; };
this.emit('message', sessionFetch); this.emit('message', sessionFetch);
} }
@@ -111,16 +111,16 @@ vi.mock('node:worker_threads', async () => {
case WorkerSendPayloadOp.SessionInfoResponse: { case WorkerSendPayloadOp.SessionInfoResponse: {
message.session ??= sessionInfo; message.session ??= sessionInfo;
const session = { const session: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.UpdateSessionInfo, op: WorkerReceivePayloadOp.UpdateSessionInfo,
shardId: message.session.shardId, shardId: message.session.shardId,
session: { ...message.session, sequence: message.session.sequence + 1 }, session: { ...message.session, sequence: message.session.sequence + 1 },
} satisfies WorkerReceivePayload; };
this.emit('message', session); this.emit('message', session);
break; break;
} }
case WorkerSendPayloadOp.ShardCanIdentify: { case WorkerSendPayloadOp.ShardIdentifyResponse: {
break; break;
} }
@@ -198,10 +198,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }), expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
); );
const payload = { const payload: GatewaySendPayload = {
op: GatewayOpcodes.RequestGuildMembers, op: GatewayOpcodes.RequestGuildMembers,
d: { guild_id: '123', limit: 0, query: '' }, d: { guild_id: '123', limit: 0, query: '' },
} satisfies GatewaySendPayload; };
await manager.send(0, payload); await manager.send(0, payload);
expect(mockSend).toHaveBeenCalledWith(0, payload); expect(mockSend).toHaveBeenCalledWith(0, payload);
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, { expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {

View File

@@ -1,46 +0,0 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { expect, test, vi, type Mock } from 'vitest';
import { IdentifyThrottler, type WebSocketManager } from '../../src/index.js';
vi.mock('node:timers/promises', () => ({
setTimeout: vi.fn(),
}));
const fetchGatewayInformation = vi.fn();
const manager = {
fetchGatewayInformation,
} as unknown as WebSocketManager;
const throttler = new IdentifyThrottler(manager);
vi.useFakeTimers();
const NOW = vi.fn().mockReturnValue(Date.now());
global.Date.now = NOW;
test('wait for identify', async () => {
fetchGatewayInformation.mockReturnValue({
session_start_limit: {
max_concurrency: 2,
},
});
// First call should never wait
await throttler.waitForIdentify();
expect(sleep).not.toHaveBeenCalled();
// Second call still won't wait because max_concurrency is 2
await throttler.waitForIdentify();
expect(sleep).not.toHaveBeenCalled();
// Third call should wait
await throttler.waitForIdentify();
expect(sleep).toHaveBeenCalled();
(sleep as Mock).mockRestore();
// Fourth call shouldn't wait, because our max_concurrency is 2 and we waited for a reset
await throttler.waitForIdentify();
expect(sleep).not.toHaveBeenCalled();
});

View File

@@ -0,0 +1,32 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { expect, test, vi, type Mock } from 'vitest';
import { SimpleIdentifyThrottler } from '../../src/index.js';
vi.mock('node:timers/promises', () => ({
setTimeout: vi.fn(),
}));
const throttler = new SimpleIdentifyThrottler(2);
vi.useFakeTimers();
const NOW = vi.fn().mockReturnValue(Date.now());
global.Date.now = NOW;
test('basic case', async () => {
// Those shouldn't wait since they're in different keys
await throttler.waitForIdentify(0);
expect(sleep).not.toHaveBeenCalled();
await throttler.waitForIdentify(1);
expect(sleep).not.toHaveBeenCalled();
// Those should wait
await throttler.waitForIdentify(2);
expect(sleep).toHaveBeenCalledTimes(1);
await throttler.waitForIdentify(3);
expect(sleep).toHaveBeenCalledTimes(2);
});

View File

@@ -6,8 +6,10 @@ export * from './strategies/sharding/IShardingStrategy.js';
export * from './strategies/sharding/SimpleShardingStrategy.js'; export * from './strategies/sharding/SimpleShardingStrategy.js';
export * from './strategies/sharding/WorkerShardingStrategy.js'; export * from './strategies/sharding/WorkerShardingStrategy.js';
export * from './throttling/IIdentifyThrottler.js';
export * from './throttling/SimpleIdentifyThrottler.js';
export * from './utils/constants.js'; export * from './utils/constants.js';
export * from './utils/IdentifyThrottler.js';
export * from './utils/WorkerBootstrapper.js'; export * from './utils/WorkerBootstrapper.js';
export * from './ws/WebSocketManager.js'; export * from './ws/WebSocketManager.js';

View File

@@ -5,7 +5,13 @@ import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../
export interface FetchingStrategyOptions export interface FetchingStrategyOptions
extends Omit< extends Omit<
WebSocketManagerOptions, WebSocketManagerOptions,
'buildStrategy' | 'rest' | 'retrieveSessionInfo' | 'shardCount' | 'shardIds' | 'updateSessionInfo' | 'buildIdentifyThrottler'
| 'buildStrategy'
| 'rest'
| 'retrieveSessionInfo'
| 'shardCount'
| 'shardIds'
| 'updateSessionInfo'
> { > {
readonly gatewayInformation: APIGatewayBotInfo; readonly gatewayInformation: APIGatewayBotInfo;
readonly shardCount: number; readonly shardCount: number;
@@ -18,13 +24,25 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions; readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>; retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>; updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>; /**
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted
*/
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
} }
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> { export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
// eslint-disable-next-line @typescript-eslint/unbound-method /* eslint-disable @typescript-eslint/unbound-method */
const { buildStrategy, retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } = const {
manager.options; buildIdentifyThrottler,
buildStrategy,
retrieveSessionInfo,
updateSessionInfo,
shardCount,
shardIds,
rest,
...managerOptions
} = manager.options;
/* eslint-enable @typescript-eslint/unbound-method */
return { return {
...managerOptions, ...managerOptions,

View File

@@ -1,28 +1,25 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
// This strategy assumes every shard is running under the same process - therefore we need a single // This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager. // IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>(); private static throttlerCache = new WeakMap<WebSocketManager, IIdentifyThrottler>();
private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler { private static async ensureThrottler(manager: WebSocketManager): Promise<IIdentifyThrottler> {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager); const throttler = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) { if (throttler) {
return existing; return throttler;
} }
const throttler = new IdentifyThrottler(manager); const newThrottler = await manager.options.buildIdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler); SimpleContextFetchingStrategy.throttlerCache.set(manager, newThrottler);
return throttler;
return newThrottler;
} }
private readonly throttler: IdentifyThrottler; public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> { public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId); return this.manager.options.retrieveSessionInfo(shardId);
@@ -32,7 +29,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
return this.manager.options.updateSessionInfo(shardId, sessionInfo); return this.manager.options.updateSessionInfo(shardId, sessionInfo);
} }
public async waitForIdentify(): Promise<void> { public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
await this.throttler.waitForIdentify(); const throttler = await SimpleContextFetchingStrategy.ensureThrottler(this.manager);
await throttler.waitForIdentify(shardId, signal);
} }
} }

View File

@@ -9,10 +9,17 @@ import {
} from '../sharding/WorkerShardingStrategy.js'; } from '../sharding/WorkerShardingStrategy.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
// Because the global types are incomplete for whatever reason
interface PolyFillAbortSignal {
readonly aborted: boolean;
addEventListener(type: 'abort', listener: () => void): void;
removeEventListener(type: 'abort', listener: () => void): void;
}
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>(); private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
private readonly waitForIdentifyPromises = new Collection<number, () => void>(); private readonly waitForIdentifyPromises = new Collection<number, { reject(): void; resolve(): void }>();
public constructor(public readonly options: FetchingStrategyOptions) { public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) { if (isMainThread) {
@@ -25,8 +32,14 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
this.sessionPromises.delete(payload.nonce); this.sessionPromises.delete(payload.nonce);
} }
if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) { if (payload.op === WorkerSendPayloadOp.ShardIdentifyResponse) {
this.waitForIdentifyPromises.get(payload.nonce)?.(); const promise = this.waitForIdentifyPromises.get(payload.nonce);
if (payload.ok) {
promise?.resolve();
} else {
promise?.reject();
}
this.waitForIdentifyPromises.delete(payload.nonce); this.waitForIdentifyPromises.delete(payload.nonce);
} }
}); });
@@ -34,11 +47,11 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> { public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random(); const nonce = Math.random();
const payload = { const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.RetrieveSessionInfo, op: WorkerReceivePayloadOp.RetrieveSessionInfo,
shardId, shardId,
nonce, nonce,
} satisfies WorkerReceivePayload; };
// eslint-disable-next-line no-promise-executor-return // eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve)); const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
@@ -46,23 +59,44 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
} }
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload = { const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.UpdateSessionInfo, op: WorkerReceivePayloadOp.UpdateSessionInfo,
shardId, shardId,
session: sessionInfo, session: sessionInfo,
} satisfies WorkerReceivePayload; };
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
} }
public async waitForIdentify(): Promise<void> { public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const nonce = Math.random(); const nonce = Math.random();
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WaitForIdentify, op: WorkerReceivePayloadOp.WaitForIdentify,
nonce, nonce,
} satisfies WorkerReceivePayload; shardId,
// eslint-disable-next-line no-promise-executor-return };
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve)); const promise = new Promise<void>((resolve, reject) =>
// eslint-disable-next-line no-promise-executor-return
this.waitForIdentifyPromises.set(nonce, { resolve, reject }),
);
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
return promise;
const listener = () => {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.CancelIdentify,
nonce,
};
parentPort!.postMessage(payload);
};
(signal as unknown as PolyFillAbortSignal).addEventListener('abort', listener);
try {
await promise;
} finally {
(signal as unknown as PolyFillAbortSignal).removeEventListener('abort', listener);
}
} }
} }

View File

@@ -23,6 +23,7 @@ export class SimpleShardingStrategy implements IShardingStrategy {
*/ */
public async spawn(shardIds: number[]) { public async spawn(shardIds: number[]) {
const strategyOptions = await managerToFetchingStrategyOptions(this.manager); const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
for (const shardId of shardIds) { for (const shardId of shardIds) {
const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions); const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions);
const shard = new WebSocketShard(strategy, shardId); const shard = new WebSocketShard(strategy, shardId);

View File

@@ -3,7 +3,7 @@ import { join, isAbsolute, resolve } from 'node:path';
import { Worker } from 'node:worker_threads'; import { Worker } from 'node:worker_threads';
import { Collection } from '@discordjs/collection'; import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10'; import type { GatewaySendPayload } from 'discord-api-types/v10';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager';
import type { WebSocketShardDestroyOptions, WebSocketShardEvents, WebSocketShardStatus } from '../../ws/WebSocketShard'; import type { WebSocketShardDestroyOptions, WebSocketShardEvents, WebSocketShardStatus } from '../../ws/WebSocketShard';
import { managerToFetchingStrategyOptions, type FetchingStrategyOptions } from '../context/IContextFetchingStrategy.js'; import { managerToFetchingStrategyOptions, type FetchingStrategyOptions } from '../context/IContextFetchingStrategy.js';
@@ -18,14 +18,14 @@ export enum WorkerSendPayloadOp {
Destroy, Destroy,
Send, Send,
SessionInfoResponse, SessionInfoResponse,
ShardCanIdentify, ShardIdentifyResponse,
FetchStatus, FetchStatus,
} }
export type WorkerSendPayload = export type WorkerSendPayload =
| { nonce: number; ok: boolean; op: WorkerSendPayloadOp.ShardIdentifyResponse }
| { nonce: number; op: WorkerSendPayloadOp.FetchStatus; shardId: number } | { nonce: number; op: WorkerSendPayloadOp.FetchStatus; shardId: number }
| { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null } | { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null }
| { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify }
| { op: WorkerSendPayloadOp.Connect; shardId: number } | { op: WorkerSendPayloadOp.Connect; shardId: number }
| { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number } | { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number }
| { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number }; | { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number };
@@ -39,14 +39,16 @@ export enum WorkerReceivePayloadOp {
WaitForIdentify, WaitForIdentify,
FetchStatusResponse, FetchStatusResponse,
WorkerReady, WorkerReady,
CancelIdentify,
} }
export type WorkerReceivePayload = export type WorkerReceivePayload =
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now // Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
| { data: any; event: WebSocketShardEvents; op: WorkerReceivePayloadOp.Event; shardId: number } | { data: any; event: WebSocketShardEvents; op: WorkerReceivePayloadOp.Event; shardId: number }
| { nonce: number; op: WorkerReceivePayloadOp.CancelIdentify }
| { nonce: number; op: WorkerReceivePayloadOp.FetchStatusResponse; status: WebSocketShardStatus } | { nonce: number; op: WorkerReceivePayloadOp.FetchStatusResponse; status: WebSocketShardStatus }
| { nonce: number; op: WorkerReceivePayloadOp.RetrieveSessionInfo; shardId: number } | { nonce: number; op: WorkerReceivePayloadOp.RetrieveSessionInfo; shardId: number }
| { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify } | { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify; shardId: number }
| { op: WorkerReceivePayloadOp.Connected; shardId: number } | { op: WorkerReceivePayloadOp.Connected; shardId: number }
| { op: WorkerReceivePayloadOp.Destroyed; shardId: number } | { op: WorkerReceivePayloadOp.Destroyed; shardId: number }
| { op: WorkerReceivePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number } | { op: WorkerReceivePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }
@@ -84,11 +86,12 @@ export class WorkerShardingStrategy implements IShardingStrategy {
private readonly fetchStatusPromises = new Collection<number, (status: WebSocketShardStatus) => void>(); private readonly fetchStatusPromises = new Collection<number, (status: WebSocketShardStatus) => void>();
private readonly throttler: IdentifyThrottler; private readonly waitForIdentifyControllers = new Collection<number, AbortController>();
private throttler?: IIdentifyThrottler;
public constructor(manager: WebSocketManager, options: WorkerShardingStrategyOptions) { public constructor(manager: WebSocketManager, options: WorkerShardingStrategyOptions) {
this.manager = manager; this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
this.options = options; this.options = options;
} }
@@ -122,10 +125,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = []; const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) { for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload = { const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Connect, op: WorkerSendPayloadOp.Connect,
shardId, shardId,
} satisfies WorkerSendPayload; };
// eslint-disable-next-line no-promise-executor-return // eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve)); const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
@@ -143,11 +146,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = []; const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) { for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload = { const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Destroy, op: WorkerSendPayloadOp.Destroy,
shardId, shardId,
options, options,
} satisfies WorkerSendPayload; };
promises.push( promises.push(
// eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then // eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then
@@ -171,11 +174,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
throw new Error(`No worker found for shard ${shardId}`); throw new Error(`No worker found for shard ${shardId}`);
} }
const payload = { const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Send, op: WorkerSendPayloadOp.Send,
shardId, shardId,
payload: data, payload: data,
} satisfies WorkerSendPayload; };
worker.postMessage(payload); worker.postMessage(payload);
} }
@@ -187,11 +190,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
for (const [shardId, worker] of this.#workerByShardId.entries()) { for (const [shardId, worker] of this.#workerByShardId.entries()) {
const nonce = Math.random(); const nonce = Math.random();
const payload = { const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.FetchStatus, op: WorkerSendPayloadOp.FetchStatus,
shardId, shardId,
nonce, nonce,
} satisfies WorkerSendPayload; };
// eslint-disable-next-line no-promise-executor-return // eslint-disable-next-line no-promise-executor-return
const promise = new Promise<WebSocketShardStatus>((resolve) => this.fetchStatusPromises.set(nonce, resolve)); const promise = new Promise<WebSocketShardStatus>((resolve) => this.fetchStatusPromises.set(nonce, resolve));
@@ -297,10 +300,21 @@ export class WorkerShardingStrategy implements IShardingStrategy {
} }
case WorkerReceivePayloadOp.WaitForIdentify: { case WorkerReceivePayloadOp.WaitForIdentify: {
await this.throttler.waitForIdentify(); const throttler = await this.ensureThrottler();
// If this rejects it means we aborted, in which case we reply elsewhere.
try {
const controller = new AbortController();
this.waitForIdentifyControllers.set(payload.nonce, controller);
await throttler.waitForIdentify(payload.shardId, controller.signal);
} catch {
return;
}
const response: WorkerSendPayload = { const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.ShardCanIdentify, op: WorkerSendPayloadOp.ShardIdentifyResponse,
nonce: payload.nonce, nonce: payload.nonce,
ok: true,
}; };
worker.postMessage(response); worker.postMessage(response);
break; break;
@@ -315,6 +329,25 @@ export class WorkerShardingStrategy implements IShardingStrategy {
case WorkerReceivePayloadOp.WorkerReady: { case WorkerReceivePayloadOp.WorkerReady: {
break; break;
} }
case WorkerReceivePayloadOp.CancelIdentify: {
this.waitForIdentifyControllers.get(payload.nonce)?.abort();
this.waitForIdentifyControllers.delete(payload.nonce);
const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.ShardIdentifyResponse,
nonce: payload.nonce,
ok: false,
};
worker.postMessage(response);
break;
}
} }
} }
private async ensureThrottler(): Promise<IIdentifyThrottler> {
this.throttler ??= await this.manager.options.buildIdentifyThrottler(this.manager);
return this.throttler;
}
} }

View File

@@ -0,0 +1,11 @@
/**
* IdentifyThrottlers are responsible for dictating when a shard is allowed to identify.
*
* @see {@link https://discord.com/developers/docs/topics/gateway#sharding-max-concurrency}
*/
export interface IIdentifyThrottler {
/**
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted.
*/
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
}

View File

@@ -0,0 +1,50 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { Collection } from '@discordjs/collection';
import { AsyncQueue } from '@sapphire/async-queue';
import type { IIdentifyThrottler } from './IIdentifyThrottler';
/**
* The state of a rate limit key's identify queue.
*/
export interface IdentifyState {
queue: AsyncQueue;
resetsAt: number;
}
/**
* Local, in-memory identify throttler.
*/
export class SimpleIdentifyThrottler implements IIdentifyThrottler {
private readonly states = new Collection<number, IdentifyState>();
public constructor(private readonly maxConcurrency: number) {}
/**
* {@inheritDoc IIdentifyThrottler.waitForIdentify}
*/
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const key = shardId % this.maxConcurrency;
const state = this.states.ensure(key, () => {
return {
queue: new AsyncQueue(),
resetsAt: Number.POSITIVE_INFINITY,
};
});
await state.queue.wait({ signal });
try {
const diff = state.resetsAt - Date.now();
if (diff <= 5_000) {
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
const time = diff + Math.random() * 1_500;
await sleep(time);
}
state.resetsAt = Date.now() + 5_000;
} finally {
state.queue.shift();
}
}
}

View File

@@ -1,39 +0,0 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { AsyncQueue } from '@sapphire/async-queue';
import type { WebSocketManager } from '../ws/WebSocketManager.js';
export class IdentifyThrottler {
private readonly queue = new AsyncQueue();
private identifyState = {
remaining: 0,
resetsAt: Number.POSITIVE_INFINITY,
};
public constructor(private readonly manager: WebSocketManager) {}
public async waitForIdentify(): Promise<void> {
await this.queue.wait();
try {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
const time = diff + Math.random() * 1_500;
await sleep(time);
}
const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
}
this.identifyState.remaining--;
} finally {
this.queue.shift();
}
}
}

View File

@@ -117,7 +117,7 @@ export class WorkerBootstrapper {
break; break;
} }
case WorkerSendPayloadOp.ShardCanIdentify: { case WorkerSendPayloadOp.ShardIdentifyResponse: {
break; break;
} }
@@ -127,11 +127,11 @@ export class WorkerBootstrapper {
throw new Error(`Shard ${payload.shardId} does not exist`); throw new Error(`Shard ${payload.shardId} does not exist`);
} }
const response = { const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.FetchStatusResponse, op: WorkerReceivePayloadOp.FetchStatusResponse,
status: shard.status, status: shard.status,
nonce: payload.nonce, nonce: payload.nonce,
} satisfies WorkerReceivePayload; };
parentPort!.postMessage(response); parentPort!.postMessage(response);
break; break;
@@ -150,12 +150,12 @@ export class WorkerBootstrapper {
for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) { for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible // @ts-expect-error: Event types incompatible
shard.on(event, (data) => { shard.on(event, (data) => {
const payload = { const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Event, op: WorkerReceivePayloadOp.Event,
event, event,
data, data,
shardId, shardId,
} satisfies WorkerReceivePayload; };
parentPort!.postMessage(payload); parentPort!.postMessage(payload);
}); });
} }
@@ -168,9 +168,9 @@ export class WorkerBootstrapper {
// Lastly, start listening to messages from the parent thread // Lastly, start listening to messages from the parent thread
this.setupThreadEvents(); this.setupThreadEvents();
const message = { const message: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WorkerReady, op: WorkerReceivePayloadOp.WorkerReady,
} satisfies WorkerReceivePayload; };
parentPort!.postMessage(message); parentPort!.postMessage(message);
} }
} }

View File

@@ -3,7 +3,8 @@ import { Collection } from '@discordjs/collection';
import { lazy } from '@discordjs/util'; import { lazy } from '@discordjs/util';
import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10'; import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10';
import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js'; import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js';
import type { SessionInfo, OptionalWebSocketManagerOptions } from '../ws/WebSocketManager.js'; import { SimpleIdentifyThrottler } from '../throttling/SimpleIdentifyThrottler.js';
import type { SessionInfo, OptionalWebSocketManagerOptions, WebSocketManager } from '../ws/WebSocketManager.js';
import type { SendRateLimitState } from '../ws/WebSocketShard.js'; import type { SendRateLimitState } from '../ws/WebSocketShard.js';
/** /**
@@ -28,6 +29,10 @@ const getDefaultSessionStore = lazy(() => new Collection<number, SessionInfo | n
* Default options used by the manager * Default options used by the manager
*/ */
export const DefaultWebSocketManagerOptions = { export const DefaultWebSocketManagerOptions = {
async buildIdentifyThrottler(manager: WebSocketManager) {
const info = await manager.fetchGatewayInformation();
return new SimpleIdentifyThrottler(info.session_start_limit.max_concurrency);
},
buildStrategy: (manager) => new SimpleShardingStrategy(manager), buildStrategy: (manager) => new SimpleShardingStrategy(manager),
shardCount: null, shardCount: null,
shardIds: null, shardIds: null,

View File

@@ -11,6 +11,7 @@ import {
type GatewaySendPayload, type GatewaySendPayload,
} from 'discord-api-types/v10'; } from 'discord-api-types/v10';
import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy'; import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy';
import type { IIdentifyThrottler } from '../throttling/IIdentifyThrottler';
import { DefaultWebSocketManagerOptions, type CompressionMethod, type Encoding } from '../utils/constants.js'; import { DefaultWebSocketManagerOptions, type CompressionMethod, type Encoding } from '../utils/constants.js';
import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard.js'; import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard.js';
@@ -55,7 +56,7 @@ export interface RequiredWebSocketManagerOptions {
/** /**
* The intents to request * The intents to request
*/ */
intents: GatewayIntentBits; intents: GatewayIntentBits | 0;
/** /**
* The REST instance to use for fetching gateway information * The REST instance to use for fetching gateway information
*/ */
@@ -70,6 +71,10 @@ export interface RequiredWebSocketManagerOptions {
* Optional additional configuration for the WebSocketManager * Optional additional configuration for the WebSocketManager
*/ */
export interface OptionalWebSocketManagerOptions { export interface OptionalWebSocketManagerOptions {
/**
* Builds an identify throttler to use for this manager's shards
*/
buildIdentifyThrottler(manager: WebSocketManager): Awaitable<IIdentifyThrottler>;
/** /**
* Builds the strategy to use for sharding * Builds the strategy to use for sharding
* *

View File

@@ -358,7 +358,21 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private async identify() { private async identify() {
this.debug(['Waiting for identify throttle']); this.debug(['Waiting for identify throttle']);
await this.strategy.waitForIdentify(); const controller = new AbortController();
const closeHandler = () => {
controller.abort();
};
this.on(WebSocketShardEvents.Closed, closeHandler);
try {
await this.strategy.waitForIdentify(this.id, controller.signal);
} catch {
this.debug(['Was waiting for an identify, but the shard closed in the meantime']);
return;
} finally {
this.off(WebSocketShardEvents.Closed, closeHandler);
}
this.debug([ this.debug([
'Identifying', 'Identifying',