refactor(discord): unify DM command auth handling
This commit is contained in:
@@ -17,6 +17,32 @@ export type DiscordDmCommandAccess = {
|
||||
allowMatch: ReturnType<typeof resolveDiscordAllowListMatch> | { allowed: false };
|
||||
};
|
||||
|
||||
function resolveSenderAllowMatch(params: {
|
||||
allowEntries: string[];
|
||||
sender: { id: string; name?: string; tag?: string };
|
||||
allowNameMatching: boolean;
|
||||
}) {
|
||||
const allowList = normalizeDiscordAllowList(params.allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
|
||||
return allowList
|
||||
? resolveDiscordAllowListMatch({
|
||||
allowList,
|
||||
candidate: params.sender,
|
||||
allowNameMatching: params.allowNameMatching,
|
||||
})
|
||||
: ({ allowed: false } as const);
|
||||
}
|
||||
|
||||
function resolveDmPolicyCommandAuthorization(params: {
|
||||
dmPolicy: DiscordDmPolicy;
|
||||
decision: DmGroupAccessDecision;
|
||||
commandAuthorized: boolean;
|
||||
}) {
|
||||
if (params.dmPolicy === "open" && params.decision === "allow") {
|
||||
return true;
|
||||
}
|
||||
return params.commandAuthorized;
|
||||
}
|
||||
|
||||
export async function resolveDiscordDmCommandAccess(params: {
|
||||
accountId: string;
|
||||
dmPolicy: DiscordDmPolicy;
|
||||
@@ -40,30 +66,19 @@ export async function resolveDiscordDmCommandAccess(params: {
|
||||
allowFrom: params.configuredAllowFrom,
|
||||
groupAllowFrom: [],
|
||||
storeAllowFrom,
|
||||
isSenderAllowed: (allowEntries) => {
|
||||
const allowList = normalizeDiscordAllowList(allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
|
||||
const allowMatch = allowList
|
||||
? resolveDiscordAllowListMatch({
|
||||
allowList,
|
||||
candidate: params.sender,
|
||||
allowNameMatching: params.allowNameMatching,
|
||||
})
|
||||
: { allowed: false };
|
||||
return allowMatch.allowed;
|
||||
},
|
||||
isSenderAllowed: (allowEntries) =>
|
||||
resolveSenderAllowMatch({
|
||||
allowEntries,
|
||||
sender: params.sender,
|
||||
allowNameMatching: params.allowNameMatching,
|
||||
}).allowed,
|
||||
});
|
||||
|
||||
const commandAllowList = normalizeDiscordAllowList(
|
||||
access.effectiveAllowFrom,
|
||||
DISCORD_ALLOW_LIST_PREFIXES,
|
||||
);
|
||||
const allowMatch = commandAllowList
|
||||
? resolveDiscordAllowListMatch({
|
||||
allowList: commandAllowList,
|
||||
candidate: params.sender,
|
||||
allowNameMatching: params.allowNameMatching,
|
||||
})
|
||||
: { allowed: false };
|
||||
const allowMatch = resolveSenderAllowMatch({
|
||||
allowEntries: access.effectiveAllowFrom,
|
||||
sender: params.sender,
|
||||
allowNameMatching: params.allowNameMatching,
|
||||
});
|
||||
|
||||
const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({
|
||||
useAccessGroups: params.useAccessGroups,
|
||||
@@ -75,13 +90,15 @@ export async function resolveDiscordDmCommandAccess(params: {
|
||||
],
|
||||
modeWhenAccessGroupsOff: "configured",
|
||||
});
|
||||
const effectiveCommandAuthorized =
|
||||
access.decision === "allow" && params.dmPolicy === "open" ? true : commandAuthorized;
|
||||
|
||||
return {
|
||||
decision: access.decision,
|
||||
reason: access.reason,
|
||||
commandAuthorized: effectiveCommandAuthorized,
|
||||
commandAuthorized: resolveDmPolicyCommandAuthorization({
|
||||
dmPolicy: params.dmPolicy,
|
||||
decision: access.decision,
|
||||
commandAuthorized,
|
||||
}),
|
||||
allowMatch,
|
||||
};
|
||||
}
|
||||
|
||||
39
src/discord/monitor/dm-command-decision.ts
Normal file
39
src/discord/monitor/dm-command-decision.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
|
||||
import type { DiscordDmCommandAccess } from "./dm-command-auth.js";
|
||||
|
||||
export async function handleDiscordDmCommandDecision(params: {
|
||||
dmAccess: DiscordDmCommandAccess;
|
||||
accountId: string;
|
||||
sender: {
|
||||
id: string;
|
||||
tag?: string;
|
||||
name?: string;
|
||||
};
|
||||
onPairingCreated: (code: string) => Promise<void>;
|
||||
onUnauthorized: () => Promise<void>;
|
||||
upsertPairingRequest?: typeof upsertChannelPairingRequest;
|
||||
}): Promise<boolean> {
|
||||
if (params.dmAccess.decision === "allow") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (params.dmAccess.decision === "pairing") {
|
||||
const upsertPairingRequest = params.upsertPairingRequest ?? upsertChannelPairingRequest;
|
||||
const { code, created } = await upsertPairingRequest({
|
||||
channel: "discord",
|
||||
id: params.sender.id,
|
||||
accountId: params.accountId,
|
||||
meta: {
|
||||
tag: params.sender.tag,
|
||||
name: params.sender.name,
|
||||
},
|
||||
});
|
||||
if (created) {
|
||||
await params.onPairingCreated(code);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
await params.onUnauthorized();
|
||||
return false;
|
||||
}
|
||||
@@ -25,7 +25,6 @@ import { enqueueSystemEvent } from "../../infra/system-events.js";
|
||||
import { logDebug } from "../../logger.js";
|
||||
import { getChildLogger } from "../../logging.js";
|
||||
import { buildPairingReply } from "../../pairing/pairing-messages.js";
|
||||
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
|
||||
import { resolveAgentRoute } from "../../routing/resolve-route.js";
|
||||
import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
|
||||
import { fetchPluralKitMessageInfo } from "../pluralkit.js";
|
||||
@@ -42,6 +41,7 @@ import {
|
||||
resolveGroupDmAllow,
|
||||
} from "./allow-list.js";
|
||||
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
|
||||
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
|
||||
import {
|
||||
formatDiscordUserTag,
|
||||
resolveDiscordSystemLocation,
|
||||
@@ -175,6 +175,7 @@ export async function preflightDiscordMessage(
|
||||
const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing";
|
||||
const useAccessGroups = params.cfg.commands?.useAccessGroups !== false;
|
||||
const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID;
|
||||
const allowNameMatching = isDangerousNameMatchingEnabled(params.discordConfig);
|
||||
let commandAuthorized = true;
|
||||
if (isDirectMessage) {
|
||||
if (dmPolicy === "disabled") {
|
||||
@@ -190,7 +191,7 @@ export async function preflightDiscordMessage(
|
||||
name: sender.name,
|
||||
tag: sender.tag,
|
||||
},
|
||||
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
|
||||
allowNameMatching,
|
||||
useAccessGroups,
|
||||
});
|
||||
commandAuthorized = dmAccess.commandAuthorized;
|
||||
@@ -198,17 +199,15 @@ export async function preflightDiscordMessage(
|
||||
const allowMatchMeta = formatAllowlistMatchMeta(
|
||||
dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined,
|
||||
);
|
||||
if (dmAccess.decision === "pairing") {
|
||||
const { code, created } = await upsertChannelPairingRequest({
|
||||
channel: "discord",
|
||||
await handleDiscordDmCommandDecision({
|
||||
dmAccess,
|
||||
accountId: resolvedAccountId,
|
||||
sender: {
|
||||
id: author.id,
|
||||
accountId: resolvedAccountId,
|
||||
meta: {
|
||||
tag: formatDiscordUserTag(author),
|
||||
name: author.username ?? undefined,
|
||||
},
|
||||
});
|
||||
if (created) {
|
||||
tag: formatDiscordUserTag(author),
|
||||
name: author.username ?? undefined,
|
||||
},
|
||||
onPairingCreated: async (code) => {
|
||||
logVerbose(
|
||||
`discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`,
|
||||
);
|
||||
@@ -229,12 +228,13 @@ export async function preflightDiscordMessage(
|
||||
} catch (err) {
|
||||
logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logVerbose(
|
||||
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
|
||||
);
|
||||
}
|
||||
},
|
||||
onUnauthorized: async () => {
|
||||
logVerbose(
|
||||
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
|
||||
);
|
||||
},
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -570,7 +570,7 @@ export async function preflightDiscordMessage(
|
||||
guildInfo,
|
||||
memberRoleIds,
|
||||
sender,
|
||||
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
|
||||
allowNameMatching,
|
||||
});
|
||||
|
||||
if (!isDirectMessage) {
|
||||
@@ -587,7 +587,7 @@ export async function preflightDiscordMessage(
|
||||
name: sender.name,
|
||||
tag: sender.tag,
|
||||
},
|
||||
{ allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig) },
|
||||
{ allowNameMatching },
|
||||
)
|
||||
: false;
|
||||
const commandGate = resolveControlCommandGate({
|
||||
|
||||
@@ -46,7 +46,6 @@ import { logVerbose } from "../../globals.js";
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js";
|
||||
import { buildPairingReply } from "../../pairing/pairing-messages.js";
|
||||
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
|
||||
import { resolveAgentRoute } from "../../routing/resolve-route.js";
|
||||
import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
|
||||
import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js";
|
||||
@@ -65,6 +64,7 @@ import {
|
||||
resolveDiscordOwnerAllowFrom,
|
||||
} from "./allow-list.js";
|
||||
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
|
||||
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
|
||||
import { resolveDiscordChannelInfo } from "./message-utils.js";
|
||||
import {
|
||||
readDiscordModelPickerRecentModels,
|
||||
@@ -1269,6 +1269,7 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
const memberRoleIds = Array.isArray(interaction.rawData.member?.roles)
|
||||
? interaction.rawData.member.roles.map((roleId: string) => String(roleId))
|
||||
: [];
|
||||
const allowNameMatching = isDangerousNameMatchingEnabled(discordConfig);
|
||||
const ownerAllowList = normalizeDiscordAllowList(
|
||||
discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [],
|
||||
["discord:", "user:", "pk:"],
|
||||
@@ -1282,7 +1283,7 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
name: sender.name,
|
||||
tag: sender.tag,
|
||||
},
|
||||
{ allowNameMatching: isDangerousNameMatchingEnabled(discordConfig) },
|
||||
{ allowNameMatching },
|
||||
)
|
||||
: false;
|
||||
const guildInfo = resolveDiscordGuildEntry({
|
||||
@@ -1366,22 +1367,20 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
name: sender.name,
|
||||
tag: sender.tag,
|
||||
},
|
||||
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
|
||||
allowNameMatching,
|
||||
useAccessGroups,
|
||||
});
|
||||
commandAuthorized = dmAccess.commandAuthorized;
|
||||
if (dmAccess.decision !== "allow") {
|
||||
if (dmAccess.decision === "pairing") {
|
||||
const { code, created } = await upsertChannelPairingRequest({
|
||||
channel: "discord",
|
||||
await handleDiscordDmCommandDecision({
|
||||
dmAccess,
|
||||
accountId,
|
||||
sender: {
|
||||
id: user.id,
|
||||
accountId,
|
||||
meta: {
|
||||
tag: sender.tag,
|
||||
name: sender.name,
|
||||
},
|
||||
});
|
||||
if (created) {
|
||||
tag: sender.tag,
|
||||
name: sender.name,
|
||||
},
|
||||
onPairingCreated: async (code) => {
|
||||
await respond(
|
||||
buildPairingReply({
|
||||
channel: "discord",
|
||||
@@ -1390,10 +1389,11 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
}),
|
||||
{ ephemeral: true },
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await respond("You are not authorized to use this command.", { ephemeral: true });
|
||||
}
|
||||
},
|
||||
onUnauthorized: async () => {
|
||||
await respond("You are not authorized to use this command.", { ephemeral: true });
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -1403,7 +1403,7 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
guildInfo,
|
||||
memberRoleIds,
|
||||
sender,
|
||||
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
|
||||
allowNameMatching,
|
||||
});
|
||||
const authorizers = useAccessGroups
|
||||
? [
|
||||
@@ -1509,7 +1509,7 @@ async function dispatchDiscordCommandInteraction(params: {
|
||||
channelConfig,
|
||||
guildInfo,
|
||||
sender: { id: sender.id, name: sender.name, tag: sender.tag },
|
||||
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
|
||||
allowNameMatching,
|
||||
});
|
||||
const ctxPayload = finalizeInboundContext({
|
||||
Body: prompt,
|
||||
|
||||
Reference in New Issue
Block a user