diff --git a/src/gateway/http-endpoint-helpers.test.ts b/src/gateway/http-endpoint-helpers.test.ts new file mode 100644 index 000000000..b359c3a56 --- /dev/null +++ b/src/gateway/http-endpoint-helpers.test.ts @@ -0,0 +1,80 @@ +import type { IncomingMessage, ServerResponse } from "node:http"; +import { describe, expect, it, vi } from "vitest"; +import type { ResolvedGatewayAuth } from "./auth.js"; +import { handleGatewayPostJsonEndpoint } from "./http-endpoint-helpers.js"; + +vi.mock("./http-auth-helpers.js", () => { + return { + authorizeGatewayBearerRequestOrReply: vi.fn(), + }; +}); + +vi.mock("./http-common.js", () => { + return { + readJsonBodyOrError: vi.fn(), + sendMethodNotAllowed: vi.fn(), + }; +}); + +const { authorizeGatewayBearerRequestOrReply } = await import("./http-auth-helpers.js"); +const { readJsonBodyOrError, sendMethodNotAllowed } = await import("./http-common.js"); + +describe("handleGatewayPostJsonEndpoint", () => { + it("returns false when path does not match", async () => { + const result = await handleGatewayPostJsonEndpoint( + { + url: "/nope", + method: "POST", + headers: { host: "localhost" }, + } as unknown as IncomingMessage, + {} as unknown as ServerResponse, + { pathname: "/v1/ok", auth: {} as unknown as ResolvedGatewayAuth, maxBodyBytes: 1 }, + ); + expect(result).toBe(false); + }); + + it("returns undefined and replies when method is not POST", async () => { + const mockedSendMethodNotAllowed = vi.mocked(sendMethodNotAllowed); + mockedSendMethodNotAllowed.mockClear(); + const result = await handleGatewayPostJsonEndpoint( + { + url: "/v1/ok", + method: "GET", + headers: { host: "localhost" }, + } as unknown as IncomingMessage, + {} as unknown as ServerResponse, + { pathname: "/v1/ok", auth: {} as unknown as ResolvedGatewayAuth, maxBodyBytes: 1 }, + ); + expect(result).toBeUndefined(); + expect(mockedSendMethodNotAllowed).toHaveBeenCalledTimes(1); + }); + + it("returns undefined when auth fails", async () => { + vi.mocked(authorizeGatewayBearerRequestOrReply).mockResolvedValue(false); + const result = await handleGatewayPostJsonEndpoint( + { + url: "/v1/ok", + method: "POST", + headers: { host: "localhost" }, + } as unknown as IncomingMessage, + {} as unknown as ServerResponse, + { pathname: "/v1/ok", auth: {} as unknown as ResolvedGatewayAuth, maxBodyBytes: 1 }, + ); + expect(result).toBeUndefined(); + }); + + it("returns body when auth succeeds and JSON parsing succeeds", async () => { + vi.mocked(authorizeGatewayBearerRequestOrReply).mockResolvedValue(true); + vi.mocked(readJsonBodyOrError).mockResolvedValue({ hello: "world" }); + const result = await handleGatewayPostJsonEndpoint( + { + url: "/v1/ok", + method: "POST", + headers: { host: "localhost" }, + } as unknown as IncomingMessage, + {} as unknown as ServerResponse, + { pathname: "/v1/ok", auth: {} as unknown as ResolvedGatewayAuth, maxBodyBytes: 123 }, + ); + expect(result).toEqual({ body: { hello: "world" } }); + }); +}); diff --git a/src/gateway/http-endpoint-helpers.ts b/src/gateway/http-endpoint-helpers.ts new file mode 100644 index 000000000..b04864114 --- /dev/null +++ b/src/gateway/http-endpoint-helpers.ts @@ -0,0 +1,45 @@ +import type { IncomingMessage, ServerResponse } from "node:http"; +import type { AuthRateLimiter } from "./auth-rate-limit.js"; +import type { ResolvedGatewayAuth } from "./auth.js"; +import { authorizeGatewayBearerRequestOrReply } from "./http-auth-helpers.js"; +import { readJsonBodyOrError, sendMethodNotAllowed } from "./http-common.js"; + +export async function handleGatewayPostJsonEndpoint( + req: IncomingMessage, + res: ServerResponse, + opts: { + pathname: string; + auth: ResolvedGatewayAuth; + maxBodyBytes: number; + trustedProxies?: string[]; + rateLimiter?: AuthRateLimiter; + }, +): Promise { + const url = new URL(req.url ?? "/", `http://${req.headers.host || "localhost"}`); + if (url.pathname !== opts.pathname) { + return false; + } + + if (req.method !== "POST") { + sendMethodNotAllowed(res); + return undefined; + } + + const authorized = await authorizeGatewayBearerRequestOrReply({ + req, + res, + auth: opts.auth, + trustedProxies: opts.trustedProxies, + rateLimiter: opts.rateLimiter, + }); + if (!authorized) { + return undefined; + } + + const body = await readJsonBodyOrError(req, res, opts.maxBodyBytes); + if (body === undefined) { + return undefined; + } + + return { body }; +} diff --git a/src/gateway/openai-http.ts b/src/gateway/openai-http.ts index 86b0e2d7d..038dc3540 100644 --- a/src/gateway/openai-http.ts +++ b/src/gateway/openai-http.ts @@ -11,14 +11,8 @@ import { buildAgentMessageFromConversationEntries, type ConversationEntry, } from "./agent-prompt.js"; -import { authorizeGatewayBearerRequestOrReply } from "./http-auth-helpers.js"; -import { - readJsonBodyOrError, - sendJson, - sendMethodNotAllowed, - setSseHeaders, - writeDone, -} from "./http-common.js"; +import { sendJson, setSseHeaders, writeDone } from "./http-common.js"; +import { handleGatewayPostJsonEndpoint } from "./http-endpoint-helpers.js"; import { resolveAgentIdForRequest, resolveSessionKey } from "./http-utils.js"; type OpenAiHttpOptions = { @@ -151,33 +145,21 @@ export async function handleOpenAiHttpRequest( res: ServerResponse, opts: OpenAiHttpOptions, ): Promise { - const url = new URL(req.url ?? "/", `http://${req.headers.host || "localhost"}`); - if (url.pathname !== "/v1/chat/completions") { - return false; - } - - if (req.method !== "POST") { - sendMethodNotAllowed(res); - return true; - } - - const authorized = await authorizeGatewayBearerRequestOrReply({ - req, - res, + const handled = await handleGatewayPostJsonEndpoint(req, res, { + pathname: "/v1/chat/completions", auth: opts.auth, trustedProxies: opts.trustedProxies, rateLimiter: opts.rateLimiter, + maxBodyBytes: opts.maxBodyBytes ?? 1024 * 1024, }); - if (!authorized) { + if (handled === false) { + return false; + } + if (!handled) { return true; } - const body = await readJsonBodyOrError(req, res, opts.maxBodyBytes ?? 1024 * 1024); - if (body === undefined) { - return true; - } - - const payload = coerceRequest(body); + const payload = coerceRequest(handled.body); const stream = Boolean(payload.stream); const model = typeof payload.model === "string" ? payload.model : "openclaw"; const user = typeof payload.user === "string" ? payload.user : undefined; diff --git a/src/gateway/openresponses-http.ts b/src/gateway/openresponses-http.ts index df32f643c..896ae9b18 100644 --- a/src/gateway/openresponses-http.ts +++ b/src/gateway/openresponses-http.ts @@ -40,14 +40,8 @@ import { buildAgentMessageFromConversationEntries, type ConversationEntry, } from "./agent-prompt.js"; -import { authorizeGatewayBearerRequestOrReply } from "./http-auth-helpers.js"; -import { - readJsonBodyOrError, - sendJson, - sendMethodNotAllowed, - setSseHeaders, - writeDone, -} from "./http-common.js"; +import { sendJson, setSseHeaders, writeDone } from "./http-common.js"; +import { handleGatewayPostJsonEndpoint } from "./http-endpoint-helpers.js"; import { resolveAgentIdForRequest, resolveSessionKey } from "./http-utils.js"; import { CreateResponseBodySchema, @@ -319,45 +313,61 @@ function createAssistantOutputItem(params: { }; } +async function runResponsesAgentCommand(params: { + message: string; + images: ImageContent[]; + clientTools: ClientToolDefinition[]; + extraSystemPrompt: string; + streamParams: { maxTokens: number } | undefined; + sessionKey: string; + runId: string; + deps: ReturnType; +}) { + return agentCommand( + { + message: params.message, + images: params.images.length > 0 ? params.images : undefined, + clientTools: params.clientTools.length > 0 ? params.clientTools : undefined, + extraSystemPrompt: params.extraSystemPrompt || undefined, + streamParams: params.streamParams ?? undefined, + sessionKey: params.sessionKey, + runId: params.runId, + deliver: false, + messageChannel: "webchat", + bestEffortDeliver: false, + }, + defaultRuntime, + params.deps, + ); +} + export async function handleOpenResponsesHttpRequest( req: IncomingMessage, res: ServerResponse, opts: OpenResponsesHttpOptions, ): Promise { - const url = new URL(req.url ?? "/", `http://${req.headers.host || "localhost"}`); - if (url.pathname !== "/v1/responses") { - return false; - } - - if (req.method !== "POST") { - sendMethodNotAllowed(res); - return true; - } - - const authorized = await authorizeGatewayBearerRequestOrReply({ - req, - res, - auth: opts.auth, - trustedProxies: opts.trustedProxies, - rateLimiter: opts.rateLimiter, - }); - if (!authorized) { - return true; - } - const limits = resolveResponsesLimits(opts.config); const maxBodyBytes = opts.maxBodyBytes ?? (opts.config?.maxBodyBytes ? limits.maxBodyBytes : Math.max(limits.maxBodyBytes, limits.files.maxBytes * 2, limits.images.maxBytes * 2)); - const body = await readJsonBodyOrError(req, res, maxBodyBytes); - if (body === undefined) { + const handled = await handleGatewayPostJsonEndpoint(req, res, { + pathname: "/v1/responses", + auth: opts.auth, + trustedProxies: opts.trustedProxies, + rateLimiter: opts.rateLimiter, + maxBodyBytes, + }); + if (handled === false) { + return false; + } + if (!handled) { return true; } // Validate request body with Zod - const parseResult = CreateResponseBodySchema.safeParse(body); + const parseResult = CreateResponseBodySchema.safeParse(handled.body); if (!parseResult.success) { const issue = parseResult.error.issues[0]; const message = issue ? `${issue.path.join(".")}: ${issue.message}` : "Invalid request body"; @@ -520,22 +530,16 @@ export async function handleOpenResponsesHttpRequest( if (!stream) { try { - const result = await agentCommand( - { - message: prompt.message, - images: images.length > 0 ? images : undefined, - clientTools: resolvedClientTools.length > 0 ? resolvedClientTools : undefined, - extraSystemPrompt: extraSystemPrompt || undefined, - streamParams: streamParams ?? undefined, - sessionKey, - runId: responseId, - deliver: false, - messageChannel: "webchat", - bestEffortDeliver: false, - }, - defaultRuntime, + const result = await runResponsesAgentCommand({ + message: prompt.message, + images, + clientTools: resolvedClientTools, + extraSystemPrompt, + streamParams, + sessionKey, + runId: responseId, deps, - ); + }); const payloads = (result as { payloads?: Array<{ text?: string }> } | null)?.payloads; const usage = extractUsageFromResult(result); @@ -760,22 +764,16 @@ export async function handleOpenResponsesHttpRequest( void (async () => { try { - const result = await agentCommand( - { - message: prompt.message, - images: images.length > 0 ? images : undefined, - clientTools: resolvedClientTools.length > 0 ? resolvedClientTools : undefined, - extraSystemPrompt: extraSystemPrompt || undefined, - streamParams: streamParams ?? undefined, - sessionKey, - runId: responseId, - deliver: false, - messageChannel: "webchat", - bestEffortDeliver: false, - }, - defaultRuntime, + const result = await runResponsesAgentCommand({ + message: prompt.message, + images, + clientTools: resolvedClientTools, + extraSystemPrompt, + streamParams, + sessionKey, + runId: responseId, deps, - ); + }); finalUsage = extractUsageFromResult(result); maybeFinalize();