diff --git a/src/discord/monitor/dm-command-auth.ts b/src/discord/monitor/dm-command-auth.ts index 45700abe3..2a9e18be0 100644 --- a/src/discord/monitor/dm-command-auth.ts +++ b/src/discord/monitor/dm-command-auth.ts @@ -17,6 +17,32 @@ export type DiscordDmCommandAccess = { allowMatch: ReturnType | { allowed: false }; }; +function resolveSenderAllowMatch(params: { + allowEntries: string[]; + sender: { id: string; name?: string; tag?: string }; + allowNameMatching: boolean; +}) { + const allowList = normalizeDiscordAllowList(params.allowEntries, DISCORD_ALLOW_LIST_PREFIXES); + return allowList + ? resolveDiscordAllowListMatch({ + allowList, + candidate: params.sender, + allowNameMatching: params.allowNameMatching, + }) + : ({ allowed: false } as const); +} + +function resolveDmPolicyCommandAuthorization(params: { + dmPolicy: DiscordDmPolicy; + decision: DmGroupAccessDecision; + commandAuthorized: boolean; +}) { + if (params.dmPolicy === "open" && params.decision === "allow") { + return true; + } + return params.commandAuthorized; +} + export async function resolveDiscordDmCommandAccess(params: { accountId: string; dmPolicy: DiscordDmPolicy; @@ -40,30 +66,19 @@ export async function resolveDiscordDmCommandAccess(params: { allowFrom: params.configuredAllowFrom, groupAllowFrom: [], storeAllowFrom, - isSenderAllowed: (allowEntries) => { - const allowList = normalizeDiscordAllowList(allowEntries, DISCORD_ALLOW_LIST_PREFIXES); - const allowMatch = allowList - ? resolveDiscordAllowListMatch({ - allowList, - candidate: params.sender, - allowNameMatching: params.allowNameMatching, - }) - : { allowed: false }; - return allowMatch.allowed; - }, + isSenderAllowed: (allowEntries) => + resolveSenderAllowMatch({ + allowEntries, + sender: params.sender, + allowNameMatching: params.allowNameMatching, + }).allowed, }); - const commandAllowList = normalizeDiscordAllowList( - access.effectiveAllowFrom, - DISCORD_ALLOW_LIST_PREFIXES, - ); - const allowMatch = commandAllowList - ? resolveDiscordAllowListMatch({ - allowList: commandAllowList, - candidate: params.sender, - allowNameMatching: params.allowNameMatching, - }) - : { allowed: false }; + const allowMatch = resolveSenderAllowMatch({ + allowEntries: access.effectiveAllowFrom, + sender: params.sender, + allowNameMatching: params.allowNameMatching, + }); const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({ useAccessGroups: params.useAccessGroups, @@ -75,13 +90,15 @@ export async function resolveDiscordDmCommandAccess(params: { ], modeWhenAccessGroupsOff: "configured", }); - const effectiveCommandAuthorized = - access.decision === "allow" && params.dmPolicy === "open" ? true : commandAuthorized; return { decision: access.decision, reason: access.reason, - commandAuthorized: effectiveCommandAuthorized, + commandAuthorized: resolveDmPolicyCommandAuthorization({ + dmPolicy: params.dmPolicy, + decision: access.decision, + commandAuthorized, + }), allowMatch, }; } diff --git a/src/discord/monitor/dm-command-decision.ts b/src/discord/monitor/dm-command-decision.ts new file mode 100644 index 000000000..a0f64fdfb --- /dev/null +++ b/src/discord/monitor/dm-command-decision.ts @@ -0,0 +1,39 @@ +import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js"; +import type { DiscordDmCommandAccess } from "./dm-command-auth.js"; + +export async function handleDiscordDmCommandDecision(params: { + dmAccess: DiscordDmCommandAccess; + accountId: string; + sender: { + id: string; + tag?: string; + name?: string; + }; + onPairingCreated: (code: string) => Promise; + onUnauthorized: () => Promise; + upsertPairingRequest?: typeof upsertChannelPairingRequest; +}): Promise { + if (params.dmAccess.decision === "allow") { + return true; + } + + if (params.dmAccess.decision === "pairing") { + const upsertPairingRequest = params.upsertPairingRequest ?? upsertChannelPairingRequest; + const { code, created } = await upsertPairingRequest({ + channel: "discord", + id: params.sender.id, + accountId: params.accountId, + meta: { + tag: params.sender.tag, + name: params.sender.name, + }, + }); + if (created) { + await params.onPairingCreated(code); + } + return false; + } + + await params.onUnauthorized(); + return false; +} diff --git a/src/discord/monitor/message-handler.preflight.ts b/src/discord/monitor/message-handler.preflight.ts index 5a13bb1b6..1db20111a 100644 --- a/src/discord/monitor/message-handler.preflight.ts +++ b/src/discord/monitor/message-handler.preflight.ts @@ -25,7 +25,6 @@ import { enqueueSystemEvent } from "../../infra/system-events.js"; import { logDebug } from "../../logger.js"; import { getChildLogger } from "../../logging.js"; import { buildPairingReply } from "../../pairing/pairing-messages.js"; -import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js"; import { resolveAgentRoute } from "../../routing/resolve-route.js"; import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js"; import { fetchPluralKitMessageInfo } from "../pluralkit.js"; @@ -42,6 +41,7 @@ import { resolveGroupDmAllow, } from "./allow-list.js"; import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js"; +import { handleDiscordDmCommandDecision } from "./dm-command-decision.js"; import { formatDiscordUserTag, resolveDiscordSystemLocation, @@ -175,6 +175,7 @@ export async function preflightDiscordMessage( const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing"; const useAccessGroups = params.cfg.commands?.useAccessGroups !== false; const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID; + const allowNameMatching = isDangerousNameMatchingEnabled(params.discordConfig); let commandAuthorized = true; if (isDirectMessage) { if (dmPolicy === "disabled") { @@ -190,7 +191,7 @@ export async function preflightDiscordMessage( name: sender.name, tag: sender.tag, }, - allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig), + allowNameMatching, useAccessGroups, }); commandAuthorized = dmAccess.commandAuthorized; @@ -198,17 +199,15 @@ export async function preflightDiscordMessage( const allowMatchMeta = formatAllowlistMatchMeta( dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined, ); - if (dmAccess.decision === "pairing") { - const { code, created } = await upsertChannelPairingRequest({ - channel: "discord", + await handleDiscordDmCommandDecision({ + dmAccess, + accountId: resolvedAccountId, + sender: { id: author.id, - accountId: resolvedAccountId, - meta: { - tag: formatDiscordUserTag(author), - name: author.username ?? undefined, - }, - }); - if (created) { + tag: formatDiscordUserTag(author), + name: author.username ?? undefined, + }, + onPairingCreated: async (code) => { logVerbose( `discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`, ); @@ -229,12 +228,13 @@ export async function preflightDiscordMessage( } catch (err) { logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`); } - } - } else { - logVerbose( - `Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`, - ); - } + }, + onUnauthorized: async () => { + logVerbose( + `Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`, + ); + }, + }); return null; } } @@ -570,7 +570,7 @@ export async function preflightDiscordMessage( guildInfo, memberRoleIds, sender, - allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig), + allowNameMatching, }); if (!isDirectMessage) { @@ -587,7 +587,7 @@ export async function preflightDiscordMessage( name: sender.name, tag: sender.tag, }, - { allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig) }, + { allowNameMatching }, ) : false; const commandGate = resolveControlCommandGate({ diff --git a/src/discord/monitor/native-command.ts b/src/discord/monitor/native-command.ts index 3a21118c5..61d446ca2 100644 --- a/src/discord/monitor/native-command.ts +++ b/src/discord/monitor/native-command.ts @@ -46,7 +46,6 @@ import { logVerbose } from "../../globals.js"; import { createSubsystemLogger } from "../../logging/subsystem.js"; import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js"; import { buildPairingReply } from "../../pairing/pairing-messages.js"; -import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js"; import { resolveAgentRoute } from "../../routing/resolve-route.js"; import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js"; import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js"; @@ -65,6 +64,7 @@ import { resolveDiscordOwnerAllowFrom, } from "./allow-list.js"; import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js"; +import { handleDiscordDmCommandDecision } from "./dm-command-decision.js"; import { resolveDiscordChannelInfo } from "./message-utils.js"; import { readDiscordModelPickerRecentModels, @@ -1269,6 +1269,7 @@ async function dispatchDiscordCommandInteraction(params: { const memberRoleIds = Array.isArray(interaction.rawData.member?.roles) ? interaction.rawData.member.roles.map((roleId: string) => String(roleId)) : []; + const allowNameMatching = isDangerousNameMatchingEnabled(discordConfig); const ownerAllowList = normalizeDiscordAllowList( discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [], ["discord:", "user:", "pk:"], @@ -1282,7 +1283,7 @@ async function dispatchDiscordCommandInteraction(params: { name: sender.name, tag: sender.tag, }, - { allowNameMatching: isDangerousNameMatchingEnabled(discordConfig) }, + { allowNameMatching }, ) : false; const guildInfo = resolveDiscordGuildEntry({ @@ -1366,22 +1367,20 @@ async function dispatchDiscordCommandInteraction(params: { name: sender.name, tag: sender.tag, }, - allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), + allowNameMatching, useAccessGroups, }); commandAuthorized = dmAccess.commandAuthorized; if (dmAccess.decision !== "allow") { - if (dmAccess.decision === "pairing") { - const { code, created } = await upsertChannelPairingRequest({ - channel: "discord", + await handleDiscordDmCommandDecision({ + dmAccess, + accountId, + sender: { id: user.id, - accountId, - meta: { - tag: sender.tag, - name: sender.name, - }, - }); - if (created) { + tag: sender.tag, + name: sender.name, + }, + onPairingCreated: async (code) => { await respond( buildPairingReply({ channel: "discord", @@ -1390,10 +1389,11 @@ async function dispatchDiscordCommandInteraction(params: { }), { ephemeral: true }, ); - } - } else { - await respond("You are not authorized to use this command.", { ephemeral: true }); - } + }, + onUnauthorized: async () => { + await respond("You are not authorized to use this command.", { ephemeral: true }); + }, + }); return; } } @@ -1403,7 +1403,7 @@ async function dispatchDiscordCommandInteraction(params: { guildInfo, memberRoleIds, sender, - allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), + allowNameMatching, }); const authorizers = useAccessGroups ? [ @@ -1509,7 +1509,7 @@ async function dispatchDiscordCommandInteraction(params: { channelConfig, guildInfo, sender: { id: sender.id, name: sender.name, tag: sender.tag }, - allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), + allowNameMatching, }); const ctxPayload = finalizeInboundContext({ Body: prompt,