diff --git a/CHANGELOG.md b/CHANGELOG.md index 9074f8c88..66ab42fca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Docs: https://docs.openclaw.ai ### Fixes +- Security: guard remote media fetches with SSRF protections (block private/localhost, DNS pinning). - Plugins: validate plugin/hook install paths and reject traversal-like names. - Telegram: add download timeouts for file fetches. (#6914) Thanks @hclsys. - Telegram: enforce thread specs for DM vs forum sends. (#6833) Thanks @obviyus. diff --git a/src/agents/tools/web-fetch.ts b/src/agents/tools/web-fetch.ts index 229e1e52f..1df9da8c6 100644 --- a/src/agents/tools/web-fetch.ts +++ b/src/agents/tools/web-fetch.ts @@ -1,13 +1,8 @@ -import type { Dispatcher } from "undici"; import { Type } from "@sinclair/typebox"; import type { OpenClawConfig } from "../../config/config.js"; import type { AnyAgentTool } from "./common.js"; -import { - closeDispatcher, - createPinnedDispatcher, - resolvePinnedHostname, - SsrFBlockedError, -} from "../../infra/net/ssrf.js"; +import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js"; +import { SsrFBlockedError } from "../../infra/net/ssrf.js"; import { wrapExternalContent, wrapWebContent } from "../../security/external-content.js"; import { stringEnum } from "../schema/typebox.js"; import { jsonResult, readNumberParam, readStringParam } from "./common.js"; @@ -184,79 +179,6 @@ function looksLikeHtml(value: string): boolean { return head.startsWith(" { - const signal = withTimeout(undefined, params.timeoutSeconds * 1000); - const visited = new Set(); - let currentUrl = params.url; - let redirectCount = 0; - - while (true) { - let parsedUrl: URL; - try { - parsedUrl = new URL(currentUrl); - } catch { - throw new Error("Invalid URL: must be http or https"); - } - if (!["http:", "https:"].includes(parsedUrl.protocol)) { - throw new Error("Invalid URL: must be http or https"); - } - - const pinned = await resolvePinnedHostname(parsedUrl.hostname); - const dispatcher = createPinnedDispatcher(pinned); - let res: Response; - try { - res = await fetch(parsedUrl.toString(), { - method: "GET", - headers: { - Accept: "*/*", - "User-Agent": params.userAgent, - "Accept-Language": "en-US,en;q=0.9", - }, - signal, - redirect: "manual", - dispatcher, - } as RequestInit); - } catch (err) { - await closeDispatcher(dispatcher); - throw err; - } - - if (isRedirectStatus(res.status)) { - const location = res.headers.get("location"); - if (!location) { - await closeDispatcher(dispatcher); - throw new Error(`Redirect missing location header (${res.status})`); - } - redirectCount += 1; - if (redirectCount > params.maxRedirects) { - await closeDispatcher(dispatcher); - throw new Error(`Too many redirects (limit: ${params.maxRedirects})`); - } - const nextUrl = new URL(location, parsedUrl).toString(); - if (visited.has(nextUrl)) { - await closeDispatcher(dispatcher); - throw new Error("Redirect loop detected"); - } - visited.add(nextUrl); - void res.body?.cancel(); - await closeDispatcher(dispatcher); - currentUrl = nextUrl; - continue; - } - - return { response: res, finalUrl: currentUrl, dispatcher }; - } -} - function formatWebFetchErrorDetail(params: { detail: string; contentType?: string | null; @@ -465,18 +387,22 @@ async function runWebFetch(params: { const start = Date.now(); let res: Response; - let dispatcher: Dispatcher | null = null; + let release: (() => Promise) | null = null; let finalUrl = params.url; try { - const result = await fetchWithRedirects({ + const result = await fetchWithSsrFGuard({ url: params.url, maxRedirects: params.maxRedirects, - timeoutSeconds: params.timeoutSeconds, - userAgent: params.userAgent, + timeoutMs: params.timeoutSeconds * 1000, + headers: { + Accept: "*/*", + "User-Agent": params.userAgent, + "Accept-Language": "en-US,en;q=0.9", + }, }); res = result.response; finalUrl = result.finalUrl; - dispatcher = result.dispatcher; + release = result.release; } catch (error) { if (error instanceof SsrFBlockedError) { throw error; @@ -630,7 +556,9 @@ async function runWebFetch(params: { writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs); return payload; } finally { - await closeDispatcher(dispatcher); + if (release) { + await release(); + } } } diff --git a/src/infra/net/fetch-guard.ts b/src/infra/net/fetch-guard.ts new file mode 100644 index 000000000..6e8dc5359 --- /dev/null +++ b/src/infra/net/fetch-guard.ts @@ -0,0 +1,170 @@ +import type { Dispatcher } from "undici"; +import { + closeDispatcher, + createPinnedDispatcher, + resolvePinnedHostname, + resolvePinnedHostnameWithPolicy, + type LookupFn, + type SsrFPolicy, +} from "./ssrf.js"; + +type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise; + +export type GuardedFetchOptions = { + url: string; + fetchImpl?: FetchLike; + method?: string; + headers?: HeadersInit; + maxRedirects?: number; + timeoutMs?: number; + signal?: AbortSignal; + policy?: SsrFPolicy; + lookupFn?: LookupFn; +}; + +export type GuardedFetchResult = { + response: Response; + finalUrl: string; + release: () => Promise; +}; + +const DEFAULT_MAX_REDIRECTS = 3; + +function isRedirectStatus(status: number): boolean { + return status === 301 || status === 302 || status === 303 || status === 307 || status === 308; +} + +function buildAbortSignal(params: { timeoutMs?: number; signal?: AbortSignal }): { + signal?: AbortSignal; + cleanup: () => void; +} { + const { timeoutMs, signal } = params; + if (!timeoutMs && !signal) { + return { signal: undefined, cleanup: () => {} }; + } + + if (!timeoutMs) { + return { signal, cleanup: () => {} }; + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), timeoutMs); + const onAbort = () => controller.abort(); + if (signal) { + if (signal.aborted) { + controller.abort(); + } else { + signal.addEventListener("abort", onAbort, { once: true }); + } + } + + const cleanup = () => { + clearTimeout(timeoutId); + if (signal) { + signal.removeEventListener("abort", onAbort); + } + }; + + return { signal: controller.signal, cleanup }; +} + +export async function fetchWithSsrFGuard(params: GuardedFetchOptions): Promise { + const fetcher: FetchLike | undefined = params.fetchImpl ?? globalThis.fetch; + if (!fetcher) { + throw new Error("fetch is not available"); + } + + const maxRedirects = + typeof params.maxRedirects === "number" && Number.isFinite(params.maxRedirects) + ? Math.max(0, Math.floor(params.maxRedirects)) + : DEFAULT_MAX_REDIRECTS; + + const { signal, cleanup } = buildAbortSignal({ + timeoutMs: params.timeoutMs, + signal: params.signal, + }); + + let released = false; + const release = async (dispatcher?: Dispatcher | null) => { + if (released) { + return; + } + released = true; + cleanup(); + await closeDispatcher(dispatcher ?? undefined); + }; + + const visited = new Set(); + let currentUrl = params.url; + let redirectCount = 0; + + while (true) { + let parsedUrl: URL; + try { + parsedUrl = new URL(currentUrl); + } catch { + await release(); + throw new Error("Invalid URL: must be http or https"); + } + if (!["http:", "https:"].includes(parsedUrl.protocol)) { + await release(); + throw new Error("Invalid URL: must be http or https"); + } + + let dispatcher: Dispatcher | null = null; + try { + const usePolicy = Boolean( + params.policy?.allowPrivateNetwork || params.policy?.allowedHostnames?.length, + ); + const pinned = usePolicy + ? await resolvePinnedHostnameWithPolicy(parsedUrl.hostname, { + lookupFn: params.lookupFn, + policy: params.policy, + }) + : await resolvePinnedHostname(parsedUrl.hostname, params.lookupFn); + dispatcher = createPinnedDispatcher(pinned); + + const init: RequestInit & { dispatcher?: Dispatcher } = { + redirect: "manual", + dispatcher, + ...(params.method ? { method: params.method } : {}), + ...(params.headers ? { headers: params.headers } : {}), + ...(signal ? { signal } : {}), + }; + + const response = await fetcher(parsedUrl.toString(), init); + + if (isRedirectStatus(response.status)) { + const location = response.headers.get("location"); + if (!location) { + await release(dispatcher); + throw new Error(`Redirect missing location header (${response.status})`); + } + redirectCount += 1; + if (redirectCount > maxRedirects) { + await release(dispatcher); + throw new Error(`Too many redirects (limit: ${maxRedirects})`); + } + const nextUrl = new URL(location, parsedUrl).toString(); + if (visited.has(nextUrl)) { + await release(dispatcher); + throw new Error("Redirect loop detected"); + } + visited.add(nextUrl); + void response.body?.cancel(); + await closeDispatcher(dispatcher); + currentUrl = nextUrl; + continue; + } + + return { + response, + finalUrl: currentUrl, + release: async () => release(dispatcher), + }; + } catch (err) { + await release(dispatcher); + throw err; + } + } +} diff --git a/src/infra/net/ssrf.ts b/src/infra/net/ssrf.ts index 026113706..a017bff1e 100644 --- a/src/infra/net/ssrf.ts +++ b/src/infra/net/ssrf.ts @@ -15,7 +15,12 @@ export class SsrFBlockedError extends Error { } } -type LookupFn = typeof dnsLookup; +export type LookupFn = typeof dnsLookup; + +export type SsrFPolicy = { + allowPrivateNetwork?: boolean; + allowedHostnames?: string[]; +}; const PRIVATE_IPV6_PREFIXES = ["fe80:", "fec0:", "fc", "fd"]; const BLOCKED_HOSTNAMES = new Set(["localhost", "metadata.google.internal"]); @@ -28,6 +33,13 @@ function normalizeHostname(hostname: string): string { return normalized; } +function normalizeHostnameSet(values?: string[]): Set { + if (!values || values.length === 0) { + return new Set(); + } + return new Set(values.map((value) => normalizeHostname(value)).filter(Boolean)); +} + function parseIpv4(address: string): number[] | null { const parts = address.split("."); if (parts.length !== 4) { @@ -206,31 +218,40 @@ export type PinnedHostname = { lookup: typeof dnsLookupCb; }; -export async function resolvePinnedHostname( +export async function resolvePinnedHostnameWithPolicy( hostname: string, - lookupFn: LookupFn = dnsLookup, + params: { lookupFn?: LookupFn; policy?: SsrFPolicy } = {}, ): Promise { const normalized = normalizeHostname(hostname); if (!normalized) { throw new Error("Invalid hostname"); } - if (isBlockedHostname(normalized)) { - throw new SsrFBlockedError(`Blocked hostname: ${hostname}`); - } - - if (isPrivateIpAddress(normalized)) { - throw new SsrFBlockedError("Blocked: private/internal IP address"); + const allowPrivateNetwork = Boolean(params.policy?.allowPrivateNetwork); + const allowedHostnames = normalizeHostnameSet(params.policy?.allowedHostnames); + const isExplicitAllowed = allowedHostnames.has(normalized); + + if (!allowPrivateNetwork && !isExplicitAllowed) { + if (isBlockedHostname(normalized)) { + throw new SsrFBlockedError(`Blocked hostname: ${hostname}`); + } + + if (isPrivateIpAddress(normalized)) { + throw new SsrFBlockedError("Blocked: private/internal IP address"); + } } + const lookupFn = params.lookupFn ?? dnsLookup; const results = await lookupFn(normalized, { all: true }); if (results.length === 0) { throw new Error(`Unable to resolve hostname: ${hostname}`); } - for (const entry of results) { - if (isPrivateIpAddress(entry.address)) { - throw new SsrFBlockedError("Blocked: resolves to private/internal IP address"); + if (!allowPrivateNetwork && !isExplicitAllowed) { + for (const entry of results) { + if (isPrivateIpAddress(entry.address)) { + throw new SsrFBlockedError("Blocked: resolves to private/internal IP address"); + } } } @@ -246,6 +267,13 @@ export async function resolvePinnedHostname( }; } +export async function resolvePinnedHostname( + hostname: string, + lookupFn: LookupFn = dnsLookup, +): Promise { + return await resolvePinnedHostnameWithPolicy(hostname, { lookupFn }); +} + export function createPinnedDispatcher(pinned: PinnedHostname): Dispatcher { return new Agent({ connect: { diff --git a/src/media/fetch.test.ts b/src/media/fetch.test.ts index 2af4f4663..d08f67dc9 100644 --- a/src/media/fetch.test.ts +++ b/src/media/fetch.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { fetchRemoteMedia } from "./fetch.js"; function makeStream(chunks: Uint8Array[]) { @@ -14,6 +14,7 @@ function makeStream(chunks: Uint8Array[]) { describe("fetchRemoteMedia", () => { it("rejects when content-length exceeds maxBytes", async () => { + const lookupFn = vi.fn(async () => [{ address: "93.184.216.34", family: 4 }]); const fetchImpl = async () => new Response(makeStream([new Uint8Array([1, 2, 3, 4, 5])]), { status: 200, @@ -25,11 +26,13 @@ describe("fetchRemoteMedia", () => { url: "https://example.com/file.bin", fetchImpl, maxBytes: 4, + lookupFn, }), ).rejects.toThrow("exceeds maxBytes"); }); it("rejects when streamed payload exceeds maxBytes", async () => { + const lookupFn = vi.fn(async () => [{ address: "93.184.216.34", family: 4 }]); const fetchImpl = async () => new Response(makeStream([new Uint8Array([1, 2, 3]), new Uint8Array([4, 5, 6])]), { status: 200, @@ -40,7 +43,20 @@ describe("fetchRemoteMedia", () => { url: "https://example.com/file.bin", fetchImpl, maxBytes: 4, + lookupFn, }), ).rejects.toThrow("exceeds maxBytes"); }); + + it("blocks private IP literals before fetching", async () => { + const fetchImpl = vi.fn(); + await expect( + fetchRemoteMedia({ + url: "http://127.0.0.1/secret.jpg", + fetchImpl, + maxBytes: 1024, + }), + ).rejects.toThrow(/private|internal|blocked/i); + expect(fetchImpl).not.toHaveBeenCalled(); + }); }); diff --git a/src/media/fetch.ts b/src/media/fetch.ts index b47213da7..59a4d0919 100644 --- a/src/media/fetch.ts +++ b/src/media/fetch.ts @@ -1,4 +1,6 @@ import path from "node:path"; +import type { LookupFn, SsrFPolicy } from "../infra/net/ssrf.js"; +import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; import { detectMime, extensionForMime } from "./mime.js"; type FetchMediaResult = { @@ -26,6 +28,9 @@ type FetchMediaOptions = { fetchImpl?: FetchLike; filePathHint?: string; maxBytes?: number; + maxRedirects?: number; + ssrfPolicy?: SsrFPolicy; + lookupFn?: LookupFn; }; function stripQuotes(value: string): string { @@ -73,83 +78,96 @@ async function readErrorBodySnippet(res: Response, maxChars = 200): Promise { - const { url, fetchImpl, filePathHint, maxBytes } = options; - const fetcher: FetchLike | undefined = fetchImpl ?? globalThis.fetch; - if (!fetcher) { - throw new Error("fetch is not available"); - } + const { url, fetchImpl, filePathHint, maxBytes, maxRedirects, ssrfPolicy, lookupFn } = options; let res: Response; + let finalUrl = url; + let release: (() => Promise) | null = null; try { - res = await fetcher(url); + const result = await fetchWithSsrFGuard({ + url, + fetchImpl, + maxRedirects, + policy: ssrfPolicy, + lookupFn, + }); + res = result.response; + finalUrl = result.finalUrl; + release = result.release; } catch (err) { throw new MediaFetchError("fetch_failed", `Failed to fetch media from ${url}: ${String(err)}`); } - if (!res.ok) { - const statusText = res.statusText ? ` ${res.statusText}` : ""; - const redirected = res.url && res.url !== url ? ` (redirected to ${res.url})` : ""; - let detail = `HTTP ${res.status}${statusText}`; - if (!res.body) { - detail = `HTTP ${res.status}${statusText}; empty response body`; - } else { - const snippet = await readErrorBodySnippet(res); - if (snippet) { - detail += `; body: ${snippet}`; + try { + if (!res.ok) { + const statusText = res.statusText ? ` ${res.statusText}` : ""; + const redirected = finalUrl !== url ? ` (redirected to ${finalUrl})` : ""; + let detail = `HTTP ${res.status}${statusText}`; + if (!res.body) { + detail = `HTTP ${res.status}${statusText}; empty response body`; + } else { + const snippet = await readErrorBodySnippet(res); + if (snippet) { + detail += `; body: ${snippet}`; + } } - } - throw new MediaFetchError( - "http_error", - `Failed to fetch media from ${url}${redirected}: ${detail}`, - ); - } - - const contentLength = res.headers.get("content-length"); - if (maxBytes && contentLength) { - const length = Number(contentLength); - if (Number.isFinite(length) && length > maxBytes) { throw new MediaFetchError( - "max_bytes", - `Failed to fetch media from ${url}: content length ${length} exceeds maxBytes ${maxBytes}`, + "http_error", + `Failed to fetch media from ${url}${redirected}: ${detail}`, ); } - } - const buffer = maxBytes - ? await readResponseWithLimit(res, maxBytes) - : Buffer.from(await res.arrayBuffer()); - let fileNameFromUrl: string | undefined; - try { - const parsed = new URL(url); - const base = path.basename(parsed.pathname); - fileNameFromUrl = base || undefined; - } catch { - // ignore parse errors; leave undefined - } + const contentLength = res.headers.get("content-length"); + if (maxBytes && contentLength) { + const length = Number(contentLength); + if (Number.isFinite(length) && length > maxBytes) { + throw new MediaFetchError( + "max_bytes", + `Failed to fetch media from ${url}: content length ${length} exceeds maxBytes ${maxBytes}`, + ); + } + } - const headerFileName = parseContentDispositionFileName(res.headers.get("content-disposition")); - let fileName = - headerFileName || fileNameFromUrl || (filePathHint ? path.basename(filePathHint) : undefined); + const buffer = maxBytes + ? await readResponseWithLimit(res, maxBytes) + : Buffer.from(await res.arrayBuffer()); + let fileNameFromUrl: string | undefined; + try { + const parsed = new URL(finalUrl); + const base = path.basename(parsed.pathname); + fileNameFromUrl = base || undefined; + } catch { + // ignore parse errors; leave undefined + } - const filePathForMime = - headerFileName && path.extname(headerFileName) ? headerFileName : (filePathHint ?? url); - const contentType = await detectMime({ - buffer, - headerMime: res.headers.get("content-type"), - filePath: filePathForMime, - }); - if (fileName && !path.extname(fileName) && contentType) { - const ext = extensionForMime(contentType); - if (ext) { - fileName = `${fileName}${ext}`; + const headerFileName = parseContentDispositionFileName(res.headers.get("content-disposition")); + let fileName = + headerFileName || fileNameFromUrl || (filePathHint ? path.basename(filePathHint) : undefined); + + const filePathForMime = + headerFileName && path.extname(headerFileName) ? headerFileName : (filePathHint ?? finalUrl); + const contentType = await detectMime({ + buffer, + headerMime: res.headers.get("content-type"), + filePath: filePathForMime, + }); + if (fileName && !path.extname(fileName) && contentType) { + const ext = extensionForMime(contentType); + if (ext) { + fileName = `${fileName}${ext}`; + } + } + + return { + buffer, + contentType: contentType ?? undefined, + fileName, + }; + } finally { + if (release) { + await release(); } } - - return { - buffer, - contentType: contentType ?? undefined, - fileName, - }; } async function readResponseWithLimit(res: Response, maxBytes: number): Promise { diff --git a/src/media/input-files.ts b/src/media/input-files.ts index 677bba74f..915bc6b70 100644 --- a/src/media/input-files.ts +++ b/src/media/input-files.ts @@ -1,9 +1,4 @@ -import type { Dispatcher } from "undici"; -import { - closeDispatcher, - createPinnedDispatcher, - resolvePinnedHostname, -} from "../infra/net/ssrf.js"; +import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; import { logWarn } from "../logger.js"; type CanvasModule = typeof import("@napi-rs/canvas"); @@ -112,10 +107,6 @@ export const DEFAULT_INPUT_PDF_MAX_PAGES = 4; export const DEFAULT_INPUT_PDF_MAX_PIXELS = 4_000_000; export const DEFAULT_INPUT_PDF_MIN_TEXT_CHARS = 200; -function isRedirectStatus(status: number): boolean { - return status === 301 || status === 302 || status === 303 || status === 307 || status === 308; -} - export function normalizeMimeType(value: string | undefined): string | undefined { if (!value) { return undefined; @@ -151,72 +142,39 @@ export async function fetchWithGuard(params: { timeoutMs: number; maxRedirects: number; }): Promise { - let currentUrl = params.url; - let redirectCount = 0; - - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), params.timeoutMs); + const { response, release } = await fetchWithSsrFGuard({ + url: params.url, + maxRedirects: params.maxRedirects, + timeoutMs: params.timeoutMs, + headers: { "User-Agent": "OpenClaw-Gateway/1.0" }, + }); try { - while (true) { - const parsedUrl = new URL(currentUrl); - if (!["http:", "https:"].includes(parsedUrl.protocol)) { - throw new Error(`Invalid URL protocol: ${parsedUrl.protocol}. Only HTTP/HTTPS allowed.`); - } - const pinned = await resolvePinnedHostname(parsedUrl.hostname); - const dispatcher = createPinnedDispatcher(pinned); + if (!response.ok) { + throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`); + } - try { - const response = await fetch(parsedUrl, { - signal: controller.signal, - headers: { "User-Agent": "OpenClaw-Gateway/1.0" }, - redirect: "manual", - dispatcher, - } as RequestInit & { dispatcher: Dispatcher }); - - if (isRedirectStatus(response.status)) { - const location = response.headers.get("location"); - if (!location) { - throw new Error(`Redirect missing location header (${response.status})`); - } - redirectCount += 1; - if (redirectCount > params.maxRedirects) { - throw new Error(`Too many redirects (limit: ${params.maxRedirects})`); - } - void response.body?.cancel(); - currentUrl = new URL(location, parsedUrl).toString(); - continue; - } - - if (!response.ok) { - throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`); - } - - const contentLength = response.headers.get("content-length"); - if (contentLength) { - const size = parseInt(contentLength, 10); - if (size > params.maxBytes) { - throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`); - } - } - - const buffer = Buffer.from(await response.arrayBuffer()); - if (buffer.byteLength > params.maxBytes) { - throw new Error( - `Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`, - ); - } - - const contentType = response.headers.get("content-type") || undefined; - const parsed = parseContentType(contentType); - const mimeType = parsed.mimeType ?? "application/octet-stream"; - return { buffer, mimeType, contentType }; - } finally { - await closeDispatcher(dispatcher); + const contentLength = response.headers.get("content-length"); + if (contentLength) { + const size = parseInt(contentLength, 10); + if (size > params.maxBytes) { + throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`); } } + + const buffer = Buffer.from(await response.arrayBuffer()); + if (buffer.byteLength > params.maxBytes) { + throw new Error( + `Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`, + ); + } + + const contentType = response.headers.get("content-type") || undefined; + const parsed = parseContentType(contentType); + const mimeType = parsed.mimeType ?? "application/octet-stream"; + return { buffer, mimeType, contentType }; } finally { - clearTimeout(timeoutId); + await release(); } } diff --git a/src/slack/monitor/media.test.ts b/src/slack/monitor/media.test.ts index b9c604984..5d8565e21 100644 --- a/src/slack/monitor/media.test.ts +++ b/src/slack/monitor/media.test.ts @@ -1,4 +1,5 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import * as ssrf from "../../infra/net/ssrf.js"; // Store original fetch const originalFetch = globalThis.fetch; @@ -171,11 +172,21 @@ describe("resolveSlackMedia", () => { beforeEach(() => { mockFetch = vi.fn(); globalThis.fetch = mockFetch as typeof fetch; + vi.spyOn(ssrf, "resolvePinnedHostname").mockImplementation(async (hostname) => { + const normalized = hostname.trim().toLowerCase().replace(/\.$/, ""); + const addresses = ["93.184.216.34"]; + return { + hostname: normalized, + addresses, + lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }), + }; + }); }); afterEach(() => { globalThis.fetch = originalFetch; vi.resetModules(); + vi.restoreAllMocks(); }); it("prefers url_private_download over url_private", async () => { diff --git a/src/slack/monitor/media.ts b/src/slack/monitor/media.ts index 666fe1f27..31be4f200 100644 --- a/src/slack/monitor/media.ts +++ b/src/slack/monitor/media.ts @@ -44,6 +44,38 @@ function assertSlackFileUrl(rawUrl: string): URL { return parsed; } +function resolveRequestUrl(input: RequestInfo | URL): string { + if (typeof input === "string") { + return input; + } + if (input instanceof URL) { + return input.toString(); + } + if ("url" in input && typeof input.url === "string") { + return input.url; + } + return String(input); +} + +function createSlackMediaFetch(token: string): FetchLike { + let includeAuth = true; + return async (input, init) => { + const url = resolveRequestUrl(input); + const { headers: initHeaders, redirect: _redirect, ...rest } = init ?? {}; + const headers = new Headers(initHeaders); + + if (includeAuth) { + includeAuth = false; + const parsed = assertSlackFileUrl(url); + headers.set("Authorization", `Bearer ${token}`); + return fetch(parsed.href, { ...rest, headers, redirect: "manual" }); + } + + headers.delete("Authorization"); + return fetch(url, { ...rest, headers, redirect: "manual" }); + }; +} + /** * Fetches a URL with Authorization header, handling cross-origin redirects. * Node.js fetch strips Authorization headers on cross-origin redirects for security. @@ -100,13 +132,9 @@ export async function resolveSlackMedia(params: { } try { // Note: fetchRemoteMedia calls fetchImpl(url) with the URL string today and - // handles size limits internally. We ignore init options because - // fetchWithSlackAuth handles redirect/auth behavior specially. - const fetchImpl: FetchLike = (input) => { - const inputUrl = - typeof input === "string" ? input : input instanceof URL ? input.href : input.url; - return fetchWithSlackAuth(inputUrl, params.token); - }; + // handles size limits internally. Provide a fetcher that uses auth once, then lets + // the redirect chain continue without credentials. + const fetchImpl = createSlackMediaFetch(params.token); const fetched = await fetchRemoteMedia({ url, fetchImpl, diff --git a/src/web/media.test.ts b/src/web/media.test.ts index b16e0dff4..ff40ef0c7 100644 --- a/src/web/media.test.ts +++ b/src/web/media.test.ts @@ -2,7 +2,8 @@ import fs from "node:fs/promises"; import os from "node:os"; import path from "node:path"; import sharp from "sharp"; -import { afterEach, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import * as ssrf from "../infra/net/ssrf.js"; import { optimizeImageToPng } from "../media/image-ops.js"; import { loadWebMedia, loadWebMediaRaw, optimizeImageToJpeg } from "./media.js"; @@ -31,9 +32,22 @@ function buildDeterministicBytes(length: number): Buffer { afterEach(async () => { await Promise.all(tmpFiles.map((file) => fs.rm(file, { force: true }))); tmpFiles.length = 0; + vi.restoreAllMocks(); }); describe("web media loading", () => { + beforeEach(() => { + vi.spyOn(ssrf, "resolvePinnedHostname").mockImplementation(async (hostname) => { + const normalized = hostname.trim().toLowerCase().replace(/\.$/, ""); + const addresses = ["93.184.216.34"]; + return { + hostname: normalized, + addresses, + lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }), + }; + }); + }); + it("compresses large local images under the provided cap", async () => { const buffer = await sharp({ create: { diff --git a/src/web/media.ts b/src/web/media.ts index f3f74c779..edc172f35 100644 --- a/src/web/media.ts +++ b/src/web/media.ts @@ -1,6 +1,7 @@ import fs from "node:fs/promises"; import path from "node:path"; import { fileURLToPath } from "node:url"; +import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { logVerbose, shouldLogVerbose } from "../globals.js"; import { type MediaKind, maxBytesForKind, mediaKindFromMime } from "../media/constants.js"; import { fetchRemoteMedia } from "../media/fetch.js"; @@ -23,6 +24,7 @@ export type WebMediaResult = { type WebMediaOptions = { maxBytes?: number; optimizeImages?: boolean; + ssrfPolicy?: SsrFPolicy; }; const HEIC_MIME_RE = /^image\/hei[cf]$/i; @@ -122,7 +124,7 @@ async function loadWebMediaInternal( mediaUrl: string, options: WebMediaOptions = {}, ): Promise { - const { maxBytes, optimizeImages = true } = options; + const { maxBytes, optimizeImages = true, ssrfPolicy } = options; // Use fileURLToPath for proper handling of file:// URLs (handles file://localhost/path, etc.) if (mediaUrl.startsWith("file://")) { try { @@ -209,7 +211,7 @@ async function loadWebMediaInternal( : optimizeImages ? Math.max(maxBytes, defaultFetchCap) : maxBytes; - const fetched = await fetchRemoteMedia({ url: mediaUrl, maxBytes: fetchCap }); + const fetched = await fetchRemoteMedia({ url: mediaUrl, maxBytes: fetchCap, ssrfPolicy }); const { buffer, contentType, fileName } = fetched; const kind = mediaKindFromMime(contentType); return await clampAndFinalize({ buffer, contentType, kind, fileName }); @@ -239,20 +241,27 @@ async function loadWebMediaInternal( }); } -export async function loadWebMedia(mediaUrl: string, maxBytes?: number): Promise { +export async function loadWebMedia( + mediaUrl: string, + maxBytes?: number, + options?: { ssrfPolicy?: SsrFPolicy }, +): Promise { return await loadWebMediaInternal(mediaUrl, { maxBytes, optimizeImages: true, + ssrfPolicy: options?.ssrfPolicy, }); } export async function loadWebMediaRaw( mediaUrl: string, maxBytes?: number, + options?: { ssrfPolicy?: SsrFPolicy }, ): Promise { return await loadWebMediaInternal(mediaUrl, { maxBytes, optimizeImages: false, + ssrfPolicy: options?.ssrfPolicy, }); }