feat(ws): custom workers (#9004)

* feat(ws): custom workers

* chore: typo

* refactor(WebSocketShard): expose shard id

* chore: remove outdated readme comment

* chore: nits

* chore: remove unnecessary mutation

* feat: fancier resolution

* chore: remove unnecessary exports

* chore: apply suggestions

* refactor: use range errors

Co-authored-by: Aura Román <kyradiscord@gmail.com>
This commit is contained in:
DD
2023-01-10 19:31:56 +02:00
committed by GitHub
parent 39c4de2dbc
commit 828a13b526
11 changed files with 343 additions and 159 deletions

View File

@@ -99,7 +99,9 @@ You can also have the shards spawn in worker threads:
```ts
import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws';
import { REST } from '@discordjs/rest';
const rest = new REST().setToken(process.env.DISCORD_TOKEN);
const manager = new WebSocketManager({
token: process.env.DISCORD_TOKEN,
intents: 0,
@@ -113,6 +115,51 @@ manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 2 }))
manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' }));
```
**Note**: By default, this will cause the workers to effectively only be responsible for the WebSocket connection, they simply pass up all the events back to the main process for the manager to emit. If you want to have the workers handle events as well, you can pass in a `workerPath` option to the `WorkerShardingStrategy` constructor:
```ts
import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws';
import { REST } from '@discordjs/rest';
const rest = new REST().setToken(process.env.DISCORD_TOKEN);
const manager = new WebSocketManager({
token: process.env.DISCORD_TOKEN,
intents: 0,
rest,
});
manager.setStrategy(
new WorkerShardingStrategy(manager, {
shardsPerWorker: 2,
workerPath: './worker.js',
}),
);
```
And your `worker.ts` file:
```ts
import { WorkerBootstrapper, WebSocketShardEvents } from '@discordjs/ws';
const bootstrapper = new WorkerBootstrapper();
void bootstrapper.bootstrap({
// Those will be sent to the main thread for the manager to emit
forwardEvents: [
WebSocketShardEvents.Closed,
WebSocketShardEvents.Debug,
WebSocketShardEvents.Hello,
WebSocketShardEvents.Ready,
WebSocketShardEvents.Resumed,
],
shardCallback: (shard) => {
shard.on(WebSocketShardEvents.Dispatch, (event) => {
// Process gateway events here however you want (e.g. send them through a message broker)
// You also have access to shard.id if you need it
});
},
});
```
## Links
- [Website][website] ([source][website-source])

View File

@@ -53,45 +53,54 @@ vi.mock('node:worker_threads', async () => {
super();
mockConstructor(...args);
// need to delay this by an event loop cycle to allow the strategy to attach a listener
setImmediate(() => this.emit('online'));
setImmediate(() => {
this.emit('online');
// same deal here
setImmediate(() => {
const message = {
op: WorkerRecievePayloadOp.WorkerReady,
} satisfies WorkerRecievePayload;
this.emit('message', message);
});
});
}
public postMessage(message: WorkerSendPayload) {
switch (message.op) {
case WorkerSendPayloadOp.Connect: {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Connected,
shardId: message.shardId,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);
break;
}
case WorkerSendPayloadOp.Destroy: {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: message.shardId,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);
break;
}
case WorkerSendPayloadOp.Send: {
if (message.payload.op === GatewayOpcodes.RequestGuildMembers) {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Event,
shardId: message.shardId,
event: WebSocketShardEvents.Dispatch,
data: memberChunkData,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);
// Fetch session info
const sessionFetch: WorkerRecievePayload = {
const sessionFetch = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId: message.shardId,
nonce: Math.random(),
};
} satisfies WorkerRecievePayload;
this.emit('message', sessionFetch);
}
@@ -102,11 +111,11 @@ vi.mock('node:worker_threads', async () => {
case WorkerSendPayloadOp.SessionInfoResponse: {
message.session ??= sessionInfo;
const session: WorkerRecievePayload = {
const session = {
op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId: message.session.shardId,
session: { ...message.session, sequence: message.session.sequence + 1 },
};
} satisfies WorkerRecievePayload;
this.emit('message', session);
break;
}
@@ -186,7 +195,7 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
await manager.connect();
expect(mockConstructor).toHaveBeenCalledWith(
expect.stringContaining('worker.js'),
expect.stringContaining('defaultWorker.js'),
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
);

View File

@@ -16,9 +16,15 @@
"module": "./dist/index.mjs",
"typings": "./dist/index.d.ts",
"exports": {
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"types": "./dist/index.d.ts"
".": {
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"./defaultWorker": {
"import": "./dist/defaultWorker.mjs",
"require": "./dist/defaultWorker.js"
}
},
"directories": {
"lib": "src",

View File

@@ -8,6 +8,7 @@ export * from './strategies/sharding/WorkerShardingStrategy.js';
export * from './utils/constants.js';
export * from './utils/IdentifyThrottler.js';
export * from './utils/WorkerBootstrapper.js';
export * from './ws/WebSocketManager.js';
export * from './ws/WebSocketShard.js';

View File

@@ -67,7 +67,10 @@ export class SimpleShardingStrategy implements IShardingStrategy {
*/
public async send(shardId: number, payload: GatewaySendPayload) {
const shard = this.shards.get(shardId);
if (!shard) throw new Error(`Shard ${shardId} not found`);
if (!shard) {
throw new RangeError(`Shard ${shardId} not found`);
}
return shard.send(payload);
}

View File

@@ -1,5 +1,5 @@
import { once } from 'node:events';
import { join } from 'node:path';
import { join, isAbsolute, resolve } from 'node:path';
import { Worker } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
@@ -38,6 +38,7 @@ export enum WorkerRecievePayloadOp {
UpdateSessionInfo,
WaitForIdentify,
FetchStatusResponse,
WorkerReady,
}
export type WorkerRecievePayload =
@@ -48,7 +49,8 @@ export type WorkerRecievePayload =
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }
| { op: WorkerRecievePayloadOp.WorkerReady };
/**
* Options for a {@link WorkerShardingStrategy}
@@ -58,6 +60,10 @@ export interface WorkerShardingStrategyOptions {
* Dictates how many shards should be spawned per worker thread.
*/
shardsPerWorker: number | 'all';
/**
* Path to the worker file to use. The worker requires quite a bit of setup, it is recommended you leverage the {@link WorkerBootstrapper} class.
*/
workerPath?: string;
}
/**
@@ -93,32 +99,20 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const shardsPerWorker = this.options.shardsPerWorker === 'all' ? shardIds.length : this.options.shardsPerWorker;
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
let shards = 0;
while (shards !== shardIds.length) {
const slice = shardIds.slice(shards, shardsPerWorker + shards);
const loops = Math.ceil(shardIds.length / shardsPerWorker);
const promises: Promise<void>[] = [];
for (let idx = 0; idx < loops; idx++) {
const slice = shardIds.slice(idx * shardsPerWorker, (idx + 1) * shardsPerWorker);
const workerData: WorkerData = {
...strategyOptions,
shardIds: slice,
};
const worker = new Worker(join(__dirname, 'worker.js'), { workerData });
await once(worker, 'online');
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));
this.#workers.push(worker);
for (const shardId of slice) {
this.#workerByShardId.set(shardId, worker);
}
shards += slice.length;
promises.push(this.setupWorker(workerData));
}
await Promise.all(promises);
}
/**
@@ -210,6 +204,63 @@ export class WorkerShardingStrategy implements IShardingStrategy {
return statuses;
}
private async setupWorker(workerData: WorkerData) {
const worker = new Worker(this.resolveWorkerPath(), { workerData });
await once(worker, 'online');
// We do this in case the user has any potentially long running code in their worker
await this.waitForWorkerReady(worker);
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));
this.#workers.push(worker);
for (const shardId of workerData.shardIds) {
this.#workerByShardId.set(shardId, worker);
}
}
private resolveWorkerPath(): string {
const path = this.options.workerPath;
if (!path) {
return join(__dirname, 'defaultWorker.js');
}
if (isAbsolute(path)) {
return path;
}
if (/^\.\.?[/\\]/.test(path)) {
return resolve(path);
}
try {
return require.resolve(path);
} catch {
return resolve(path);
}
}
private async waitForWorkerReady(worker: Worker): Promise<void> {
return new Promise((resolve) => {
const handler = (payload: WorkerRecievePayload) => {
if (payload.op === WorkerRecievePayloadOp.WorkerReady) {
resolve();
worker.off('message', handler);
}
};
worker.on('message', handler);
});
}
private async onMessage(worker: Worker, payload: WorkerRecievePayload) {
switch (payload.op) {
case WorkerRecievePayloadOp.Connected: {
@@ -260,6 +311,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
this.fetchStatusPromises.delete(payload.nonce);
break;
}
case WorkerRecievePayloadOp.WorkerReady: {
break;
}
}
}
}

View File

@@ -0,0 +1,4 @@
import { WorkerBootstrapper } from '../../utils/WorkerBootstrapper.js';
const bootstrapper = new WorkerBootstrapper();
void bootstrapper.bootstrap();

View File

@@ -1,117 +0,0 @@
import { isMainThread, workerData, parentPort } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js';
import { WorkerContextFetchingStrategy } from '../context/WorkerContextFetchingStrategy.js';
import {
WorkerRecievePayloadOp,
WorkerSendPayloadOp,
type WorkerData,
type WorkerRecievePayload,
type WorkerSendPayload,
} from './WorkerShardingStrategy.js';
if (isMainThread) {
throw new Error('Expected worker script to not be ran within the main thread');
}
const data = workerData as WorkerData;
const shards = new Collection<number, WebSocketShard>();
async function connect(shardId: number) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.connect();
}
async function destroy(shardId: number, options?: WebSocketShardDestroyOptions) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.destroy(options);
}
for (const shardId of data.shardIds) {
const shard = new WebSocketShard(new WorkerContextFetchingStrategy(data), shardId);
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
shards.set(shardId, shard);
}
parentPort!
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerSendPayload) => {
switch (payload.op) {
case WorkerSendPayloadOp.Connect: {
await connect(payload.shardId);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Connected,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Destroy: {
await destroy(payload.shardId, payload.options);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Send: {
const shard = shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
await shard.send(payload.payload);
break;
}
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
case WorkerSendPayloadOp.FetchStatus: {
const shard = shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
const response = {
op: WorkerRecievePayloadOp.FetchStatusResponse,
status: shard.status,
nonce: payload.nonce,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(response);
break;
}
}
});

View File

@@ -0,0 +1,176 @@
import { isMainThread, parentPort, workerData } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { Awaitable } from '@discordjs/util';
import { WorkerContextFetchingStrategy } from '../strategies/context/WorkerContextFetchingStrategy.js';
import {
WorkerRecievePayloadOp,
WorkerSendPayloadOp,
type WorkerData,
type WorkerRecievePayload,
type WorkerSendPayload,
} from '../strategies/sharding/WorkerShardingStrategy.js';
import type { WebSocketShardDestroyOptions } from '../ws/WebSocketShard.js';
import { WebSocketShardEvents, WebSocketShard } from '../ws/WebSocketShard.js';
/**
* Options for bootstrapping the worker
*/
export interface BootstrapOptions {
/**
* Shard events to just arbitrarily forward to the parent thread for the manager to emit
* Note: By default, this will include ALL events
* you most likely want to handle dispatch within the worker itself
*/
forwardEvents?: WebSocketShardEvents[];
/**
* Function to call when a shard is created for additional setup
*/
shardCallback?(shard: WebSocketShard): Awaitable<void>;
}
/**
* Utility class for bootstraping a worker thread to be used for sharding
*/
export class WorkerBootstrapper {
/**
* The data passed to the worker thread
*/
protected readonly data = workerData as WorkerData;
/**
* The shards that are managed by this worker
*/
protected readonly shards = new Collection<number, WebSocketShard>();
public constructor() {
if (isMainThread) {
throw new Error('Expected WorkerBootstrap to not be used within the main thread');
}
}
/**
* Helper method to initiate a shard's connection process
*/
protected async connect(shardId: number): Promise<void> {
const shard = this.shards.get(shardId);
if (!shard) {
throw new RangeError(`Shard ${shardId} does not exist`);
}
await shard.connect();
}
/**
* Helper method to destroy a shard
*/
protected async destroy(shardId: number, options?: WebSocketShardDestroyOptions): Promise<void> {
const shard = this.shards.get(shardId);
if (!shard) {
throw new RangeError(`Shard ${shardId} does not exist`);
}
await shard.destroy(options);
}
/**
* Helper method to attach event listeners to the parentPort
*/
protected setupThreadEvents(): void {
parentPort!
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerSendPayload) => {
switch (payload.op) {
case WorkerSendPayloadOp.Connect: {
await this.connect(payload.shardId);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Connected,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Destroy: {
await this.destroy(payload.shardId, payload.options);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Send: {
const shard = this.shards.get(payload.shardId);
if (!shard) {
throw new RangeError(`Shard ${payload.shardId} does not exist`);
}
await shard.send(payload.payload);
break;
}
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
case WorkerSendPayloadOp.FetchStatus: {
const shard = this.shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
const response = {
op: WorkerRecievePayloadOp.FetchStatusResponse,
status: shard.status,
nonce: payload.nonce,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(response);
break;
}
}
});
}
/**
* Bootstraps the worker thread with the provided options
*/
public async bootstrap(options: Readonly<BootstrapOptions> = {}): Promise<void> {
// Start by initializing the shards
for (const shardId of this.data.shardIds) {
const shard = new WebSocketShard(new WorkerContextFetchingStrategy(this.data), shardId);
for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
// Any additional setup the user might want to do
await options.shardCallback?.(shard);
this.shards.set(shardId, shard);
}
// Lastly, start listening to messages from the parent thread
this.setupThreadEvents();
const message = {
op: WorkerRecievePayloadOp.WorkerReady,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(message);
}
}

View File

@@ -81,8 +81,6 @@ export interface SendRateLimitState {
export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private connection: WebSocket | null = null;
private readonly id: number;
private useIdentifyCompress = false;
private inflate: Inflate | null = null;
@@ -105,7 +103,9 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private readonly timeouts = new Collection<WebSocketShardEvents, NodeJS.Timeout>();
public readonly strategy: IContextFetchingStrategy;
private readonly strategy: IContextFetchingStrategy;
public readonly id: number;
#status: WebSocketShardStatus = WebSocketShardStatus.Idle;

View File

@@ -4,7 +4,7 @@ import { createTsupConfig } from '../../tsup.config.js';
export default createTsupConfig({
entry: {
index: 'src/index.ts',
worker: 'src/strategies/sharding/worker.ts',
defaultWorker: 'src/strategies/sharding/defaultWorker.ts',
},
external: ['zlib-sync'],
esbuildPlugins: [esbuildPluginVersionInjector()],