feat(ws): custom workers (#9004)

* feat(ws): custom workers

* chore: typo

* refactor(WebSocketShard): expose shard id

* chore: remove outdated readme comment

* chore: nits

* chore: remove unnecessary mutation

* feat: fancier resolution

* chore: remove unnecessary exports

* chore: apply suggestions

* refactor: use range errors

Co-authored-by: Aura Román <kyradiscord@gmail.com>
This commit is contained in:
DD
2023-01-10 19:31:56 +02:00
committed by GitHub
parent 39c4de2dbc
commit 828a13b526
11 changed files with 343 additions and 159 deletions

View File

@@ -8,6 +8,7 @@ export * from './strategies/sharding/WorkerShardingStrategy.js';
export * from './utils/constants.js';
export * from './utils/IdentifyThrottler.js';
export * from './utils/WorkerBootstrapper.js';
export * from './ws/WebSocketManager.js';
export * from './ws/WebSocketShard.js';

View File

@@ -67,7 +67,10 @@ export class SimpleShardingStrategy implements IShardingStrategy {
*/
public async send(shardId: number, payload: GatewaySendPayload) {
const shard = this.shards.get(shardId);
if (!shard) throw new Error(`Shard ${shardId} not found`);
if (!shard) {
throw new RangeError(`Shard ${shardId} not found`);
}
return shard.send(payload);
}

View File

@@ -1,5 +1,5 @@
import { once } from 'node:events';
import { join } from 'node:path';
import { join, isAbsolute, resolve } from 'node:path';
import { Worker } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
@@ -38,6 +38,7 @@ export enum WorkerRecievePayloadOp {
UpdateSessionInfo,
WaitForIdentify,
FetchStatusResponse,
WorkerReady,
}
export type WorkerRecievePayload =
@@ -48,7 +49,8 @@ export type WorkerRecievePayload =
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }
| { op: WorkerRecievePayloadOp.WorkerReady };
/**
* Options for a {@link WorkerShardingStrategy}
@@ -58,6 +60,10 @@ export interface WorkerShardingStrategyOptions {
* Dictates how many shards should be spawned per worker thread.
*/
shardsPerWorker: number | 'all';
/**
* Path to the worker file to use. The worker requires quite a bit of setup, it is recommended you leverage the {@link WorkerBootstrapper} class.
*/
workerPath?: string;
}
/**
@@ -93,32 +99,20 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const shardsPerWorker = this.options.shardsPerWorker === 'all' ? shardIds.length : this.options.shardsPerWorker;
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);
let shards = 0;
while (shards !== shardIds.length) {
const slice = shardIds.slice(shards, shardsPerWorker + shards);
const loops = Math.ceil(shardIds.length / shardsPerWorker);
const promises: Promise<void>[] = [];
for (let idx = 0; idx < loops; idx++) {
const slice = shardIds.slice(idx * shardsPerWorker, (idx + 1) * shardsPerWorker);
const workerData: WorkerData = {
...strategyOptions,
shardIds: slice,
};
const worker = new Worker(join(__dirname, 'worker.js'), { workerData });
await once(worker, 'online');
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));
this.#workers.push(worker);
for (const shardId of slice) {
this.#workerByShardId.set(shardId, worker);
}
shards += slice.length;
promises.push(this.setupWorker(workerData));
}
await Promise.all(promises);
}
/**
@@ -210,6 +204,63 @@ export class WorkerShardingStrategy implements IShardingStrategy {
return statuses;
}
private async setupWorker(workerData: WorkerData) {
const worker = new Worker(this.resolveWorkerPath(), { workerData });
await once(worker, 'online');
// We do this in case the user has any potentially long running code in their worker
await this.waitForWorkerReady(worker);
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));
this.#workers.push(worker);
for (const shardId of workerData.shardIds) {
this.#workerByShardId.set(shardId, worker);
}
}
private resolveWorkerPath(): string {
const path = this.options.workerPath;
if (!path) {
return join(__dirname, 'defaultWorker.js');
}
if (isAbsolute(path)) {
return path;
}
if (/^\.\.?[/\\]/.test(path)) {
return resolve(path);
}
try {
return require.resolve(path);
} catch {
return resolve(path);
}
}
private async waitForWorkerReady(worker: Worker): Promise<void> {
return new Promise((resolve) => {
const handler = (payload: WorkerRecievePayload) => {
if (payload.op === WorkerRecievePayloadOp.WorkerReady) {
resolve();
worker.off('message', handler);
}
};
worker.on('message', handler);
});
}
private async onMessage(worker: Worker, payload: WorkerRecievePayload) {
switch (payload.op) {
case WorkerRecievePayloadOp.Connected: {
@@ -260,6 +311,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
this.fetchStatusPromises.delete(payload.nonce);
break;
}
case WorkerRecievePayloadOp.WorkerReady: {
break;
}
}
}
}

View File

@@ -0,0 +1,4 @@
import { WorkerBootstrapper } from '../../utils/WorkerBootstrapper.js';
const bootstrapper = new WorkerBootstrapper();
void bootstrapper.bootstrap();

View File

@@ -1,117 +0,0 @@
import { isMainThread, workerData, parentPort } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js';
import { WorkerContextFetchingStrategy } from '../context/WorkerContextFetchingStrategy.js';
import {
WorkerRecievePayloadOp,
WorkerSendPayloadOp,
type WorkerData,
type WorkerRecievePayload,
type WorkerSendPayload,
} from './WorkerShardingStrategy.js';
if (isMainThread) {
throw new Error('Expected worker script to not be ran within the main thread');
}
const data = workerData as WorkerData;
const shards = new Collection<number, WebSocketShard>();
async function connect(shardId: number) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.connect();
}
async function destroy(shardId: number, options?: WebSocketShardDestroyOptions) {
const shard = shards.get(shardId);
if (!shard) {
throw new Error(`Shard ${shardId} does not exist`);
}
await shard.destroy(options);
}
for (const shardId of data.shardIds) {
const shard = new WebSocketShard(new WorkerContextFetchingStrategy(data), shardId);
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
shards.set(shardId, shard);
}
parentPort!
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerSendPayload) => {
switch (payload.op) {
case WorkerSendPayloadOp.Connect: {
await connect(payload.shardId);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Connected,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Destroy: {
await destroy(payload.shardId, payload.options);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Send: {
const shard = shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
await shard.send(payload.payload);
break;
}
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
case WorkerSendPayloadOp.FetchStatus: {
const shard = shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
const response = {
op: WorkerRecievePayloadOp.FetchStatusResponse,
status: shard.status,
nonce: payload.nonce,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(response);
break;
}
}
});

View File

@@ -0,0 +1,176 @@
import { isMainThread, parentPort, workerData } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { Awaitable } from '@discordjs/util';
import { WorkerContextFetchingStrategy } from '../strategies/context/WorkerContextFetchingStrategy.js';
import {
WorkerRecievePayloadOp,
WorkerSendPayloadOp,
type WorkerData,
type WorkerRecievePayload,
type WorkerSendPayload,
} from '../strategies/sharding/WorkerShardingStrategy.js';
import type { WebSocketShardDestroyOptions } from '../ws/WebSocketShard.js';
import { WebSocketShardEvents, WebSocketShard } from '../ws/WebSocketShard.js';
/**
* Options for bootstrapping the worker
*/
export interface BootstrapOptions {
/**
* Shard events to just arbitrarily forward to the parent thread for the manager to emit
* Note: By default, this will include ALL events
* you most likely want to handle dispatch within the worker itself
*/
forwardEvents?: WebSocketShardEvents[];
/**
* Function to call when a shard is created for additional setup
*/
shardCallback?(shard: WebSocketShard): Awaitable<void>;
}
/**
* Utility class for bootstraping a worker thread to be used for sharding
*/
export class WorkerBootstrapper {
/**
* The data passed to the worker thread
*/
protected readonly data = workerData as WorkerData;
/**
* The shards that are managed by this worker
*/
protected readonly shards = new Collection<number, WebSocketShard>();
public constructor() {
if (isMainThread) {
throw new Error('Expected WorkerBootstrap to not be used within the main thread');
}
}
/**
* Helper method to initiate a shard's connection process
*/
protected async connect(shardId: number): Promise<void> {
const shard = this.shards.get(shardId);
if (!shard) {
throw new RangeError(`Shard ${shardId} does not exist`);
}
await shard.connect();
}
/**
* Helper method to destroy a shard
*/
protected async destroy(shardId: number, options?: WebSocketShardDestroyOptions): Promise<void> {
const shard = this.shards.get(shardId);
if (!shard) {
throw new RangeError(`Shard ${shardId} does not exist`);
}
await shard.destroy(options);
}
/**
* Helper method to attach event listeners to the parentPort
*/
protected setupThreadEvents(): void {
parentPort!
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerSendPayload) => {
switch (payload.op) {
case WorkerSendPayloadOp.Connect: {
await this.connect(payload.shardId);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Connected,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Destroy: {
await this.destroy(payload.shardId, payload.options);
const response: WorkerRecievePayload = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: payload.shardId,
};
parentPort!.postMessage(response);
break;
}
case WorkerSendPayloadOp.Send: {
const shard = this.shards.get(payload.shardId);
if (!shard) {
throw new RangeError(`Shard ${payload.shardId} does not exist`);
}
await shard.send(payload.payload);
break;
}
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}
case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
case WorkerSendPayloadOp.FetchStatus: {
const shard = this.shards.get(payload.shardId);
if (!shard) {
throw new Error(`Shard ${payload.shardId} does not exist`);
}
const response = {
op: WorkerRecievePayloadOp.FetchStatusResponse,
status: shard.status,
nonce: payload.nonce,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(response);
break;
}
}
});
}
/**
* Bootstraps the worker thread with the provided options
*/
public async bootstrap(options: Readonly<BootstrapOptions> = {}): Promise<void> {
// Start by initializing the shards
for (const shardId of this.data.shardIds) {
const shard = new WebSocketShard(new WorkerContextFetchingStrategy(this.data), shardId);
for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
// Any additional setup the user might want to do
await options.shardCallback?.(shard);
this.shards.set(shardId, shard);
}
// Lastly, start listening to messages from the parent thread
this.setupThreadEvents();
const message = {
op: WorkerRecievePayloadOp.WorkerReady,
} satisfies WorkerRecievePayload;
parentPort!.postMessage(message);
}
}

View File

@@ -81,8 +81,6 @@ export interface SendRateLimitState {
export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private connection: WebSocket | null = null;
private readonly id: number;
private useIdentifyCompress = false;
private inflate: Inflate | null = null;
@@ -105,7 +103,9 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private readonly timeouts = new Collection<WebSocketShardEvents, NodeJS.Timeout>();
public readonly strategy: IContextFetchingStrategy;
private readonly strategy: IContextFetchingStrategy;
public readonly id: number;
#status: WebSocketShardStatus = WebSocketShardStatus.Idle;