diff --git a/src/infra/outbound/channel-selection.test.ts b/src/infra/outbound/channel-selection.test.ts new file mode 100644 index 000000000..15642a33b --- /dev/null +++ b/src/infra/outbound/channel-selection.test.ts @@ -0,0 +1,91 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const mocks = vi.hoisted(() => ({ + listChannelPlugins: vi.fn(), +})); + +vi.mock("../../channels/plugins/index.js", () => ({ + listChannelPlugins: mocks.listChannelPlugins, +})); + +import { resolveMessageChannelSelection } from "./channel-selection.js"; + +describe("resolveMessageChannelSelection", () => { + beforeEach(() => { + mocks.listChannelPlugins.mockReset(); + mocks.listChannelPlugins.mockReturnValue([]); + }); + + it("keeps explicit known channels and marks source explicit", async () => { + const selection = await resolveMessageChannelSelection({ + cfg: {} as never, + channel: "telegram", + }); + + expect(selection).toEqual({ + channel: "telegram", + configured: [], + source: "explicit", + }); + }); + + it("falls back to tool context channel when explicit channel is unknown", async () => { + const selection = await resolveMessageChannelSelection({ + cfg: {} as never, + channel: "channel:C123", + fallbackChannel: "slack", + }); + + expect(selection).toEqual({ + channel: "slack", + configured: [], + source: "tool-context-fallback", + }); + }); + + it("uses fallback channel when explicit channel is omitted", async () => { + const selection = await resolveMessageChannelSelection({ + cfg: {} as never, + fallbackChannel: "signal", + }); + + expect(selection).toEqual({ + channel: "signal", + configured: [], + source: "tool-context-fallback", + }); + }); + + it("selects single configured channel when no explicit/fallback channel exists", async () => { + mocks.listChannelPlugins.mockReturnValue([ + { + id: "discord", + config: { + listAccountIds: () => ["default"], + resolveAccount: () => ({}), + isConfigured: async () => true, + }, + }, + ]); + + const selection = await resolveMessageChannelSelection({ + cfg: {} as never, + }); + + expect(selection).toEqual({ + channel: "discord", + configured: ["discord"], + source: "single-configured", + }); + }); + + it("throws unknown channel when explicit and fallback channels are both invalid", async () => { + await expect( + resolveMessageChannelSelection({ + cfg: {} as never, + channel: "channel:C123", + fallbackChannel: "not-a-channel", + }), + ).rejects.toThrow("Unknown channel: channel:c123"); + }); +}); diff --git a/src/infra/outbound/channel-selection.ts b/src/infra/outbound/channel-selection.ts index a8ba2b699..9fbd592a5 100644 --- a/src/infra/outbound/channel-selection.ts +++ b/src/infra/outbound/channel-selection.ts @@ -4,10 +4,15 @@ import type { OpenClawConfig } from "../../config/config.js"; import { listDeliverableMessageChannels, type DeliverableMessageChannel, + isDeliverableMessageChannel, normalizeMessageChannel, } from "../../utils/message-channel.js"; export type MessageChannelId = DeliverableMessageChannel; +export type MessageChannelSelectionSource = + | "explicit" + | "tool-context-fallback" + | "single-configured"; const getMessageChannels = () => listDeliverableMessageChannels(); @@ -15,6 +20,20 @@ function isKnownChannel(value: string): boolean { return getMessageChannels().includes(value as MessageChannelId); } +function resolveKnownChannel(value?: string | null): MessageChannelId | undefined { + const normalized = normalizeMessageChannel(value); + if (!normalized) { + return undefined; + } + if (!isDeliverableMessageChannel(normalized)) { + return undefined; + } + if (!isKnownChannel(normalized)) { + return undefined; + } + return normalized as MessageChannelId; +} + function isAccountEnabled(account: unknown): boolean { if (!account || typeof account !== "object") { return true; @@ -67,21 +86,44 @@ export async function listConfiguredMessageChannels( export async function resolveMessageChannelSelection(params: { cfg: OpenClawConfig; channel?: string | null; -}): Promise<{ channel: MessageChannelId; configured: MessageChannelId[] }> { + fallbackChannel?: string | null; +}): Promise<{ + channel: MessageChannelId; + configured: MessageChannelId[]; + source: MessageChannelSelectionSource; +}> { const normalized = normalizeMessageChannel(params.channel); if (normalized) { if (!isKnownChannel(normalized)) { + const fallback = resolveKnownChannel(params.fallbackChannel); + if (fallback) { + return { + channel: fallback, + configured: await listConfiguredMessageChannels(params.cfg), + source: "tool-context-fallback", + }; + } throw new Error(`Unknown channel: ${String(normalized)}`); } return { channel: normalized as MessageChannelId, configured: await listConfiguredMessageChannels(params.cfg), + source: "explicit", + }; + } + + const fallback = resolveKnownChannel(params.fallbackChannel); + if (fallback) { + return { + channel: fallback, + configured: await listConfiguredMessageChannels(params.cfg), + source: "tool-context-fallback", }; } const configured = await listConfiguredMessageChannels(params.cfg); if (configured.length === 1) { - return { channel: configured[0], configured }; + return { channel: configured[0], configured, source: "single-configured" }; } if (configured.length === 0) { throw new Error("Channel is required (no configured channels detected)."); diff --git a/src/infra/outbound/message-action-normalization.test.ts b/src/infra/outbound/message-action-normalization.test.ts new file mode 100644 index 000000000..8acf557ef --- /dev/null +++ b/src/infra/outbound/message-action-normalization.test.ts @@ -0,0 +1,68 @@ +import { describe, expect, it } from "vitest"; +import { normalizeMessageActionInput } from "./message-action-normalization.js"; + +describe("normalizeMessageActionInput", () => { + it("prefers explicit target and clears legacy target fields", () => { + const normalized = normalizeMessageActionInput({ + action: "send", + args: { + target: "channel:C1", + to: "legacy", + channelId: "legacy-channel", + }, + }); + + expect(normalized.target).toBe("channel:C1"); + expect(normalized.to).toBe("channel:C1"); + expect("channelId" in normalized).toBe(false); + }); + + it("maps legacy target fields into canonical target", () => { + const normalized = normalizeMessageActionInput({ + action: "send", + args: { + to: "channel:C1", + }, + }); + + expect(normalized.target).toBe("channel:C1"); + expect(normalized.to).toBe("channel:C1"); + }); + + it("infers target from tool context when required", () => { + const normalized = normalizeMessageActionInput({ + action: "send", + args: {}, + toolContext: { + currentChannelId: "channel:C1", + }, + }); + + expect(normalized.target).toBe("channel:C1"); + expect(normalized.to).toBe("channel:C1"); + }); + + it("infers channel from tool context provider", () => { + const normalized = normalizeMessageActionInput({ + action: "send", + args: { + target: "channel:C1", + }, + toolContext: { + currentChannelId: "C1", + currentChannelProvider: "slack", + }, + }); + + expect(normalized.channel).toBe("slack"); + }); + + it("throws when required target remains unresolved", () => { + expect(() => + normalizeMessageActionInput({ + action: "send", + args: {}, + }), + ).toThrow(/requires a target/); + }); +}); diff --git a/src/infra/outbound/message-action-normalization.ts b/src/infra/outbound/message-action-normalization.ts new file mode 100644 index 000000000..4047a7e26 --- /dev/null +++ b/src/infra/outbound/message-action-normalization.ts @@ -0,0 +1,70 @@ +import type { + ChannelMessageActionName, + ChannelThreadingToolContext, +} from "../../channels/plugins/types.js"; +import { + isDeliverableMessageChannel, + normalizeMessageChannel, +} from "../../utils/message-channel.js"; +import { applyTargetToParams } from "./channel-target.js"; +import { actionHasTarget, actionRequiresTarget } from "./message-action-spec.js"; + +export function normalizeMessageActionInput(params: { + action: ChannelMessageActionName; + args: Record; + toolContext?: ChannelThreadingToolContext; +}): Record { + const normalizedArgs = { ...params.args }; + const { action, toolContext } = params; + + const explicitTarget = + typeof normalizedArgs.target === "string" ? normalizedArgs.target.trim() : ""; + const hasLegacyTarget = + (typeof normalizedArgs.to === "string" && normalizedArgs.to.trim().length > 0) || + (typeof normalizedArgs.channelId === "string" && normalizedArgs.channelId.trim().length > 0); + + if (explicitTarget && hasLegacyTarget) { + delete normalizedArgs.to; + delete normalizedArgs.channelId; + } + + if ( + !explicitTarget && + !hasLegacyTarget && + actionRequiresTarget(action) && + !actionHasTarget(action, normalizedArgs) + ) { + const inferredTarget = toolContext?.currentChannelId?.trim(); + if (inferredTarget) { + normalizedArgs.target = inferredTarget; + } + } + + if (!explicitTarget && actionRequiresTarget(action) && hasLegacyTarget) { + const legacyTo = typeof normalizedArgs.to === "string" ? normalizedArgs.to.trim() : ""; + const legacyChannelId = + typeof normalizedArgs.channelId === "string" ? normalizedArgs.channelId.trim() : ""; + const legacyTarget = legacyTo || legacyChannelId; + if (legacyTarget) { + normalizedArgs.target = legacyTarget; + delete normalizedArgs.to; + delete normalizedArgs.channelId; + } + } + + const explicitChannel = + typeof normalizedArgs.channel === "string" ? normalizedArgs.channel.trim() : ""; + if (!explicitChannel) { + const inferredChannel = normalizeMessageChannel(toolContext?.currentChannelProvider); + if (inferredChannel && isDeliverableMessageChannel(inferredChannel)) { + normalizedArgs.channel = inferredChannel; + } + } + + applyTargetToParams({ action, args: normalizedArgs }); + if (actionRequiresTarget(action) && !actionHasTarget(action, normalizedArgs)) { + throw new Error(`Action ${action} requires a target.`); + } + + return normalizedArgs; +} diff --git a/src/infra/outbound/message-action-runner.ts b/src/infra/outbound/message-action-runner.ts index 0336db6f2..d8ec94190 100644 --- a/src/infra/outbound/message-action-runner.ts +++ b/src/infra/outbound/message-action-runner.ts @@ -16,19 +16,14 @@ import type { OpenClawConfig } from "../../config/config.js"; import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js"; import { buildChannelAccountBindings } from "../../routing/bindings.js"; import { normalizeAgentId } from "../../routing/session-key.js"; -import { - isDeliverableMessageChannel, - normalizeMessageChannel, - type GatewayClientMode, - type GatewayClientName, -} from "../../utils/message-channel.js"; +import { type GatewayClientMode, type GatewayClientName } from "../../utils/message-channel.js"; import { throwIfAborted } from "./abort.js"; import { listConfiguredMessageChannels, resolveMessageChannelSelection, } from "./channel-selection.js"; -import { applyTargetToParams } from "./channel-target.js"; import type { OutboundSendDeps } from "./deliver.js"; +import { normalizeMessageActionInput } from "./message-action-normalization.js"; import { hydrateAttachmentParamsForAction, normalizeSandboxMediaList, @@ -41,7 +36,6 @@ import { resolveSlackAutoThreadId, resolveTelegramAutoThreadId, } from "./message-action-params.js"; -import { actionHasTarget, actionRequiresTarget } from "./message-action-spec.js"; import type { MessagePollResult, MessageSendResult } from "./message.js"; import { applyCrossContextDecoration, @@ -222,23 +216,15 @@ async function resolveChannel( params: Record, toolContext?: { currentChannelProvider?: string }, ) { - const channelHint = readStringParam(params, "channel"); - try { - const selection = await resolveMessageChannelSelection({ - cfg, - channel: channelHint, - }); - return selection.channel; - } catch (error) { - if (channelHint && toolContext?.currentChannelProvider) { - const fallback = normalizeMessageChannel(toolContext.currentChannelProvider); - if (fallback && isDeliverableMessageChannel(fallback)) { - params.channel = fallback; - return fallback; - } - } - throw error; + const selection = await resolveMessageChannelSelection({ + cfg, + channel: readStringParam(params, "channel"), + fallbackChannel: toolContext?.currentChannelProvider, + }); + if (selection.source === "tool-context-fallback") { + params.channel = selection.channel; } + return selection.channel; } async function resolveActionTarget(params: { @@ -710,7 +696,7 @@ export async function runMessageAction( input: RunMessageActionParams, ): Promise { const cfg = input.cfg; - const params = { ...input.params }; + let params = { ...input.params }; const resolvedAgentId = input.agentId ?? (input.sessionKey @@ -724,50 +710,11 @@ export async function runMessageAction( if (action === "broadcast") { return handleBroadcastAction(input, params); } - - const explicitTarget = typeof params.target === "string" ? params.target.trim() : ""; - const hasLegacyTarget = - (typeof params.to === "string" && params.to.trim().length > 0) || - (typeof params.channelId === "string" && params.channelId.trim().length > 0); - if (explicitTarget && hasLegacyTarget) { - delete params.to; - delete params.channelId; - } - if ( - !explicitTarget && - !hasLegacyTarget && - actionRequiresTarget(action) && - !actionHasTarget(action, params) - ) { - const inferredTarget = input.toolContext?.currentChannelId?.trim(); - if (inferredTarget) { - params.target = inferredTarget; - } - } - if (!explicitTarget && actionRequiresTarget(action) && hasLegacyTarget) { - const legacyTo = typeof params.to === "string" ? params.to.trim() : ""; - const legacyChannelId = typeof params.channelId === "string" ? params.channelId.trim() : ""; - const legacyTarget = legacyTo || legacyChannelId; - if (legacyTarget) { - params.target = legacyTarget; - delete params.to; - delete params.channelId; - } - } - const explicitChannel = typeof params.channel === "string" ? params.channel.trim() : ""; - if (!explicitChannel) { - const inferredChannel = normalizeMessageChannel(input.toolContext?.currentChannelProvider); - if (inferredChannel && isDeliverableMessageChannel(inferredChannel)) { - params.channel = inferredChannel; - } - } - - applyTargetToParams({ action, args: params }); - if (actionRequiresTarget(action)) { - if (!actionHasTarget(action, params)) { - throw new Error(`Action ${action} requires a target.`); - } - } + params = normalizeMessageActionInput({ + action, + args: params, + toolContext: input.toolContext, + }); const channel = await resolveChannel(cfg, params, input.toolContext); let accountId = readStringParam(params, "accountId") ?? input.defaultAccountId; diff --git a/src/infra/outbound/message.test.ts b/src/infra/outbound/message.test.ts index 36780b995..7cebff01d 100644 --- a/src/infra/outbound/message.test.ts +++ b/src/infra/outbound/message.test.ts @@ -10,6 +10,7 @@ const mocks = vi.hoisted(() => ({ vi.mock("../../channels/plugins/index.js", () => ({ normalizeChannelId: (channel?: string) => channel?.trim().toLowerCase() ?? undefined, getChannelPlugin: mocks.getChannelPlugin, + listChannelPlugins: () => [], })); vi.mock("../../agents/agent-scope.js", () => ({ diff --git a/src/infra/outbound/message.ts b/src/infra/outbound/message.ts index 9bee14f45..f8c09538f 100644 --- a/src/infra/outbound/message.ts +++ b/src/infra/outbound/message.ts @@ -9,10 +9,7 @@ import { type GatewayClientMode, type GatewayClientName, } from "../../utils/message-channel.js"; -import { - normalizeDeliverableOutboundChannel, - resolveOutboundChannelPlugin, -} from "./channel-resolution.js"; +import { resolveOutboundChannelPlugin } from "./channel-resolution.js"; import { resolveMessageChannelSelection } from "./channel-selection.js"; import { deliverOutboundPayloads, @@ -111,14 +108,12 @@ async function resolveRequiredChannel(params: { cfg: OpenClawConfig; channel?: string; }): Promise { - if (params.channel?.trim()) { - const normalized = normalizeDeliverableOutboundChannel(params.channel); - if (!normalized) { - throw new Error(`Unknown channel: ${params.channel}`); - } - return normalized; - } - return (await resolveMessageChannelSelection({ cfg: params.cfg })).channel; + return ( + await resolveMessageChannelSelection({ + cfg: params.cfg, + channel: params.channel, + }) + ).channel; } function resolveRequiredPlugin(channel: string, cfg: OpenClawConfig) {