refactor: abstract identify throttling and correct max_concurrency handling (#9375)

* refactor: properly support max_concurrency ratelimit keys

* fix: properly block for same key

* chore: export session state

* chore: throttler no longer requires manager

* refactor: abstract throttlers

* chore: proper member order

* chore: remove leftover debug log

* chore: use @link tag in doc comment

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>

* chore: suggested changes

* fix(WebSocketShard): cancel identify if the shard closed in the meantime

* refactor(throttlers): support abort signals

* fix: memory leak

* chore: remove leftover

---------

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
This commit is contained in:
DD
2023-04-14 23:26:37 +03:00
committed by GitHub
parent cac3c07729
commit 02dfaf1aa2
16 changed files with 279 additions and 161 deletions

View File

@@ -5,7 +5,13 @@ import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../
export interface FetchingStrategyOptions
extends Omit<
WebSocketManagerOptions,
'buildStrategy' | 'rest' | 'retrieveSessionInfo' | 'shardCount' | 'shardIds' | 'updateSessionInfo'
| 'buildIdentifyThrottler'
| 'buildStrategy'
| 'rest'
| 'retrieveSessionInfo'
| 'shardCount'
| 'shardIds'
| 'updateSessionInfo'
> {
readonly gatewayInformation: APIGatewayBotInfo;
readonly shardCount: number;
@@ -18,13 +24,25 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>;
/**
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted
*/
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
}
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { buildStrategy, retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } =
manager.options;
/* eslint-disable @typescript-eslint/unbound-method */
const {
buildIdentifyThrottler,
buildStrategy,
retrieveSessionInfo,
updateSessionInfo,
shardCount,
shardIds,
rest,
...managerOptions
} = manager.options;
/* eslint-enable @typescript-eslint/unbound-method */
return {
...managerOptions,

View File

@@ -1,28 +1,25 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
// This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();
private static throttlerCache = new WeakMap<WebSocketManager, IIdentifyThrottler>();
private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) {
return existing;
private static async ensureThrottler(manager: WebSocketManager): Promise<IIdentifyThrottler> {
const throttler = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (throttler) {
return throttler;
}
const throttler = new IdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
return throttler;
const newThrottler = await manager.options.buildIdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, newThrottler);
return newThrottler;
}
private readonly throttler: IdentifyThrottler;
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
}
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId);
@@ -32,7 +29,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
}
public async waitForIdentify(): Promise<void> {
await this.throttler.waitForIdentify();
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const throttler = await SimpleContextFetchingStrategy.ensureThrottler(this.manager);
await throttler.waitForIdentify(shardId, signal);
}
}

View File

@@ -9,10 +9,17 @@ import {
} from '../sharding/WorkerShardingStrategy.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';
// Because the global types are incomplete for whatever reason
interface PolyFillAbortSignal {
readonly aborted: boolean;
addEventListener(type: 'abort', listener: () => void): void;
removeEventListener(type: 'abort', listener: () => void): void;
}
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
private readonly waitForIdentifyPromises = new Collection<number, () => void>();
private readonly waitForIdentifyPromises = new Collection<number, { reject(): void; resolve(): void }>();
public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) {
@@ -25,8 +32,14 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
this.sessionPromises.delete(payload.nonce);
}
if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
this.waitForIdentifyPromises.get(payload.nonce)?.();
if (payload.op === WorkerSendPayloadOp.ShardIdentifyResponse) {
const promise = this.waitForIdentifyPromises.get(payload.nonce);
if (payload.ok) {
promise?.resolve();
} else {
promise?.reject();
}
this.waitForIdentifyPromises.delete(payload.nonce);
}
});
@@ -34,11 +47,11 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random();
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.RetrieveSessionInfo,
shardId,
nonce,
} satisfies WorkerReceivePayload;
};
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
@@ -46,23 +59,44 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
}
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.UpdateSessionInfo,
shardId,
session: sessionInfo,
} satisfies WorkerReceivePayload;
};
parentPort!.postMessage(payload);
}
public async waitForIdentify(): Promise<void> {
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const nonce = Math.random();
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WaitForIdentify,
nonce,
} satisfies WorkerReceivePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
shardId,
};
const promise = new Promise<void>((resolve, reject) =>
// eslint-disable-next-line no-promise-executor-return
this.waitForIdentifyPromises.set(nonce, { resolve, reject }),
);
parentPort!.postMessage(payload);
return promise;
const listener = () => {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.CancelIdentify,
nonce,
};
parentPort!.postMessage(payload);
};
(signal as unknown as PolyFillAbortSignal).addEventListener('abort', listener);
try {
await promise;
} finally {
(signal as unknown as PolyFillAbortSignal).removeEventListener('abort', listener);
}
}
}