fix(memory): enforce guarded remote policy for embeddings

This commit is contained in:
Peter Steinberger
2026-02-22 18:13:44 +01:00
parent f6feb4144c
commit f87db7c627
6 changed files with 45 additions and 26 deletions

View File

@@ -4,12 +4,15 @@ import {
} from "../agents/api-key-rotation.js";
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
import { parseGeminiAuth } from "../infra/gemini-auth.js";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
modelPath: string;
apiKeys: string[];
@@ -73,19 +76,26 @@ export async function createGeminiEmbeddingProvider(
...authHeaders.headers,
...client.headers,
};
const res = await fetch(endpoint, {
method: "POST",
headers,
body: JSON.stringify(body),
const payload = await withRemoteHttpResponse({
url: endpoint,
ssrfPolicy: client.ssrfPolicy,
init: {
method: "POST",
headers,
body: JSON.stringify(body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
});
if (!res.ok) {
const payload = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
return payload;
};
const embedQuery = async (text: string): Promise<number[]> => {
@@ -158,6 +168,7 @@ export async function resolveGeminiEmbeddingClient(
const providerConfig = options.config.models?.providers?.google;
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
...headerOverrides,
@@ -176,5 +187,5 @@ export async function resolveGeminiEmbeddingClient(
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
});
return { baseUrl, headers, model, modelPath, apiKeys };
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
}

View File

@@ -1,3 +1,4 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
@@ -40,6 +42,7 @@ export async function createOpenAiEmbeddingProvider(
return await fetchRemoteEmbeddingVectors({
url,
headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body: { model: client.model, input },
errorPrefix: "openai embeddings failed",
});
@@ -63,11 +66,11 @@ export async function createOpenAiEmbeddingProvider(
export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "openai",
options,
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
});
const model = normalizeOpenAiModel(options.model);
return { baseUrl, headers, model };
return { baseUrl, headers, ssrfPolicy, model };
}

View File

@@ -1,5 +1,7 @@
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import type { EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
type RemoteEmbeddingProviderId = "openai" | "voyage";
@@ -7,7 +9,7 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
provider: RemoteEmbeddingProviderId;
options: EmbeddingProviderOptions;
defaultBaseUrl: string;
}): Promise<{ baseUrl: string; headers: Record<string, string> }> {
}): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
const remote = params.options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
@@ -29,5 +31,5 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
Authorization: `Bearer ${apiKey}`,
...headerOverrides,
};
return { baseUrl, headers };
return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
}

View File

@@ -84,7 +84,7 @@ describe("voyage embedding provider", () => {
model: "voyage-4-lite",
fallback: "none",
remote: {
baseUrl: "https://proxy.example.com",
baseUrl: "https://example.com",
apiKey: "remote-override-key",
headers: { "X-Custom": "123" },
},
@@ -95,7 +95,7 @@ describe("voyage embedding provider", () => {
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://proxy.example.com/embeddings");
expect(url).toBe("https://example.com/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-override-key");

View File

@@ -1,3 +1,4 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
export type VoyageEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
@@ -48,6 +50,7 @@ export async function createVoyageEmbeddingProvider(
return await fetchRemoteEmbeddingVectors({
url,
headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body,
errorPrefix: "voyage embeddings failed",
});
@@ -71,11 +74,11 @@ export async function createVoyageEmbeddingProvider(
export async function resolveVoyageEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<VoyageEmbeddingClient> {
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "voyage",
options,
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
});
const model = normalizeVoyageModel(options.model);
return { baseUrl, headers, model };
return { baseUrl, headers, ssrfPolicy, model };
}

View File

@@ -93,7 +93,7 @@ describe("embedding provider remote overrides", () => {
models: {
providers: {
openai: {
baseUrl: "https://provider.example/v1",
baseUrl: "https://api.openai.com/v1",
headers: {
"X-Provider": "p",
"X-Shared": "provider",
@@ -107,7 +107,7 @@ describe("embedding provider remote overrides", () => {
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://remote.example/v1",
baseUrl: "https://example.com/v1",
apiKey: " remote-key ",
headers: {
"X-Shared": "remote",
@@ -124,7 +124,7 @@ describe("embedding provider remote overrides", () => {
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
expect(url).toBe("https://remote.example/v1/embeddings");
expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key");
expect(headers["Content-Type"]).toBe("application/json");
@@ -142,7 +142,7 @@ describe("embedding provider remote overrides", () => {
models: {
providers: {
openai: {
baseUrl: "https://provider.example/v1",
baseUrl: "https://api.openai.com/v1",
},
},
},
@@ -152,7 +152,7 @@ describe("embedding provider remote overrides", () => {
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://remote.example/v1",
baseUrl: "https://example.com/v1",
apiKey: " ",
},
model: "text-embedding-3-small",