fix(memory): enforce guarded remote policy for embeddings
This commit is contained in:
@@ -4,12 +4,15 @@ import {
|
||||
} from "../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
||||
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
@@ -73,19 +76,26 @@ export async function createGeminiEmbeddingProvider(
|
||||
...authHeaders.headers,
|
||||
...client.headers,
|
||||
};
|
||||
const res = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
const payload = await withRemoteHttpResponse({
|
||||
url: endpoint,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
if (!res.ok) {
|
||||
const payload = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
return payload;
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
@@ -158,6 +168,7 @@ export async function resolveGeminiEmbeddingClient(
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
...headerOverrides,
|
||||
@@ -176,5 +187,5 @@ export async function resolveGeminiEmbeddingClient(
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, model, modelPath, apiKeys };
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys };
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
@@ -40,6 +42,7 @@ export async function createOpenAiEmbeddingProvider(
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body: { model: client.model, input },
|
||||
errorPrefix: "openai embeddings failed",
|
||||
});
|
||||
@@ -63,11 +66,11 @@ export async function createOpenAiEmbeddingProvider(
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "openai",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||
});
|
||||
const model = normalizeOpenAiModel(options.model);
|
||||
return { baseUrl, headers, model };
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
|
||||
type RemoteEmbeddingProviderId = "openai" | "voyage";
|
||||
|
||||
@@ -7,7 +9,7 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
options: EmbeddingProviderOptions;
|
||||
defaultBaseUrl: string;
|
||||
}): Promise<{ baseUrl: string; headers: Record<string, string> }> {
|
||||
}): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
|
||||
const remote = params.options.remote;
|
||||
const remoteApiKey = remote?.apiKey?.trim();
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
@@ -29,5 +31,5 @@ export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...headerOverrides,
|
||||
};
|
||||
return { baseUrl, headers };
|
||||
return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ describe("voyage embedding provider", () => {
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://proxy.example.com",
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
@@ -95,7 +95,7 @@ describe("voyage embedding provider", () => {
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://proxy.example.com/embeddings");
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
@@ -5,6 +6,7 @@ import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.j
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
@@ -48,6 +50,7 @@ export async function createVoyageEmbeddingProvider(
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
@@ -71,11 +74,11 @@ export async function createVoyageEmbeddingProvider(
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers } = await resolveRemoteEmbeddingBearerClient({
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||
});
|
||||
const model = normalizeVoyageModel(options.model);
|
||||
return { baseUrl, headers, model };
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ describe("embedding provider remote overrides", () => {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://provider.example/v1",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: {
|
||||
"X-Provider": "p",
|
||||
"X-Shared": "provider",
|
||||
@@ -107,7 +107,7 @@ describe("embedding provider remote overrides", () => {
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://remote.example/v1",
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: {
|
||||
"X-Shared": "remote",
|
||||
@@ -124,7 +124,7 @@ describe("embedding provider remote overrides", () => {
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const url = fetchMock.mock.calls[0]?.[0];
|
||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||
expect(url).toBe("https://remote.example/v1/embeddings");
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
@@ -142,7 +142,7 @@ describe("embedding provider remote overrides", () => {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://provider.example/v1",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -152,7 +152,7 @@ describe("embedding provider remote overrides", () => {
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://remote.example/v1",
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " ",
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
|
||||
Reference in New Issue
Block a user