feat: @discordjs/ws (#8260)

Co-authored-by: Parbez <imranbarbhuiya.fsd@gmail.com>
This commit is contained in:
DD
2022-07-22 20:13:47 +03:00
committed by GitHub
parent 830c670c61
commit 748d7271c4
37 changed files with 3659 additions and 2612 deletions

14
packages/ws/src/index.ts Normal file
View File

@@ -0,0 +1,14 @@
export * from './strategies/context/IContextFetchingStrategy';
export * from './strategies/context/SimpleContextFetchingStrategy';
export * from './strategies/context/WorkerContextFetchingStrategy';
export * from './strategies/sharding/IShardingStrategy';
export * from './strategies/sharding/SimpleShardingStrategy';
export * from './strategies/sharding/WorkerShardingStrategy';
export * from './utils/constants';
export * from './utils/IdentifyThrottler';
export * from './utils/utils';
export * from './ws/WebSocketManager';
export * from './ws/WebSocketShard';

View File

@@ -0,0 +1,31 @@
import type { Awaitable } from '@vladfrangu/async_event_emitter';
import type { APIGatewayBotInfo } from 'discord-api-types/v10';
import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../../ws/WebSocketManager';
export interface FetchingStrategyOptions
extends Omit<
WebSocketManagerOptions,
'retrieveSessionInfo' | 'updateSessionInfo' | 'shardCount' | 'shardIds' | 'rest'
> {
readonly gatewayInformation: APIGatewayBotInfo;
readonly shardCount: number;
}
/**
* Strategies responsible solely for making manager information accessible
*/
export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions;
retrieveSessionInfo: (shardId: number) => Awaitable<SessionInfo | null>;
updateSessionInfo: (shardId: number, sessionInfo: SessionInfo | null) => Awaitable<void>;
}
export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
const { retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } = manager.options;
return {
...managerOptions,
gatewayInformation: await manager.fetchGatewayInformation(),
shardCount: await manager.getShardCount(),
};
}

View File

@@ -0,0 +1,14 @@
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager';
export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId);
}
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
}
}

View File

@@ -0,0 +1,49 @@
import { isMainThread, parentPort } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy';
import type { SessionInfo } from '../../ws/WebSocketManager';
import {
WorkerRecievePayload,
WorkerRecievePayloadOp,
WorkerSendPayload,
WorkerSendPayloadOp,
} from '../sharding/WorkerShardingStrategy';
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();
public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) {
throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread');
}
parentPort!.on('message', (payload: WorkerSendPayload) => {
if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) {
const resolve = this.sessionPromises.get(payload.nonce);
resolve?.(payload.session);
this.sessionPromises.delete(payload.nonce);
}
});
}
public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random();
const payload: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId,
nonce,
};
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId,
session: sessionInfo,
};
parentPort!.postMessage(payload);
}
}

View File

@@ -0,0 +1,25 @@
import type { GatewaySendPayload } from 'discord-api-types/v10';
import type { Awaitable } from '../../utils/utils';
import type { WebSocketShardDestroyOptions } from '../../ws/WebSocketShard';
/**
* Strategies responsible for spawning, initializing connections, destroying shards, and relaying events
*/
export interface IShardingStrategy {
/**
* Spawns all the shards
*/
spawn: (shardIds: number[]) => Awaitable<void>;
/**
* Initializes all the shards
*/
connect: () => Awaitable<void>;
/**
* Destroys all the shards
*/
destroy: (options?: Omit<WebSocketShardDestroyOptions, 'recover'>) => Awaitable<void>;
/**
* Sends a payload to a shard
*/
send: (shardId: number, payload: GatewaySendPayload) => Awaitable<void>;
}

View File

@@ -0,0 +1,64 @@
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
import type { IShardingStrategy } from './IShardingStrategy';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler';
import type { WebSocketManager } from '../../ws/WebSocketManager';
import { WebSocketShard, WebSocketShardDestroyOptions, WebSocketShardEvents } from '../../ws/WebSocketShard';
import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy';
import { SimpleContextFetchingStrategy } from '../context/SimpleContextFetchingStrategy';
/**
* Simple strategy that just spawns shards in the current process
*/
export class SimpleShardingStrategy implements IShardingStrategy {
private readonly manager: WebSocketManager;
private readonly shards = new Collection<number, WebSocketShard>();
private readonly throttler: IdentifyThrottler;
public constructor(manager: WebSocketManager) {
this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
}
public async spawn(shardIds: number[]) {
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
for (const shardId of shardIds) {
const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions);
const shard = new WebSocketShard(strategy, shardId);
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error
shard.on(event, (payload) => this.manager.emit(event, { ...payload, shardId }));
}
this.shards.set(shardId, shard);
}
}
public async connect() {
const promises = [];
for (const shard of this.shards.values()) {
await this.throttler.waitForIdentify();
promises.push(shard.connect());
}
await Promise.all(promises);
}
public async destroy(options?: Omit<WebSocketShardDestroyOptions, 'recover'>) {
const promises = [];
for (const shard of this.shards.values()) {
promises.push(shard.destroy(options));
}
await Promise.all(promises);
this.shards.clear();
}
public send(shardId: number, payload: GatewaySendPayload) {
const shard = this.shards.get(shardId);
if (!shard) throw new Error(`Shard ${shardId} not found`);
return shard.send(payload);
}
}

View File

@@ -0,0 +1,203 @@
import { once } from 'node:events';
import { join } from 'node:path';
import { Worker } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
import type { IShardingStrategy } from './IShardingStrategy';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager';
import type { WebSocketShardDestroyOptions, WebSocketShardEvents } from '../../ws/WebSocketShard';
import { FetchingStrategyOptions, managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy';
export interface WorkerData extends FetchingStrategyOptions {
shardIds: number[];
}
export enum WorkerSendPayloadOp {
Connect,
Destroy,
Send,
SessionInfoResponse,
}
export type WorkerSendPayload =
| { op: WorkerSendPayloadOp.Connect; shardId: number }
| { op: WorkerSendPayloadOp.Destroy; shardId: number; options?: WebSocketShardDestroyOptions }
| { op: WorkerSendPayloadOp.Send; shardId: number; payload: GatewaySendPayload }
| { op: WorkerSendPayloadOp.SessionInfoResponse; nonce: number; session: SessionInfo | null };
export enum WorkerRecievePayloadOp {
Connected,
Destroyed,
Event,
RetrieveSessionInfo,
UpdateSessionInfo,
}
export type WorkerRecievePayload =
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
| { op: WorkerRecievePayloadOp.Event; shardId: number; event: WebSocketShardEvents; data: any }
| { op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number; nonce: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; shardId: number; session: SessionInfo | null };
/**
* Options for a {@link WorkerShardingStrategy}
*/
export interface WorkerShardingStrategyOptions {
/**
* Dictates how many shards should be spawned per worker thread.
*/
shardsPerWorker: number | 'all';
}
/**
* Strategy used to spawn threads in worker_threads
*/
export class WorkerShardingStrategy implements IShardingStrategy {
private readonly manager: WebSocketManager;
private readonly options: WorkerShardingStrategyOptions;
#workers: Worker[] = [];
readonly #workerByShardId = new Collection<number, Worker>();
private readonly connectPromises = new Collection<number, () => void>();
private readonly destroyPromises = new Collection<number, () => void>();
private readonly throttler: IdentifyThrottler;
public constructor(manager: WebSocketManager, options: WorkerShardingStrategyOptions) {
this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
this.options = options;
}
public async spawn(shardIds: number[]) {
const shardsPerWorker = this.options.shardsPerWorker === 'all' ? shardIds.length : this.options.shardsPerWorker;
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
let shards = 0;
while (shards !== shardIds.length) {
const slice = shardIds.slice(shards, shardsPerWorker + shards);
const workerData: WorkerData = {
...strategyOptions,
shardIds: slice,
};
const worker = new Worker(join(__dirname, 'worker.cjs'), { workerData });
await once(worker, 'online');
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
// eslint-disable-next-line @typescript-eslint/no-misused-promises
.on('message', (payload: WorkerRecievePayload) => this.onMessage(worker, payload));
this.#workers.push(worker);
for (const shardId of slice) {
this.#workerByShardId.set(shardId, worker);
}
shards += slice.length;
}
}
public async connect() {
const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) {
await this.throttler.waitForIdentify();
const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Connect,
shardId,
};
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
worker.postMessage(payload);
promises.push(promise);
}
await Promise.all(promises);
}
public async destroy(options: Omit<WebSocketShardDestroyOptions, 'recover'> = {}) {
const promises = [];
for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Destroy,
shardId,
options,
};
promises.push(
new Promise<void>((resolve) => this.destroyPromises.set(shardId, resolve)).then(() => worker.terminate()),
);
worker.postMessage(payload);
}
this.#workers = [];
this.#workerByShardId.clear();
await Promise.all(promises);
}
public send(shardId: number, data: GatewaySendPayload) {
const worker = this.#workerByShardId.get(shardId);
if (!worker) {
throw new Error(`No worker found for shard ${shardId}`);
}
const payload: WorkerSendPayload = {
op: WorkerSendPayloadOp.Send,
shardId,
payload: data,
};
worker.postMessage(payload);
}
private async onMessage(worker: Worker, payload: WorkerRecievePayload) {
switch (payload.op) {
case WorkerRecievePayloadOp.Connected: {
const resolve = this.connectPromises.get(payload.shardId)!;
resolve();
this.connectPromises.delete(payload.shardId);
break;
}
case WorkerRecievePayloadOp.Destroyed: {
const resolve = this.destroyPromises.get(payload.shardId)!;
resolve();
this.destroyPromises.delete(payload.shardId);
break;
}
case WorkerRecievePayloadOp.Event: {
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
this.manager.emit(payload.event, { ...payload.data, shardId: payload.shardId });
break;
}
case WorkerRecievePayloadOp.RetrieveSessionInfo: {
const session = await this.manager.options.retrieveSessionInfo(payload.shardId);
const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.SessionInfoResponse,
nonce: payload.nonce,
session,
};
worker.postMessage(response);
break;
}
case WorkerRecievePayloadOp.UpdateSessionInfo: {
await this.manager.options.updateSessionInfo(payload.shardId, payload.session);
break;
}
}
}
}

View File

@@ -0,0 +1,93 @@
import { isMainThread, workerData, parentPort } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import {
WorkerData,
WorkerRecievePayload,
WorkerRecievePayloadOp,
WorkerSendPayload,
WorkerSendPayloadOp,
} from './WorkerShardingStrategy';
import { WebSocketShard, WebSocketShardDestroyOptions, WebSocketShardEvents } from '../../ws/WebSocketShard';
import { WorkerContextFetchingStrategy } from '../context/WorkerContextFetchingStrategy';
if (isMainThread) {
throw new Error('Expected worker script to not be ran within the main thread');
}
const data = workerData as WorkerData;
const shards = new Collection<number, WebSocketShard>();
async function connect(shardId: number) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.connect();
}
async function destroy(shardId: number, options?: WebSocketShardDestroyOptions) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.destroy(options);
}
for (const shardId of data.shardIds) {
const shard = new WebSocketShard(new WorkerContextFetchingStrategy(data), shardId);
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error
shard.on(event, (data) => {
const payload: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
};
parentPort!.postMessage(payload);
});
}
shards.set(shardId, shard);
}
parentPort!
.on('messageerror', (err) => {
throw err;
})
// eslint-disable-next-line @typescript-eslint/no-misused-promises
.on('message', async (payload: WorkerSendPayload) => {
switch (payload.op) {
case WorkerSendPayloadOp.Connect: {
await connect(payload.shardId);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Connected,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Destroy: {
await destroy(payload.shardId, payload.options);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Send: {
const shard = shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
await shard.send(payload.payload);
break;
}
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
}
});

View File

@@ -0,0 +1,29 @@
import { setTimeout as sleep } from 'node:timers/promises';
import type { WebSocketManager } from '../ws/WebSocketManager';
export class IdentifyThrottler {
private identifyState = {
remaining: 0,
resetsAt: Infinity,
};
public constructor(private readonly manager: WebSocketManager) {}
public async waitForIdentify(): Promise<void> {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
const time = diff + Math.random() * 1_500;
await sleep(time);
}
const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
}
this.identifyState.remaining--;
}
}

View File

@@ -0,0 +1,68 @@
import { readFileSync } from 'node:fs';
import { join } from 'node:path';
import { Collection } from '@discordjs/collection';
import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10';
import { lazy } from './utils';
import type { OptionalWebSocketManagerOptions, SessionInfo } from '../ws/WebSocketManager';
/**
* Valid encoding types
*/
export enum Encoding {
JSON = 'json',
}
/**
* Valid compression methods
*/
export enum CompressionMethod {
ZlibStream = 'zlib-stream',
}
const packageJson = readFileSync(join(__dirname, '..', '..', 'package.json'), 'utf8');
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const Package = JSON.parse(packageJson);
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions, @typescript-eslint/no-unsafe-member-access
export const DefaultDeviceProperty = `@discordjs/ws ${Package.version}`;
const getDefaultSessionStore = lazy(() => new Collection<number, SessionInfo | null>());
/**
* Default options used by the manager
*/
export const DefaultWebSocketManagerOptions: OptionalWebSocketManagerOptions = {
shardCount: null,
shardIds: null,
largeThreshold: null,
initialPresence: null,
identifyProperties: {
browser: DefaultDeviceProperty,
device: DefaultDeviceProperty,
os: process.platform,
},
version: APIVersion,
encoding: Encoding.JSON,
compression: null,
retrieveSessionInfo(shardId) {
const store = getDefaultSessionStore();
return store.get(shardId) ?? null;
},
updateSessionInfo(shardId: number, info: SessionInfo | null) {
const store = getDefaultSessionStore();
if (info) {
store.set(shardId, info);
} else {
store.delete(shardId);
}
},
handshakeTimeout: 30_000,
helloTimeout: 60_000,
readyTimeout: 15_000,
};
export const ImportantGatewayOpcodes = new Set([
GatewayOpcodes.Heartbeat,
GatewayOpcodes.Identify,
GatewayOpcodes.Resume,
]);

View File

@@ -0,0 +1,20 @@
import type { ShardRange } from '../ws/WebSocketManager';
export type Awaitable<T> = T | Promise<T>;
/**
* Yields the numbers in the given range as an array
* @example
* range({ start: 3, end: 5 }); // [3, 4, 5]
*/
export function range({ start, end }: ShardRange): number[] {
return Array.from({ length: end - start + 1 }, (_, i) => i + start);
}
/**
* Lazily evaluate a callback, storing its result
*/
export function lazy<T>(cb: () => T): () => T {
let defaultValue: T;
return () => (defaultValue ??= cb());
}

View File

@@ -0,0 +1,274 @@
import type { REST } from '@discordjs/rest';
import { AsyncEventEmitter } from '@vladfrangu/async_event_emitter';
import {
APIGatewayBotInfo,
GatewayIdentifyProperties,
GatewayPresenceUpdateData,
RESTGetAPIGatewayBotResult,
GatewayIntentBits,
Routes,
GatewaySendPayload,
} from 'discord-api-types/v10';
import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard';
import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy';
import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy';
import { CompressionMethod, DefaultWebSocketManagerOptions, Encoding } from '../utils/constants';
import { Awaitable, range } from '../utils/utils';
/**
* Represents a range of shard ids
*/
export interface ShardRange {
start: number;
end: number;
}
/**
* Session information for a given shard, used to resume a session
*/
export interface SessionInfo {
/**
* Session id for this shard
*/
sessionId: string;
/**
* The sequence number of the last message sent by the shard
*/
sequence: number;
/**
* The id of the shard
*/
shardId: number;
/**
* The total number of shards at the time of this shard identifying
*/
shardCount: number;
}
/**
* Required options for the WebSocketManager
*/
export interface RequiredWebSocketManagerOptions {
/**
* The token to use for identifying with the gateway
*/
token: string;
/**
* The intents to request
*/
intents: GatewayIntentBits;
/**
* The REST instance to use for fetching gateway information
*/
rest: REST;
}
/**
* Optional additional configuration for the WebSocketManager
*/
export interface OptionalWebSocketManagerOptions {
/**
* The total number of shards across all WebsocketManagers you intend to instantiate.
* Use `null` to use Discord's recommended shard count
*/
shardCount: number | null;
/**
* The ids of the shards this WebSocketManager should manage.
* Use `null` to simply spawn 0 through `shardCount - 1`
* @example
* const manager = new WebSocketManager({
* shardIds: [1, 3, 7], // spawns shard 1, 3, and 7, nothing else
* });
* @example
* const manager = new WebSocketManager({
* shardIds: {
* start: 3,
* end: 6,
* }, // spawns shards 3, 4, 5, and 6
* });
*/
shardIds: number[] | ShardRange | null;
/**
* Value between 50 and 250, total number of members where the gateway will stop sending offline members in the guild member list
*/
largeThreshold: number | null;
/**
* Initial presence data to send to the gateway when identifying
*/
initialPresence: GatewayPresenceUpdateData | null;
/**
* Properties to send to the gateway when identifying
*/
identifyProperties: GatewayIdentifyProperties;
/**
* The gateway version to use
* @default '10'
*/
version: string;
/**
* The encoding to use
* @default 'json'
*/
encoding: Encoding;
/**
* The compression method to use
* @default null (no compression)
*/
compression: CompressionMethod | null;
/**
* Function used to retrieve session information (and attempt to resume) for a given shard
* @example
* const manager = new WebSocketManager({
* async retrieveSessionInfo(shardId): Awaitable<SessionInfo | null> {
* // Fetch this info from redis or similar
* return { sessionId: string, sequence: number };
* // Return null if no information is found
* },
* });
*/
retrieveSessionInfo: (shardId: number) => Awaitable<SessionInfo | null>;
/**
* Function used to store session information for a given shard
*/
updateSessionInfo: (shardId: number, sessionInfo: SessionInfo | null) => Awaitable<void>;
/**
* How long to wait for a shard to connect before giving up
*/
handshakeTimeout: number | null;
/**
* How long to wait for a shard's HELLO packet before giving up
*/
helloTimeout: number | null;
/**
* How long to wait for a shard's READY packet before giving up
*/
readyTimeout: number | null;
}
export type WebSocketManagerOptions = RequiredWebSocketManagerOptions & OptionalWebSocketManagerOptions;
export type ManagerShardEventsMap = {
[K in keyof WebSocketShardEventsMap]: [
WebSocketShardEventsMap[K] extends [] ? { shardId: number } : WebSocketShardEventsMap[K][0] & { shardId: number },
];
};
export class WebSocketManager extends AsyncEventEmitter<ManagerShardEventsMap> {
/**
* The options being used by this manager
*/
public readonly options: WebSocketManagerOptions;
/**
* Internal cache for a GET /gateway/bot result
*/
private gatewayInformation: {
data: APIGatewayBotInfo;
expiresAt: number;
} | null = null;
/**
* Internal cache for the shard ids
*/
private shardIds: number[] | null = null;
/**
* Strategy used to manage shards
* @default SimpleManagerToShardStrategy
*/
private strategy: IShardingStrategy = new SimpleShardingStrategy(this);
public constructor(options: RequiredWebSocketManagerOptions & Partial<OptionalWebSocketManagerOptions>) {
super();
this.options = { ...DefaultWebSocketManagerOptions, ...options };
}
public setStrategy(strategy: IShardingStrategy) {
this.strategy = strategy;
return this;
}
/**
* Fetches the gateway information from Discord - or returns it from cache if available
* @param force Whether to ignore the cache and force a fresh fetch
*/
public async fetchGatewayInformation(force = false) {
if (this.gatewayInformation) {
if (this.gatewayInformation.expiresAt <= Date.now()) {
this.gatewayInformation = null;
} else if (!force) {
return this.gatewayInformation.data;
}
}
const data = (await this.options.rest.get(Routes.gatewayBot())) as RESTGetAPIGatewayBotResult;
this.gatewayInformation = { data, expiresAt: Date.now() + data.session_start_limit.reset_after };
return this.gatewayInformation.data;
}
/**
* Updates your total shard count on-the-fly, spawning shards as needed
* @param shardCount The new shard count to use
*/
public async updateShardCount(shardCount: number | null) {
await this.strategy.destroy({ reason: 'User is adjusting their shards' });
this.options.shardCount = shardCount;
const shardIds = await this.getShardIds(true);
await this.strategy.spawn(shardIds);
return this;
}
/**
* Yields the total number of shards across for your bot, accounting for Discord recommendations
*/
public async getShardCount(): Promise<number> {
if (this.options.shardCount) {
return this.options.shardCount;
}
const shardIds = await this.getShardIds();
return Math.max(...shardIds) + 1;
}
/**
* Yields the ids of the shards this manager should manage
*/
public async getShardIds(force = false): Promise<number[]> {
if (this.shardIds && !force) {
return this.shardIds;
}
let shardIds: number[];
if (this.options.shardIds) {
if (Array.isArray(this.options.shardIds)) {
shardIds = this.options.shardIds;
} else {
shardIds = range(this.options.shardIds);
}
} else {
const data = await this.fetchGatewayInformation();
shardIds = range({ start: 0, end: (this.options.shardCount ?? data.shards) - 1 });
}
this.shardIds = shardIds;
return shardIds;
}
public async connect() {
const shardCount = await this.getShardCount();
// First, make sure all our shards are spawned
await this.updateShardCount(shardCount);
await this.strategy.connect();
}
public destroy(options?: Omit<WebSocketShardDestroyOptions, 'recover'>) {
return this.strategy.destroy(options);
}
public send(shardId: number, payload: GatewaySendPayload) {
return this.strategy.send(shardId, payload);
}
}

View File

@@ -0,0 +1,549 @@
import { once } from 'node:events';
import { setTimeout } from 'node:timers';
import { setTimeout as sleep } from 'node:timers/promises';
import { TextDecoder } from 'node:util';
import { inflate } from 'node:zlib';
import { Collection } from '@discordjs/collection';
import { AsyncQueue } from '@sapphire/async-queue';
import { AsyncEventEmitter } from '@vladfrangu/async_event_emitter';
import {
GatewayCloseCodes,
GatewayDispatchEvents,
GatewayDispatchPayload,
GatewayIdentifyData,
GatewayOpcodes,
GatewayReceivePayload,
GatewaySendPayload,
} from 'discord-api-types/v10';
import { CONNECTING, OPEN, RawData, WebSocket } from 'ws';
import type { Inflate } from 'zlib-sync';
import type { SessionInfo } from './WebSocketManager';
import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy';
import { ImportantGatewayOpcodes } from '../utils/constants';
import { lazy } from '../utils/utils';
const getZlibSync = lazy(() => import('zlib-sync').then((mod) => mod.default).catch(() => null));
export enum WebSocketShardEvents {
Debug = 'debug',
Hello = 'hello',
Ready = 'ready',
Resumed = 'resumed',
Dispatch = 'dispatch',
}
export enum WebSocketShardStatus {
Idle,
Connecting,
Resuming,
Ready,
}
export enum WebSocketShardDestroyRecovery {
Reconnect,
Resume,
}
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
export type WebSocketShardEventsMap = {
[WebSocketShardEvents.Debug]: [payload: { message: string }];
[WebSocketShardEvents.Hello]: [];
[WebSocketShardEvents.Ready]: [];
[WebSocketShardEvents.Resumed]: [];
[WebSocketShardEvents.Dispatch]: [payload: { data: GatewayDispatchPayload }];
};
export interface WebSocketShardDestroyOptions {
reason?: string;
code?: number;
recover?: WebSocketShardDestroyRecovery;
}
export enum CloseCodes {
Normal = 1000,
Resuming = 4200,
}
export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private connection: WebSocket | null = null;
private readonly id: number;
private useIdentifyCompress = false;
private inflate: Inflate | null = null;
private readonly textDecoder = new TextDecoder();
private status: WebSocketShardStatus = WebSocketShardStatus.Idle;
private replayedEvents = 0;
private isAck = true;
private sendRateLimitState = {
remaining: 120,
resetAt: Date.now(),
};
private heartbeatInterval: NodeJS.Timer | null = null;
private lastHeartbeatAt = -1;
private session: SessionInfo | null = null;
private readonly sendQueue = new AsyncQueue();
private readonly timeouts = new Collection<WebSocketShardEvents, NodeJS.Timeout>();
public readonly strategy: IContextFetchingStrategy;
public constructor(strategy: IContextFetchingStrategy, id: number) {
super();
this.strategy = strategy;
this.id = id;
}
public async connect() {
if (this.status !== WebSocketShardStatus.Idle) {
throw new Error("Tried to connect a shard that wasn't idle");
}
const data = this.strategy.options.gatewayInformation;
const { version, encoding, compression } = this.strategy.options;
const params = new URLSearchParams({ v: version, encoding });
if (compression) {
const zlib = await getZlibSync();
if (zlib) {
params.append('compress', compression);
this.inflate = new zlib.Inflate({
chunkSize: 65535,
to: 'string',
});
} else if (!this.useIdentifyCompress) {
this.useIdentifyCompress = true;
console.warn(
'WebSocketShard: Compression is enabled but zlib-sync is not installed, falling back to identify compress',
);
}
}
const url = `${data.url}?${params.toString()}`;
this.debug([`Connecting to ${url}`]);
const connection = new WebSocket(url, { handshakeTimeout: this.strategy.options.handshakeTimeout ?? undefined })
/* eslint-disable @typescript-eslint/no-misused-promises */
.on('message', this.onMessage.bind(this))
.on('error', this.onError.bind(this))
.on('close', this.onClose.bind(this));
/* eslint-enable @typescript-eslint/no-misused-promises */
connection.binaryType = 'arraybuffer';
this.connection = connection;
this.status = WebSocketShardStatus.Connecting;
await this.waitForEvent(WebSocketShardEvents.Hello, this.strategy.options.helloTimeout);
const session = this.session ?? (await this.strategy.retrieveSessionInfo(this.id));
if (session?.shardCount === this.strategy.options.shardCount) {
this.session = session;
await this.resume(session);
} else {
await this.identify();
}
}
public async destroy(options: WebSocketShardDestroyOptions = {}) {
if (this.status === WebSocketShardStatus.Idle) {
this.debug(['Tried to destroy a shard that was idle']);
return;
}
if (!options.code) {
options.code = options.recover === WebSocketShardDestroyRecovery.Resume ? CloseCodes.Resuming : CloseCodes.Normal;
}
this.debug([
'Destroying shard',
`Reason: ${options.reason ?? 'none'}`,
`Code: ${options.code}`,
`Recover: ${options.recover === undefined ? 'none' : WebSocketShardDestroyRecovery[options.recover]!}`,
]);
// Reset state
this.isAck = true;
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
}
this.lastHeartbeatAt = -1;
// Clear session state if applicable
if (options.recover !== WebSocketShardDestroyRecovery.Resume && this.session) {
this.session = null;
await this.strategy.updateSessionInfo(this.id, null);
}
if (this.connection && (this.connection.readyState === OPEN || this.connection.readyState === CONNECTING)) {
this.connection.close(options.code, options.reason);
}
this.status = WebSocketShardStatus.Idle;
if (options.recover !== undefined) {
return this.connect();
}
}
private async waitForEvent(event: WebSocketShardEvents, timeoutDuration?: number | null) {
this.debug([`Waiting for event ${event} for ${timeoutDuration ? `${timeoutDuration}ms` : 'indefinitely'}`]);
const controller = new AbortController();
const timeout = timeoutDuration ? setTimeout(() => controller.abort(), timeoutDuration).unref() : null;
if (timeout) {
this.timeouts.set(event, timeout);
}
await once(this, event, { signal: controller.signal });
if (timeout) {
clearTimeout(timeout);
this.timeouts.delete(event);
}
}
public async send(payload: GatewaySendPayload) {
if (!this.connection) {
throw new Error("WebSocketShard wasn't connected");
}
if (this.status !== WebSocketShardStatus.Ready && !ImportantGatewayOpcodes.has(payload.op)) {
await once(this, WebSocketShardEvents.Ready);
}
await this.sendQueue.wait();
if (--this.sendRateLimitState.remaining <= 0) {
if (this.sendRateLimitState.resetAt < Date.now()) {
await sleep(Date.now() - this.sendRateLimitState.resetAt);
}
this.sendRateLimitState = {
remaining: 119,
resetAt: Date.now() + 60_000,
};
}
this.sendQueue.shift();
this.connection.send(JSON.stringify(payload));
}
private async identify() {
this.debug([
'Identifying',
`shard id: ${this.id.toString()}`,
`shard count: ${this.strategy.options.shardCount}`,
`intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
]);
const d: GatewayIdentifyData = {
token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties,
intents: this.strategy.options.intents,
compress: this.useIdentifyCompress,
shard: [this.id, this.strategy.options.shardCount],
};
if (this.strategy.options.largeThreshold) {
d.large_threshold = this.strategy.options.largeThreshold;
}
if (this.strategy.options.initialPresence) {
d.presence = this.strategy.options.initialPresence;
}
await this.send({
op: GatewayOpcodes.Identify,
d,
});
await this.waitForEvent(WebSocketShardEvents.Ready, this.strategy.options.readyTimeout);
this.status = WebSocketShardStatus.Ready;
}
private resume(session: SessionInfo) {
this.debug(['Resuming session']);
this.status = WebSocketShardStatus.Resuming;
this.replayedEvents = 0;
return this.send({
op: GatewayOpcodes.Resume,
d: {
token: this.strategy.options.token,
seq: session.sequence,
session_id: session.sessionId,
},
});
}
private async heartbeat(requested = false) {
if (!this.isAck && !requested) {
return this.destroy({ reason: 'Zombie connection', recover: WebSocketShardDestroyRecovery.Resume });
}
await this.send({
op: GatewayOpcodes.Heartbeat,
d: this.session?.sequence ?? null,
});
this.lastHeartbeatAt = Date.now();
this.isAck = false;
}
private async unpackMessage(data: Buffer | ArrayBuffer, isBinary: boolean): Promise<GatewayReceivePayload | null> {
const decompressable = new Uint8Array(data);
// Deal with no compression
if (!isBinary) {
return JSON.parse(this.textDecoder.decode(decompressable)) as GatewayReceivePayload;
}
// Deal with identify compress
if (this.useIdentifyCompress) {
return new Promise((resolve, reject) => {
inflate(decompressable, { chunkSize: 65535 }, (err, result) => {
if (err) {
return reject(err);
}
resolve(JSON.parse(this.textDecoder.decode(result)) as GatewayReceivePayload);
});
});
}
// Deal with gw wide zlib-stream compression
if (this.inflate) {
const l = decompressable.length;
const flush =
l >= 4 &&
decompressable[l - 4] === 0x00 &&
decompressable[l - 3] === 0x00 &&
decompressable[l - 2] === 0xff &&
decompressable[l - 1] === 0xff;
const zlib = (await getZlibSync())!;
this.inflate.push(Buffer.from(decompressable), flush ? zlib.Z_SYNC_FLUSH : zlib.Z_NO_FLUSH);
if (this.inflate.err) {
this.emit('error', `${this.inflate.err}${this.inflate.msg ? `: ${this.inflate.msg}` : ''}`);
}
if (!flush) {
return null;
}
const { result } = this.inflate;
if (!result) {
return null;
}
return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload;
}
this.debug([
'Received a message we were unable to decompress',
`isBinary: ${isBinary.toString()}`,
`useIdentifyCompress: ${this.useIdentifyCompress.toString()}`,
`inflate: ${Boolean(this.inflate).toString()}`,
]);
return null;
}
private async onMessage(data: RawData, isBinary: boolean) {
const payload = await this.unpackMessage(data as Buffer | ArrayBuffer, isBinary);
if (!payload) {
return;
}
switch (payload.op) {
case GatewayOpcodes.Dispatch: {
if (this.status === WebSocketShardStatus.Ready || this.status === WebSocketShardStatus.Resuming) {
this.emit(WebSocketShardEvents.Dispatch, { data: payload });
}
if (this.status === WebSocketShardStatus.Resuming) {
this.replayedEvents++;
}
switch (payload.t) {
case GatewayDispatchEvents.Ready: {
this.emit(WebSocketShardEvents.Ready);
this.session ??= {
sequence: payload.s,
sessionId: payload.d.session_id,
shardId: this.id,
shardCount: this.strategy.options.shardCount,
};
await this.strategy.updateSessionInfo(this.id, this.session);
break;
}
case GatewayDispatchEvents.Resumed: {
this.status = WebSocketShardStatus.Ready;
this.debug([`Resumed and replayed ${this.replayedEvents} events`]);
this.emit(WebSocketShardEvents.Resumed);
break;
}
default: {
break;
}
}
if (this.session) {
if (payload.s > this.session.sequence) {
this.session.sequence = payload.s;
await this.strategy.updateSessionInfo(this.id, this.session);
}
}
break;
}
case GatewayOpcodes.Heartbeat: {
await this.heartbeat(true);
break;
}
case GatewayOpcodes.Reconnect: {
await this.destroy({
reason: 'Told to reconnect by Discord',
recover: WebSocketShardDestroyRecovery.Resume,
});
break;
}
case GatewayOpcodes.InvalidSession: {
const readyTimeout = this.timeouts.get(WebSocketShardEvents.Ready);
readyTimeout?.refresh();
this.debug([`Invalid session; will attempt to resume: ${payload.d.toString()}`]);
const session = this.session ?? (await this.strategy.retrieveSessionInfo(this.id));
if (payload.d && session) {
await this.resume(session);
} else {
await this.destroy({
reason: 'Invalid session',
recover: WebSocketShardDestroyRecovery.Reconnect,
});
}
break;
}
case GatewayOpcodes.Hello: {
this.emit(WebSocketShardEvents.Hello);
this.debug([`Starting to heartbeat every ${payload.d.heartbeat_interval}ms`]);
this.heartbeatInterval = setInterval(() => void this.heartbeat(), payload.d.heartbeat_interval);
break;
}
case GatewayOpcodes.HeartbeatAck: {
this.isAck = true;
this.debug([`Got heartbeat ack after ${Date.now() - this.lastHeartbeatAt}ms`]);
break;
}
}
}
private onError(err: Error) {
this.emit('error', { err });
}
private async onClose(code: number) {
switch (code) {
case 1000:
case 4200: {
this.debug([`Disconnected normally from code ${code}`]);
break;
}
case GatewayCloseCodes.UnknownError: {
this.debug([`An unknown error occured: ${code}`]);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Resume });
}
case GatewayCloseCodes.UnknownOpcode: {
this.debug(['An invalid opcode was sent to Discord.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Resume });
}
case GatewayCloseCodes.DecodeError: {
this.debug(['An invalid payload was sent to Discord.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Resume });
}
case GatewayCloseCodes.NotAuthenticated: {
this.debug(['A request was somehow sent before the identify/resume payload.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Reconnect });
}
case GatewayCloseCodes.AuthenticationFailed: {
throw new Error('Authentication failed');
}
case GatewayCloseCodes.AlreadyAuthenticated: {
this.debug(['More than one auth payload was sent.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Reconnect });
}
case GatewayCloseCodes.InvalidSeq: {
this.debug(['An invalid sequence was sent.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Reconnect });
}
case GatewayCloseCodes.RateLimited: {
this.debug(['The WebSocket rate limit has been hit, this should never happen']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Reconnect });
}
case GatewayCloseCodes.SessionTimedOut: {
this.debug(['Session timed out.']);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Resume });
}
case GatewayCloseCodes.InvalidShard: {
throw new Error('Invalid shard');
}
case GatewayCloseCodes.ShardingRequired: {
throw new Error('Sharding is required');
}
case GatewayCloseCodes.InvalidAPIVersion: {
throw new Error('Used an invalid API version');
}
case GatewayCloseCodes.InvalidIntents: {
throw new Error('Used invalid intents');
}
case GatewayCloseCodes.DisallowedIntents: {
throw new Error('Used disallowed intents');
}
default: {
this.debug([`The gateway closed with an unexpected code ${code}, attempting to resume.`]);
return this.destroy({ code, recover: WebSocketShardDestroyRecovery.Resume });
}
}
}
private debug(messages: [string, ...string[]]) {
const message = `${messages[0]}${
messages.length > 1
? `\n${messages
.slice(1)
.map((m) => ` ${m}`)
.join('\n')}`
: ''
}`;
this.emit(WebSocketShardEvents.Debug, { message });
}
}