From e7cd4bf1bd45ea829001141a51662ef9b2cd287f Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Mon, 2 Mar 2026 01:13:57 +0000 Subject: [PATCH] refactor(web): split trusted and strict web tool fetch paths --- src/agents/tools/web-guarded-fetch.ts | 25 ++++++++++++-- .../tools/web-search-citation-redirect.ts | 22 ++++++++++++ src/agents/tools/web-search.ts | 34 +++---------------- src/discord/monitor.gateway.ts | 16 +++++---- .../monitor/provider.lifecycle.test.ts | 4 +-- 5 files changed, 61 insertions(+), 40 deletions(-) create mode 100644 src/agents/tools/web-search-citation-redirect.ts diff --git a/src/agents/tools/web-guarded-fetch.ts b/src/agents/tools/web-guarded-fetch.ts index 02b69cd1f..2f905a215 100644 --- a/src/agents/tools/web-guarded-fetch.ts +++ b/src/agents/tools/web-guarded-fetch.ts @@ -5,13 +5,14 @@ import { } from "../../infra/net/fetch-guard.js"; import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -export const WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY: SsrFPolicy = { +const WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY: SsrFPolicy = { dangerouslyAllowPrivateNetwork: true, }; type WebToolGuardedFetchOptions = Omit & { timeoutSeconds?: number; }; +type WebToolEndpointFetchOptions = Omit; function resolveTimeoutMs(params: { timeoutMs?: number; @@ -37,7 +38,7 @@ export async function fetchWithWebToolsNetworkGuard( }); } -export async function withWebToolsNetworkGuard( +async function withWebToolsNetworkGuard( params: WebToolGuardedFetchOptions, run: (result: { response: Response; finalUrl: string }) => Promise, ): Promise { @@ -48,3 +49,23 @@ export async function withWebToolsNetworkGuard( await release(); } } + +export async function withTrustedWebToolsEndpoint( + params: WebToolEndpointFetchOptions, + run: (result: { response: Response; finalUrl: string }) => Promise, +): Promise { + return await withWebToolsNetworkGuard( + { + ...params, + policy: WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY, + }, + run, + ); +} + +export async function withStrictWebToolsEndpoint( + params: WebToolEndpointFetchOptions, + run: (result: { response: Response; finalUrl: string }) => Promise, +): Promise { + return await withWebToolsNetworkGuard(params, run); +} diff --git a/src/agents/tools/web-search-citation-redirect.ts b/src/agents/tools/web-search-citation-redirect.ts new file mode 100644 index 000000000..424fb769e --- /dev/null +++ b/src/agents/tools/web-search-citation-redirect.ts @@ -0,0 +1,22 @@ +import { withStrictWebToolsEndpoint } from "./web-guarded-fetch.js"; + +const REDIRECT_TIMEOUT_MS = 5000; + +/** + * Resolve a citation redirect URL to its final destination using a HEAD request. + * Returns the original URL if resolution fails or times out. + */ +export async function resolveCitationRedirectUrl(url: string): Promise { + try { + return await withStrictWebToolsEndpoint( + { + url, + init: { method: "HEAD" }, + timeoutMs: REDIRECT_TIMEOUT_MS, + }, + async ({ finalUrl }) => finalUrl || url, + ); + } catch { + return url; + } +} diff --git a/src/agents/tools/web-search.ts b/src/agents/tools/web-search.ts index 8456bf9d4..da2f07960 100644 --- a/src/agents/tools/web-search.ts +++ b/src/agents/tools/web-search.ts @@ -6,10 +6,8 @@ import { wrapWebContent } from "../../security/external-content.js"; import { normalizeSecretInput } from "../../utils/normalize-secret-input.js"; import type { AnyAgentTool } from "./common.js"; import { jsonResult, readNumberParam, readStringParam } from "./common.js"; -import { - WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY, - withWebToolsNetworkGuard, -} from "./web-guarded-fetch.js"; +import { withTrustedWebToolsEndpoint } from "./web-guarded-fetch.js"; +import { resolveCitationRedirectUrl } from "./web-search-citation-redirect.js"; import { CacheEntry, DEFAULT_CACHE_TTL_MINUTES, @@ -609,12 +607,11 @@ async function withTrustedWebSearchEndpoint( }, run: (response: Response) => Promise, ): Promise { - return withWebToolsNetworkGuard( + return withTrustedWebToolsEndpoint( { url: params.url, init: params.init, timeoutSeconds: params.timeoutSeconds, - policy: WEB_TOOLS_TRUSTED_NETWORK_SSRF_POLICY, }, async ({ response }) => run(response), ); @@ -696,7 +693,7 @@ async function runGeminiSearch(params: { const batch = rawCitations.slice(i, i + MAX_CONCURRENT_REDIRECTS); const resolved = await Promise.all( batch.map(async (citation) => { - const resolvedUrl = await resolveRedirectUrl(citation.url); + const resolvedUrl = await resolveCitationRedirectUrl(citation.url); return { ...citation, url: resolvedUrl }; }), ); @@ -708,27 +705,6 @@ async function runGeminiSearch(params: { ); } -const REDIRECT_TIMEOUT_MS = 5000; - -/** - * Resolve a redirect URL to its final destination using a HEAD request. - * Returns the original URL if resolution fails or times out. - */ -async function resolveRedirectUrl(url: string): Promise { - try { - return await withWebToolsNetworkGuard( - { - url, - init: { method: "HEAD" }, - timeoutMs: REDIRECT_TIMEOUT_MS, - }, - async ({ finalUrl }) => finalUrl || url, - ); - } catch { - return url; - } -} - function resolveSearchCount(value: unknown, fallback: number): number { const parsed = typeof value === "number" && Number.isFinite(value) ? value : fallback; const clamped = Math.max(1, Math.min(MAX_SEARCH_COUNT, Math.floor(parsed))); @@ -1492,5 +1468,5 @@ export const __testing = { resolveKimiModel, resolveKimiBaseUrl, extractKimiCitations, - resolveRedirectUrl, + resolveRedirectUrl: resolveCitationRedirectUrl, } as const; diff --git a/src/discord/monitor.gateway.ts b/src/discord/monitor.gateway.ts index 624153cad..d8e83a12a 100644 --- a/src/discord/monitor.gateway.ts +++ b/src/discord/monitor.gateway.ts @@ -5,17 +5,21 @@ export type DiscordGatewayHandle = { disconnect?: () => void; }; -export function getDiscordGatewayEmitter(gateway?: unknown): EventEmitter | undefined { - return (gateway as { emitter?: EventEmitter } | undefined)?.emitter; -} - -export async function waitForDiscordGatewayStop(params: { +export type WaitForDiscordGatewayStopParams = { gateway?: DiscordGatewayHandle; abortSignal?: AbortSignal; onGatewayError?: (err: unknown) => void; shouldStopOnError?: (err: unknown) => boolean; registerForceStop?: (forceStop: (err: unknown) => void) => void; -}): Promise { +}; + +export function getDiscordGatewayEmitter(gateway?: unknown): EventEmitter | undefined { + return (gateway as { emitter?: EventEmitter } | undefined)?.emitter; +} + +export async function waitForDiscordGatewayStop( + params: WaitForDiscordGatewayStopParams, +): Promise { const { gateway, abortSignal, onGatewayError, shouldStopOnError } = params; const emitter = gateway?.emitter; return await new Promise((resolve, reject) => { diff --git a/src/discord/monitor/provider.lifecycle.test.ts b/src/discord/monitor/provider.lifecycle.test.ts index a4387144c..da4a06d5b 100644 --- a/src/discord/monitor/provider.lifecycle.test.ts +++ b/src/discord/monitor/provider.lifecycle.test.ts @@ -2,9 +2,7 @@ import { EventEmitter } from "node:events"; import type { Client } from "@buape/carbon"; import { beforeEach, describe, expect, it, vi } from "vitest"; import type { RuntimeEnv } from "../../runtime.js"; -import type { waitForDiscordGatewayStop } from "../monitor.gateway.js"; - -type WaitForDiscordGatewayStopParams = Parameters[0]; +import type { WaitForDiscordGatewayStopParams } from "../monitor.gateway.js"; const { attachDiscordGatewayLoggingMock,