feat(rest): use undici (#7747)

Co-authored-by: Vlad Frangu <kingdgrizzle@gmail.com>
Co-authored-by: ckohen <chaikohen@gmail.com>
This commit is contained in:
Khafra
2022-05-12 16:49:15 -04:00
committed by GitHub
parent 4515a1ea80
commit d1ec8c37ff
19 changed files with 964 additions and 605 deletions

View File

@@ -1,7 +1,6 @@
import { EventEmitter } from 'node:events';
import type { AgentOptions } from 'node:https';
import type Collection from '@discordjs/collection';
import type { RequestInit, Response } from 'node-fetch';
import type { request, Dispatcher } from 'undici';
import { CDN } from './CDN';
import {
HandlerRequestData,
@@ -20,10 +19,9 @@ import { DefaultRestOptions, RESTEvents } from './utils/constants';
*/
export interface RESTOptions {
/**
* HTTPS Agent options
* @default {}
* The agent to set globally
*/
agent: Omit<AgentOptions, 'keepAlive'>;
agent: Dispatcher;
/**
* The base api path, without version
* @default 'https://discord.com/api'
@@ -169,7 +167,7 @@ export interface APIRequest {
/**
* Additional HTTP options for this request
*/
options: RequestInit;
options: RequestOptions;
/**
* The data that was used to form the body of this request
*/
@@ -195,8 +193,7 @@ export interface RestEvents {
invalidRequestWarning: [invalidRequestInfo: InvalidRequestWarningData];
restDebug: [info: string];
rateLimited: [rateLimitInfo: RateLimitData];
request: [request: APIRequest];
response: [request: APIRequest, response: Response];
response: [request: APIRequest, response: Dispatcher.ResponseData];
newListener: [name: string, listener: (...args: any) => void];
removeListener: [name: string, listener: (...args: any) => void];
hashSweep: [sweptHashes: Collection<string, HashData>];
@@ -220,6 +217,8 @@ export interface REST {
(<S extends string | symbol>(event?: Exclude<S, keyof RestEvents>) => this);
}
export type RequestOptions = Exclude<Parameters<typeof request>[1], undefined>;
export class REST extends EventEmitter {
public readonly cdn: CDN;
public readonly requestManager: RequestManager;
@@ -234,13 +233,29 @@ export class REST extends EventEmitter {
.on(RESTEvents.HashSweep, this.emit.bind(this, RESTEvents.HashSweep));
this.on('newListener', (name, listener) => {
if (name === RESTEvents.Request || name === RESTEvents.Response) this.requestManager.on(name, listener);
if (name === RESTEvents.Response) this.requestManager.on(name, listener);
});
this.on('removeListener', (name, listener) => {
if (name === RESTEvents.Request || name === RESTEvents.Response) this.requestManager.off(name, listener);
if (name === RESTEvents.Response) this.requestManager.off(name, listener);
});
}
/**
* Gets the agent set for this instance
*/
public getAgent() {
return this.requestManager.agent;
}
/**
* Sets the default agent to use for requests performed by this instance
* @param agent Sets the agent to use
*/
public setAgent(agent: Dispatcher) {
this.requestManager.setAgent(agent);
return this;
}
/**
* Sets the authorization token that should be used for requests
* @param token The authorization token to use

View File

@@ -1,14 +1,13 @@
import { Blob } from 'node:buffer';
import { EventEmitter } from 'node:events';
import { Agent as httpAgent } from 'node:http';
import { Agent as httpsAgent } from 'node:https';
import Collection from '@discordjs/collection';
import { DiscordSnowflake } from '@sapphire/snowflake';
import FormData from 'form-data';
import type { RequestInit, BodyInit } from 'node-fetch';
import type { RESTOptions, RestEvents } from './REST';
import { FormData, type RequestInit, type BodyInit, type Dispatcher, Agent } from 'undici';
import type { RESTOptions, RestEvents, RequestOptions } from './REST';
import type { IHandler } from './handlers/IHandler';
import { SequentialHandler } from './handlers/SequentialHandler';
import { DefaultRestOptions, DefaultUserAgent, RESTEvents } from './utils/constants';
import { resolveBody } from './utils/utils';
/**
* Represents a file to be added to the request
@@ -53,6 +52,10 @@ export interface RequestData {
* If providing as BodyInit, set `passThroughBody: true`
*/
body?: BodyInit | unknown;
/**
* The {@link https://undici.nodejs.org/#/docs/api/Agent Agent} to use for the request.
*/
dispatcher?: Agent;
/**
* Files to be attached to this request
*/
@@ -94,11 +97,11 @@ export interface RequestHeaders {
* Possible API methods to be used when doing requests
*/
export const enum RequestMethod {
Delete = 'delete',
Get = 'get',
Patch = 'patch',
Post = 'post',
Put = 'put',
Delete = 'DELETE',
Get = 'GET',
Patch = 'PATCH',
Post = 'POST',
Put = 'PUT',
}
export type RouteLike = `/${string}`;
@@ -157,6 +160,11 @@ export interface RequestManager {
* Represents the class that manages handlers for endpoints
*/
export class RequestManager extends EventEmitter {
/**
* The {@link https://undici.nodejs.org/#/docs/api/Agent Agent} for all requests
* performed by this manager.
*/
public agent: Dispatcher | null = null;
/**
* The number of requests remaining in the global bucket
*/
@@ -187,7 +195,6 @@ export class RequestManager extends EventEmitter {
private hashTimer!: NodeJS.Timer;
private handlerTimer!: NodeJS.Timer;
private agent: httpsAgent | httpAgent | null = null;
public readonly options: RESTOptions;
@@ -196,6 +203,7 @@ export class RequestManager extends EventEmitter {
this.options = { ...DefaultRestOptions, ...options };
this.options.offset = Math.max(0, this.options.offset);
this.globalRemaining = this.options.globalRequestsPerSecond;
this.agent = options.agent ?? null;
// Start sweepers
this.setupSweepers();
@@ -263,6 +271,15 @@ export class RequestManager extends EventEmitter {
}
}
/**
* Sets the default agent to use for requests performed by this manager
* @param agent The agent to use
*/
public setAgent(agent: Dispatcher) {
this.agent = agent;
return this;
}
/**
* Sets the authorization token that should be used for requests
* @param token The authorization token to use
@@ -291,8 +308,8 @@ export class RequestManager extends EventEmitter {
this.handlers.get(`${hash.value}:${routeId.majorParameter}`) ??
this.createHandler(hash.value, routeId.majorParameter);
// Resolve the request into usable fetch/node-fetch options
const { url, fetchOptions } = this.resolveRequest(request);
// Resolve the request into usable fetch options
const { url, fetchOptions } = await this.resolveRequest(request);
// Queue the request
return handler.queueRequest(routeId, url, fetchOptions, {
@@ -321,13 +338,9 @@ export class RequestManager extends EventEmitter {
* Formats the request data to a usable format for fetch
* @param request The request data
*/
private resolveRequest(request: InternalRequest): { url: string; fetchOptions: RequestInit } {
private async resolveRequest(request: InternalRequest): Promise<{ url: string; fetchOptions: RequestOptions }> {
const { options } = this;
this.agent ??= options.api.startsWith('https')
? new httpsAgent({ ...options.agent, keepAlive: true })
: new httpAgent({ ...options.agent, keepAlive: true });
let query = '';
// If a query option is passed, use it
@@ -372,7 +385,18 @@ export class RequestManager extends EventEmitter {
// Attach all files to the request
for (const [index, file] of request.files.entries()) {
formData.append(file.key ?? `files[${index}]`, file.data, file.name);
const fileKey = file.key ?? `files[${index}]`;
// https://developer.mozilla.org/en-US/docs/Web/API/FormData/append#parameters
// FormData.append only accepts a string or Blob.
// https://developer.mozilla.org/en-US/docs/Web/API/Blob/Blob#parameters
// The Blob constructor accepts TypedArray/ArrayBuffer, strings, and Blobs.
if (Buffer.isBuffer(file.data) || typeof file.data === 'string') {
formData.append(fileKey, new Blob([file.data]), file.name);
} else {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
formData.append(fileKey, new Blob([`${file.data}`]), file.name);
}
}
// If a JSON body was added as well, attach it to the form data, using payload_json unless otherwise specified
@@ -389,8 +413,6 @@ export class RequestManager extends EventEmitter {
// Set the final body to the form data
finalBody = formData;
// Set the additional headers to the form data ones
additionalHeaders = formData.getHeaders();
// eslint-disable-next-line no-eq-null
} else if (request.body != null) {
@@ -404,14 +426,21 @@ export class RequestManager extends EventEmitter {
}
}
const fetchOptions = {
agent: this.agent,
body: finalBody,
finalBody = await resolveBody(finalBody);
const fetchOptions: RequestOptions = {
// eslint-disable-next-line @typescript-eslint/consistent-type-assertions
headers: { ...(request.headers ?? {}), ...additionalHeaders, ...headers } as Record<string, string>,
method: request.method,
method: request.method.toUpperCase() as Dispatcher.HttpMethod,
};
if (finalBody !== undefined) {
fetchOptions.body = finalBody as Exclude<RequestOptions['body'], undefined>;
}
// Prioritize setting an agent per request, use the agent for this instance otherwise.
fetchOptions.dispatcher = request.dispatcher ?? this.agent ?? undefined!;
return { url, fetchOptions };
}

View File

@@ -8,7 +8,6 @@ export class HTTPError extends Error {
public requestBody: RequestBody;
/**
* @param message The error message
* @param name The name of the error
* @param status The status code of the response
* @param method The method of the request that erred
@@ -16,14 +15,13 @@ export class HTTPError extends Error {
* @param bodyData The unparsed data for the request that errored
*/
public constructor(
message: string,
public override name: string,
public status: number,
public method: string,
public url: string,
bodyData: Pick<InternalRequest, 'files' | 'body'>,
) {
super(message);
super();
this.requestBody = { files: bodyData.files, json: bodyData.body };
}

View File

@@ -1,11 +1,11 @@
import type { RequestInit } from 'node-fetch';
import type { RequestOptions } from '../REST';
import type { HandlerRequestData, RouteData } from '../RequestManager';
export interface IHandler {
queueRequest: (
routeId: RouteData,
url: string,
options: RequestInit,
options: RequestOptions,
requestData: HandlerRequestData,
) => Promise<unknown>;
// eslint-disable-next-line @typescript-eslint/method-signature-style -- This is meant to be a getter returning a bool

View File

@@ -1,14 +1,14 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { AsyncQueue } from '@sapphire/async-queue';
import fetch, { RequestInit, Response } from 'node-fetch';
import { request, type Dispatcher } from 'undici';
import type { IHandler } from './IHandler';
import type { RateLimitData } from '../REST';
import type { RateLimitData, RequestOptions } from '../REST';
import type { HandlerRequestData, RequestManager, RouteData } from '../RequestManager';
import { DiscordAPIError, DiscordErrorData, OAuthErrorData } from '../errors/DiscordAPIError';
import { HTTPError } from '../errors/HTTPError';
import { RateLimitError } from '../errors/RateLimitError';
import { RESTEvents } from '../utils/constants';
import { hasSublimit, parseResponse } from '../utils/utils';
import { hasSublimit, parseHeader, parseResponse } from '../utils/utils';
/* Invalid request limiting is done on a per-IP basis, not a per-token basis.
* The best we can do is track invalid counts process-wide (on the theory that
@@ -168,7 +168,7 @@ export class SequentialHandler implements IHandler {
public async queueRequest(
routeId: RouteData,
url: string,
options: RequestInit,
options: RequestOptions,
requestData: HandlerRequestData,
): Promise<unknown> {
let queue = this.#asyncQueue;
@@ -218,14 +218,14 @@ export class SequentialHandler implements IHandler {
* The method that actually makes the request to the api, and updates info about the bucket accordingly
* @param routeId The generalized api route with literal ids for major parameters
* @param url The fully resolved url to make the request to
* @param options The node-fetch options needed to make the request
* @param options The fetch options needed to make the request
* @param requestData Extra data from the user's request needed for errors and additional processing
* @param retries The number of retries this request has already attempted (recursion)
*/
private async runRequest(
routeId: RouteData,
url: string,
options: RequestInit,
options: RequestOptions,
requestData: HandlerRequestData,
retries = 0,
): Promise<unknown> {
@@ -287,26 +287,12 @@ export class SequentialHandler implements IHandler {
const method = options.method ?? 'get';
if (this.manager.listenerCount(RESTEvents.Request)) {
this.manager.emit(RESTEvents.Request, {
method,
path: routeId.original,
route: routeId.bucketRoute,
options,
data: requestData,
retries,
});
}
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), this.manager.options.timeout).unref();
let res: Response;
let res: Dispatcher.ResponseData;
try {
// node-fetch typings are a bit weird, so we have to cast to any to get the correct signature
// Type 'AbortSignal' is not assignable to type 'import("discord.js-modules/node_modules/@types/node-fetch/externals").AbortSignal'
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
res = await fetch(url, { ...options, signal: controller.signal as any });
res = await request(url, { ...options, signal: controller.signal });
} catch (error: unknown) {
// Retry the specified number of times for possible timed out requests
if (error instanceof Error && error.name === 'AbortError' && retries !== this.manager.options.retries) {
@@ -329,17 +315,18 @@ export class SequentialHandler implements IHandler {
data: requestData,
retries,
},
res.clone(),
{ ...res },
);
}
const status = res.statusCode;
let retryAfter = 0;
const limit = res.headers.get('X-RateLimit-Limit');
const remaining = res.headers.get('X-RateLimit-Remaining');
const reset = res.headers.get('X-RateLimit-Reset-After');
const hash = res.headers.get('X-RateLimit-Bucket');
const retry = res.headers.get('Retry-After');
const limit = parseHeader(res.headers['x-ratelimit-limit']);
const remaining = parseHeader(res.headers['x-ratelimit-remaining']);
const reset = parseHeader(res.headers['x-ratelimit-reset-after']);
const hash = parseHeader(res.headers['x-ratelimit-bucket']);
const retry = parseHeader(res.headers['retry-after']);
// Update the total number of requests that can be made before the rate limit resets
this.limit = limit ? Number(limit) : Infinity;
@@ -371,7 +358,7 @@ export class SequentialHandler implements IHandler {
// Handle retryAfter, which means we have actually hit a rate limit
let sublimitTimeout: number | null = null;
if (retryAfter > 0) {
if (res.headers.get('X-RateLimit-Global')) {
if (res.headers['x-ratelimit-global'] !== undefined) {
this.manager.globalRemaining = 0;
this.manager.globalReset = Date.now() + retryAfter;
} else if (!this.localLimited) {
@@ -385,7 +372,7 @@ export class SequentialHandler implements IHandler {
}
// Count the invalid requests
if (res.status === 401 || res.status === 403 || res.status === 429) {
if (status === 401 || status === 403 || status === 429) {
if (!invalidCountResetTime || invalidCountResetTime < Date.now()) {
invalidCountResetTime = Date.now() + 1000 * 60 * 10;
invalidCount = 0;
@@ -404,9 +391,9 @@ export class SequentialHandler implements IHandler {
}
}
if (res.ok) {
if (status === 200) {
return parseResponse(res);
} else if (res.status === 429) {
} else if (status === 429) {
// A rate limit was hit - this may happen if the route isn't associated with an official bucket hash yet, or when first globally rate limited
const isGlobal = this.globalLimited;
let limit: number;
@@ -468,24 +455,24 @@ export class SequentialHandler implements IHandler {
}
// Since this is not a server side issue, the next request should pass, so we don't bump the retries counter
return this.runRequest(routeId, url, options, requestData, retries);
} else if (res.status >= 500 && res.status < 600) {
} else if (status >= 500 && status < 600) {
// Retry the specified number of times for possible server side issues
if (retries !== this.manager.options.retries) {
return this.runRequest(routeId, url, options, requestData, ++retries);
}
// We are out of retries, throw an error
throw new HTTPError(res.statusText, res.constructor.name, res.status, method, url, requestData);
throw new HTTPError(res.constructor.name, status, method, url, requestData);
} else {
// Handle possible malformed requests
if (res.status >= 400 && res.status < 500) {
if (status >= 400 && status < 500) {
// If we receive this status code, it means the token we had is no longer valid.
if (res.status === 401 && requestData.auth) {
if (status === 401 && requestData.auth) {
this.manager.setToken(null!);
}
// The request will not succeed for some reason, parse the error returned from the api
const data = (await parseResponse(res)) as DiscordErrorData | OAuthErrorData;
// throw the API error
throw new DiscordAPIError(data, 'code' in data ? data.code : data.error, res.status, method, url, requestData);
throw new DiscordAPIError(data, 'code' in data ? data.code : data.error, status, method, url, requestData);
}
return null;
}

View File

@@ -1,4 +1,5 @@
import { APIVersion } from 'discord-api-types/v10';
import { getGlobalDispatcher } from 'undici';
import type { RESTOptions } from '../REST';
// eslint-disable-next-line @typescript-eslint/no-var-requires, @typescript-eslint/no-require-imports, @typescript-eslint/no-unsafe-assignment
const Package = require('../../../package.json');
@@ -7,7 +8,9 @@ const Package = require('../../../package.json');
export const DefaultUserAgent = `DiscordBot (${Package.homepage}, ${Package.version})`;
export const DefaultRestOptions: Required<RESTOptions> = {
agent: {},
get agent() {
return getGlobalDispatcher();
},
api: 'https://discord.com/api',
authPrefix: 'Bot',
cdn: 'https://cdn.discordapp.com',
@@ -32,7 +35,6 @@ export const enum RESTEvents {
Debug = 'restDebug',
InvalidRequestWarning = 'invalidRequestWarning',
RateLimited = 'rateLimited',
Request = 'request',
Response = 'response',
HashSweep = 'hashSweep',
HandlerSweep = 'handlerSweep',

View File

@@ -1,7 +1,21 @@
import { Blob } from 'node:buffer';
import { URLSearchParams } from 'node:url';
import { types } from 'node:util';
import type { RESTPatchAPIChannelJSONBody } from 'discord-api-types/v10';
import type { Response } from 'node-fetch';
import { FormData, type Dispatcher, type RequestInit } from 'undici';
import type { RequestOptions } from '../REST';
import { RequestMethod } from '../RequestManager';
export function parseHeader(header: string | string[] | undefined): string | undefined {
if (header === undefined) {
return header;
} else if (typeof header === 'string') {
return header;
}
return header.join(';');
}
function serializeSearchParam(value: unknown): string | null {
switch (typeof value) {
case 'string':
@@ -43,14 +57,15 @@ export function makeURLSearchParams(options?: Record<string, unknown>) {
/**
* Converts the response to usable data
* @param res The node-fetch response
* @param res The fetch response
*/
export function parseResponse(res: Response): Promise<unknown> {
if (res.headers.get('Content-Type')?.startsWith('application/json')) {
return res.json();
export function parseResponse(res: Dispatcher.ResponseData): Promise<unknown> {
const header = parseHeader(res.headers['content-type']);
if (header?.startsWith('application/json')) {
return res.body.json();
}
return res.arrayBuffer();
return res.body.arrayBuffer();
}
/**
@@ -75,3 +90,48 @@ export function hasSublimit(bucketRoute: string, body?: unknown, method?: string
// If we are checking if a request has a sublimit on a route not checked above, sublimit all requests to avoid a flood of 429s
return true;
}
export async function resolveBody(body: RequestInit['body']): Promise<RequestOptions['body']> {
// eslint-disable-next-line no-eq-null
if (body == null) {
return null;
} else if (typeof body === 'string') {
return body;
} else if (types.isUint8Array(body)) {
return body;
} else if (types.isArrayBuffer(body)) {
return new Uint8Array(body);
} else if (body instanceof URLSearchParams) {
return body.toString();
} else if (body instanceof DataView) {
return new Uint8Array(body.buffer);
} else if (body instanceof Blob) {
return new Uint8Array(await body.arrayBuffer());
} else if (body instanceof FormData) {
return body;
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
} else if ((body as Iterable<Uint8Array>)[Symbol.iterator]) {
const chunks = [...(body as Iterable<Uint8Array>)];
const length = chunks.reduce((a, b) => a + b.length, 0);
const uint8 = new Uint8Array(length);
let lengthUsed = 0;
return chunks.reduce((a, b) => {
a.set(b, lengthUsed);
lengthUsed += b.length;
return a;
}, uint8);
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
} else if ((body as AsyncIterable<Uint8Array>)[Symbol.asyncIterator]) {
const chunks: Uint8Array[] = [];
for await (const chunk of body as AsyncIterable<Uint8Array>) {
chunks.push(chunk);
}
return Buffer.concat(chunks);
}
throw new TypeError(`Unable to resolve body.`);
}