refactor(web): split trusted and strict web tool fetch paths

This commit is contained in:
Peter Steinberger
2026-03-02 01:13:57 +00:00
parent e07c51b045
commit e7cd4bf1bd
5 changed files with 61 additions and 40 deletions

View File

@@ -5,13 +5,14 @@ import {
} from "../../infra/net/fetch-guard.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
export const WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY: SsrFPolicy = {
const WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY: SsrFPolicy = {
dangerouslyAllowPrivateNetwork: true,
};
type WebToolGuardedFetchOptions = Omit<GuardedFetchOptions, "proxy"> & {
timeoutSeconds?: number;
};
type WebToolEndpointFetchOptions = Omit<WebToolGuardedFetchOptions, "policy">;
function resolveTimeoutMs(params: {
timeoutMs?: number;
@@ -37,7 +38,7 @@ export async function fetchWithWebToolsNetworkGuard(
});
}
export async function withWebToolsNetworkGuard<T>(
async function withWebToolsNetworkGuard<T>(
params: WebToolGuardedFetchOptions,
run: (result: { response: Response; finalUrl: string }) => Promise<T>,
): Promise<T> {
@@ -48,3 +49,23 @@ export async function withWebToolsNetworkGuard<T>(
await release();
}
}
export async function withTrustedWebToolsEndpoint<T>(
params: WebToolEndpointFetchOptions,
run: (result: { response: Response; finalUrl: string }) => Promise<T>,
): Promise<T> {
return await withWebToolsNetworkGuard(
{
...params,
policy: WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY,
},
run,
);
}
export async function withStrictWebToolsEndpoint<T>(
params: WebToolEndpointFetchOptions,
run: (result: { response: Response; finalUrl: string }) => Promise<T>,
): Promise<T> {
return await withWebToolsNetworkGuard(params, run);
}

View File

@@ -0,0 +1,22 @@
import { withStrictWebToolsEndpoint } from "./web-guarded-fetch.js";
const REDIRECT_TIMEOUT_MS = 5000;
/**
* Resolve a citation redirect URL to its final destination using a HEAD request.
* Returns the original URL if resolution fails or times out.
*/
export async function resolveCitationRedirectUrl(url: string): Promise<string> {
try {
return await withStrictWebToolsEndpoint(
{
url,
init: { method: "HEAD" },
timeoutMs: REDIRECT_TIMEOUT_MS,
},
async ({ finalUrl }) => finalUrl || url,
);
} catch {
return url;
}
}

View File

@@ -6,10 +6,8 @@ import { wrapWebContent } from "../../security/external-content.js";
import { normalizeSecretInput } from "../../utils/normalize-secret-input.js";
import type { AnyAgentTool } from "./common.js";
import { jsonResult, readNumberParam, readStringParam } from "./common.js";
import {
WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY,
withWebToolsNetworkGuard,
} from "./web-guarded-fetch.js";
import { withTrustedWebToolsEndpoint } from "./web-guarded-fetch.js";
import { resolveCitationRedirectUrl } from "./web-search-citation-redirect.js";
import {
CacheEntry,
DEFAULT_CACHE_TTL_MINUTES,
@@ -609,12 +607,11 @@ async function withTrustedWebSearchEndpoint<T>(
},
run: (response: Response) => Promise<T>,
): Promise<T> {
return withWebToolsNetworkGuard(
return withTrustedWebToolsEndpoint(
{
url: params.url,
init: params.init,
timeoutSeconds: params.timeoutSeconds,
policy: WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY,
},
async ({ response }) => run(response),
);
@@ -696,7 +693,7 @@ async function runGeminiSearch(params: {
const batch = rawCitations.slice(i, i + MAX_CONCURRENT_REDIRECTS);
const resolved = await Promise.all(
batch.map(async (citation) => {
const resolvedUrl = await resolveRedirectUrl(citation.url);
const resolvedUrl = await resolveCitationRedirectUrl(citation.url);
return { ...citation, url: resolvedUrl };
}),
);
@@ -708,27 +705,6 @@ async function runGeminiSearch(params: {
);
}
const REDIRECT_TIMEOUT_MS = 5000;
/**
* Resolve a redirect URL to its final destination using a HEAD request.
* Returns the original URL if resolution fails or times out.
*/
async function resolveRedirectUrl(url: string): Promise<string> {
try {
return await withWebToolsNetworkGuard(
{
url,
init: { method: "HEAD" },
timeoutMs: REDIRECT_TIMEOUT_MS,
},
async ({ finalUrl }) => finalUrl || url,
);
} catch {
return url;
}
}
function resolveSearchCount(value: unknown, fallback: number): number {
const parsed = typeof value === "number" && Number.isFinite(value) ? value : fallback;
const clamped = Math.max(1, Math.min(MAX_SEARCH_COUNT, Math.floor(parsed)));
@@ -1492,5 +1468,5 @@ export const __testing = {
resolveKimiModel,
resolveKimiBaseUrl,
extractKimiCitations,
resolveRedirectUrl,
resolveRedirectUrl: resolveCitationRedirectUrl,
} as const;

View File

@@ -5,17 +5,21 @@ export type DiscordGatewayHandle = {
disconnect?: () => void;
};
export function getDiscordGatewayEmitter(gateway?: unknown): EventEmitter | undefined {
return (gateway as { emitter?: EventEmitter } | undefined)?.emitter;
}
export async function waitForDiscordGatewayStop(params: {
export type WaitForDiscordGatewayStopParams = {
gateway?: DiscordGatewayHandle;
abortSignal?: AbortSignal;
onGatewayError?: (err: unknown) => void;
shouldStopOnError?: (err: unknown) => boolean;
registerForceStop?: (forceStop: (err: unknown) => void) => void;
}): Promise<void> {
};
export function getDiscordGatewayEmitter(gateway?: unknown): EventEmitter | undefined {
return (gateway as { emitter?: EventEmitter } | undefined)?.emitter;
}
export async function waitForDiscordGatewayStop(
params: WaitForDiscordGatewayStopParams,
): Promise<void> {
const { gateway, abortSignal, onGatewayError, shouldStopOnError } = params;
const emitter = gateway?.emitter;
return await new Promise<void>((resolve, reject) => {

View File

@@ -2,9 +2,7 @@ import { EventEmitter } from "node:events";
import type { Client } from "@buape/carbon";
import { beforeEach, describe, expect, it, vi } from "vitest";
import type { RuntimeEnv } from "../../runtime.js";
import type { waitForDiscordGatewayStop } from "../monitor.gateway.js";
type WaitForDiscordGatewayStopParams = Parameters<typeof waitForDiscordGatewayStop>[0];
import type { WaitForDiscordGatewayStopParams } from "../monitor.gateway.js";
const {
attachDiscordGatewayLoggingMock,