From 5c5a5832b94cd4d371cc99c4f9c3384523dabeeb Mon Sep 17 00:00:00 2001 From: DD Date: Sun, 19 Feb 2023 20:57:31 +0200 Subject: [PATCH] refactor(WebSocketManager): passing in strategy (#9122) * refactor(WebSocketManager): passing in strategy * chore: update tests * chore: requested nits --------- Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com> --- packages/ws/README.md | 21 ++++++++--------- .../strategy/WorkerShardingStrategy.test.ts | 4 +--- .../ws/__tests__/ws/WebSocketManager.test.ts | 15 ++++++++---- packages/ws/src/utils/constants.ts | 2 ++ packages/ws/src/ws/WebSocketManager.ts | 23 +++++++++++++------ 5 files changed, 39 insertions(+), 26 deletions(-) diff --git a/packages/ws/README.md b/packages/ws/README.md index cfeb39cd1..87b5b2567 100644 --- a/packages/ws/README.md +++ b/packages/ws/README.md @@ -107,12 +107,11 @@ const manager = new WebSocketManager({ intents: 0, rest, shardCount: 6, + // This will cause 3 workers to spawn, 2 shards per each + buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 2 }), + // Or maybe you want all your shards under a single worker + buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' }), }); - -// This will cause 3 workers to spawn, 2 shards per each -manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 2 })); -// Or maybe you want all your shards under a single worker -manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' })); ``` **Note**: By default, this will cause the workers to effectively only be responsible for the WebSocket connection, they simply pass up all the events back to the main process for the manager to emit. If you want to have the workers handle events as well, you can pass in a `workerPath` option to the `WorkerShardingStrategy` constructor: @@ -126,14 +125,12 @@ const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, rest, + buildStrategy: (manager) => + new WorkerShardingStrategy(manager, { + shardsPerWorker: 2, + workerPath: './worker.js', + }), }); - -manager.setStrategy( - new WorkerShardingStrategy(manager, { - shardsPerWorker: 2, - workerPath: './worker.js', - }), -); ``` And your `worker.ts` file: diff --git a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts index 8106cb9d6..472f518ca 100644 --- a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts +++ b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts @@ -163,6 +163,7 @@ test('spawn, connect, send a message, session info, and destroy', async () => { shardIds: [0, 1], retrieveSessionInfo: mockRetrieveSessionInfo, updateSessionInfo: mockUpdateSessionInfo, + buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' }), }); const managerEmitSpy = vi.spyOn(manager, 'emit'); @@ -191,9 +192,6 @@ test('spawn, connect, send a message, session info, and destroy', async () => { }, })); - const strategy = new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' }); - manager.setStrategy(strategy); - await manager.connect(); expect(mockConstructor).toHaveBeenCalledWith( expect.stringContaining('defaultWorker.js'), diff --git a/packages/ws/__tests__/ws/WebSocketManager.test.ts b/packages/ws/__tests__/ws/WebSocketManager.test.ts index fad53ae1c..d34a4dd84 100644 --- a/packages/ws/__tests__/ws/WebSocketManager.test.ts +++ b/packages/ws/__tests__/ws/WebSocketManager.test.ts @@ -177,14 +177,21 @@ test('strategies', async () => { public destroy = vi.fn(); public send = vi.fn(); + + public fetchStatus = vi.fn(); } + const strategy = new MockStrategy(); + const rest = new REST().setAgent(mockAgent).setToken('A-Very-Fake-Token'); const shardIds = [0, 1, 2]; - const manager = new WebSocketManager({ token: 'A-Very-Fake-Token', intents: 0, rest, shardIds }); - - const strategy = new MockStrategy(); - manager.setStrategy(strategy); + const manager = new WebSocketManager({ + token: 'A-Very-Fake-Token', + intents: 0, + rest, + shardIds, + buildStrategy: () => strategy, + }); const data: APIGatewayBotInfo = { shards: 1, diff --git a/packages/ws/src/utils/constants.ts b/packages/ws/src/utils/constants.ts index eebeef201..b0208d00f 100644 --- a/packages/ws/src/utils/constants.ts +++ b/packages/ws/src/utils/constants.ts @@ -2,6 +2,7 @@ import process from 'node:process'; import { Collection } from '@discordjs/collection'; import { lazy } from '@discordjs/util'; import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10'; +import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js'; import type { SessionInfo, OptionalWebSocketManagerOptions } from '../ws/WebSocketManager.js'; import type { SendRateLimitState } from '../ws/WebSocketShard.js'; @@ -27,6 +28,7 @@ const getDefaultSessionStore = lazy(() => new Collection new SimpleShardingStrategy(manager), shardCount: null, shardIds: null, largeThreshold: null, diff --git a/packages/ws/src/ws/WebSocketManager.ts b/packages/ws/src/ws/WebSocketManager.ts index 5321535c4..a18986b4a 100644 --- a/packages/ws/src/ws/WebSocketManager.ts +++ b/packages/ws/src/ws/WebSocketManager.ts @@ -11,7 +11,6 @@ import { type GatewaySendPayload, } from 'discord-api-types/v10'; import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy'; -import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js'; import { DefaultWebSocketManagerOptions, type CompressionMethod, type Encoding } from '../utils/constants.js'; import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard.js'; @@ -71,6 +70,20 @@ export interface RequiredWebSocketManagerOptions { * Optional additional configuration for the WebSocketManager */ export interface OptionalWebSocketManagerOptions { + /** + * Builds the strategy to use for sharding + * + * @example + * ```ts + * const manager = new WebSocketManager({ + * token: process.env.DISCORD_TOKEN, + * intents: 0, // for no intents + * rest, + * buildStrategy: (manager) => new WorkerShardingStrategy(manager, { shardsPerWorker: 2 }), + * }); + * ``` + */ + buildStrategy(manager: WebSocketManager): IShardingStrategy; /** * The compression method to use * @@ -192,16 +205,12 @@ export class WebSocketManager extends AsyncEventEmitter { * * @defaultValue `SimpleShardingStrategy` */ - private strategy: IShardingStrategy = new SimpleShardingStrategy(this); + private readonly strategy: IShardingStrategy; public constructor(options: Partial & RequiredWebSocketManagerOptions) { super(); this.options = { ...DefaultWebSocketManagerOptions, ...options }; - } - - public setStrategy(strategy: IShardingStrategy) { - this.strategy = strategy; - return this; + this.strategy = this.options.buildStrategy(this); } /**