64 lines
1.9 KiB
TypeScript
64 lines
1.9 KiB
TypeScript
import type { SsrFPolicy } from "../infra/net/ssrf.js";
|
|
import {
|
|
resolveRemoteEmbeddingBearerClient,
|
|
type RemoteEmbeddingProviderId,
|
|
} from "./embeddings-remote-client.js";
|
|
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
|
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
|
|
|
export type RemoteEmbeddingClient = {
|
|
baseUrl: string;
|
|
headers: Record<string, string>;
|
|
ssrfPolicy?: SsrFPolicy;
|
|
model: string;
|
|
};
|
|
|
|
export function createRemoteEmbeddingProvider(params: {
|
|
id: string;
|
|
client: RemoteEmbeddingClient;
|
|
errorPrefix: string;
|
|
maxInputTokens?: number;
|
|
}): EmbeddingProvider {
|
|
const { client } = params;
|
|
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
|
|
|
const embed = async (input: string[]): Promise<number[][]> => {
|
|
if (input.length === 0) {
|
|
return [];
|
|
}
|
|
return await fetchRemoteEmbeddingVectors({
|
|
url,
|
|
headers: client.headers,
|
|
ssrfPolicy: client.ssrfPolicy,
|
|
body: { model: client.model, input },
|
|
errorPrefix: params.errorPrefix,
|
|
});
|
|
};
|
|
|
|
return {
|
|
id: params.id,
|
|
model: client.model,
|
|
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
|
|
embedQuery: async (text) => {
|
|
const [vec] = await embed([text]);
|
|
return vec ?? [];
|
|
},
|
|
embedBatch: embed,
|
|
};
|
|
}
|
|
|
|
export async function resolveRemoteEmbeddingClient(params: {
|
|
provider: RemoteEmbeddingProviderId;
|
|
options: EmbeddingProviderOptions;
|
|
defaultBaseUrl: string;
|
|
normalizeModel: (model: string) => string;
|
|
}): Promise<RemoteEmbeddingClient> {
|
|
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
|
provider: params.provider,
|
|
options: params.options,
|
|
defaultBaseUrl: params.defaultBaseUrl,
|
|
});
|
|
const model = params.normalizeModel(params.options.model);
|
|
return { baseUrl, headers, ssrfPolicy, model };
|
|
}
|