refactor(discord): unify DM command auth handling

This commit is contained in:
Peter Steinberger
2026-03-01 23:59:55 +00:00
parent 12c1257023
commit 75596e9370
4 changed files with 120 additions and 64 deletions

View File

@@ -17,6 +17,32 @@ export type DiscordDmCommandAccess = {
allowMatch: ReturnType<typeof resolveDiscordAllowListMatch> | { allowed: false };
};
function resolveSenderAllowMatch(params: {
allowEntries: string[];
sender: { id: string; name?: string; tag?: string };
allowNameMatching: boolean;
}) {
const allowList = normalizeDiscordAllowList(params.allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
return allowList
? resolveDiscordAllowListMatch({
allowList,
candidate: params.sender,
allowNameMatching: params.allowNameMatching,
})
: ({ allowed: false } as const);
}
function resolveDmPolicyCommandAuthorization(params: {
dmPolicy: DiscordDmPolicy;
decision: DmGroupAccessDecision;
commandAuthorized: boolean;
}) {
if (params.dmPolicy === "open" && params.decision === "allow") {
return true;
}
return params.commandAuthorized;
}
export async function resolveDiscordDmCommandAccess(params: {
accountId: string;
dmPolicy: DiscordDmPolicy;
@@ -40,30 +66,19 @@ export async function resolveDiscordDmCommandAccess(params: {
allowFrom: params.configuredAllowFrom,
groupAllowFrom: [],
storeAllowFrom,
isSenderAllowed: (allowEntries) => {
const allowList = normalizeDiscordAllowList(allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
const allowMatch = allowList
? resolveDiscordAllowListMatch({
allowList,
candidate: params.sender,
allowNameMatching: params.allowNameMatching,
})
: { allowed: false };
return allowMatch.allowed;
},
isSenderAllowed: (allowEntries) =>
resolveSenderAllowMatch({
allowEntries,
sender: params.sender,
allowNameMatching: params.allowNameMatching,
}).allowed,
});
const commandAllowList = normalizeDiscordAllowList(
access.effectiveAllowFrom,
DISCORD_ALLOW_LIST_PREFIXES,
);
const allowMatch = commandAllowList
? resolveDiscordAllowListMatch({
allowList: commandAllowList,
candidate: params.sender,
allowNameMatching: params.allowNameMatching,
})
: { allowed: false };
const allowMatch = resolveSenderAllowMatch({
allowEntries: access.effectiveAllowFrom,
sender: params.sender,
allowNameMatching: params.allowNameMatching,
});
const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({
useAccessGroups: params.useAccessGroups,
@@ -75,13 +90,15 @@ export async function resolveDiscordDmCommandAccess(params: {
],
modeWhenAccessGroupsOff: "configured",
});
const effectiveCommandAuthorized =
access.decision === "allow" && params.dmPolicy === "open" ? true : commandAuthorized;
return {
decision: access.decision,
reason: access.reason,
commandAuthorized: effectiveCommandAuthorized,
commandAuthorized: resolveDmPolicyCommandAuthorization({
dmPolicy: params.dmPolicy,
decision: access.decision,
commandAuthorized,
}),
allowMatch,
};
}

View File

@@ -0,0 +1,39 @@
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import type { DiscordDmCommandAccess } from "./dm-command-auth.js";
export async function handleDiscordDmCommandDecision(params: {
dmAccess: DiscordDmCommandAccess;
accountId: string;
sender: {
id: string;
tag?: string;
name?: string;
};
onPairingCreated: (code: string) => Promise<void>;
onUnauthorized: () => Promise<void>;
upsertPairingRequest?: typeof upsertChannelPairingRequest;
}): Promise<boolean> {
if (params.dmAccess.decision === "allow") {
return true;
}
if (params.dmAccess.decision === "pairing") {
const upsertPairingRequest = params.upsertPairingRequest ?? upsertChannelPairingRequest;
const { code, created } = await upsertPairingRequest({
channel: "discord",
id: params.sender.id,
accountId: params.accountId,
meta: {
tag: params.sender.tag,
name: params.sender.name,
},
});
if (created) {
await params.onPairingCreated(code);
}
return false;
}
await params.onUnauthorized();
return false;
}

View File

@@ -25,7 +25,6 @@ import { enqueueSystemEvent } from "../../infra/system-events.js";
import { logDebug } from "../../logger.js";
import { getChildLogger } from "../../logging.js";
import { buildPairingReply } from "../../pairing/pairing-messages.js";
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import { resolveAgentRoute } from "../../routing/resolve-route.js";
import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
import { fetchPluralKitMessageInfo } from "../pluralkit.js";
@@ -42,6 +41,7 @@ import {
resolveGroupDmAllow,
} from "./allow-list.js";
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
import {
formatDiscordUserTag,
resolveDiscordSystemLocation,
@@ -175,6 +175,7 @@ export async function preflightDiscordMessage(
const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing";
const useAccessGroups = params.cfg.commands?.useAccessGroups !== false;
const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID;
const allowNameMatching = isDangerousNameMatchingEnabled(params.discordConfig);
let commandAuthorized = true;
if (isDirectMessage) {
if (dmPolicy === "disabled") {
@@ -190,7 +191,7 @@ export async function preflightDiscordMessage(
name: sender.name,
tag: sender.tag,
},
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
allowNameMatching,
useAccessGroups,
});
commandAuthorized = dmAccess.commandAuthorized;
@@ -198,17 +199,15 @@ export async function preflightDiscordMessage(
const allowMatchMeta = formatAllowlistMatchMeta(
dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined,
);
if (dmAccess.decision === "pairing") {
const { code, created } = await upsertChannelPairingRequest({
channel: "discord",
await handleDiscordDmCommandDecision({
dmAccess,
accountId: resolvedAccountId,
sender: {
id: author.id,
accountId: resolvedAccountId,
meta: {
tag: formatDiscordUserTag(author),
name: author.username ?? undefined,
},
});
if (created) {
tag: formatDiscordUserTag(author),
name: author.username ?? undefined,
},
onPairingCreated: async (code) => {
logVerbose(
`discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`,
);
@@ -229,12 +228,13 @@ export async function preflightDiscordMessage(
} catch (err) {
logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`);
}
}
} else {
logVerbose(
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
);
}
},
onUnauthorized: async () => {
logVerbose(
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
);
},
});
return null;
}
}
@@ -570,7 +570,7 @@ export async function preflightDiscordMessage(
guildInfo,
memberRoleIds,
sender,
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
allowNameMatching,
});
if (!isDirectMessage) {
@@ -587,7 +587,7 @@ export async function preflightDiscordMessage(
name: sender.name,
tag: sender.tag,
},
{ allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig) },
{ allowNameMatching },
)
: false;
const commandGate = resolveControlCommandGate({

View File

@@ -46,7 +46,6 @@ import { logVerbose } from "../../globals.js";
import { createSubsystemLogger } from "../../logging/subsystem.js";
import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js";
import { buildPairingReply } from "../../pairing/pairing-messages.js";
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import { resolveAgentRoute } from "../../routing/resolve-route.js";
import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js";
@@ -65,6 +64,7 @@ import {
resolveDiscordOwnerAllowFrom,
} from "./allow-list.js";
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
import { resolveDiscordChannelInfo } from "./message-utils.js";
import {
readDiscordModelPickerRecentModels,
@@ -1269,6 +1269,7 @@ async function dispatchDiscordCommandInteraction(params: {
const memberRoleIds = Array.isArray(interaction.rawData.member?.roles)
? interaction.rawData.member.roles.map((roleId: string) => String(roleId))
: [];
const allowNameMatching = isDangerousNameMatchingEnabled(discordConfig);
const ownerAllowList = normalizeDiscordAllowList(
discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [],
["discord:", "user:", "pk:"],
@@ -1282,7 +1283,7 @@ async function dispatchDiscordCommandInteraction(params: {
name: sender.name,
tag: sender.tag,
},
{ allowNameMatching: isDangerousNameMatchingEnabled(discordConfig) },
{ allowNameMatching },
)
: false;
const guildInfo = resolveDiscordGuildEntry({
@@ -1366,22 +1367,20 @@ async function dispatchDiscordCommandInteraction(params: {
name: sender.name,
tag: sender.tag,
},
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
allowNameMatching,
useAccessGroups,
});
commandAuthorized = dmAccess.commandAuthorized;
if (dmAccess.decision !== "allow") {
if (dmAccess.decision === "pairing") {
const { code, created } = await upsertChannelPairingRequest({
channel: "discord",
await handleDiscordDmCommandDecision({
dmAccess,
accountId,
sender: {
id: user.id,
accountId,
meta: {
tag: sender.tag,
name: sender.name,
},
});
if (created) {
tag: sender.tag,
name: sender.name,
},
onPairingCreated: async (code) => {
await respond(
buildPairingReply({
channel: "discord",
@@ -1390,10 +1389,11 @@ async function dispatchDiscordCommandInteraction(params: {
}),
{ ephemeral: true },
);
}
} else {
await respond("You are not authorized to use this command.", { ephemeral: true });
}
},
onUnauthorized: async () => {
await respond("You are not authorized to use this command.", { ephemeral: true });
},
});
return;
}
}
@@ -1403,7 +1403,7 @@ async function dispatchDiscordCommandInteraction(params: {
guildInfo,
memberRoleIds,
sender,
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
allowNameMatching,
});
const authorizers = useAccessGroups
? [
@@ -1509,7 +1509,7 @@ async function dispatchDiscordCommandInteraction(params: {
channelConfig,
guildInfo,
sender: { id: sender.id, name: sender.name, tag: sender.tag },
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
allowNameMatching,
});
const ctxPayload = finalizeInboundContext({
Body: prompt,