mirror of
https://github.com/discordjs/discord.js.git
synced 2026-03-10 00:23:30 +01:00
refactor: abstract identify throttling and correct max_concurrency handling (#9375)
* refactor: properly support max_concurrency ratelimit keys * fix: properly block for same key * chore: export session state * chore: throttler no longer requires manager * refactor: abstract throttlers * chore: proper member order * chore: remove leftover debug log * chore: use @link tag in doc comment Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com> * chore: suggested changes * fix(WebSocketShard): cancel identify if the shard closed in the meantime * refactor(throttlers): support abort signals * fix: memory leak * chore: remove leftover --------- Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com> Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
This commit is contained in:
@@ -57,9 +57,9 @@ vi.mock('node:worker_threads', async () => {
|
||||
this.emit('online');
|
||||
// same deal here
|
||||
setImmediate(() => {
|
||||
const message = {
|
||||
const message: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.WorkerReady,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', message);
|
||||
});
|
||||
});
|
||||
@@ -68,39 +68,39 @@ vi.mock('node:worker_threads', async () => {
|
||||
public postMessage(message: WorkerSendPayload) {
|
||||
switch (message.op) {
|
||||
case WorkerSendPayloadOp.Connect: {
|
||||
const response = {
|
||||
const response: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.Connected,
|
||||
shardId: message.shardId,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', response);
|
||||
break;
|
||||
}
|
||||
|
||||
case WorkerSendPayloadOp.Destroy: {
|
||||
const response = {
|
||||
const response: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.Destroyed,
|
||||
shardId: message.shardId,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', response);
|
||||
break;
|
||||
}
|
||||
|
||||
case WorkerSendPayloadOp.Send: {
|
||||
if (message.payload.op === GatewayOpcodes.RequestGuildMembers) {
|
||||
const response = {
|
||||
const response: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.Event,
|
||||
shardId: message.shardId,
|
||||
event: WebSocketShardEvents.Dispatch,
|
||||
data: memberChunkData,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', response);
|
||||
|
||||
// Fetch session info
|
||||
const sessionFetch = {
|
||||
const sessionFetch: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.RetrieveSessionInfo,
|
||||
shardId: message.shardId,
|
||||
nonce: Math.random(),
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', sessionFetch);
|
||||
}
|
||||
|
||||
@@ -111,16 +111,16 @@ vi.mock('node:worker_threads', async () => {
|
||||
case WorkerSendPayloadOp.SessionInfoResponse: {
|
||||
message.session ??= sessionInfo;
|
||||
|
||||
const session = {
|
||||
const session: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.UpdateSessionInfo,
|
||||
shardId: message.session.shardId,
|
||||
session: { ...message.session, sequence: message.session.sequence + 1 },
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
this.emit('message', session);
|
||||
break;
|
||||
}
|
||||
|
||||
case WorkerSendPayloadOp.ShardCanIdentify: {
|
||||
case WorkerSendPayloadOp.ShardIdentifyResponse: {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -198,10 +198,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
|
||||
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
|
||||
);
|
||||
|
||||
const payload = {
|
||||
const payload: GatewaySendPayload = {
|
||||
op: GatewayOpcodes.RequestGuildMembers,
|
||||
d: { guild_id: '123', limit: 0, query: '' },
|
||||
} satisfies GatewaySendPayload;
|
||||
};
|
||||
await manager.send(0, payload);
|
||||
expect(mockSend).toHaveBeenCalledWith(0, payload);
|
||||
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import { setTimeout as sleep } from 'node:timers/promises';
|
||||
import { expect, test, vi, type Mock } from 'vitest';
|
||||
import { IdentifyThrottler, type WebSocketManager } from '../../src/index.js';
|
||||
|
||||
vi.mock('node:timers/promises', () => ({
|
||||
setTimeout: vi.fn(),
|
||||
}));
|
||||
|
||||
const fetchGatewayInformation = vi.fn();
|
||||
|
||||
const manager = {
|
||||
fetchGatewayInformation,
|
||||
} as unknown as WebSocketManager;
|
||||
|
||||
const throttler = new IdentifyThrottler(manager);
|
||||
|
||||
vi.useFakeTimers();
|
||||
|
||||
const NOW = vi.fn().mockReturnValue(Date.now());
|
||||
global.Date.now = NOW;
|
||||
|
||||
test('wait for identify', async () => {
|
||||
fetchGatewayInformation.mockReturnValue({
|
||||
session_start_limit: {
|
||||
max_concurrency: 2,
|
||||
},
|
||||
});
|
||||
|
||||
// First call should never wait
|
||||
await throttler.waitForIdentify();
|
||||
expect(sleep).not.toHaveBeenCalled();
|
||||
|
||||
// Second call still won't wait because max_concurrency is 2
|
||||
await throttler.waitForIdentify();
|
||||
expect(sleep).not.toHaveBeenCalled();
|
||||
|
||||
// Third call should wait
|
||||
await throttler.waitForIdentify();
|
||||
expect(sleep).toHaveBeenCalled();
|
||||
|
||||
(sleep as Mock).mockRestore();
|
||||
|
||||
// Fourth call shouldn't wait, because our max_concurrency is 2 and we waited for a reset
|
||||
await throttler.waitForIdentify();
|
||||
expect(sleep).not.toHaveBeenCalled();
|
||||
});
|
||||
32
packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts
Normal file
32
packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { setTimeout as sleep } from 'node:timers/promises';
|
||||
import { expect, test, vi, type Mock } from 'vitest';
|
||||
import { SimpleIdentifyThrottler } from '../../src/index.js';
|
||||
|
||||
vi.mock('node:timers/promises', () => ({
|
||||
setTimeout: vi.fn(),
|
||||
}));
|
||||
|
||||
const throttler = new SimpleIdentifyThrottler(2);
|
||||
|
||||
vi.useFakeTimers();
|
||||
|
||||
const NOW = vi.fn().mockReturnValue(Date.now());
|
||||
global.Date.now = NOW;
|
||||
|
||||
test('basic case', async () => {
|
||||
// Those shouldn't wait since they're in different keys
|
||||
|
||||
await throttler.waitForIdentify(0);
|
||||
expect(sleep).not.toHaveBeenCalled();
|
||||
|
||||
await throttler.waitForIdentify(1);
|
||||
expect(sleep).not.toHaveBeenCalled();
|
||||
|
||||
// Those should wait
|
||||
|
||||
await throttler.waitForIdentify(2);
|
||||
expect(sleep).toHaveBeenCalledTimes(1);
|
||||
|
||||
await throttler.waitForIdentify(3);
|
||||
expect(sleep).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
@@ -6,8 +6,10 @@ export * from './strategies/sharding/IShardingStrategy.js';
|
||||
export * from './strategies/sharding/SimpleShardingStrategy.js';
|
||||
export * from './strategies/sharding/WorkerShardingStrategy.js';
|
||||
|
||||
export * from './throttling/IIdentifyThrottler.js';
|
||||
export * from './throttling/SimpleIdentifyThrottler.js';
|
||||
|
||||
export * from './utils/constants.js';
|
||||
export * from './utils/IdentifyThrottler.js';
|
||||
export * from './utils/WorkerBootstrapper.js';
|
||||
|
||||
export * from './ws/WebSocketManager.js';
|
||||
|
||||
@@ -5,7 +5,13 @@ import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../
|
||||
export interface FetchingStrategyOptions
|
||||
extends Omit<
|
||||
WebSocketManagerOptions,
|
||||
'buildStrategy' | 'rest' | 'retrieveSessionInfo' | 'shardCount' | 'shardIds' | 'updateSessionInfo'
|
||||
| 'buildIdentifyThrottler'
|
||||
| 'buildStrategy'
|
||||
| 'rest'
|
||||
| 'retrieveSessionInfo'
|
||||
| 'shardCount'
|
||||
| 'shardIds'
|
||||
| 'updateSessionInfo'
|
||||
> {
|
||||
readonly gatewayInformation: APIGatewayBotInfo;
|
||||
readonly shardCount: number;
|
||||
@@ -18,13 +24,25 @@ export interface IContextFetchingStrategy {
|
||||
readonly options: FetchingStrategyOptions;
|
||||
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
|
||||
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
|
||||
waitForIdentify(): Promise<void>;
|
||||
/**
|
||||
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted
|
||||
*/
|
||||
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
|
||||
}
|
||||
|
||||
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const { buildStrategy, retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } =
|
||||
manager.options;
|
||||
/* eslint-disable @typescript-eslint/unbound-method */
|
||||
const {
|
||||
buildIdentifyThrottler,
|
||||
buildStrategy,
|
||||
retrieveSessionInfo,
|
||||
updateSessionInfo,
|
||||
shardCount,
|
||||
shardIds,
|
||||
rest,
|
||||
...managerOptions
|
||||
} = manager.options;
|
||||
/* eslint-enable @typescript-eslint/unbound-method */
|
||||
|
||||
return {
|
||||
...managerOptions,
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
|
||||
import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler.js';
|
||||
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
|
||||
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
|
||||
|
||||
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
// This strategy assumes every shard is running under the same process - therefore we need a single
|
||||
// IdentifyThrottler per manager.
|
||||
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();
|
||||
private static throttlerCache = new WeakMap<WebSocketManager, IIdentifyThrottler>();
|
||||
|
||||
private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
|
||||
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
|
||||
if (existing) {
|
||||
return existing;
|
||||
private static async ensureThrottler(manager: WebSocketManager): Promise<IIdentifyThrottler> {
|
||||
const throttler = SimpleContextFetchingStrategy.throttlerCache.get(manager);
|
||||
if (throttler) {
|
||||
return throttler;
|
||||
}
|
||||
|
||||
const throttler = new IdentifyThrottler(manager);
|
||||
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
|
||||
return throttler;
|
||||
const newThrottler = await manager.options.buildIdentifyThrottler(manager);
|
||||
SimpleContextFetchingStrategy.throttlerCache.set(manager, newThrottler);
|
||||
|
||||
return newThrottler;
|
||||
}
|
||||
|
||||
private readonly throttler: IdentifyThrottler;
|
||||
|
||||
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
|
||||
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
|
||||
}
|
||||
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
|
||||
|
||||
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
|
||||
return this.manager.options.retrieveSessionInfo(shardId);
|
||||
@@ -32,7 +29,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
|
||||
}
|
||||
|
||||
public async waitForIdentify(): Promise<void> {
|
||||
await this.throttler.waitForIdentify();
|
||||
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
|
||||
const throttler = await SimpleContextFetchingStrategy.ensureThrottler(this.manager);
|
||||
await throttler.waitForIdentify(shardId, signal);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,17 @@ import {
|
||||
} from '../sharding/WorkerShardingStrategy.js';
|
||||
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
|
||||
|
||||
// Because the global types are incomplete for whatever reason
|
||||
interface PolyFillAbortSignal {
|
||||
readonly aborted: boolean;
|
||||
addEventListener(type: 'abort', listener: () => void): void;
|
||||
removeEventListener(type: 'abort', listener: () => void): void;
|
||||
}
|
||||
|
||||
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
|
||||
|
||||
private readonly waitForIdentifyPromises = new Collection<number, () => void>();
|
||||
private readonly waitForIdentifyPromises = new Collection<number, { reject(): void; resolve(): void }>();
|
||||
|
||||
public constructor(public readonly options: FetchingStrategyOptions) {
|
||||
if (isMainThread) {
|
||||
@@ -25,8 +32,14 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
this.sessionPromises.delete(payload.nonce);
|
||||
}
|
||||
|
||||
if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
|
||||
this.waitForIdentifyPromises.get(payload.nonce)?.();
|
||||
if (payload.op === WorkerSendPayloadOp.ShardIdentifyResponse) {
|
||||
const promise = this.waitForIdentifyPromises.get(payload.nonce);
|
||||
if (payload.ok) {
|
||||
promise?.resolve();
|
||||
} else {
|
||||
promise?.reject();
|
||||
}
|
||||
|
||||
this.waitForIdentifyPromises.delete(payload.nonce);
|
||||
}
|
||||
});
|
||||
@@ -34,11 +47,11 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
|
||||
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
|
||||
const nonce = Math.random();
|
||||
const payload = {
|
||||
const payload: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.RetrieveSessionInfo,
|
||||
shardId,
|
||||
nonce,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
// eslint-disable-next-line no-promise-executor-return
|
||||
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
|
||||
parentPort!.postMessage(payload);
|
||||
@@ -46,23 +59,44 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
|
||||
}
|
||||
|
||||
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
|
||||
const payload = {
|
||||
const payload: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.UpdateSessionInfo,
|
||||
shardId,
|
||||
session: sessionInfo,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
parentPort!.postMessage(payload);
|
||||
}
|
||||
|
||||
public async waitForIdentify(): Promise<void> {
|
||||
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
|
||||
const nonce = Math.random();
|
||||
const payload = {
|
||||
|
||||
const payload: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.WaitForIdentify,
|
||||
nonce,
|
||||
} satisfies WorkerReceivePayload;
|
||||
// eslint-disable-next-line no-promise-executor-return
|
||||
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
|
||||
shardId,
|
||||
};
|
||||
const promise = new Promise<void>((resolve, reject) =>
|
||||
// eslint-disable-next-line no-promise-executor-return
|
||||
this.waitForIdentifyPromises.set(nonce, { resolve, reject }),
|
||||
);
|
||||
|
||||
parentPort!.postMessage(payload);
|
||||
return promise;
|
||||
|
||||
const listener = () => {
|
||||
const payload: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.CancelIdentify,
|
||||
nonce,
|
||||
};
|
||||
|
||||
parentPort!.postMessage(payload);
|
||||
};
|
||||
|
||||
(signal as unknown as PolyFillAbortSignal).addEventListener('abort', listener);
|
||||
|
||||
try {
|
||||
await promise;
|
||||
} finally {
|
||||
(signal as unknown as PolyFillAbortSignal).removeEventListener('abort', listener);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ export class SimpleShardingStrategy implements IShardingStrategy {
|
||||
*/
|
||||
public async spawn(shardIds: number[]) {
|
||||
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
|
||||
|
||||
for (const shardId of shardIds) {
|
||||
const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions);
|
||||
const shard = new WebSocketShard(strategy, shardId);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { join, isAbsolute, resolve } from 'node:path';
|
||||
import { Worker } from 'node:worker_threads';
|
||||
import { Collection } from '@discordjs/collection';
|
||||
import type { GatewaySendPayload } from 'discord-api-types/v10';
|
||||
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
|
||||
import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler';
|
||||
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager';
|
||||
import type { WebSocketShardDestroyOptions, WebSocketShardEvents, WebSocketShardStatus } from '../../ws/WebSocketShard';
|
||||
import { managerToFetchingStrategyOptions, type FetchingStrategyOptions } from '../context/IContextFetchingStrategy.js';
|
||||
@@ -18,14 +18,14 @@ export enum WorkerSendPayloadOp {
|
||||
Destroy,
|
||||
Send,
|
||||
SessionInfoResponse,
|
||||
ShardCanIdentify,
|
||||
ShardIdentifyResponse,
|
||||
FetchStatus,
|
||||
}
|
||||
|
||||
export type WorkerSendPayload =
|
||||
| { nonce: number; ok: boolean; op: WorkerSendPayloadOp.ShardIdentifyResponse }
|
||||
| { nonce: number; op: WorkerSendPayloadOp.FetchStatus; shardId: number }
|
||||
| { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null }
|
||||
| { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify }
|
||||
| { op: WorkerSendPayloadOp.Connect; shardId: number }
|
||||
| { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number }
|
||||
| { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number };
|
||||
@@ -39,14 +39,16 @@ export enum WorkerReceivePayloadOp {
|
||||
WaitForIdentify,
|
||||
FetchStatusResponse,
|
||||
WorkerReady,
|
||||
CancelIdentify,
|
||||
}
|
||||
|
||||
export type WorkerReceivePayload =
|
||||
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
|
||||
| { data: any; event: WebSocketShardEvents; op: WorkerReceivePayloadOp.Event; shardId: number }
|
||||
| { nonce: number; op: WorkerReceivePayloadOp.CancelIdentify }
|
||||
| { nonce: number; op: WorkerReceivePayloadOp.FetchStatusResponse; status: WebSocketShardStatus }
|
||||
| { nonce: number; op: WorkerReceivePayloadOp.RetrieveSessionInfo; shardId: number }
|
||||
| { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify }
|
||||
| { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify; shardId: number }
|
||||
| { op: WorkerReceivePayloadOp.Connected; shardId: number }
|
||||
| { op: WorkerReceivePayloadOp.Destroyed; shardId: number }
|
||||
| { op: WorkerReceivePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }
|
||||
@@ -84,11 +86,12 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
|
||||
private readonly fetchStatusPromises = new Collection<number, (status: WebSocketShardStatus) => void>();
|
||||
|
||||
private readonly throttler: IdentifyThrottler;
|
||||
private readonly waitForIdentifyControllers = new Collection<number, AbortController>();
|
||||
|
||||
private throttler?: IIdentifyThrottler;
|
||||
|
||||
public constructor(manager: WebSocketManager, options: WorkerShardingStrategyOptions) {
|
||||
this.manager = manager;
|
||||
this.throttler = new IdentifyThrottler(manager);
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
@@ -122,10 +125,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
const promises = [];
|
||||
|
||||
for (const [shardId, worker] of this.#workerByShardId.entries()) {
|
||||
const payload = {
|
||||
const payload: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.Connect,
|
||||
shardId,
|
||||
} satisfies WorkerSendPayload;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-promise-executor-return
|
||||
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
|
||||
@@ -143,11 +146,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
const promises = [];
|
||||
|
||||
for (const [shardId, worker] of this.#workerByShardId.entries()) {
|
||||
const payload = {
|
||||
const payload: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.Destroy,
|
||||
shardId,
|
||||
options,
|
||||
} satisfies WorkerSendPayload;
|
||||
};
|
||||
|
||||
promises.push(
|
||||
// eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then
|
||||
@@ -171,11 +174,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
throw new Error(`No worker found for shard ${shardId}`);
|
||||
}
|
||||
|
||||
const payload = {
|
||||
const payload: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.Send,
|
||||
shardId,
|
||||
payload: data,
|
||||
} satisfies WorkerSendPayload;
|
||||
};
|
||||
worker.postMessage(payload);
|
||||
}
|
||||
|
||||
@@ -187,11 +190,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
|
||||
for (const [shardId, worker] of this.#workerByShardId.entries()) {
|
||||
const nonce = Math.random();
|
||||
const payload = {
|
||||
const payload: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.FetchStatus,
|
||||
shardId,
|
||||
nonce,
|
||||
} satisfies WorkerSendPayload;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-promise-executor-return
|
||||
const promise = new Promise<WebSocketShardStatus>((resolve) => this.fetchStatusPromises.set(nonce, resolve));
|
||||
@@ -297,10 +300,21 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
}
|
||||
|
||||
case WorkerReceivePayloadOp.WaitForIdentify: {
|
||||
await this.throttler.waitForIdentify();
|
||||
const throttler = await this.ensureThrottler();
|
||||
|
||||
// If this rejects it means we aborted, in which case we reply elsewhere.
|
||||
try {
|
||||
const controller = new AbortController();
|
||||
this.waitForIdentifyControllers.set(payload.nonce, controller);
|
||||
await throttler.waitForIdentify(payload.shardId, controller.signal);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
const response: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.ShardCanIdentify,
|
||||
op: WorkerSendPayloadOp.ShardIdentifyResponse,
|
||||
nonce: payload.nonce,
|
||||
ok: true,
|
||||
};
|
||||
worker.postMessage(response);
|
||||
break;
|
||||
@@ -315,6 +329,25 @@ export class WorkerShardingStrategy implements IShardingStrategy {
|
||||
case WorkerReceivePayloadOp.WorkerReady: {
|
||||
break;
|
||||
}
|
||||
|
||||
case WorkerReceivePayloadOp.CancelIdentify: {
|
||||
this.waitForIdentifyControllers.get(payload.nonce)?.abort();
|
||||
this.waitForIdentifyControllers.delete(payload.nonce);
|
||||
|
||||
const response: WorkerSendPayload = {
|
||||
op: WorkerSendPayloadOp.ShardIdentifyResponse,
|
||||
nonce: payload.nonce,
|
||||
ok: false,
|
||||
};
|
||||
worker.postMessage(response);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async ensureThrottler(): Promise<IIdentifyThrottler> {
|
||||
this.throttler ??= await this.manager.options.buildIdentifyThrottler(this.manager);
|
||||
return this.throttler;
|
||||
}
|
||||
}
|
||||
|
||||
11
packages/ws/src/throttling/IIdentifyThrottler.ts
Normal file
11
packages/ws/src/throttling/IIdentifyThrottler.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
/**
|
||||
* IdentifyThrottlers are responsible for dictating when a shard is allowed to identify.
|
||||
*
|
||||
* @see {@link https://discord.com/developers/docs/topics/gateway#sharding-max-concurrency}
|
||||
*/
|
||||
export interface IIdentifyThrottler {
|
||||
/**
|
||||
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted.
|
||||
*/
|
||||
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
|
||||
}
|
||||
50
packages/ws/src/throttling/SimpleIdentifyThrottler.ts
Normal file
50
packages/ws/src/throttling/SimpleIdentifyThrottler.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { setTimeout as sleep } from 'node:timers/promises';
|
||||
import { Collection } from '@discordjs/collection';
|
||||
import { AsyncQueue } from '@sapphire/async-queue';
|
||||
import type { IIdentifyThrottler } from './IIdentifyThrottler';
|
||||
|
||||
/**
|
||||
* The state of a rate limit key's identify queue.
|
||||
*/
|
||||
export interface IdentifyState {
|
||||
queue: AsyncQueue;
|
||||
resetsAt: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Local, in-memory identify throttler.
|
||||
*/
|
||||
export class SimpleIdentifyThrottler implements IIdentifyThrottler {
|
||||
private readonly states = new Collection<number, IdentifyState>();
|
||||
|
||||
public constructor(private readonly maxConcurrency: number) {}
|
||||
|
||||
/**
|
||||
* {@inheritDoc IIdentifyThrottler.waitForIdentify}
|
||||
*/
|
||||
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
|
||||
const key = shardId % this.maxConcurrency;
|
||||
|
||||
const state = this.states.ensure(key, () => {
|
||||
return {
|
||||
queue: new AsyncQueue(),
|
||||
resetsAt: Number.POSITIVE_INFINITY,
|
||||
};
|
||||
});
|
||||
|
||||
await state.queue.wait({ signal });
|
||||
|
||||
try {
|
||||
const diff = state.resetsAt - Date.now();
|
||||
if (diff <= 5_000) {
|
||||
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
|
||||
const time = diff + Math.random() * 1_500;
|
||||
await sleep(time);
|
||||
}
|
||||
|
||||
state.resetsAt = Date.now() + 5_000;
|
||||
} finally {
|
||||
state.queue.shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
import { setTimeout as sleep } from 'node:timers/promises';
|
||||
import { AsyncQueue } from '@sapphire/async-queue';
|
||||
import type { WebSocketManager } from '../ws/WebSocketManager.js';
|
||||
|
||||
export class IdentifyThrottler {
|
||||
private readonly queue = new AsyncQueue();
|
||||
|
||||
private identifyState = {
|
||||
remaining: 0,
|
||||
resetsAt: Number.POSITIVE_INFINITY,
|
||||
};
|
||||
|
||||
public constructor(private readonly manager: WebSocketManager) {}
|
||||
|
||||
public async waitForIdentify(): Promise<void> {
|
||||
await this.queue.wait();
|
||||
|
||||
try {
|
||||
if (this.identifyState.remaining <= 0) {
|
||||
const diff = this.identifyState.resetsAt - Date.now();
|
||||
if (diff <= 5_000) {
|
||||
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
|
||||
const time = diff + Math.random() * 1_500;
|
||||
await sleep(time);
|
||||
}
|
||||
|
||||
const info = await this.manager.fetchGatewayInformation();
|
||||
this.identifyState = {
|
||||
remaining: info.session_start_limit.max_concurrency,
|
||||
resetsAt: Date.now() + 5_000,
|
||||
};
|
||||
}
|
||||
|
||||
this.identifyState.remaining--;
|
||||
} finally {
|
||||
this.queue.shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,7 +117,7 @@ export class WorkerBootstrapper {
|
||||
break;
|
||||
}
|
||||
|
||||
case WorkerSendPayloadOp.ShardCanIdentify: {
|
||||
case WorkerSendPayloadOp.ShardIdentifyResponse: {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -127,11 +127,11 @@ export class WorkerBootstrapper {
|
||||
throw new Error(`Shard ${payload.shardId} does not exist`);
|
||||
}
|
||||
|
||||
const response = {
|
||||
const response: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.FetchStatusResponse,
|
||||
status: shard.status,
|
||||
nonce: payload.nonce,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
|
||||
parentPort!.postMessage(response);
|
||||
break;
|
||||
@@ -150,12 +150,12 @@ export class WorkerBootstrapper {
|
||||
for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) {
|
||||
// @ts-expect-error: Event types incompatible
|
||||
shard.on(event, (data) => {
|
||||
const payload = {
|
||||
const payload: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.Event,
|
||||
event,
|
||||
data,
|
||||
shardId,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
parentPort!.postMessage(payload);
|
||||
});
|
||||
}
|
||||
@@ -168,9 +168,9 @@ export class WorkerBootstrapper {
|
||||
// Lastly, start listening to messages from the parent thread
|
||||
this.setupThreadEvents();
|
||||
|
||||
const message = {
|
||||
const message: WorkerReceivePayload = {
|
||||
op: WorkerReceivePayloadOp.WorkerReady,
|
||||
} satisfies WorkerReceivePayload;
|
||||
};
|
||||
parentPort!.postMessage(message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ import { Collection } from '@discordjs/collection';
|
||||
import { lazy } from '@discordjs/util';
|
||||
import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10';
|
||||
import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js';
|
||||
import type { SessionInfo, OptionalWebSocketManagerOptions } from '../ws/WebSocketManager.js';
|
||||
import { SimpleIdentifyThrottler } from '../throttling/SimpleIdentifyThrottler.js';
|
||||
import type { SessionInfo, OptionalWebSocketManagerOptions, WebSocketManager } from '../ws/WebSocketManager.js';
|
||||
import type { SendRateLimitState } from '../ws/WebSocketShard.js';
|
||||
|
||||
/**
|
||||
@@ -28,6 +29,10 @@ const getDefaultSessionStore = lazy(() => new Collection<number, SessionInfo | n
|
||||
* Default options used by the manager
|
||||
*/
|
||||
export const DefaultWebSocketManagerOptions = {
|
||||
async buildIdentifyThrottler(manager: WebSocketManager) {
|
||||
const info = await manager.fetchGatewayInformation();
|
||||
return new SimpleIdentifyThrottler(info.session_start_limit.max_concurrency);
|
||||
},
|
||||
buildStrategy: (manager) => new SimpleShardingStrategy(manager),
|
||||
shardCount: null,
|
||||
shardIds: null,
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
type GatewaySendPayload,
|
||||
} from 'discord-api-types/v10';
|
||||
import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy';
|
||||
import type { IIdentifyThrottler } from '../throttling/IIdentifyThrottler';
|
||||
import { DefaultWebSocketManagerOptions, type CompressionMethod, type Encoding } from '../utils/constants.js';
|
||||
import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard.js';
|
||||
|
||||
@@ -55,7 +56,7 @@ export interface RequiredWebSocketManagerOptions {
|
||||
/**
|
||||
* The intents to request
|
||||
*/
|
||||
intents: GatewayIntentBits;
|
||||
intents: GatewayIntentBits | 0;
|
||||
/**
|
||||
* The REST instance to use for fetching gateway information
|
||||
*/
|
||||
@@ -70,6 +71,10 @@ export interface RequiredWebSocketManagerOptions {
|
||||
* Optional additional configuration for the WebSocketManager
|
||||
*/
|
||||
export interface OptionalWebSocketManagerOptions {
|
||||
/**
|
||||
* Builds an identify throttler to use for this manager's shards
|
||||
*/
|
||||
buildIdentifyThrottler(manager: WebSocketManager): Awaitable<IIdentifyThrottler>;
|
||||
/**
|
||||
* Builds the strategy to use for sharding
|
||||
*
|
||||
|
||||
@@ -358,7 +358,21 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
|
||||
private async identify() {
|
||||
this.debug(['Waiting for identify throttle']);
|
||||
|
||||
await this.strategy.waitForIdentify();
|
||||
const controller = new AbortController();
|
||||
const closeHandler = () => {
|
||||
controller.abort();
|
||||
};
|
||||
|
||||
this.on(WebSocketShardEvents.Closed, closeHandler);
|
||||
|
||||
try {
|
||||
await this.strategy.waitForIdentify(this.id, controller.signal);
|
||||
} catch {
|
||||
this.debug(['Was waiting for an identify, but the shard closed in the meantime']);
|
||||
return;
|
||||
} finally {
|
||||
this.off(WebSocketShardEvents.Closed, closeHandler);
|
||||
}
|
||||
|
||||
this.debug([
|
||||
'Identifying',
|
||||
|
||||
Reference in New Issue
Block a user