diff --git a/src/memory/embeddings-gemini.ts b/src/memory/embeddings-gemini.ts index 414ad9075..01e7dbb23 100644 --- a/src/memory/embeddings-gemini.ts +++ b/src/memory/embeddings-gemini.ts @@ -4,12 +4,15 @@ import { } from "../agents/api-key-rotation.js"; import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { parseGeminiAuth } from "../infra/gemini-auth.js"; +import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; +import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; export type GeminiEmbeddingClient = { baseUrl: string; headers: Record; + ssrfPolicy?: SsrFPolicy; model: string; modelPath: string; apiKeys: string[]; @@ -73,19 +76,26 @@ export async function createGeminiEmbeddingProvider( ...authHeaders.headers, ...client.headers, }; - const res = await fetch(endpoint, { - method: "POST", - headers, - body: JSON.stringify(body), + const payload = await withRemoteHttpResponse({ + url: endpoint, + ssrfPolicy: client.ssrfPolicy, + init: { + method: "POST", + headers, + body: JSON.stringify(body), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini embeddings failed: ${res.status} ${text}`); + } + return (await res.json()) as { + embedding?: { values?: number[] }; + embeddings?: Array<{ values?: number[] }>; + }; + }, }); - if (!res.ok) { - const payload = await res.text(); - throw new Error(`gemini embeddings failed: ${res.status} ${payload}`); - } - return (await res.json()) as { - embedding?: { values?: number[] }; - embeddings?: Array<{ values?: number[] }>; - }; + return payload; }; const embedQuery = async (text: string): Promise => { @@ -158,6 +168,7 @@ export async function resolveGeminiEmbeddingClient( const providerConfig = options.config.models?.providers?.google; const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL; const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl); + const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl); const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); const headers: Record = { ...headerOverrides, @@ -176,5 +187,5 @@ export async function resolveGeminiEmbeddingClient( embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, }); - return { baseUrl, headers, model, modelPath, apiKeys }; + return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys }; } diff --git a/src/memory/embeddings-openai.ts b/src/memory/embeddings-openai.ts index b319fbcd2..02b92e68f 100644 --- a/src/memory/embeddings-openai.ts +++ b/src/memory/embeddings-openai.ts @@ -1,3 +1,4 @@ +import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; @@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j export type OpenAiEmbeddingClient = { baseUrl: string; headers: Record; + ssrfPolicy?: SsrFPolicy; model: string; }; @@ -40,6 +42,7 @@ export async function createOpenAiEmbeddingProvider( return await fetchRemoteEmbeddingVectors({ url, headers: client.headers, + ssrfPolicy: client.ssrfPolicy, body: { model: client.model, input }, errorPrefix: "openai embeddings failed", }); @@ -63,11 +66,11 @@ export async function createOpenAiEmbeddingProvider( export async function resolveOpenAiEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { - const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({ + const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ provider: "openai", options, defaultBaseUrl: DEFAULT_OPENAI_BASE_URL, }); const model = normalizeOpenAiModel(options.model); - return { baseUrl, headers, model }; + return { baseUrl, headers, ssrfPolicy, model }; } diff --git a/src/memory/embeddings-remote-client.ts b/src/memory/embeddings-remote-client.ts index dc99717e7..c3ec1106b 100644 --- a/src/memory/embeddings-remote-client.ts +++ b/src/memory/embeddings-remote-client.ts @@ -1,5 +1,7 @@ import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; +import type { SsrFPolicy } from "../infra/net/ssrf.js"; import type { EmbeddingProviderOptions } from "./embeddings.js"; +import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; type RemoteEmbeddingProviderId = "openai" | "voyage"; @@ -7,7 +9,7 @@ export async function resolveRemoteEmbeddingBearerClient(params: { provider: RemoteEmbeddingProviderId; options: EmbeddingProviderOptions; defaultBaseUrl: string; -}): Promise<{ baseUrl: string; headers: Record }> { +}): Promise<{ baseUrl: string; headers: Record; ssrfPolicy?: SsrFPolicy }> { const remote = params.options.remote; const remoteApiKey = remote?.apiKey?.trim(); const remoteBaseUrl = remote?.baseUrl?.trim(); @@ -29,5 +31,5 @@ export async function resolveRemoteEmbeddingBearerClient(params: { Authorization: `Bearer ${apiKey}`, ...headerOverrides, }; - return { baseUrl, headers }; + return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) }; } diff --git a/src/memory/embeddings-voyage.test.ts b/src/memory/embeddings-voyage.test.ts index 08e59474a..4851d3743 100644 --- a/src/memory/embeddings-voyage.test.ts +++ b/src/memory/embeddings-voyage.test.ts @@ -84,7 +84,7 @@ describe("voyage embedding provider", () => { model: "voyage-4-lite", fallback: "none", remote: { - baseUrl: "https://proxy.example.com", + baseUrl: "https://example.com", apiKey: "remote-override-key", headers: { "X-Custom": "123" }, }, @@ -95,7 +95,7 @@ describe("voyage embedding provider", () => { const call = fetchMock.mock.calls[0]; expect(call).toBeDefined(); const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; - expect(url).toBe("https://proxy.example.com/embeddings"); + expect(url).toBe("https://example.com/embeddings"); const headers = (init?.headers ?? {}) as Record; expect(headers.Authorization).toBe("Bearer remote-override-key"); diff --git a/src/memory/embeddings-voyage.ts b/src/memory/embeddings-voyage.ts index faf82c5f1..faf9fe1c8 100644 --- a/src/memory/embeddings-voyage.ts +++ b/src/memory/embeddings-voyage.ts @@ -1,3 +1,4 @@ +import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; @@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j export type VoyageEmbeddingClient = { baseUrl: string; headers: Record; + ssrfPolicy?: SsrFPolicy; model: string; }; @@ -48,6 +50,7 @@ export async function createVoyageEmbeddingProvider( return await fetchRemoteEmbeddingVectors({ url, headers: client.headers, + ssrfPolicy: client.ssrfPolicy, body, errorPrefix: "voyage embeddings failed", }); @@ -71,11 +74,11 @@ export async function createVoyageEmbeddingProvider( export async function resolveVoyageEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { - const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({ + const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ provider: "voyage", options, defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL, }); const model = normalizeVoyageModel(options.model); - return { baseUrl, headers, model }; + return { baseUrl, headers, ssrfPolicy, model }; } diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 5c60415d4..8a327da3a 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -93,7 +93,7 @@ describe("embedding provider remote overrides", () => { models: { providers: { openai: { - baseUrl: "https://provider.example/v1", + baseUrl: "https://api.openai.com/v1", headers: { "X-Provider": "p", "X-Shared": "provider", @@ -107,7 +107,7 @@ describe("embedding provider remote overrides", () => { config: cfg as never, provider: "openai", remote: { - baseUrl: "https://remote.example/v1", + baseUrl: "https://example.com/v1", apiKey: " remote-key ", headers: { "X-Shared": "remote", @@ -124,7 +124,7 @@ describe("embedding provider remote overrides", () => { expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); const url = fetchMock.mock.calls[0]?.[0]; const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; - expect(url).toBe("https://remote.example/v1/embeddings"); + expect(url).toBe("https://example.com/v1/embeddings"); const headers = (init?.headers ?? {}) as Record; expect(headers.Authorization).toBe("Bearer remote-key"); expect(headers["Content-Type"]).toBe("application/json"); @@ -142,7 +142,7 @@ describe("embedding provider remote overrides", () => { models: { providers: { openai: { - baseUrl: "https://provider.example/v1", + baseUrl: "https://api.openai.com/v1", }, }, }, @@ -152,7 +152,7 @@ describe("embedding provider remote overrides", () => { config: cfg as never, provider: "openai", remote: { - baseUrl: "https://remote.example/v1", + baseUrl: "https://example.com/v1", apiKey: " ", }, model: "text-embedding-3-small",