fix(RequestHandler): only reset tokens for authenticated 401s (#7508)

This commit is contained in:
Vlad Frangu
2022-03-06 21:43:12 +02:00
committed by GitHub
parent c12d61a342
commit b9ff7b0573
6 changed files with 50 additions and 25 deletions

View File

@@ -245,7 +245,7 @@ test('Request and Response Events', async () => {
method: 'get', method: 'get',
path: '/request', path: '/request',
route: '/request', route: '/request',
data: { files: undefined, body: undefined }, data: { files: undefined, body: undefined, auth: true },
retries: 0, retries: 0,
}) as APIRequest, }) as APIRequest,
); );
@@ -254,7 +254,7 @@ test('Request and Response Events', async () => {
method: 'get', method: 'get',
path: '/request', path: '/request',
route: '/request', route: '/request',
data: { files: undefined, body: undefined }, data: { files: undefined, body: undefined, auth: true },
retries: 0, retries: 0,
}) as APIRequest, }) as APIRequest,
expect.objectContaining({ status: 200, statusText: 'OK' }) as Response, expect.objectContaining({ status: 200, statusText: 'OK' }) as Response,

View File

@@ -357,9 +357,19 @@ test('Bad Request', async () => {
}); });
test('Unauthorized', async () => { test('Unauthorized', async () => {
const setTokenSpy = jest.spyOn(invalidAuthApi.requestManager, 'setToken');
// Ensure authless requests don't reset the token
const promiseWithoutTokenClear = invalidAuthApi.get('/unauthorized', { auth: false });
await expect(promiseWithoutTokenClear).rejects.toThrowError('401: Unauthorized');
await expect(promiseWithoutTokenClear).rejects.toBeInstanceOf(DiscordAPIError);
expect(setTokenSpy).not.toHaveBeenCalled();
// Ensure authed requests do reset the token
const promise = invalidAuthApi.get('/unauthorized'); const promise = invalidAuthApi.get('/unauthorized');
await expect(promise).rejects.toThrowError('401: Unauthorized'); await expect(promise).rejects.toThrowError('401: Unauthorized');
await expect(promise).rejects.toBeInstanceOf(DiscordAPIError); await expect(promise).rejects.toBeInstanceOf(DiscordAPIError);
expect(setTokenSpy).toHaveBeenCalledTimes(1);
}); });
test('Reject on RateLimit', async () => { test('Reject on RateLimit', async () => {

View File

@@ -1,6 +1,13 @@
import { EventEmitter } from 'node:events'; import { EventEmitter } from 'node:events';
import { CDN } from './CDN'; import { CDN } from './CDN';
import { InternalRequest, RequestData, RequestManager, RequestMethod, RouteLike } from './RequestManager'; import {
HandlerRequestData,
InternalRequest,
RequestData,
RequestManager,
RequestMethod,
RouteLike,
} from './RequestManager';
import { DefaultRestOptions, RESTEvents } from './utils/constants'; import { DefaultRestOptions, RESTEvents } from './utils/constants';
import type { AgentOptions } from 'node:https'; import type { AgentOptions } from 'node:https';
import type { RequestInit, Response } from 'node-fetch'; import type { RequestInit, Response } from 'node-fetch';
@@ -160,7 +167,7 @@ export interface APIRequest {
/** /**
* The data that was used to form the body of this request * The data that was used to form the body of this request
*/ */
data: Pick<InternalRequest, 'files' | 'body'>; data: HandlerRequestData;
/** /**
* The number of times this request has been attempted * The number of times this request has been attempted
*/ */

View File

@@ -113,6 +113,8 @@ export interface InternalRequest extends RequestData {
fullRoute: RouteLike; fullRoute: RouteLike;
} }
export type HandlerRequestData = Pick<InternalRequest, 'files' | 'body' | 'auth'>;
/** /**
* Parsed route data for an endpoint * Parsed route data for an endpoint
* *
@@ -293,7 +295,11 @@ export class RequestManager extends EventEmitter {
const { url, fetchOptions } = this.resolveRequest(request); const { url, fetchOptions } = this.resolveRequest(request);
// Queue the request // Queue the request
return handler.queueRequest(routeId, url, fetchOptions, { body: request.body, files: request.files }); return handler.queueRequest(routeId, url, fetchOptions, {
body: request.body,
files: request.files,
auth: request.auth !== false,
});
} }
/** /**

View File

@@ -1,13 +1,14 @@
import type { RequestInit } from 'node-fetch'; import type { RequestInit } from 'node-fetch';
import type { InternalRequest, RouteData } from '../RequestManager'; import type { HandlerRequestData, RouteData } from '../RequestManager';
export interface IHandler { export interface IHandler {
queueRequest: ( queueRequest: (
routeId: RouteData, routeId: RouteData,
url: string, url: string,
options: RequestInit, options: RequestInit,
bodyData: Pick<InternalRequest, 'files' | 'body'>, requestData: HandlerRequestData,
) => Promise<unknown>; ) => Promise<unknown>;
readonly inactive: boolean; // eslint-disable-next-line @typescript-eslint/method-signature-style -- This is meant to be a getter returning a bool
get inactive(): boolean;
readonly id: string; readonly id: string;
} }

View File

@@ -4,10 +4,11 @@ import fetch, { RequestInit, Response } from 'node-fetch';
import { DiscordAPIError, DiscordErrorData, OAuthErrorData } from '../errors/DiscordAPIError'; import { DiscordAPIError, DiscordErrorData, OAuthErrorData } from '../errors/DiscordAPIError';
import { HTTPError } from '../errors/HTTPError'; import { HTTPError } from '../errors/HTTPError';
import { RateLimitError } from '../errors/RateLimitError'; import { RateLimitError } from '../errors/RateLimitError';
import type { InternalRequest, RequestManager, RouteData } from '../RequestManager'; import type { HandlerRequestData, RequestManager, RouteData } from '../RequestManager';
import { RESTEvents } from '../utils/constants'; import { RESTEvents } from '../utils/constants';
import { hasSublimit, parseResponse } from '../utils/utils'; import { hasSublimit, parseResponse } from '../utils/utils';
import type { RateLimitData } from '../REST'; import type { RateLimitData } from '../REST';
import type { IHandler } from './IHandler';
/* Invalid request limiting is done on a per-IP basis, not a per-token basis. /* Invalid request limiting is done on a per-IP basis, not a per-token basis.
* The best we can do is track invalid counts process-wide (on the theory that * The best we can do is track invalid counts process-wide (on the theory that
@@ -26,7 +27,7 @@ const enum QueueType {
/** /**
* The structure used to handle requests for a given bucket * The structure used to handle requests for a given bucket
*/ */
export class SequentialHandler { export class SequentialHandler implements IHandler {
/** /**
* The unique id of the handler * The unique id of the handler
*/ */
@@ -162,18 +163,18 @@ export class SequentialHandler {
* @param routeId The generalized api route with literal ids for major parameters * @param routeId The generalized api route with literal ids for major parameters
* @param url The url to do the request on * @param url The url to do the request on
* @param options All the information needed to make a request * @param options All the information needed to make a request
* @param bodyData The data that was used to form the body, passed to any errors generated and for determining whether to sublimit * @param requestData Extra data from the user's request needed for errors and additional processing
*/ */
public async queueRequest( public async queueRequest(
routeId: RouteData, routeId: RouteData,
url: string, url: string,
options: RequestInit, options: RequestInit,
bodyData: Pick<InternalRequest, 'files' | 'body'>, requestData: HandlerRequestData,
): Promise<unknown> { ): Promise<unknown> {
let queue = this.#asyncQueue; let queue = this.#asyncQueue;
let queueType = QueueType.Standard; let queueType = QueueType.Standard;
// Separate sublimited requests when already sublimited // Separate sublimited requests when already sublimited
if (this.#sublimitedQueue && hasSublimit(routeId.bucketRoute, bodyData.body, options.method)) { if (this.#sublimitedQueue && hasSublimit(routeId.bucketRoute, requestData.body, options.method)) {
queue = this.#sublimitedQueue!; queue = this.#sublimitedQueue!;
queueType = QueueType.Sublimit; queueType = QueueType.Sublimit;
} }
@@ -181,7 +182,7 @@ export class SequentialHandler {
await queue.wait(); await queue.wait();
// This set handles retroactively sublimiting requests // This set handles retroactively sublimiting requests
if (queueType === QueueType.Standard) { if (queueType === QueueType.Standard) {
if (this.#sublimitedQueue && hasSublimit(routeId.bucketRoute, bodyData.body, options.method)) { if (this.#sublimitedQueue && hasSublimit(routeId.bucketRoute, requestData.body, options.method)) {
/** /**
* Remove the request from the standard queue, it should never be possible to get here while processing the * Remove the request from the standard queue, it should never be possible to get here while processing the
* sublimit queue so there is no need to worry about shifting the wrong request * sublimit queue so there is no need to worry about shifting the wrong request
@@ -197,7 +198,7 @@ export class SequentialHandler {
} }
try { try {
// Make the request, and return the results // Make the request, and return the results
return await this.runRequest(routeId, url, options, bodyData); return await this.runRequest(routeId, url, options, requestData);
} finally { } finally {
// Allow the next request to fire // Allow the next request to fire
queue.shift(); queue.shift();
@@ -218,14 +219,14 @@ export class SequentialHandler {
* @param routeId The generalized api route with literal ids for major parameters * @param routeId The generalized api route with literal ids for major parameters
* @param url The fully resolved url to make the request to * @param url The fully resolved url to make the request to
* @param options The node-fetch options needed to make the request * @param options The node-fetch options needed to make the request
* @param bodyData The data that was used to form the body, passed to any errors generated * @param requestData Extra data from the user's request needed for errors and additional processing
* @param retries The number of retries this request has already attempted (recursion) * @param retries The number of retries this request has already attempted (recursion)
*/ */
private async runRequest( private async runRequest(
routeId: RouteData, routeId: RouteData,
url: string, url: string,
options: RequestInit, options: RequestInit,
bodyData: Pick<InternalRequest, 'files' | 'body'>, requestData: HandlerRequestData,
retries = 0, retries = 0,
): Promise<unknown> { ): Promise<unknown> {
/* /*
@@ -292,7 +293,7 @@ export class SequentialHandler {
path: routeId.original, path: routeId.original,
route: routeId.bucketRoute, route: routeId.bucketRoute,
options, options,
data: bodyData, data: requestData,
retries, retries,
}); });
} }
@@ -309,7 +310,7 @@ export class SequentialHandler {
} catch (error: unknown) { } catch (error: unknown) {
// Retry the specified number of times for possible timed out requests // Retry the specified number of times for possible timed out requests
if (error instanceof Error && error.name === 'AbortError' && retries !== this.manager.options.retries) { if (error instanceof Error && error.name === 'AbortError' && retries !== this.manager.options.retries) {
return await this.runRequest(routeId, url, options, bodyData, ++retries); return await this.runRequest(routeId, url, options, requestData, ++retries);
} }
throw error; throw error;
@@ -325,7 +326,7 @@ export class SequentialHandler {
path: routeId.original, path: routeId.original,
route: routeId.bucketRoute, route: routeId.bucketRoute,
options, options,
data: bodyData, data: requestData,
retries, retries,
}, },
res.clone(), res.clone(),
@@ -466,25 +467,25 @@ export class SequentialHandler {
} }
} }
// Since this is not a server side issue, the next request should pass, so we don't bump the retries counter // Since this is not a server side issue, the next request should pass, so we don't bump the retries counter
return this.runRequest(routeId, url, options, bodyData, retries); return this.runRequest(routeId, url, options, requestData, retries);
} else if (res.status >= 500 && res.status < 600) { } else if (res.status >= 500 && res.status < 600) {
// Retry the specified number of times for possible server side issues // Retry the specified number of times for possible server side issues
if (retries !== this.manager.options.retries) { if (retries !== this.manager.options.retries) {
return this.runRequest(routeId, url, options, bodyData, ++retries); return this.runRequest(routeId, url, options, requestData, ++retries);
} }
// We are out of retries, throw an error // We are out of retries, throw an error
throw new HTTPError(res.statusText, res.constructor.name, res.status, method, url, bodyData); throw new HTTPError(res.statusText, res.constructor.name, res.status, method, url, requestData);
} else { } else {
// Handle possible malformed requests // Handle possible malformed requests
if (res.status >= 400 && res.status < 500) { if (res.status >= 400 && res.status < 500) {
// If we receive this status code, it means the token we had is no longer valid. // If we receive this status code, it means the token we had is no longer valid.
if (res.status === 401) { if (res.status === 401 && requestData.auth) {
this.manager.setToken(null!); this.manager.setToken(null!);
} }
// The request will not succeed for some reason, parse the error returned from the api // The request will not succeed for some reason, parse the error returned from the api
const data = (await parseResponse(res)) as DiscordErrorData | OAuthErrorData; const data = (await parseResponse(res)) as DiscordErrorData | OAuthErrorData;
// throw the API error // throw the API error
throw new DiscordAPIError(data, 'code' in data ? data.code : data.error, res.status, method, url, bodyData); throw new DiscordAPIError(data, 'code' in data ? data.code : data.error, res.status, method, url, requestData);
} }
return null; return null;
} }