diff --git a/src/gateway/server-methods/wizard.ts b/src/gateway/server-methods/wizard.ts index 8585a066c..1fab55822 100644 --- a/src/gateway/server-methods/wizard.ts +++ b/src/gateway/server-methods/wizard.ts @@ -1,5 +1,6 @@ +import type { ErrorObject } from "ajv"; import { randomUUID } from "node:crypto"; -import type { GatewayRequestHandlers } from "./types.js"; +import type { GatewayRequestHandlers, RespondFn } from "./types.js"; import { defaultRuntime } from "../../runtime.js"; import { WizardSession } from "../../wizard/session.js"; import { @@ -13,17 +14,40 @@ import { } from "../protocol/index.js"; import { formatForLog } from "../ws-log.js"; +type Validator = ((params: unknown) => params is T) & { + errors?: ErrorObject[] | null; +}; + +function assertValidParams( + params: unknown, + validate: Validator, + method: string, + respond: RespondFn, +): params is T { + if (validate(params)) { + return true; + } + respond( + false, + undefined, + errorShape( + ErrorCodes.INVALID_REQUEST, + `invalid ${method} params: ${formatValidationErrors(validate.errors)}`, + ), + ); + return false; +} + +function readWizardStatus(session: WizardSession) { + return { + status: session.getStatus(), + error: session.getError(), + }; +} + export const wizardHandlers: GatewayRequestHandlers = { "wizard.start": async ({ params, respond, context }) => { - if (!validateWizardStartParams(params)) { - respond( - false, - undefined, - errorShape( - ErrorCodes.INVALID_REQUEST, - `invalid wizard.start params: ${formatValidationErrors(validateWizardStartParams.errors)}`, - ), - ); + if (!assertValidParams(params, validateWizardStartParams, "wizard.start", respond)) { return; } const running = context.findRunningWizard(); @@ -47,15 +71,7 @@ export const wizardHandlers: GatewayRequestHandlers = { respond(true, { sessionId, ...result }, undefined); }, "wizard.next": async ({ params, respond, context }) => { - if (!validateWizardNextParams(params)) { - respond( - false, - undefined, - errorShape( - ErrorCodes.INVALID_REQUEST, - `invalid wizard.next params: ${formatValidationErrors(validateWizardNextParams.errors)}`, - ), - ); + if (!assertValidParams(params, validateWizardNextParams, "wizard.next", respond)) { return; } const sessionId = params.sessionId; @@ -84,15 +100,7 @@ export const wizardHandlers: GatewayRequestHandlers = { respond(true, result, undefined); }, "wizard.cancel": ({ params, respond, context }) => { - if (!validateWizardCancelParams(params)) { - respond( - false, - undefined, - errorShape( - ErrorCodes.INVALID_REQUEST, - `invalid wizard.cancel params: ${formatValidationErrors(validateWizardCancelParams.errors)}`, - ), - ); + if (!assertValidParams(params, validateWizardCancelParams, "wizard.cancel", respond)) { return; } const sessionId = params.sessionId; @@ -102,23 +110,12 @@ export const wizardHandlers: GatewayRequestHandlers = { return; } session.cancel(); - const status = { - status: session.getStatus(), - error: session.getError(), - }; + const status = readWizardStatus(session); context.wizardSessions.delete(sessionId); respond(true, status, undefined); }, "wizard.status": ({ params, respond, context }) => { - if (!validateWizardStatusParams(params)) { - respond( - false, - undefined, - errorShape( - ErrorCodes.INVALID_REQUEST, - `invalid wizard.status params: ${formatValidationErrors(validateWizardStatusParams.errors)}`, - ), - ); + if (!assertValidParams(params, validateWizardStatusParams, "wizard.status", respond)) { return; } const sessionId = params.sessionId; @@ -127,10 +124,7 @@ export const wizardHandlers: GatewayRequestHandlers = { respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, "wizard not found")); return; } - const status = { - status: session.getStatus(), - error: session.getError(), - }; + const status = readWizardStatus(session); if (status.status !== "running") { context.wizardSessions.delete(sessionId); }