fix: guard remote media fetches with SSRF checks

This commit is contained in:
Peter Steinberger
2026-02-02 04:04:27 -08:00
parent d842b28a15
commit 81c68f582d
11 changed files with 422 additions and 241 deletions

View File

@@ -22,6 +22,7 @@ Docs: https://docs.openclaw.ai
### Fixes
- Security: guard remote media fetches with SSRF protections (block private/localhost, DNS pinning).
- Plugins: validate plugin/hook install paths and reject traversal-like names.
- Telegram: add download timeouts for file fetches. (#6914) Thanks @hclsys.
- Telegram: enforce thread specs for DM vs forum sends. (#6833) Thanks @obviyus.

View File

@@ -1,13 +1,8 @@
import type { Dispatcher } from "undici";
import { Type } from "@sinclair/typebox";
import type { OpenClawConfig } from "../../config/config.js";
import type { AnyAgentTool } from "./common.js";
import {
closeDispatcher,
createPinnedDispatcher,
resolvePinnedHostname,
SsrFBlockedError,
} from "../../infra/net/ssrf.js";
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
import { SsrFBlockedError } from "../../infra/net/ssrf.js";
import { wrapExternalContent, wrapWebContent } from "../../security/external-content.js";
import { stringEnum } from "../schema/typebox.js";
import { jsonResult, readNumberParam, readStringParam } from "./common.js";
@@ -184,79 +179,6 @@ function looksLikeHtml(value: string): boolean {
return head.startsWith("<!doctype html") || head.startsWith("<html");
}
function isRedirectStatus(status: number): boolean {
return status === 301 || status === 302 || status === 303 || status === 307 || status === 308;
}
async function fetchWithRedirects(params: {
url: string;
maxRedirects: number;
timeoutSeconds: number;
userAgent: string;
}): Promise<{ response: Response; finalUrl: string; dispatcher: Dispatcher }> {
const signal = withTimeout(undefined, params.timeoutSeconds * 1000);
const visited = new Set<string>();
let currentUrl = params.url;
let redirectCount = 0;
while (true) {
let parsedUrl: URL;
try {
parsedUrl = new URL(currentUrl);
} catch {
throw new Error("Invalid URL: must be http or https");
}
if (!["http:", "https:"].includes(parsedUrl.protocol)) {
throw new Error("Invalid URL: must be http or https");
}
const pinned = await resolvePinnedHostname(parsedUrl.hostname);
const dispatcher = createPinnedDispatcher(pinned);
let res: Response;
try {
res = await fetch(parsedUrl.toString(), {
method: "GET",
headers: {
Accept: "*/*",
"User-Agent": params.userAgent,
"Accept-Language": "en-US,en;q=0.9",
},
signal,
redirect: "manual",
dispatcher,
} as RequestInit);
} catch (err) {
await closeDispatcher(dispatcher);
throw err;
}
if (isRedirectStatus(res.status)) {
const location = res.headers.get("location");
if (!location) {
await closeDispatcher(dispatcher);
throw new Error(`Redirect missing location header (${res.status})`);
}
redirectCount += 1;
if (redirectCount > params.maxRedirects) {
await closeDispatcher(dispatcher);
throw new Error(`Too many redirects (limit: ${params.maxRedirects})`);
}
const nextUrl = new URL(location, parsedUrl).toString();
if (visited.has(nextUrl)) {
await closeDispatcher(dispatcher);
throw new Error("Redirect loop detected");
}
visited.add(nextUrl);
void res.body?.cancel();
await closeDispatcher(dispatcher);
currentUrl = nextUrl;
continue;
}
return { response: res, finalUrl: currentUrl, dispatcher };
}
}
function formatWebFetchErrorDetail(params: {
detail: string;
contentType?: string | null;
@@ -465,18 +387,22 @@ async function runWebFetch(params: {
const start = Date.now();
let res: Response;
let dispatcher: Dispatcher | null = null;
let release: (() => Promise<void>) | null = null;
let finalUrl = params.url;
try {
const result = await fetchWithRedirects({
const result = await fetchWithSsrFGuard({
url: params.url,
maxRedirects: params.maxRedirects,
timeoutSeconds: params.timeoutSeconds,
userAgent: params.userAgent,
timeoutMs: params.timeoutSeconds * 1000,
headers: {
Accept: "*/*",
"User-Agent": params.userAgent,
"Accept-Language": "en-US,en;q=0.9",
},
});
res = result.response;
finalUrl = result.finalUrl;
dispatcher = result.dispatcher;
release = result.release;
} catch (error) {
if (error instanceof SsrFBlockedError) {
throw error;
@@ -630,7 +556,9 @@ async function runWebFetch(params: {
writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs);
return payload;
} finally {
await closeDispatcher(dispatcher);
if (release) {
await release();
}
}
}

View File

@@ -0,0 +1,170 @@
import type { Dispatcher } from "undici";
import {
closeDispatcher,
createPinnedDispatcher,
resolvePinnedHostname,
resolvePinnedHostnameWithPolicy,
type LookupFn,
type SsrFPolicy,
} from "./ssrf.js";
type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>;
export type GuardedFetchOptions = {
url: string;
fetchImpl?: FetchLike;
method?: string;
headers?: HeadersInit;
maxRedirects?: number;
timeoutMs?: number;
signal?: AbortSignal;
policy?: SsrFPolicy;
lookupFn?: LookupFn;
};
export type GuardedFetchResult = {
response: Response;
finalUrl: string;
release: () => Promise<void>;
};
const DEFAULT_MAX_REDIRECTS = 3;
function isRedirectStatus(status: number): boolean {
return status === 301 || status === 302 || status === 303 || status === 307 || status === 308;
}
function buildAbortSignal(params: { timeoutMs?: number; signal?: AbortSignal }): {
signal?: AbortSignal;
cleanup: () => void;
} {
const { timeoutMs, signal } = params;
if (!timeoutMs && !signal) {
return { signal: undefined, cleanup: () => {} };
}
if (!timeoutMs) {
return { signal, cleanup: () => {} };
}
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
const onAbort = () => controller.abort();
if (signal) {
if (signal.aborted) {
controller.abort();
} else {
signal.addEventListener("abort", onAbort, { once: true });
}
}
const cleanup = () => {
clearTimeout(timeoutId);
if (signal) {
signal.removeEventListener("abort", onAbort);
}
};
return { signal: controller.signal, cleanup };
}
export async function fetchWithSsrFGuard(params: GuardedFetchOptions): Promise<GuardedFetchResult> {
const fetcher: FetchLike | undefined = params.fetchImpl ?? globalThis.fetch;
if (!fetcher) {
throw new Error("fetch is not available");
}
const maxRedirects =
typeof params.maxRedirects === "number" && Number.isFinite(params.maxRedirects)
? Math.max(0, Math.floor(params.maxRedirects))
: DEFAULT_MAX_REDIRECTS;
const { signal, cleanup } = buildAbortSignal({
timeoutMs: params.timeoutMs,
signal: params.signal,
});
let released = false;
const release = async (dispatcher?: Dispatcher | null) => {
if (released) {
return;
}
released = true;
cleanup();
await closeDispatcher(dispatcher ?? undefined);
};
const visited = new Set<string>();
let currentUrl = params.url;
let redirectCount = 0;
while (true) {
let parsedUrl: URL;
try {
parsedUrl = new URL(currentUrl);
} catch {
await release();
throw new Error("Invalid URL: must be http or https");
}
if (!["http:", "https:"].includes(parsedUrl.protocol)) {
await release();
throw new Error("Invalid URL: must be http or https");
}
let dispatcher: Dispatcher | null = null;
try {
const usePolicy = Boolean(
params.policy?.allowPrivateNetwork || params.policy?.allowedHostnames?.length,
);
const pinned = usePolicy
? await resolvePinnedHostnameWithPolicy(parsedUrl.hostname, {
lookupFn: params.lookupFn,
policy: params.policy,
})
: await resolvePinnedHostname(parsedUrl.hostname, params.lookupFn);
dispatcher = createPinnedDispatcher(pinned);
const init: RequestInit & { dispatcher?: Dispatcher } = {
redirect: "manual",
dispatcher,
...(params.method ? { method: params.method } : {}),
...(params.headers ? { headers: params.headers } : {}),
...(signal ? { signal } : {}),
};
const response = await fetcher(parsedUrl.toString(), init);
if (isRedirectStatus(response.status)) {
const location = response.headers.get("location");
if (!location) {
await release(dispatcher);
throw new Error(`Redirect missing location header (${response.status})`);
}
redirectCount += 1;
if (redirectCount > maxRedirects) {
await release(dispatcher);
throw new Error(`Too many redirects (limit: ${maxRedirects})`);
}
const nextUrl = new URL(location, parsedUrl).toString();
if (visited.has(nextUrl)) {
await release(dispatcher);
throw new Error("Redirect loop detected");
}
visited.add(nextUrl);
void response.body?.cancel();
await closeDispatcher(dispatcher);
currentUrl = nextUrl;
continue;
}
return {
response,
finalUrl: currentUrl,
release: async () => release(dispatcher),
};
} catch (err) {
await release(dispatcher);
throw err;
}
}
}

View File

@@ -15,7 +15,12 @@ export class SsrFBlockedError extends Error {
}
}
type LookupFn = typeof dnsLookup;
export type LookupFn = typeof dnsLookup;
export type SsrFPolicy = {
allowPrivateNetwork?: boolean;
allowedHostnames?: string[];
};
const PRIVATE_IPV6_PREFIXES = ["fe80:", "fec0:", "fc", "fd"];
const BLOCKED_HOSTNAMES = new Set(["localhost", "metadata.google.internal"]);
@@ -28,6 +33,13 @@ function normalizeHostname(hostname: string): string {
return normalized;
}
function normalizeHostnameSet(values?: string[]): Set<string> {
if (!values || values.length === 0) {
return new Set<string>();
}
return new Set(values.map((value) => normalizeHostname(value)).filter(Boolean));
}
function parseIpv4(address: string): number[] | null {
const parts = address.split(".");
if (parts.length !== 4) {
@@ -206,31 +218,40 @@ export type PinnedHostname = {
lookup: typeof dnsLookupCb;
};
export async function resolvePinnedHostname(
export async function resolvePinnedHostnameWithPolicy(
hostname: string,
lookupFn: LookupFn = dnsLookup,
params: { lookupFn?: LookupFn; policy?: SsrFPolicy } = {},
): Promise<PinnedHostname> {
const normalized = normalizeHostname(hostname);
if (!normalized) {
throw new Error("Invalid hostname");
}
if (isBlockedHostname(normalized)) {
throw new SsrFBlockedError(`Blocked hostname: ${hostname}`);
}
if (isPrivateIpAddress(normalized)) {
throw new SsrFBlockedError("Blocked: private/internal IP address");
const allowPrivateNetwork = Boolean(params.policy?.allowPrivateNetwork);
const allowedHostnames = normalizeHostnameSet(params.policy?.allowedHostnames);
const isExplicitAllowed = allowedHostnames.has(normalized);
if (!allowPrivateNetwork && !isExplicitAllowed) {
if (isBlockedHostname(normalized)) {
throw new SsrFBlockedError(`Blocked hostname: ${hostname}`);
}
if (isPrivateIpAddress(normalized)) {
throw new SsrFBlockedError("Blocked: private/internal IP address");
}
}
const lookupFn = params.lookupFn ?? dnsLookup;
const results = await lookupFn(normalized, { all: true });
if (results.length === 0) {
throw new Error(`Unable to resolve hostname: ${hostname}`);
}
for (const entry of results) {
if (isPrivateIpAddress(entry.address)) {
throw new SsrFBlockedError("Blocked: resolves to private/internal IP address");
if (!allowPrivateNetwork && !isExplicitAllowed) {
for (const entry of results) {
if (isPrivateIpAddress(entry.address)) {
throw new SsrFBlockedError("Blocked: resolves to private/internal IP address");
}
}
}
@@ -246,6 +267,13 @@ export async function resolvePinnedHostname(
};
}
export async function resolvePinnedHostname(
hostname: string,
lookupFn: LookupFn = dnsLookup,
): Promise<PinnedHostname> {
return await resolvePinnedHostnameWithPolicy(hostname, { lookupFn });
}
export function createPinnedDispatcher(pinned: PinnedHostname): Dispatcher {
return new Agent({
connect: {

View File

@@ -1,4 +1,4 @@
import { describe, expect, it } from "vitest";
import { describe, expect, it, vi } from "vitest";
import { fetchRemoteMedia } from "./fetch.js";
function makeStream(chunks: Uint8Array[]) {
@@ -14,6 +14,7 @@ function makeStream(chunks: Uint8Array[]) {
describe("fetchRemoteMedia", () => {
it("rejects when content-length exceeds maxBytes", async () => {
const lookupFn = vi.fn(async () => [{ address: "93.184.216.34", family: 4 }]);
const fetchImpl = async () =>
new Response(makeStream([new Uint8Array([1, 2, 3, 4, 5])]), {
status: 200,
@@ -25,11 +26,13 @@ describe("fetchRemoteMedia", () => {
url: "https://example.com/file.bin",
fetchImpl,
maxBytes: 4,
lookupFn,
}),
).rejects.toThrow("exceeds maxBytes");
});
it("rejects when streamed payload exceeds maxBytes", async () => {
const lookupFn = vi.fn(async () => [{ address: "93.184.216.34", family: 4 }]);
const fetchImpl = async () =>
new Response(makeStream([new Uint8Array([1, 2, 3]), new Uint8Array([4, 5, 6])]), {
status: 200,
@@ -40,7 +43,20 @@ describe("fetchRemoteMedia", () => {
url: "https://example.com/file.bin",
fetchImpl,
maxBytes: 4,
lookupFn,
}),
).rejects.toThrow("exceeds maxBytes");
});
it("blocks private IP literals before fetching", async () => {
const fetchImpl = vi.fn();
await expect(
fetchRemoteMedia({
url: "http://127.0.0.1/secret.jpg",
fetchImpl,
maxBytes: 1024,
}),
).rejects.toThrow(/private|internal|blocked/i);
expect(fetchImpl).not.toHaveBeenCalled();
});
});

View File

@@ -1,4 +1,6 @@
import path from "node:path";
import type { LookupFn, SsrFPolicy } from "../infra/net/ssrf.js";
import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js";
import { detectMime, extensionForMime } from "./mime.js";
type FetchMediaResult = {
@@ -26,6 +28,9 @@ type FetchMediaOptions = {
fetchImpl?: FetchLike;
filePathHint?: string;
maxBytes?: number;
maxRedirects?: number;
ssrfPolicy?: SsrFPolicy;
lookupFn?: LookupFn;
};
function stripQuotes(value: string): string {
@@ -73,83 +78,96 @@ async function readErrorBodySnippet(res: Response, maxChars = 200): Promise<stri
}
export async function fetchRemoteMedia(options: FetchMediaOptions): Promise<FetchMediaResult> {
const { url, fetchImpl, filePathHint, maxBytes } = options;
const fetcher: FetchLike | undefined = fetchImpl ?? globalThis.fetch;
if (!fetcher) {
throw new Error("fetch is not available");
}
const { url, fetchImpl, filePathHint, maxBytes, maxRedirects, ssrfPolicy, lookupFn } = options;
let res: Response;
let finalUrl = url;
let release: (() => Promise<void>) | null = null;
try {
res = await fetcher(url);
const result = await fetchWithSsrFGuard({
url,
fetchImpl,
maxRedirects,
policy: ssrfPolicy,
lookupFn,
});
res = result.response;
finalUrl = result.finalUrl;
release = result.release;
} catch (err) {
throw new MediaFetchError("fetch_failed", `Failed to fetch media from ${url}: ${String(err)}`);
}
if (!res.ok) {
const statusText = res.statusText ? ` ${res.statusText}` : "";
const redirected = res.url && res.url !== url ? ` (redirected to ${res.url})` : "";
let detail = `HTTP ${res.status}${statusText}`;
if (!res.body) {
detail = `HTTP ${res.status}${statusText}; empty response body`;
} else {
const snippet = await readErrorBodySnippet(res);
if (snippet) {
detail += `; body: ${snippet}`;
try {
if (!res.ok) {
const statusText = res.statusText ? ` ${res.statusText}` : "";
const redirected = finalUrl !== url ? ` (redirected to ${finalUrl})` : "";
let detail = `HTTP ${res.status}${statusText}`;
if (!res.body) {
detail = `HTTP ${res.status}${statusText}; empty response body`;
} else {
const snippet = await readErrorBodySnippet(res);
if (snippet) {
detail += `; body: ${snippet}`;
}
}
}
throw new MediaFetchError(
"http_error",
`Failed to fetch media from ${url}${redirected}: ${detail}`,
);
}
const contentLength = res.headers.get("content-length");
if (maxBytes && contentLength) {
const length = Number(contentLength);
if (Number.isFinite(length) && length > maxBytes) {
throw new MediaFetchError(
"max_bytes",
`Failed to fetch media from ${url}: content length ${length} exceeds maxBytes ${maxBytes}`,
"http_error",
`Failed to fetch media from ${url}${redirected}: ${detail}`,
);
}
}
const buffer = maxBytes
? await readResponseWithLimit(res, maxBytes)
: Buffer.from(await res.arrayBuffer());
let fileNameFromUrl: string | undefined;
try {
const parsed = new URL(url);
const base = path.basename(parsed.pathname);
fileNameFromUrl = base || undefined;
} catch {
// ignore parse errors; leave undefined
}
const contentLength = res.headers.get("content-length");
if (maxBytes && contentLength) {
const length = Number(contentLength);
if (Number.isFinite(length) && length > maxBytes) {
throw new MediaFetchError(
"max_bytes",
`Failed to fetch media from ${url}: content length ${length} exceeds maxBytes ${maxBytes}`,
);
}
}
const headerFileName = parseContentDispositionFileName(res.headers.get("content-disposition"));
let fileName =
headerFileName || fileNameFromUrl || (filePathHint ? path.basename(filePathHint) : undefined);
const buffer = maxBytes
? await readResponseWithLimit(res, maxBytes)
: Buffer.from(await res.arrayBuffer());
let fileNameFromUrl: string | undefined;
try {
const parsed = new URL(finalUrl);
const base = path.basename(parsed.pathname);
fileNameFromUrl = base || undefined;
} catch {
// ignore parse errors; leave undefined
}
const filePathForMime =
headerFileName && path.extname(headerFileName) ? headerFileName : (filePathHint ?? url);
const contentType = await detectMime({
buffer,
headerMime: res.headers.get("content-type"),
filePath: filePathForMime,
});
if (fileName && !path.extname(fileName) && contentType) {
const ext = extensionForMime(contentType);
if (ext) {
fileName = `${fileName}${ext}`;
const headerFileName = parseContentDispositionFileName(res.headers.get("content-disposition"));
let fileName =
headerFileName || fileNameFromUrl || (filePathHint ? path.basename(filePathHint) : undefined);
const filePathForMime =
headerFileName && path.extname(headerFileName) ? headerFileName : (filePathHint ?? finalUrl);
const contentType = await detectMime({
buffer,
headerMime: res.headers.get("content-type"),
filePath: filePathForMime,
});
if (fileName && !path.extname(fileName) && contentType) {
const ext = extensionForMime(contentType);
if (ext) {
fileName = `${fileName}${ext}`;
}
}
return {
buffer,
contentType: contentType ?? undefined,
fileName,
};
} finally {
if (release) {
await release();
}
}
return {
buffer,
contentType: contentType ?? undefined,
fileName,
};
}
async function readResponseWithLimit(res: Response, maxBytes: number): Promise<Buffer> {

View File

@@ -1,9 +1,4 @@
import type { Dispatcher } from "undici";
import {
closeDispatcher,
createPinnedDispatcher,
resolvePinnedHostname,
} from "../infra/net/ssrf.js";
import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js";
import { logWarn } from "../logger.js";
type CanvasModule = typeof import("@napi-rs/canvas");
@@ -112,10 +107,6 @@ export const DEFAULT_INPUT_PDF_MAX_PAGES = 4;
export const DEFAULT_INPUT_PDF_MAX_PIXELS = 4_000_000;
export const DEFAULT_INPUT_PDF_MIN_TEXT_CHARS = 200;
function isRedirectStatus(status: number): boolean {
return status === 301 || status === 302 || status === 303 || status === 307 || status === 308;
}
export function normalizeMimeType(value: string | undefined): string | undefined {
if (!value) {
return undefined;
@@ -151,72 +142,39 @@ export async function fetchWithGuard(params: {
timeoutMs: number;
maxRedirects: number;
}): Promise<InputFetchResult> {
let currentUrl = params.url;
let redirectCount = 0;
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), params.timeoutMs);
const { response, release } = await fetchWithSsrFGuard({
url: params.url,
maxRedirects: params.maxRedirects,
timeoutMs: params.timeoutMs,
headers: { "User-Agent": "OpenClaw-Gateway/1.0" },
});
try {
while (true) {
const parsedUrl = new URL(currentUrl);
if (!["http:", "https:"].includes(parsedUrl.protocol)) {
throw new Error(`Invalid URL protocol: ${parsedUrl.protocol}. Only HTTP/HTTPS allowed.`);
}
const pinned = await resolvePinnedHostname(parsedUrl.hostname);
const dispatcher = createPinnedDispatcher(pinned);
if (!response.ok) {
throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`);
}
try {
const response = await fetch(parsedUrl, {
signal: controller.signal,
headers: { "User-Agent": "OpenClaw-Gateway/1.0" },
redirect: "manual",
dispatcher,
} as RequestInit & { dispatcher: Dispatcher });
if (isRedirectStatus(response.status)) {
const location = response.headers.get("location");
if (!location) {
throw new Error(`Redirect missing location header (${response.status})`);
}
redirectCount += 1;
if (redirectCount > params.maxRedirects) {
throw new Error(`Too many redirects (limit: ${params.maxRedirects})`);
}
void response.body?.cancel();
currentUrl = new URL(location, parsedUrl).toString();
continue;
}
if (!response.ok) {
throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`);
}
const contentLength = response.headers.get("content-length");
if (contentLength) {
const size = parseInt(contentLength, 10);
if (size > params.maxBytes) {
throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`);
}
}
const buffer = Buffer.from(await response.arrayBuffer());
if (buffer.byteLength > params.maxBytes) {
throw new Error(
`Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`,
);
}
const contentType = response.headers.get("content-type") || undefined;
const parsed = parseContentType(contentType);
const mimeType = parsed.mimeType ?? "application/octet-stream";
return { buffer, mimeType, contentType };
} finally {
await closeDispatcher(dispatcher);
const contentLength = response.headers.get("content-length");
if (contentLength) {
const size = parseInt(contentLength, 10);
if (size > params.maxBytes) {
throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`);
}
}
const buffer = Buffer.from(await response.arrayBuffer());
if (buffer.byteLength > params.maxBytes) {
throw new Error(
`Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`,
);
}
const contentType = response.headers.get("content-type") || undefined;
const parsed = parseContentType(contentType);
const mimeType = parsed.mimeType ?? "application/octet-stream";
return { buffer, mimeType, contentType };
} finally {
clearTimeout(timeoutId);
await release();
}
}

View File

@@ -1,4 +1,5 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as ssrf from "../../infra/net/ssrf.js";
// Store original fetch
const originalFetch = globalThis.fetch;
@@ -171,11 +172,21 @@ describe("resolveSlackMedia", () => {
beforeEach(() => {
mockFetch = vi.fn();
globalThis.fetch = mockFetch as typeof fetch;
vi.spyOn(ssrf, "resolvePinnedHostname").mockImplementation(async (hostname) => {
const normalized = hostname.trim().toLowerCase().replace(/\.$/, "");
const addresses = ["93.184.216.34"];
return {
hostname: normalized,
addresses,
lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }),
};
});
});
afterEach(() => {
globalThis.fetch = originalFetch;
vi.resetModules();
vi.restoreAllMocks();
});
it("prefers url_private_download over url_private", async () => {

View File

@@ -44,6 +44,38 @@ function assertSlackFileUrl(rawUrl: string): URL {
return parsed;
}
function resolveRequestUrl(input: RequestInfo | URL): string {
if (typeof input === "string") {
return input;
}
if (input instanceof URL) {
return input.toString();
}
if ("url" in input && typeof input.url === "string") {
return input.url;
}
return String(input);
}
function createSlackMediaFetch(token: string): FetchLike {
let includeAuth = true;
return async (input, init) => {
const url = resolveRequestUrl(input);
const { headers: initHeaders, redirect: _redirect, ...rest } = init ?? {};
const headers = new Headers(initHeaders);
if (includeAuth) {
includeAuth = false;
const parsed = assertSlackFileUrl(url);
headers.set("Authorization", `Bearer ${token}`);
return fetch(parsed.href, { ...rest, headers, redirect: "manual" });
}
headers.delete("Authorization");
return fetch(url, { ...rest, headers, redirect: "manual" });
};
}
/**
* Fetches a URL with Authorization header, handling cross-origin redirects.
* Node.js fetch strips Authorization headers on cross-origin redirects for security.
@@ -100,13 +132,9 @@ export async function resolveSlackMedia(params: {
}
try {
// Note: fetchRemoteMedia calls fetchImpl(url) with the URL string today and
// handles size limits internally. We ignore init options because
// fetchWithSlackAuth handles redirect/auth behavior specially.
const fetchImpl: FetchLike = (input) => {
const inputUrl =
typeof input === "string" ? input : input instanceof URL ? input.href : input.url;
return fetchWithSlackAuth(inputUrl, params.token);
};
// handles size limits internally. Provide a fetcher that uses auth once, then lets
// the redirect chain continue without credentials.
const fetchImpl = createSlackMediaFetch(params.token);
const fetched = await fetchRemoteMedia({
url,
fetchImpl,

View File

@@ -2,7 +2,8 @@ import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import sharp from "sharp";
import { afterEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as ssrf from "../infra/net/ssrf.js";
import { optimizeImageToPng } from "../media/image-ops.js";
import { loadWebMedia, loadWebMediaRaw, optimizeImageToJpeg } from "./media.js";
@@ -31,9 +32,22 @@ function buildDeterministicBytes(length: number): Buffer {
afterEach(async () => {
await Promise.all(tmpFiles.map((file) => fs.rm(file, { force: true })));
tmpFiles.length = 0;
vi.restoreAllMocks();
});
describe("web media loading", () => {
beforeEach(() => {
vi.spyOn(ssrf, "resolvePinnedHostname").mockImplementation(async (hostname) => {
const normalized = hostname.trim().toLowerCase().replace(/\.$/, "");
const addresses = ["93.184.216.34"];
return {
hostname: normalized,
addresses,
lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }),
};
});
});
it("compresses large local images under the provided cap", async () => {
const buffer = await sharp({
create: {

View File

@@ -1,6 +1,7 @@
import fs from "node:fs/promises";
import path from "node:path";
import { fileURLToPath } from "node:url";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { logVerbose, shouldLogVerbose } from "../globals.js";
import { type MediaKind, maxBytesForKind, mediaKindFromMime } from "../media/constants.js";
import { fetchRemoteMedia } from "../media/fetch.js";
@@ -23,6 +24,7 @@ export type WebMediaResult = {
type WebMediaOptions = {
maxBytes?: number;
optimizeImages?: boolean;
ssrfPolicy?: SsrFPolicy;
};
const HEIC_MIME_RE = /^image\/hei[cf]$/i;
@@ -122,7 +124,7 @@ async function loadWebMediaInternal(
mediaUrl: string,
options: WebMediaOptions = {},
): Promise<WebMediaResult> {
const { maxBytes, optimizeImages = true } = options;
const { maxBytes, optimizeImages = true, ssrfPolicy } = options;
// Use fileURLToPath for proper handling of file:// URLs (handles file://localhost/path, etc.)
if (mediaUrl.startsWith("file://")) {
try {
@@ -209,7 +211,7 @@ async function loadWebMediaInternal(
: optimizeImages
? Math.max(maxBytes, defaultFetchCap)
: maxBytes;
const fetched = await fetchRemoteMedia({ url: mediaUrl, maxBytes: fetchCap });
const fetched = await fetchRemoteMedia({ url: mediaUrl, maxBytes: fetchCap, ssrfPolicy });
const { buffer, contentType, fileName } = fetched;
const kind = mediaKindFromMime(contentType);
return await clampAndFinalize({ buffer, contentType, kind, fileName });
@@ -239,20 +241,27 @@ async function loadWebMediaInternal(
});
}
export async function loadWebMedia(mediaUrl: string, maxBytes?: number): Promise<WebMediaResult> {
export async function loadWebMedia(
mediaUrl: string,
maxBytes?: number,
options?: { ssrfPolicy?: SsrFPolicy },
): Promise<WebMediaResult> {
return await loadWebMediaInternal(mediaUrl, {
maxBytes,
optimizeImages: true,
ssrfPolicy: options?.ssrfPolicy,
});
}
export async function loadWebMediaRaw(
mediaUrl: string,
maxBytes?: number,
options?: { ssrfPolicy?: SsrFPolicy },
): Promise<WebMediaResult> {
return await loadWebMediaInternal(mediaUrl, {
maxBytes,
optimizeImages: false,
ssrfPolicy: options?.ssrfPolicy,
});
}