diff --git a/packages/ws/README.md b/packages/ws/README.md index 9da53cd4f..fa171deed 100644 --- a/packages/ws/README.md +++ b/packages/ws/README.md @@ -42,13 +42,16 @@ bun add @discordjs/ws ```ts import { WebSocketManager, WebSocketShardEvents, CompressionMethod } from '@discordjs/ws'; import { REST } from '@discordjs/rest'; +import type { RESTGetAPIGatewayBotResult } from 'discord-api-types/v10'; const rest = new REST().setToken(process.env.DISCORD_TOKEN); // This example will spawn Discord's recommended shard count, all under the current process. const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, // for no intents - rest, + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, // uncomment if you have zlib-sync installed and want to use compression // compression: CompressionMethod.ZlibSync, @@ -70,8 +73,10 @@ await manager.connect(); const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, - rest, shardCount: 4, + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, }); // The manager also supports being responsible for only a subset of your shards: @@ -81,21 +86,25 @@ const manager = new WebSocketManager({ const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, - rest, shardCount: 8, shardIds: [0, 2, 4, 6], + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, }); // Alternatively, if your shards are consecutive, you can pass in a range const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, - rest, shardCount: 8, shardIds: { start: 0, end: 4, }, + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, }); ``` @@ -111,8 +120,10 @@ const rest = new REST().setToken(process.env.DISCORD_TOKEN); const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, - rest, shardCount: 6, + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, // This will cause 3 workers to spawn, 2 shards per each buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 2 }), // Or maybe you want all your shards under a single worker @@ -130,7 +141,9 @@ const rest = new REST().setToken(process.env.DISCORD_TOKEN); const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, - rest, + fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 2, diff --git a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts index 4a96408d3..6d547aabd 100644 --- a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts +++ b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts @@ -1,12 +1,8 @@ /* eslint-disable id-length */ import { setImmediate } from 'node:timers'; import { REST } from '@discordjs/rest'; -import { - GatewayDispatchEvents, - GatewayOpcodes, - type GatewayDispatchPayload, - type GatewaySendPayload, -} from 'discord-api-types/v10'; +import type { RESTGetAPIGatewayBotResult, GatewayDispatchPayload, GatewaySendPayload } from 'discord-api-types/v10'; +import { GatewayDispatchEvents, GatewayOpcodes, Routes } from 'discord-api-types/v10'; import { MockAgent, type Interceptable } from 'undici'; import { beforeEach, test, vi, expect, afterEach } from 'vitest'; import { @@ -159,7 +155,9 @@ test('spawn, connect, send a message, session info, and destroy', async () => { const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, - rest, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, shardIds: [0, 1], retrieveSessionInfo: mockRetrieveSessionInfo, updateSessionInfo: mockUpdateSessionInfo, diff --git a/packages/ws/__tests__/ws/WebSocketManager.test.ts b/packages/ws/__tests__/ws/WebSocketManager.test.ts index 877e339ca..16419ebb3 100644 --- a/packages/ws/__tests__/ws/WebSocketManager.test.ts +++ b/packages/ws/__tests__/ws/WebSocketManager.test.ts @@ -1,5 +1,6 @@ import { REST } from '@discordjs/rest'; -import { GatewayOpcodes, type APIGatewayBotInfo, type GatewaySendPayload } from 'discord-api-types/v10'; +import type { RESTGetAPIGatewayBotResult, APIGatewayBotInfo, GatewaySendPayload } from 'discord-api-types/v10'; +import { GatewayOpcodes, Routes } from 'discord-api-types/v10'; import { MockAgent, type Interceptable } from 'undici'; import { beforeEach, describe, expect, test, vi } from 'vitest'; import { WebSocketManager, type IShardingStrategy } from '../../src/index.js'; @@ -20,7 +21,13 @@ global.Date.now = NOW; test('fetch gateway information', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); const data: APIGatewayBotInfo = { shards: 1, @@ -89,7 +96,14 @@ test('fetch gateway information', async () => { describe('get shard count', () => { test('with shard count', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardCount: 2 }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + shardCount: 2, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); expect(await manager.getShardCount()).toBe(2); }); @@ -97,7 +111,14 @@ describe('get shard count', () => { test('with shard ids array', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); const shardIds = [5, 9]; - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardIds }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + shardIds, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); expect(await manager.getShardCount()).toBe(shardIds.at(-1)! + 1); }); @@ -105,7 +126,14 @@ describe('get shard count', () => { test('with shard id range', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); const shardIds = { start: 5, end: 9 }; - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardIds }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + shardIds, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); expect(await manager.getShardCount()).toBe(shardIds.end + 1); }); @@ -113,7 +141,14 @@ describe('get shard count', () => { test('update shard count', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardCount: 2 }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + shardCount: 2, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); const data: APIGatewayBotInfo = { shards: 1, @@ -162,7 +197,15 @@ test('update shard count', async () => { test('it handles passing in both shardIds and shardCount', async () => { const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); const shardIds = { start: 2, end: 3 }; - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardIds, shardCount: 4 }); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + shardIds, + shardCount: 4, + async fetchGatewayInformation() { + return rest.get(Routes.gatewayBot()) as Promise; + }, + }); expect(await manager.getShardCount()).toBe(4); expect(await manager.getShardIds()).toStrictEqual([2, 3]); diff --git a/packages/ws/src/utils/constants.ts b/packages/ws/src/utils/constants.ts index 9b41994be..413daecf4 100644 --- a/packages/ws/src/utils/constants.ts +++ b/packages/ws/src/utils/constants.ts @@ -68,7 +68,7 @@ export const DefaultWebSocketManagerOptions = { handshakeTimeout: 30_000, helloTimeout: 60_000, readyTimeout: 15_000, -} as const satisfies Omit; +} as const satisfies Omit; export const ImportantGatewayOpcodes = new Set([ GatewayOpcodes.Heartbeat, diff --git a/packages/ws/src/ws/WebSocketManager.ts b/packages/ws/src/ws/WebSocketManager.ts index 1f369ab99..ba73b43be 100644 --- a/packages/ws/src/ws/WebSocketManager.ts +++ b/packages/ws/src/ws/WebSocketManager.ts @@ -63,10 +63,6 @@ export interface RequiredWebSocketManagerOptions { * The intents to request */ intents: GatewayIntentBits | 0; - /** - * The REST instance to use for fetching gateway information - */ - rest: REST; } /** @@ -103,6 +99,21 @@ export interface OptionalWebSocketManagerOptions { * @defaultValue `'json'` */ encoding: Encoding; + /** + * Fetches the initial gateway URL used to connect to Discord. When missing, this will default to the gateway URL + * that Discord returns from the `/gateway/bot` route. + * + * @example + * ```ts + * const manager = new WebSocketManager({ + * token: process.env.DISCORD_TOKEN, + * fetchGatewayInformation() { + * return rest.get(Routes.gatewayBot()); + * }, + * }) + * ``` + */ + fetchGatewayInformation(): Awaitable; /** * How long to wait for a shard to connect before giving up */ @@ -127,6 +138,12 @@ export interface OptionalWebSocketManagerOptions { * How long to wait for a shard's READY packet before giving up */ readyTimeout: number | null; + /** + * The REST instance to use for fetching gateway information + * + * @deprecated Providing a REST instance is deprecated. Provide the `fetchGatewayInformation` function instead. + */ + rest?: REST; /** * Function used to retrieve session information (and attempt to resume) for a given shard * @@ -257,8 +274,24 @@ export class WebSocketManager extends AsyncEventEmitter i } public constructor(options: CreateWebSocketManagerOptions) { + if (!options.rest && !options.fetchGatewayInformation) { + throw new RangeError('Either a REST instance or a fetchGatewayInformation function must be provided'); + } + super(); - this.options = { ...DefaultWebSocketManagerOptions, ...options }; + this.options = { + ...DefaultWebSocketManagerOptions, + fetchGatewayInformation: + options.fetchGatewayInformation ?? + (async () => { + if (!options.rest) { + throw new RangeError('A REST instance must be provided if no fetchGatewayInformation function is provided'); + } + + return options.rest.get(Routes.gatewayBot()) as Promise; + }), + ...options, + }; this.strategy = this.options.buildStrategy(this); this.#token = options.token ?? null; } @@ -277,7 +310,7 @@ export class WebSocketManager extends AsyncEventEmitter i } } - const data = (await this.options.rest.get(Routes.gatewayBot())) as RESTGetAPIGatewayBotResult; + const data = await this.options.fetchGatewayInformation(); // For single sharded bots session_start_limit.reset_after will be 0, use 5 seconds as a minimum expiration time this.gatewayInformation = { data, expiresAt: Date.now() + (data.session_start_limit.reset_after || 5_000) };