diff --git a/src/agents/auth-profiles/session-override.test.ts b/src/agents/auth-profiles/session-override.test.ts new file mode 100644 index 000000000..e4c90c3a7 --- /dev/null +++ b/src/agents/auth-profiles/session-override.test.ts @@ -0,0 +1,61 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { describe, expect, it } from "vitest"; + +import type { ClawdbotConfig } from "../../config/config.js"; +import type { SessionEntry } from "../../config/sessions.js"; +import { resolveSessionAuthProfileOverride } from "./session-override.js"; + +async function writeAuthStore(agentDir: string) { + const authPath = path.join(agentDir, "auth-profiles.json"); + const payload = { + version: 1, + profiles: { + "zai:work": { type: "api_key", provider: "zai", key: "sk-test" }, + }, + order: { + zai: ["zai:work"], + }, + }; + await fs.writeFile(authPath, JSON.stringify(payload), "utf-8"); +} + +describe("resolveSessionAuthProfileOverride", () => { + it("keeps user override when provider alias differs", async () => { + const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-auth-")); + const prevStateDir = process.env.CLAWDBOT_STATE_DIR; + process.env.CLAWDBOT_STATE_DIR = tmpDir; + try { + const agentDir = path.join(tmpDir, "agent"); + await fs.mkdir(agentDir, { recursive: true }); + await writeAuthStore(agentDir); + + const sessionEntry: SessionEntry = { + sessionId: "s1", + updatedAt: Date.now(), + authProfileOverride: "zai:work", + authProfileOverrideSource: "user", + }; + const sessionStore = { "agent:main:main": sessionEntry }; + + const resolved = await resolveSessionAuthProfileOverride({ + cfg: {} as ClawdbotConfig, + provider: "z.ai", + agentDir, + sessionEntry, + sessionStore, + sessionKey: "agent:main:main", + storePath: undefined, + isNewSession: false, + }); + + expect(resolved).toBe("zai:work"); + expect(sessionEntry.authProfileOverride).toBe("zai:work"); + } finally { + if (prevStateDir === undefined) delete process.env.CLAWDBOT_STATE_DIR; + else process.env.CLAWDBOT_STATE_DIR = prevStateDir; + await fs.rm(tmpDir, { recursive: true, force: true }); + } + }); +}); diff --git a/src/agents/auth-profiles/session-override.ts b/src/agents/auth-profiles/session-override.ts new file mode 100644 index 000000000..29e6dc01b --- /dev/null +++ b/src/agents/auth-profiles/session-override.ts @@ -0,0 +1,139 @@ +import type { ClawdbotConfig } from "../../config/config.js"; +import { updateSessionStore, type SessionEntry } from "../../config/sessions.js"; +import { normalizeProviderId } from "../model-selection.js"; +import { + ensureAuthProfileStore, + isProfileInCooldown, + resolveAuthProfileOrder, +} from "../auth-profiles.js"; + +function isProfileForProvider(params: { + provider: string; + profileId: string; + store: ReturnType; +}): boolean { + const entry = params.store.profiles[params.profileId]; + if (!entry?.provider) return false; + return normalizeProviderId(entry.provider) === normalizeProviderId(params.provider); +} + +export async function clearSessionAuthProfileOverride(params: { + sessionEntry: SessionEntry; + sessionStore: Record; + sessionKey: string; + storePath?: string; +}) { + const { sessionEntry, sessionStore, sessionKey, storePath } = params; + delete sessionEntry.authProfileOverride; + delete sessionEntry.authProfileOverrideSource; + delete sessionEntry.authProfileOverrideCompactionCount; + sessionEntry.updatedAt = Date.now(); + sessionStore[sessionKey] = sessionEntry; + if (storePath) { + await updateSessionStore(storePath, (store) => { + store[sessionKey] = sessionEntry; + }); + } +} + +export async function resolveSessionAuthProfileOverride(params: { + cfg: ClawdbotConfig; + provider: string; + agentDir: string; + sessionEntry?: SessionEntry; + sessionStore?: Record; + sessionKey?: string; + storePath?: string; + isNewSession: boolean; +}): Promise { + const { + cfg, + provider, + agentDir, + sessionEntry, + sessionStore, + sessionKey, + storePath, + isNewSession, + } = params; + if (!sessionEntry || !sessionStore || !sessionKey) return sessionEntry?.authProfileOverride; + + const store = ensureAuthProfileStore(agentDir, { allowKeychainPrompt: false }); + const order = resolveAuthProfileOrder({ cfg, store, provider }); + let current = sessionEntry.authProfileOverride?.trim(); + + if (current && !store.profiles[current]) { + await clearSessionAuthProfileOverride({ sessionEntry, sessionStore, sessionKey, storePath }); + current = undefined; + } + + if (current && !isProfileForProvider({ provider, profileId: current, store })) { + await clearSessionAuthProfileOverride({ sessionEntry, sessionStore, sessionKey, storePath }); + current = undefined; + } + + if (current && order.length > 0 && !order.includes(current)) { + await clearSessionAuthProfileOverride({ sessionEntry, sessionStore, sessionKey, storePath }); + current = undefined; + } + + if (order.length === 0) return undefined; + + const pickFirstAvailable = () => + order.find((profileId) => !isProfileInCooldown(store, profileId)) ?? order[0]; + const pickNextAvailable = (active: string) => { + const startIndex = order.indexOf(active); + if (startIndex < 0) return pickFirstAvailable(); + for (let offset = 1; offset <= order.length; offset += 1) { + const candidate = order[(startIndex + offset) % order.length]; + if (!isProfileInCooldown(store, candidate)) return candidate; + } + return order[startIndex] ?? order[0]; + }; + + const compactionCount = sessionEntry.compactionCount ?? 0; + const storedCompaction = + typeof sessionEntry.authProfileOverrideCompactionCount === "number" + ? sessionEntry.authProfileOverrideCompactionCount + : compactionCount; + + const source = + sessionEntry.authProfileOverrideSource ?? + (typeof sessionEntry.authProfileOverrideCompactionCount === "number" + ? "auto" + : current + ? "user" + : undefined); + if (source === "user" && current && !isNewSession) { + return current; + } + + let next = current; + if (isNewSession) { + next = current ? pickNextAvailable(current) : pickFirstAvailable(); + } else if (current && compactionCount > storedCompaction) { + next = pickNextAvailable(current); + } else if (!current || isProfileInCooldown(store, current)) { + next = pickFirstAvailable(); + } + + if (!next) return current; + const shouldPersist = + next !== sessionEntry.authProfileOverride || + sessionEntry.authProfileOverrideSource !== "auto" || + sessionEntry.authProfileOverrideCompactionCount !== compactionCount; + if (shouldPersist) { + sessionEntry.authProfileOverride = next; + sessionEntry.authProfileOverrideSource = "auto"; + sessionEntry.authProfileOverrideCompactionCount = compactionCount; + sessionEntry.updatedAt = Date.now(); + sessionStore[sessionKey] = sessionEntry; + if (storePath) { + await updateSessionStore(storePath, (store) => { + store[sessionKey] = sessionEntry; + }); + } + } + + return next; +} diff --git a/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts b/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts index 6837c4cb1..b931230af 100644 --- a/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts +++ b/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts @@ -92,13 +92,16 @@ const makeConfig = (): ClawdbotConfig => }, }) satisfies ClawdbotConfig; -const writeAuthStore = async (agentDir: string) => { +const writeAuthStore = async (agentDir: string, opts?: { includeAnthropic?: boolean }) => { const authPath = path.join(agentDir, "auth-profiles.json"); const payload = { version: 1, profiles: { "openai:p1": { type: "api_key", provider: "openai", key: "sk-one" }, "openai:p2": { type: "api_key", provider: "openai", key: "sk-two" }, + ...(opts?.includeAnthropic + ? { "anthropic:default": { type: "api_key", provider: "anthropic", key: "sk-anth" } } + : {}), }, usageStats: { "openai:p1": { lastUsed: 1 }, @@ -206,4 +209,43 @@ describe("runEmbeddedPiAgent auth profile rotation", () => { await fs.rm(workspaceDir, { recursive: true, force: true }); } }); + + it("ignores user-locked profile when provider mismatches", async () => { + const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-")); + const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-")); + try { + await writeAuthStore(agentDir, { includeAnthropic: true }); + + runEmbeddedAttemptMock.mockResolvedValueOnce( + makeAttempt({ + assistantTexts: ["ok"], + lastAssistant: buildAssistant({ + stopReason: "stop", + content: [{ type: "text", text: "ok" }], + }), + }), + ); + + await runEmbeddedPiAgent({ + sessionId: "session:test", + sessionKey: "agent:test:mismatch", + sessionFile: path.join(workspaceDir, "session.jsonl"), + workspaceDir, + agentDir, + config: makeConfig(), + prompt: "hello", + provider: "openai", + model: "mock-1", + authProfileId: "anthropic:default", + authProfileIdSource: "user", + timeoutMs: 5_000, + runId: "run:mismatch", + }); + + expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(1); + } finally { + await fs.rm(agentDir, { recursive: true, force: true }); + await fs.rm(workspaceDir, { recursive: true, force: true }); + } + }); }); diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 45a0db943..8d63a2904 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -23,6 +23,7 @@ import { resolveAuthProfileOrder, type ResolvedProviderAuth, } from "../model-auth.js"; +import { normalizeProviderId } from "../model-selection.js"; import { ensureClawdbotModelsJson } from "../models-config.js"; import { classifyFailoverReason, @@ -116,8 +117,16 @@ export async function runEmbeddedPiAgent( const authStore = ensureAuthProfileStore(agentDir, { allowKeychainPrompt: false }); const preferredProfileId = params.authProfileId?.trim(); - const lockedProfileId = - params.authProfileIdSource === "user" ? preferredProfileId : undefined; + let lockedProfileId = params.authProfileIdSource === "user" ? preferredProfileId : undefined; + if (lockedProfileId) { + const lockedProfile = authStore.profiles[lockedProfileId]; + if ( + !lockedProfile || + normalizeProviderId(lockedProfile.provider) !== normalizeProviderId(provider) + ) { + lockedProfileId = undefined; + } + } const profileOrder = resolveAuthProfileOrder({ cfg: params.config, store: authStore, diff --git a/src/agents/tools/session-status-tool.ts b/src/agents/tools/session-status-tool.ts index dcf0decf8..f2cc6e980 100644 --- a/src/agents/tools/session-status-tool.ts +++ b/src/agents/tools/session-status-tool.ts @@ -36,6 +36,7 @@ import { DEFAULT_AGENT_ID, resolveAgentIdFromSessionKey, } from "../../routing/session-key.js"; +import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js"; import type { AnyAgentTool } from "./common.js"; import { readStringParam } from "./common.js"; import { resolveInternalSessionKey, resolveMainSessionAlias } from "./sessions-helpers.js"; @@ -240,6 +241,7 @@ export function createSessionStatusTool(opts?: { throw new Error(`Unknown sessionKey: ${requestedKeyRaw}`); } + const configured = resolveDefaultModelForAgent({ cfg, agentId }); const modelRaw = readStringParam(params, "model"); let changedModel = false; if (typeof modelRaw === "string") { @@ -249,33 +251,33 @@ export function createSessionStatusTool(opts?: { sessionEntry: resolved.entry, agentId, }); - const nextEntry: SessionEntry = { - ...resolved.entry, - updatedAt: Date.now(), - }; - if (selection.kind === "reset" || selection.isDefault) { - delete nextEntry.providerOverride; - delete nextEntry.modelOverride; - delete nextEntry.authProfileOverride; - delete nextEntry.authProfileOverrideSource; - delete nextEntry.authProfileOverrideCompactionCount; - } else { - nextEntry.providerOverride = selection.provider; - nextEntry.modelOverride = selection.model; - delete nextEntry.authProfileOverride; - delete nextEntry.authProfileOverrideSource; - delete nextEntry.authProfileOverrideCompactionCount; - } - store[resolved.key] = nextEntry; - await updateSessionStore(storePath, (nextStore) => { - nextStore[resolved.key] = nextEntry; + const nextEntry: SessionEntry = { ...resolved.entry }; + const applied = applyModelOverrideToSessionEntry({ + entry: nextEntry, + selection: + selection.kind === "reset" + ? { + provider: configured.provider, + model: configured.model, + isDefault: true, + } + : { + provider: selection.provider, + model: selection.model, + isDefault: selection.isDefault, + }, }); - resolved.entry = nextEntry; - changedModel = true; + if (applied.updated) { + store[resolved.key] = nextEntry; + await updateSessionStore(storePath, (nextStore) => { + nextStore[resolved.key] = nextEntry; + }); + resolved.entry = nextEntry; + changedModel = true; + } } const agentDir = resolveAgentDir(cfg, agentId); - const configured = resolveDefaultModelForAgent({ cfg, agentId }); const providerForCard = resolved.entry.providerOverride?.trim() || configured.provider; const usageProvider = resolveUsageProviderId(providerForCard); let usageLine: string | undefined; diff --git a/src/auto-reply/reply/directive-handling.impl.ts b/src/auto-reply/reply/directive-handling.impl.ts index 0189be238..33f19ee3f 100644 --- a/src/auto-reply/reply/directive-handling.impl.ts +++ b/src/auto-reply/reply/directive-handling.impl.ts @@ -10,6 +10,7 @@ import { type SessionEntry, updateSessionStore } from "../../config/sessions.js" import type { ExecAsk, ExecHost, ExecSecurity } from "../../infra/exec-approvals.js"; import { enqueueSystemEvent } from "../../infra/system-events.js"; import { applyVerboseOverride } from "../../sessions/level-overrides.js"; +import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js"; import { formatThinkingLevels, formatXHighModelHint, supportsXHighThinking } from "../thinking.js"; import type { ReplyPayload } from "../types.js"; import { @@ -340,22 +341,11 @@ export async function handleDirectiveOnly(params: { } } if (modelSelection) { - if (modelSelection.isDefault) { - delete sessionEntry.providerOverride; - delete sessionEntry.modelOverride; - } else { - sessionEntry.providerOverride = modelSelection.provider; - sessionEntry.modelOverride = modelSelection.model; - } - if (profileOverride) { - sessionEntry.authProfileOverride = profileOverride; - sessionEntry.authProfileOverrideSource = "user"; - delete sessionEntry.authProfileOverrideCompactionCount; - } else if (directives.hasModelDirective) { - delete sessionEntry.authProfileOverride; - delete sessionEntry.authProfileOverrideSource; - delete sessionEntry.authProfileOverrideCompactionCount; - } + applyModelOverrideToSessionEntry({ + entry: sessionEntry, + selection: modelSelection, + profileOverride, + }); } if (directives.hasQueueDirective && directives.queueReset) { delete sessionEntry.queueMode; diff --git a/src/auto-reply/reply/directive-handling.persist.ts b/src/auto-reply/reply/directive-handling.persist.ts index e8dab44d2..4733418ba 100644 --- a/src/auto-reply/reply/directive-handling.persist.ts +++ b/src/auto-reply/reply/directive-handling.persist.ts @@ -16,6 +16,7 @@ import type { ClawdbotConfig } from "../../config/config.js"; import { type SessionEntry, updateSessionStore } from "../../config/sessions.js"; import { enqueueSystemEvent } from "../../infra/system-events.js"; import { applyVerboseOverride } from "../../sessions/level-overrides.js"; +import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js"; import { resolveProfileOverride } from "./directive-handling.auth.js"; import type { InlineDirectives } from "./directive-handling.parse.js"; import { formatElevatedEvent, formatReasoningEvent } from "./directive-handling.shared.js"; @@ -164,22 +165,15 @@ export async function persistInlineDirectives(params: { } const isDefault = resolved.ref.provider === defaultProvider && resolved.ref.model === defaultModel; - if (isDefault) { - delete sessionEntry.providerOverride; - delete sessionEntry.modelOverride; - } else { - sessionEntry.providerOverride = resolved.ref.provider; - sessionEntry.modelOverride = resolved.ref.model; - } - if (profileOverride) { - sessionEntry.authProfileOverride = profileOverride; - sessionEntry.authProfileOverrideSource = "user"; - delete sessionEntry.authProfileOverrideCompactionCount; - } else if (directives.hasModelDirective) { - delete sessionEntry.authProfileOverride; - delete sessionEntry.authProfileOverrideSource; - delete sessionEntry.authProfileOverrideCompactionCount; - } + const { updated: modelUpdated } = applyModelOverrideToSessionEntry({ + entry: sessionEntry, + selection: { + provider: resolved.ref.provider, + model: resolved.ref.model, + isDefault, + }, + profileOverride, + }); provider = resolved.ref.provider; model = resolved.ref.model; const nextLabel = `${provider}/${model}`; @@ -189,7 +183,7 @@ export async function persistInlineDirectives(params: { contextKey: `model:${nextLabel}`, }); } - updated = true; + updated = updated || modelUpdated; } } } diff --git a/src/auto-reply/reply/get-reply-run.ts b/src/auto-reply/reply/get-reply-run.ts index 5c5b67c1d..aa2281de6 100644 --- a/src/auto-reply/reply/get-reply-run.ts +++ b/src/auto-reply/reply/get-reply-run.ts @@ -5,16 +5,11 @@ import { isEmbeddedPiRunStreaming, resolveEmbeddedSessionLane, } from "../../agents/pi-embedded.js"; -import { - ensureAuthProfileStore, - isProfileInCooldown, - resolveAuthProfileOrder, -} from "../../agents/auth-profiles.js"; +import { resolveSessionAuthProfileOverride } from "../../agents/auth-profiles/session-override.js"; import type { ExecToolDefaults } from "../../agents/bash-tools.js"; import type { ClawdbotConfig } from "../../config/config.js"; import { resolveSessionFilePath, - saveSessionStore, type SessionEntry, updateSessionStore, } from "../../config/sessions.js"; @@ -108,92 +103,6 @@ type RunPreparedReplyParams = { abortedLastRun: boolean; }; -async function resolveSessionAuthProfileOverride(params: { - cfg: ClawdbotConfig; - provider: string; - agentDir: string; - sessionEntry?: SessionEntry; - sessionStore?: Record; - sessionKey?: string; - storePath?: string; - isNewSession: boolean; -}): Promise { - const { - cfg, - provider, - agentDir, - sessionEntry, - sessionStore, - sessionKey, - storePath, - isNewSession, - } = params; - if (!sessionEntry || !sessionStore || !sessionKey) return sessionEntry?.authProfileOverride; - - const store = ensureAuthProfileStore(agentDir, { allowKeychainPrompt: false }); - const order = resolveAuthProfileOrder({ cfg, store, provider }); - if (order.length === 0) return sessionEntry.authProfileOverride; - - const pickFirstAvailable = () => - order.find((profileId) => !isProfileInCooldown(store, profileId)) ?? order[0]; - const pickNextAvailable = (current: string) => { - const startIndex = order.indexOf(current); - if (startIndex < 0) return pickFirstAvailable(); - for (let offset = 1; offset <= order.length; offset += 1) { - const candidate = order[(startIndex + offset) % order.length]; - if (!isProfileInCooldown(store, candidate)) return candidate; - } - return order[startIndex] ?? order[0]; - }; - - const compactionCount = sessionEntry.compactionCount ?? 0; - const storedCompaction = - typeof sessionEntry.authProfileOverrideCompactionCount === "number" - ? sessionEntry.authProfileOverrideCompactionCount - : compactionCount; - - let current = sessionEntry.authProfileOverride?.trim(); - if (current && !order.includes(current)) current = undefined; - - const source = - sessionEntry.authProfileOverrideSource ?? - (typeof sessionEntry.authProfileOverrideCompactionCount === "number" - ? "auto" - : current - ? "user" - : undefined); - if (source === "user" && current && !isNewSession) { - return current; - } - - let next = current; - if (isNewSession) { - next = current ? pickNextAvailable(current) : pickFirstAvailable(); - } else if (current && compactionCount > storedCompaction) { - next = pickNextAvailable(current); - } else if (!current || isProfileInCooldown(store, current)) { - next = pickFirstAvailable(); - } - - if (!next) return current; - const shouldPersist = - next !== sessionEntry.authProfileOverride || - sessionEntry.authProfileOverrideSource !== "auto" || - sessionEntry.authProfileOverrideCompactionCount !== compactionCount; - if (shouldPersist) { - sessionEntry.authProfileOverride = next; - sessionEntry.authProfileOverrideSource = "auto"; - sessionEntry.authProfileOverrideCompactionCount = compactionCount; - sessionEntry.updatedAt = Date.now(); - sessionStore[sessionKey] = sessionEntry; - if (storePath) { - await saveSessionStore(storePath, sessionStore); - } - } - - return next; -} - export async function runPreparedReply( params: RunPreparedReplyParams, ): Promise { diff --git a/src/auto-reply/reply/model-selection.ts b/src/auto-reply/reply/model-selection.ts index 37f742f1a..f1d9948a8 100644 --- a/src/auto-reply/reply/model-selection.ts +++ b/src/auto-reply/reply/model-selection.ts @@ -11,6 +11,8 @@ import { } from "../../agents/model-selection.js"; import type { ClawdbotConfig } from "../../config/config.js"; import { type SessionEntry, updateSessionStore } from "../../config/sessions.js"; +import { clearSessionAuthProfileOverride } from "../../agents/auth-profiles/session-override.js"; +import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js"; import type { ThinkLevel } from "./directives.js"; export type ModelDirectiveSelection = { @@ -184,16 +186,19 @@ export async function createModelSelectionState(params: { if (overrideModel) { const key = modelKey(overrideProvider, overrideModel); if (allowedModelKeys.size > 0 && !allowedModelKeys.has(key)) { - delete sessionEntry.providerOverride; - delete sessionEntry.modelOverride; - sessionEntry.updatedAt = Date.now(); - sessionStore[sessionKey] = sessionEntry; - if (storePath) { - await updateSessionStore(storePath, (store) => { - store[sessionKey] = sessionEntry; - }); + const { updated } = applyModelOverrideToSessionEntry({ + entry: sessionEntry, + selection: { provider: defaultProvider, model: defaultModel, isDefault: true }, + }); + if (updated) { + sessionStore[sessionKey] = sessionEntry; + if (storePath) { + await updateSessionStore(storePath, (store) => { + store[sessionKey] = sessionEntry; + }); + } } - resetModelOverride = true; + resetModelOverride = updated; } } } @@ -215,17 +220,14 @@ export async function createModelSelectionState(params: { allowKeychainPrompt: false, }); const profile = store.profiles[sessionEntry.authProfileOverride]; - if (!profile || profile.provider !== provider) { - delete sessionEntry.authProfileOverride; - delete sessionEntry.authProfileOverrideSource; - delete sessionEntry.authProfileOverrideCompactionCount; - sessionEntry.updatedAt = Date.now(); - sessionStore[sessionKey] = sessionEntry; - if (storePath) { - await updateSessionStore(storePath, (store) => { - store[sessionKey] = sessionEntry; - }); - } + const providerKey = normalizeProviderId(provider); + if (!profile || normalizeProviderId(profile.provider) !== providerKey) { + await clearSessionAuthProfileOverride({ + sessionEntry, + sessionStore, + sessionKey, + storePath, + }); } } diff --git a/src/auto-reply/reply/session-reset-model.test.ts b/src/auto-reply/reply/session-reset-model.test.ts index 07aeaf2f9..db840038c 100644 --- a/src/auto-reply/reply/session-reset-model.test.ts +++ b/src/auto-reply/reply/session-reset-model.test.ts @@ -42,6 +42,39 @@ describe("applyResetModelOverride", () => { expect(sessionCtx.BodyStripped).toBe("summarize"); }); + it("clears auth profile overrides when reset applies a model", async () => { + const cfg = {} as ClawdbotConfig; + const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); + const sessionEntry = { + sessionId: "s1", + updatedAt: Date.now(), + authProfileOverride: "anthropic:default", + authProfileOverrideSource: "user", + authProfileOverrideCompactionCount: 2, + }; + const sessionStore = { "agent:main:dm:1": sessionEntry }; + const sessionCtx = { BodyStripped: "minimax summarize" }; + const ctx = { ChatType: "direct" }; + + await applyResetModelOverride({ + cfg, + resetTriggered: true, + bodyStripped: "minimax summarize", + sessionCtx, + ctx, + sessionEntry, + sessionStore, + sessionKey: "agent:main:dm:1", + defaultProvider: "openai", + defaultModel: "gpt-4o-mini", + aliasIndex, + }); + + expect(sessionEntry.authProfileOverride).toBeUndefined(); + expect(sessionEntry.authProfileOverrideSource).toBeUndefined(); + expect(sessionEntry.authProfileOverrideCompactionCount).toBeUndefined(); + }); + it("skips when resetTriggered is false", async () => { const cfg = {} as ClawdbotConfig; const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); diff --git a/src/auto-reply/reply/session-reset-model.ts b/src/auto-reply/reply/session-reset-model.ts index d14982920..9250faadd 100644 --- a/src/auto-reply/reply/session-reset-model.ts +++ b/src/auto-reply/reply/session-reset-model.ts @@ -12,6 +12,7 @@ import { updateSessionStore } from "../../config/sessions.js"; import type { MsgContext, TemplateContext } from "../templating.js"; import { formatInboundBodyWithSenderMeta } from "./inbound-sender-meta.js"; import { resolveModelDirectiveSelection, type ModelDirectiveSelection } from "./model-selection.js"; +import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js"; type ResetModelResult = { selection?: ModelDirectiveSelection; @@ -62,25 +63,11 @@ function applySelectionToSession(params: { }) { const { selection, sessionEntry, sessionStore, sessionKey, storePath } = params; if (!sessionEntry || !sessionStore || !sessionKey) return; - let updated = false; - if (selection.isDefault) { - if (sessionEntry.providerOverride || sessionEntry.modelOverride) { - delete sessionEntry.providerOverride; - delete sessionEntry.modelOverride; - updated = true; - } - } else { - if (sessionEntry.providerOverride !== selection.provider) { - sessionEntry.providerOverride = selection.provider; - updated = true; - } - if (sessionEntry.modelOverride !== selection.model) { - sessionEntry.modelOverride = selection.model; - updated = true; - } - } + const { updated } = applyModelOverrideToSessionEntry({ + entry: sessionEntry, + selection, + }); if (!updated) return; - sessionEntry.updatedAt = Date.now(); sessionStore[sessionKey] = sessionEntry; if (storePath) { updateSessionStore(storePath, (store) => { diff --git a/src/commands/agent.ts b/src/commands/agent.ts index bc7f79941..9033759f4 100644 --- a/src/commands/agent.ts +++ b/src/commands/agent.ts @@ -50,6 +50,8 @@ import { defaultRuntime, type RuntimeEnv } from "../runtime.js"; import { formatCliCommand } from "../cli/command-format.js"; import { applyVerboseOverride } from "../sessions/level-overrides.js"; import { resolveSendPolicy } from "../sessions/send-policy.js"; +import { applyModelOverrideToSessionEntry } from "../sessions/model-overrides.js"; +import { clearSessionAuthProfileOverride } from "../agents/auth-profiles/session-override.js"; import { resolveMessageChannel } from "../utils/message-channel.js"; import { deliverAgentCommandResult } from "./agent/delivery.js"; import { resolveAgentRunContext } from "./agent/run-context.js"; @@ -283,13 +285,16 @@ export async function agentCommand( allowedModelKeys.size > 0 && !allowedModelKeys.has(key) ) { - delete entry.providerOverride; - delete entry.modelOverride; - entry.updatedAt = Date.now(); - sessionStore[sessionKey] = entry; - await updateSessionStore(storePath, (store) => { - store[sessionKey] = entry; + const { updated } = applyModelOverrideToSessionEntry({ + entry, + selection: { provider: defaultProvider, model: defaultModel, isDefault: true }, }); + if (updated) { + sessionStore[sessionKey] = entry; + await updateSessionStore(storePath, (store) => { + store[sessionKey] = entry; + }); + } } } } @@ -315,14 +320,12 @@ export async function agentCommand( const store = ensureAuthProfileStore(); const profile = store.profiles[authProfileId]; if (!profile || profile.provider !== provider) { - delete entry.authProfileOverride; - delete entry.authProfileOverrideSource; - delete entry.authProfileOverrideCompactionCount; - entry.updatedAt = Date.now(); if (sessionStore && sessionKey) { - sessionStore[sessionKey] = entry; - await updateSessionStore(storePath, (store) => { - store[sessionKey] = entry; + await clearSessionAuthProfileOverride({ + sessionEntry: entry, + sessionStore, + sessionKey, + storePath, }); } } diff --git a/src/gateway/sessions-patch.test.ts b/src/gateway/sessions-patch.test.ts index cd330fa0a..96ba0bf75 100644 --- a/src/gateway/sessions-patch.test.ts +++ b/src/gateway/sessions-patch.test.ts @@ -57,4 +57,32 @@ describe("gateway sessions patch", () => { if (res.ok) return; expect(res.error.message).toContain("invalid elevatedLevel"); }); + + test("clears auth overrides when model patch changes", async () => { + const store: Record = { + "agent:main:main": { + sessionId: "sess", + updatedAt: 1, + providerOverride: "anthropic", + modelOverride: "claude-opus-4-5", + authProfileOverride: "anthropic:default", + authProfileOverrideSource: "user", + authProfileOverrideCompactionCount: 3, + } as SessionEntry, + }; + const res = await applySessionsPatchToStore({ + cfg: {} as ClawdbotConfig, + store, + storeKey: "agent:main:main", + patch: { model: "openai/gpt-5.2" }, + loadGatewayModelCatalog: async () => [{ provider: "openai", id: "gpt-5.2" }], + }); + expect(res.ok).toBe(true); + if (!res.ok) return; + expect(res.entry.providerOverride).toBe("openai"); + expect(res.entry.modelOverride).toBe("gpt-5.2"); + expect(res.entry.authProfileOverride).toBeUndefined(); + expect(res.entry.authProfileOverrideSource).toBeUndefined(); + expect(res.entry.authProfileOverrideCompactionCount).toBeUndefined(); + }); }); diff --git a/src/gateway/sessions-patch.ts b/src/gateway/sessions-patch.ts index e7d7780cd..1a3736971 100644 --- a/src/gateway/sessions-patch.ts +++ b/src/gateway/sessions-patch.ts @@ -19,6 +19,7 @@ import { isSubagentSessionKey } from "../routing/session-key.js"; import { applyVerboseOverride, parseVerboseOverride } from "../sessions/level-overrides.js"; import { normalizeSendPolicy } from "../sessions/send-policy.js"; import { parseSessionLabel } from "../sessions/session-label.js"; +import { applyModelOverrideToSessionEntry } from "../sessions/model-overrides.js"; import { ErrorCodes, type ErrorShape, @@ -220,18 +221,23 @@ export async function applySessionsPatchToStore(params: { if ("model" in patch) { const raw = patch.model; + const resolvedDefault = resolveConfiguredModelRef({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + defaultModel: DEFAULT_MODEL, + }); if (raw === null) { - delete next.providerOverride; - delete next.modelOverride; + applyModelOverrideToSessionEntry({ + entry: next, + selection: { + provider: resolvedDefault.provider, + model: resolvedDefault.model, + isDefault: true, + }, + }); } else if (raw !== undefined) { const trimmed = String(raw).trim(); if (!trimmed) return invalid("invalid model: empty"); - - const resolvedDefault = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); if (!params.loadGatewayModelCatalog) { return { ok: false, @@ -249,16 +255,17 @@ export async function applySessionsPatchToStore(params: { if ("error" in resolved) { return invalid(resolved.error); } - if ( + const isDefault = resolved.ref.provider === resolvedDefault.provider && - resolved.ref.model === resolvedDefault.model - ) { - delete next.providerOverride; - delete next.modelOverride; - } else { - next.providerOverride = resolved.ref.provider; - next.modelOverride = resolved.ref.model; - } + resolved.ref.model === resolvedDefault.model; + applyModelOverrideToSessionEntry({ + entry: next, + selection: { + provider: resolved.ref.provider, + model: resolved.ref.model, + isDefault, + }, + }); } } diff --git a/src/sessions/model-overrides.ts b/src/sessions/model-overrides.ts new file mode 100644 index 000000000..97795d98d --- /dev/null +++ b/src/sessions/model-overrides.ts @@ -0,0 +1,72 @@ +import type { SessionEntry } from "../config/sessions.js"; + +export type ModelOverrideSelection = { + provider: string; + model: string; + isDefault?: boolean; +}; + +export function applyModelOverrideToSessionEntry(params: { + entry: SessionEntry; + selection: ModelOverrideSelection; + profileOverride?: string; + profileOverrideSource?: "auto" | "user"; +}): { updated: boolean } { + const { entry, selection, profileOverride } = params; + const profileOverrideSource = params.profileOverrideSource ?? "user"; + let updated = false; + + if (selection.isDefault) { + if (entry.providerOverride) { + delete entry.providerOverride; + updated = true; + } + if (entry.modelOverride) { + delete entry.modelOverride; + updated = true; + } + } else { + if (entry.providerOverride !== selection.provider) { + entry.providerOverride = selection.provider; + updated = true; + } + if (entry.modelOverride !== selection.model) { + entry.modelOverride = selection.model; + updated = true; + } + } + + if (profileOverride) { + if (entry.authProfileOverride !== profileOverride) { + entry.authProfileOverride = profileOverride; + updated = true; + } + if (entry.authProfileOverrideSource !== profileOverrideSource) { + entry.authProfileOverrideSource = profileOverrideSource; + updated = true; + } + if (entry.authProfileOverrideCompactionCount !== undefined) { + delete entry.authProfileOverrideCompactionCount; + updated = true; + } + } else { + if (entry.authProfileOverride) { + delete entry.authProfileOverride; + updated = true; + } + if (entry.authProfileOverrideSource) { + delete entry.authProfileOverrideSource; + updated = true; + } + if (entry.authProfileOverrideCompactionCount !== undefined) { + delete entry.authProfileOverrideCompactionCount; + updated = true; + } + } + + if (updated) { + entry.updatedAt = Date.now(); + } + + return { updated }; +}