diff --git a/src/memory/batch-http.test.ts b/src/memory/batch-http.test.ts new file mode 100644 index 000000000..d70cdf292 --- /dev/null +++ b/src/memory/batch-http.test.ts @@ -0,0 +1,78 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { retryAsync } from "../infra/retry.js"; +import { postJsonWithRetry } from "./batch-http.js"; +import { postJson } from "./post-json.js"; + +vi.mock("../infra/retry.js", () => ({ + retryAsync: vi.fn(async (run: () => Promise) => await run()), +})); + +vi.mock("./post-json.js", () => ({ + postJson: vi.fn(), +})); + +describe("postJsonWithRetry", () => { + const retryAsyncMock = vi.mocked(retryAsync); + const postJsonMock = vi.mocked(postJson); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("posts JSON and returns parsed response payload", async () => { + postJsonMock.mockImplementationOnce(async (params) => { + return await params.parse({ ok: true, ids: [1, 2] }); + }); + + const result = await postJsonWithRetry<{ ok: boolean; ids: number[] }>({ + url: "https://memory.example/v1/batch", + headers: { Authorization: "Bearer test" }, + body: { chunks: ["a", "b"] }, + errorPrefix: "memory batch failed", + }); + + expect(result).toEqual({ ok: true, ids: [1, 2] }); + expect(postJsonMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://memory.example/v1/batch", + headers: { Authorization: "Bearer test" }, + body: { chunks: ["a", "b"] }, + errorPrefix: "memory batch failed", + attachStatus: true, + }), + ); + + const retryOptions = retryAsyncMock.mock.calls[0]?.[1] as + | { + attempts: number; + minDelayMs: number; + maxDelayMs: number; + shouldRetry: (err: unknown) => boolean; + } + | undefined; + expect(retryOptions?.attempts).toBe(3); + expect(retryOptions?.minDelayMs).toBe(300); + expect(retryOptions?.maxDelayMs).toBe(2000); + expect(retryOptions?.shouldRetry({ status: 429 })).toBe(true); + expect(retryOptions?.shouldRetry({ status: 503 })).toBe(true); + expect(retryOptions?.shouldRetry({ status: 400 })).toBe(false); + }); + + it("attaches status to non-ok errors", async () => { + postJsonMock.mockRejectedValueOnce( + Object.assign(new Error("memory batch failed: 503 backend down"), { status: 503 }), + ); + + await expect( + postJsonWithRetry({ + url: "https://memory.example/v1/batch", + headers: {}, + body: { chunks: [] }, + errorPrefix: "memory batch failed", + }), + ).rejects.toMatchObject({ + message: expect.stringContaining("memory batch failed: 503 backend down"), + status: 503, + }); + }); +}); diff --git a/src/memory/batch-http.ts b/src/memory/batch-http.ts index de7ad23f4..0610c62e5 100644 --- a/src/memory/batch-http.ts +++ b/src/memory/batch-http.ts @@ -1,6 +1,6 @@ import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { retryAsync } from "../infra/retry.js"; -import { withRemoteHttpResponse } from "./remote-http.js"; +import { postJson } from "./post-json.js"; export async function postJsonWithRetry(params: { url: string; @@ -11,25 +11,14 @@ export async function postJsonWithRetry(params: { }): Promise { return await retryAsync( async () => { - return await withRemoteHttpResponse({ + return await postJson({ url: params.url, + headers: params.headers, ssrfPolicy: params.ssrfPolicy, - init: { - method: "POST", - headers: params.headers, - body: JSON.stringify(params.body), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - const err = new Error(`${params.errorPrefix}: ${res.status} ${text}`) as Error & { - status?: number; - }; - err.status = res.status; - throw err; - } - return (await res.json()) as T; - }, + body: params.body, + errorPrefix: params.errorPrefix, + attachStatus: true, + parse: async (payload) => payload as T, }); }, { diff --git a/src/memory/embeddings-remote-fetch.test.ts b/src/memory/embeddings-remote-fetch.test.ts new file mode 100644 index 000000000..bcef98faf --- /dev/null +++ b/src/memory/embeddings-remote-fetch.test.ts @@ -0,0 +1,53 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import { postJson } from "./post-json.js"; + +vi.mock("./post-json.js", () => ({ + postJson: vi.fn(), +})); + +describe("fetchRemoteEmbeddingVectors", () => { + const postJsonMock = vi.mocked(postJson); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("maps remote embedding response data to vectors", async () => { + postJsonMock.mockImplementationOnce(async (params) => { + return await params.parse({ + data: [{ embedding: [0.1, 0.2] }, {}, { embedding: [0.3] }], + }); + }); + + const vectors = await fetchRemoteEmbeddingVectors({ + url: "https://memory.example/v1/embeddings", + headers: { Authorization: "Bearer test" }, + body: { input: ["one", "two", "three"] }, + errorPrefix: "embedding fetch failed", + }); + + expect(vectors).toEqual([[0.1, 0.2], [], [0.3]]); + expect(postJsonMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://memory.example/v1/embeddings", + headers: { Authorization: "Bearer test" }, + body: { input: ["one", "two", "three"] }, + errorPrefix: "embedding fetch failed", + }), + ); + }); + + it("throws a status-rich error on non-ok responses", async () => { + postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden")); + + await expect( + fetchRemoteEmbeddingVectors({ + url: "https://memory.example/v1/embeddings", + headers: {}, + body: { input: ["one"] }, + errorPrefix: "embedding fetch failed", + }), + ).rejects.toThrow("embedding fetch failed: 403 forbidden"); + }); +}); diff --git a/src/memory/embeddings-remote-fetch.ts b/src/memory/embeddings-remote-fetch.ts index af8f5b33a..538806e8f 100644 --- a/src/memory/embeddings-remote-fetch.ts +++ b/src/memory/embeddings-remote-fetch.ts @@ -1,5 +1,5 @@ import type { SsrFPolicy } from "../infra/net/ssrf.js"; -import { withRemoteHttpResponse } from "./remote-http.js"; +import { postJson } from "./post-json.js"; export async function fetchRemoteEmbeddingVectors(params: { url: string; @@ -8,23 +8,17 @@ export async function fetchRemoteEmbeddingVectors(params: { body: unknown; errorPrefix: string; }): Promise { - return await withRemoteHttpResponse({ + return await postJson({ url: params.url, + headers: params.headers, ssrfPolicy: params.ssrfPolicy, - init: { - method: "POST", - headers: params.headers, - body: JSON.stringify(params.body), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`${params.errorPrefix}: ${res.status} ${text}`); - } - const payload = (await res.json()) as { + body: params.body, + errorPrefix: params.errorPrefix, + parse: (payload) => { + const typedPayload = payload as { data?: Array<{ embedding?: number[] }>; }; - const data = payload.data ?? []; + const data = typedPayload.data ?? []; return data.map((entry) => entry.embedding ?? []); }, }); diff --git a/src/memory/manager.vector-dedupe.test.ts b/src/memory/manager.vector-dedupe.test.ts index 699f6c67e..fcd21a884 100644 --- a/src/memory/manager.vector-dedupe.test.ts +++ b/src/memory/manager.vector-dedupe.test.ts @@ -26,18 +26,27 @@ describe("memory vector dedupe", () => { let indexPath: string; let manager: MemoryIndexManager | null = null; + async function seedMemoryWorkspace(rootDir: string) { + await fs.mkdir(path.join(rootDir, "memory")); + await fs.writeFile(path.join(rootDir, "MEMORY.md"), "Hello memory."); + } + + async function closeManagerIfOpen() { + if (!manager) { + return; + } + await manager.close(); + manager = null; + } + beforeEach(async () => { workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-mem-")); indexPath = path.join(workspaceDir, "index.sqlite"); - await fs.mkdir(path.join(workspaceDir, "memory")); - await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "Hello memory."); + await seedMemoryWorkspace(workspaceDir); }); afterEach(async () => { - if (manager) { - await manager.close(); - manager = null; - } + await closeManagerIfOpen(); await fs.rm(workspaceDir, { recursive: true, force: true }); }); diff --git a/src/memory/post-json.test.ts b/src/memory/post-json.test.ts new file mode 100644 index 000000000..7e1aaf27c --- /dev/null +++ b/src/memory/post-json.test.ts @@ -0,0 +1,53 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { postJson } from "./post-json.js"; +import { withRemoteHttpResponse } from "./remote-http.js"; + +vi.mock("./remote-http.js", () => ({ + withRemoteHttpResponse: vi.fn(), +})); + +describe("postJson", () => { + const remoteHttpMock = vi.mocked(withRemoteHttpResponse); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("parses JSON payload on successful response", async () => { + remoteHttpMock.mockImplementationOnce(async (params) => { + return await params.onResponse( + new Response(JSON.stringify({ data: [{ embedding: [1, 2] }] }), { status: 200 }), + ); + }); + + const result = await postJson({ + url: "https://memory.example/v1/post", + headers: { Authorization: "Bearer test" }, + body: { input: ["x"] }, + errorPrefix: "post failed", + parse: (payload) => payload, + }); + + expect(result).toEqual({ data: [{ embedding: [1, 2] }] }); + }); + + it("attaches status to thrown error when requested", async () => { + remoteHttpMock.mockImplementationOnce(async (params) => { + return await params.onResponse(new Response("bad gateway", { status: 502 })); + }); + + await expect( + postJson({ + url: "https://memory.example/v1/post", + headers: {}, + body: {}, + errorPrefix: "post failed", + attachStatus: true, + parse: () => ({}), + }), + ).rejects.toMatchObject({ + message: expect.stringContaining("post failed: 502 bad gateway"), + status: 502, + }); + }); +}); diff --git a/src/memory/post-json.ts b/src/memory/post-json.ts new file mode 100644 index 000000000..5251fdab4 --- /dev/null +++ b/src/memory/post-json.ts @@ -0,0 +1,35 @@ +import type { SsrFPolicy } from "../infra/net/ssrf.js"; +import { withRemoteHttpResponse } from "./remote-http.js"; + +export async function postJson(params: { + url: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + body: unknown; + errorPrefix: string; + attachStatus?: boolean; + parse: (payload: unknown) => T | Promise; +}): Promise { + return await withRemoteHttpResponse({ + url: params.url, + ssrfPolicy: params.ssrfPolicy, + init: { + method: "POST", + headers: params.headers, + body: JSON.stringify(params.body), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + const err = new Error(`${params.errorPrefix}: ${res.status} ${text}`) as Error & { + status?: number; + }; + if (params.attachStatus) { + err.status = res.status; + } + throw err; + } + return await params.parse(await res.json()); + }, + }); +}