refactor(memory): reuse batch utils in gemini
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
import { isTruthyEnvValue } from "../infra/env.js";
|
||||
import { createSubsystemLogger } from "../logging/subsystem.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl, splitBatchRequests } from "./batch-utils.js";
|
||||
import { hashText, runWithConcurrency } from "./internal.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
@@ -45,26 +46,6 @@ const debugLog = (message: string, meta?: Record<string, unknown>) => {
|
||||
log.raw(`${message}${suffix}`);
|
||||
};
|
||||
|
||||
function getGeminiBaseUrl(gemini: GeminiEmbeddingClient): string {
|
||||
return gemini.baseUrl?.replace(/\/$/, "") ?? "";
|
||||
}
|
||||
|
||||
function getGeminiHeaders(
|
||||
gemini: GeminiEmbeddingClient,
|
||||
params: { json: boolean },
|
||||
): Record<string, string> {
|
||||
const headers = gemini.headers ? { ...gemini.headers } : {};
|
||||
if (params.json) {
|
||||
if (!headers["Content-Type"] && !headers["content-type"]) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
}
|
||||
} else {
|
||||
delete headers["Content-Type"];
|
||||
delete headers["content-type"];
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
@@ -72,17 +53,6 @@ function getGeminiUploadUrl(baseUrl: string): string {
|
||||
return `${baseUrl.replace(/\/$/, "")}/upload`;
|
||||
}
|
||||
|
||||
function splitGeminiBatchRequests(requests: GeminiBatchRequest[]): GeminiBatchRequest[][] {
|
||||
if (requests.length <= GEMINI_BATCH_MAX_REQUESTS) {
|
||||
return [requests];
|
||||
}
|
||||
const groups: GeminiBatchRequest[][] = [];
|
||||
for (let i = 0; i < requests.length; i += GEMINI_BATCH_MAX_REQUESTS) {
|
||||
groups.push(requests.slice(i, i + GEMINI_BATCH_MAX_REQUESTS));
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): {
|
||||
body: Blob;
|
||||
contentType: string;
|
||||
@@ -113,7 +83,7 @@ async function submitGeminiBatch(params: {
|
||||
requests: GeminiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = getGeminiBaseUrl(params.gemini);
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const jsonl = params.requests
|
||||
.map((request) =>
|
||||
JSON.stringify({
|
||||
@@ -137,7 +107,7 @@ async function submitGeminiBatch(params: {
|
||||
const fileRes = await fetch(uploadUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
...getGeminiHeaders(params.gemini, { json: false }),
|
||||
...buildBatchHeaders(params.gemini, { json: false }),
|
||||
"Content-Type": uploadPayload.contentType,
|
||||
},
|
||||
body: uploadPayload.body,
|
||||
@@ -168,7 +138,7 @@ async function submitGeminiBatch(params: {
|
||||
});
|
||||
const batchRes = await fetch(batchEndpoint, {
|
||||
method: "POST",
|
||||
headers: getGeminiHeaders(params.gemini, { json: true }),
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
body: JSON.stringify(batchBody),
|
||||
});
|
||||
if (batchRes.ok) {
|
||||
@@ -187,14 +157,14 @@ async function fetchGeminiBatchStatus(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
batchName: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = getGeminiBaseUrl(params.gemini);
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const name = params.batchName.startsWith("batches/")
|
||||
? params.batchName
|
||||
: `batches/${params.batchName}`;
|
||||
const statusUrl = `${baseUrl}/${name}`;
|
||||
debugLog("memory embeddings: gemini batch status", { statusUrl });
|
||||
const res = await fetch(statusUrl, {
|
||||
headers: getGeminiHeaders(params.gemini, { json: true }),
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
@@ -207,12 +177,12 @@ async function fetchGeminiFileContent(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
const baseUrl = getGeminiBaseUrl(params.gemini);
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
|
||||
const downloadUrl = `${baseUrl}/${file}:download`;
|
||||
debugLog("memory embeddings: gemini batch download", { downloadUrl });
|
||||
const res = await fetch(downloadUrl, {
|
||||
headers: getGeminiHeaders(params.gemini, { json: true }),
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
@@ -290,7 +260,7 @@ export async function runGeminiEmbeddingBatches(params: {
|
||||
if (params.requests.length === 0) {
|
||||
return new Map();
|
||||
}
|
||||
const groups = splitGeminiBatchRequests(params.requests);
|
||||
const groups = splitBatchRequests(params.requests, GEMINI_BATCH_MAX_REQUESTS);
|
||||
const byCustomId = new Map<string, number[]>();
|
||||
|
||||
const tasks = groups.map((group, groupIndex) => async () => {
|
||||
|
||||
Reference in New Issue
Block a user