diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 0aa3eefd..8c898a14 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -604,6 +604,7 @@ func provideChannelRouter( emailOutboxService, heartbeatService, queries, + aclService, &commandSkillLoaderAdapter{handler: containerdHandler}, &commandContainerFSAdapter{manager: manager}, )) @@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID, return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil } +func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) { + sess, err := a.svc.GetActiveForRoute(ctx, routeID) + if err != nil { + return inbound.SessionResult{}, err + } + return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil +} + func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) if err != nil { diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 3a36cc41..75a608ac 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc emailOutboxService, heartbeatService, queries, + aclService, &commandSkillLoaderAdapter{handler: containerdHandler}, &commandContainerFSAdapter{manager: manager}, )) @@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID, return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil } +func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) { + sess, err := a.svc.GetActiveForRoute(ctx, routeID) + if err != nil { + return inbound.SessionResult{}, err + } + return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil +} + func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) if err != nil { diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index 3c20b8b0..ced6c869 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -71,6 +71,7 @@ type ttsModelResolver interface { // SessionEnsurer resolves or creates an active session for a route. type SessionEnsurer interface { EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error) + GetActiveSession(ctx context.Context, routeID string) (SessionResult, error) // CreateNewSession always creates a fresh session and sets it as the // active session for the given route, replacing any previous one. // sessionType defaults to "chat" if empty. @@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if isStopCommand(cmdText) && isDirectedAtBot(msg) { return p.handleStopCommand(ctx, cfg, msg, sender, identity) } + if isStatusCommand(cmdText) && isDirectedAtBot(msg) { + return p.handleStatusCommand(ctx, cfg, msg, sender, identity) + } // Skip generic command handler for mode-prefix commands (/btw, /now, /next) // so they pass through to mode detection below. if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) { - reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText) + reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{ + BotID: strings.TrimSpace(identity.BotID), + ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + UserID: strings.TrimSpace(identity.UserID), + Text: cmdText, + ChannelType: msg.Channel.String(), + ConversationType: strings.TrimSpace(msg.Conversation.Type), + ConversationID: strings.TrimSpace(msg.Conversation.ID), + ThreadID: extractThreadID(msg), + }) if err != nil { reply = "Error: " + err.Error() } @@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool { return parsed.Resource == "new" } +func isStatusCommand(cmdText string) bool { + extracted := command.ExtractCommandText(cmdText) + if extracted == "" { + return false + } + parsed, err := command.Parse(extracted) + if err != nil { + return false + } + return parsed.Resource == "status" +} + // resolveNewSessionType determines the session type for /new command. // /new chat → chat, /new discuss → discuss, /new (no arg) → default by context. // WebUI (local channel) always defaults to chat. @@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand( Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)}, }) } + +func (p *ChannelInboundProcessor) handleStatusCommand( + ctx context.Context, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + sender channel.StreamReplySender, + identity InboundIdentity, +) error { + target := strings.TrimSpace(msg.ReplyTarget) + if target == "" { + return errors.New("reply target missing for /status command") + } + if p.routeResolver == nil { + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: route resolver not configured."}, + }) + } + if p.commandHandler == nil { + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: command handler not configured."}, + }) + } + + threadID := extractThreadID(msg) + routeMetadata := buildRouteMetadata(msg, identity) + p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata) + resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ + BotID: identity.BotID, + Platform: msg.Channel.String(), + ConversationID: msg.Conversation.ID, + ThreadID: threadID, + ConversationType: msg.Conversation.Type, + ChannelIdentityID: identity.UserID, + ChannelConfigID: identity.ChannelConfigID, + ReplyTarget: target, + Metadata: routeMetadata, + }) + if err != nil { + if p.logger != nil { + p.logger.Warn("resolve route for /status command failed", slog.Any("error", err)) + } + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: failed to resolve conversation route."}, + }) + } + + sessionID := "" + if p.sessionEnsurer != nil { + sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID) + if sessErr == nil { + sessionID = strings.TrimSpace(sess.ID) + } else if p.logger != nil { + p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr)) + } + } + + reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{ + BotID: strings.TrimSpace(identity.BotID), + ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + UserID: strings.TrimSpace(identity.UserID), + Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())), + ChannelType: msg.Channel.String(), + ConversationType: strings.TrimSpace(msg.Conversation.Type), + ConversationID: strings.TrimSpace(msg.Conversation.ID), + ThreadID: threadID, + RouteID: strings.TrimSpace(resolved.RouteID), + SessionID: sessionID, + }) + if execErr != nil { + reply = "Error: " + execErr.Error() + } + + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: reply}, + }) +} diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index fe00e9fc..9cbb1ae5 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -13,11 +13,15 @@ import ( "strings" "testing" + "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/acl" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/command" "github.com/memohai/memoh/internal/conversation" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/media" messagepkg "github.com/memohai/memoh/internal/message" "github.com/memohai/memoh/internal/schedule" @@ -185,6 +189,71 @@ type fakeChatACL struct { lastReq acl.EvaluateRequest } +type fakeSessionEnsurer struct { + activeSession SessionResult + activeErr error + lastRouteID string +} + +func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +type fakeCommandQueries struct { + messageCount int64 + usage int64 + cacheRow dbsqlc.GetSessionCacheStatsRow + skills []string +} + +func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) { + return pgtype.UUID{}, errors.New("unexpected latest session lookup") +} + +func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.messageCount, nil +} + +func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.usage, nil +} + +func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) { + return f.cacheRow, nil +} + +func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) { + return f.skills, nil +} + +func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) { + return nil, nil +} + +func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) { + return nil, nil +} + func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) { f.calls++ f.lastReq = req @@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { } } +func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}} + policySvc := &fakePolicyService{} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}} + gateway := &fakeChatGateway{} + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0) + processor.SetSessionEnsurer(&fakeSessionEnsurer{ + activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"}, + }) + processor.SetCommandHandler(command.NewHandler( + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + &fakeCommandQueries{ + messageCount: 9, + usage: 512, + cacheRow: dbsqlc.GetSessionCacheStatsRow{ + CacheReadTokens: 64, + CacheWriteTokens: 32, + TotalInputTokens: 512, + }, + skills: []string{"search"}, + }, + nil, + nil, + nil, + )) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("discord"), + Message: channel.Message{Text: "/status"}, + ReplyTarget: "discord:status", + Sender: channel.Identity{SubjectID: "user-1"}, + Conversation: channel.Conversation{ + ID: "conv-status", + Type: channel.ConversationTypePrivate, + }, + } + + if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one status reply, got %d", len(sender.sent)) + } + if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") { + t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text) + } + if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") { + t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text) + } +} + func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) { t.Parallel() diff --git a/internal/command/access.go b/internal/command/access.go new file mode 100644 index 00000000..5dbd764c --- /dev/null +++ b/internal/command/access.go @@ -0,0 +1,78 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/memohai/memoh/internal/acl" +) + +func (h *Handler) buildAccessGroup() *CommandGroup { + g := newCommandGroup("access", "Inspect identity and permission context") + g.DefaultAction = "show" + g.Register(SubCommand{ + Name: "show", + Usage: "show - Show current identity, write access, and chat ACL context", + Handler: func(cc CommandContext) (string, error) { + writeAccess := "no" + if cc.Role == "owner" { + writeAccess = "yes" + } + + pairs := []kv{ + {"Channel Identity", fallbackValue(cc.ChannelIdentityID)}, + {"Linked User", fallbackValue(cc.UserID)}, + {"Bot Role", fallbackValue(cc.Role)}, + {"Write Commands", writeAccess}, + {"Channel", fallbackValue(cc.ChannelType)}, + {"Conversation Type", fallbackValue(cc.ConversationType)}, + {"Conversation ID", fallbackValue(cc.ConversationID)}, + {"Thread ID", fallbackValue(cc.ThreadID)}, + } + if strings.TrimSpace(cc.RouteID) != "" { + pairs = append(pairs, kv{"Route ID", cc.RouteID}) + } + if strings.TrimSpace(cc.SessionID) != "" { + pairs = append(pairs, kv{"Session ID", cc.SessionID}) + } + + aclStatus := "unavailable" + if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" { + allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{ + BotID: cc.BotID, + ChannelIdentityID: cc.ChannelIdentityID, + ChannelType: cc.ChannelType, + SourceScope: acl.SourceScope{ + ConversationType: cc.ConversationType, + ConversationID: cc.ConversationID, + ThreadID: cc.ThreadID, + }, + }) + switch { + case err != nil: + aclStatus = "error: " + err.Error() + case allowed: + aclStatus = "allow" + default: + aclStatus = "deny" + } + } + pairs = append(pairs, kv{"Chat ACL", aclStatus}) + + return formatKV(pairs), nil + }, + }) + return g +} + +func fallbackValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "(none)" + } + return value +} + +func formatChangedValue(label, before, after string) string { + return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after)) +} diff --git a/internal/command/browser.go b/internal/command/browser.go index d5de3a4d..bc4b77ac 100644 --- a/internal/command/browser.go +++ b/internal/command/browser.go @@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { Name: "list", Usage: "list - List all browser contexts", Handler: func(cc CommandContext) (string, error) { + if h.browserCtxService == nil { + return "Browser context service is not available.", nil + } items, err := h.browserCtxService.List(cc.Ctx) if err != nil { return "", err @@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if len(items) == 0 { return "No browser contexts found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { - records = append(records, []kv{ - {"Name", item.Name}, - }) + label := item.Name + record := []kv{{"Name", label}} + if item.ID == settingsResp.BrowserContextID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current browser context", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil }, }) g.Register(SubCommand{ @@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /browser set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.browserCtxService.List(cc.Ctx) if err != nil { return "", err @@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Browser context set to %q.", item.Name), nil + return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil } } return fmt.Sprintf("Browser context %q not found.", name), nil diff --git a/internal/command/commands.go b/internal/command/commands.go index b40b3513..919f968a 100644 --- a/internal/command/commands.go +++ b/internal/command/commands.go @@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry { r.RegisterGroup(h.buildSkillGroup()) r.RegisterGroup(h.buildFSGroup()) r.RegisterGroup(h.buildStatusGroup()) + r.RegisterGroup(h.buildAccessGroup()) return r } diff --git a/internal/command/email_cmd.go b/internal/command/email_cmd.go index 951f70c2..43fddbba 100644 --- a/internal/command/email_cmd.go +++ b/internal/command/email_cmd.go @@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Provider", item.Provider}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil }, }) g.Register(SubCommand{ @@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Permissions", perms}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil }, }) g.Register(SubCommand{ @@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Sent", item.SentAt.Format("01-02 15:04")}, }) } - return formatItems(records), nil + return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil }, }) return g diff --git a/internal/command/formatter.go b/internal/command/formatter.go index 0b03996f..1b626ce9 100644 --- a/internal/command/formatter.go +++ b/internal/command/formatter.go @@ -6,19 +6,10 @@ import ( "unicode/utf8" ) +const defaultListLimit = 12 + // formatItems renders a list of records as a Markdown-style list. -// Each record is a slice of kv pairs; the first pair's value is used as the -// bullet title, and subsequent pairs are indented beneath it. -// -// Example output: -// -// - mybot -// Description: A helpful assistant -// ID: abc123 -// -// - another -// Description: Something else -// ID: def456 +// Each record is rendered on a single line so long lists stay readable in IM. func formatItems(items [][]kv) string { if len(items) == 0 { return "" @@ -31,14 +22,41 @@ func formatItems(items [][]kv) string { if i > 0 { b.WriteByte('\n') } - fmt.Fprintf(&b, "- %s\n", record[0].value) + fmt.Fprintf(&b, "- %s", record[0].value) + extras := make([]string, 0, len(record)-1) for _, pair := range record[1:] { - fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value) + if strings.TrimSpace(pair.value) == "" { + continue + } + extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value)) + } + if len(extras) > 0 { + fmt.Fprintf(&b, " | %s", strings.Join(extras, " | ")) } } return b.String() } +func formatLimitedItems(items [][]kv, limit int, hint string) string { + if len(items) == 0 { + return "" + } + if limit <= 0 { + limit = defaultListLimit + } + total := len(items) + if total <= limit { + return formatItems(items) + } + shown := items[:limit] + result := formatItems(shown) + suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total) + if strings.TrimSpace(hint) != "" { + suffix += " " + strings.TrimSpace(hint) + } + return result + "\n\n" + suffix +} + // formatKV renders key-value pairs as a simple Markdown list. // // Example output: diff --git a/internal/command/handler.go b/internal/command/handler.go index ecc510be..810d2ab1 100644 --- a/internal/command/handler.go +++ b/internal/command/handler.go @@ -4,10 +4,10 @@ import ( "context" "fmt" "log/slog" + "strings" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/browsercontexts" - dbsqlc "github.com/memohai/memoh/internal/db/sqlc" emailpkg "github.com/memohai/memoh/internal/email" "github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/mcp" @@ -56,13 +56,28 @@ type Handler struct { emailService *emailpkg.Service emailOutboxService *emailpkg.OutboxService heartbeatService *heartbeat.Service - queries *dbsqlc.Queries + queries CommandQueries + aclEvaluator AccessEvaluator skillLoader SkillLoader containerFS ContainerFS logger *slog.Logger } +// ExecuteInput carries the caller identity and channel context for command execution. +type ExecuteInput struct { + BotID string + ChannelIdentityID string + UserID string + Text string + ChannelType string + ConversationType string + ConversationID string + ThreadID string + RouteID string + SessionID string +} + // NewHandler creates a Handler with all required services. func NewHandler( log *slog.Logger, @@ -78,7 +93,8 @@ func NewHandler( emailService *emailpkg.Service, emailOutboxService *emailpkg.OutboxService, heartbeatService *heartbeat.Service, - queries *dbsqlc.Queries, + queries CommandQueries, + aclEvaluator AccessEvaluator, skillLoader SkillLoader, containerFS ContainerFS, ) *Handler { @@ -99,6 +115,7 @@ func NewHandler( emailOutboxService: emailOutboxService, heartbeatService: heartbeatService, queries: queries, + aclEvaluator: aclEvaluator, skillLoader: skillLoader, containerFS: containerFS, logger: log.With(slog.String("component", "command")), @@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool { // Execute parses and runs a slash command, returning the text reply. func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) { - cmdText := ExtractCommandText(text) + return h.ExecuteWithInput(ctx, ExecuteInput{ + BotID: botID, + ChannelIdentityID: channelIdentityID, + Text: text, + }) +} + +// ExecuteWithInput parses and runs a slash command with channel/session context. +func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) { + cmdText := ExtractCommandText(input.Text) if cmdText == "" { return h.registry.GlobalHelp(), nil } @@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st // Resolve the user's role in this bot. role := "" - if h.roleResolver != nil && channelIdentityID != "" { - r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID) + roleIdentityID := input.ChannelIdentityID + if strings.TrimSpace(input.UserID) != "" { + roleIdentityID = strings.TrimSpace(input.UserID) + } + if h.roleResolver != nil && roleIdentityID != "" { + r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID) if err != nil { h.logger.Warn("failed to resolve member role", - slog.String("bot_id", botID), - slog.String("channel_identity_id", channelIdentityID), + slog.String("bot_id", input.BotID), + slog.String("role_identity_id", roleIdentityID), slog.Any("error", err), ) } else { @@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st } cc := CommandContext{ - Ctx: ctx, - BotID: botID, - Role: role, - Args: parsed.Args, + Ctx: ctx, + BotID: input.BotID, + Role: role, + Args: parsed.Args, + ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID), + UserID: strings.TrimSpace(input.UserID), + ChannelType: strings.TrimSpace(input.ChannelType), + ConversationType: strings.TrimSpace(input.ConversationType), + ConversationID: strings.TrimSpace(input.ConversationID), + ThreadID: strings.TrimSpace(input.ThreadID), + RouteID: strings.TrimSpace(input.RouteID), + SessionID: strings.TrimSpace(input.SessionID), } // /help if parsed.Resource == "help" { - return h.registry.GlobalHelp(), nil + if parsed.Action == "" { + return h.registry.GlobalHelp(), nil + } + if len(parsed.Args) == 0 { + return h.registry.GroupHelp(parsed.Action), nil + } + return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil } // Top-level commands (e.g. /new) are handled by the channel inbound diff --git a/internal/command/handler_test.go b/internal/command/handler_test.go index f8205c4f..8cb79c73 100644 --- a/internal/command/handler_test.go +++ b/internal/command/handler_test.go @@ -5,6 +5,10 @@ import ( "strings" "testing" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/schedule" "github.com/memohai/memoh/internal/settings" @@ -25,9 +29,58 @@ type fakeScheduleService struct { items []schedule.Schedule } +type fakeCommandQueries struct { + latestSessionID pgtype.UUID + latestSessionErr error + messageCount int64 + latestUsage int64 + latestUsageErr error + cacheRow dbsqlc.GetSessionCacheStatsRow + cacheErr error + skills []string +} + +func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) { + return f.latestSessionID, f.latestSessionErr +} + +func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.messageCount, nil +} + +func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) { + if f.latestUsageErr != nil { + return 0, f.latestUsageErr + } + return f.latestUsage, nil +} + +func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) { + if f.cacheErr != nil { + return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr + } + return f.cacheRow, nil +} + +func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) { + return f.skills, nil +} + +func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) { + return nil, nil +} + +func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) { + return nil, nil +} + // newTestHandler creates a Handler with nil services for use in tests. func newTestHandler(roleResolver MemberRoleResolver) *Handler { - return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) +} + +func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler { + return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil) } // --- tests --- @@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) { if !strings.Contains(result, "Available commands") { t.Errorf("expected help text, got: %s", result) } + if strings.Contains(result, "set-heartbeat") { + t.Errorf("top-level help should not expand nested actions, got: %s", result) + } + if !strings.Contains(result, "- /model - Manage bot models") { + t.Errorf("expected top-level model entry, got: %s", result) + } +} + +func TestExecute_HelpGroup(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "/model - Manage bot models") { + t.Errorf("expected group help, got: %s", result) + } + if !strings.Contains(result, "- set - Set the chat model [owner]") { + t.Errorf("expected compact action summary, got: %s", result) + } +} + +func TestExecute_HelpAction(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "Usage: /model set | ") { + t.Errorf("expected action usage, got: %s", result) + } + if !strings.Contains(result, "Access: owner only") { + t.Errorf("expected owner hint, got: %s", result) + } } func TestExecute_UnknownCommand(t *testing.T) { @@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) { if !strings.Contains(result, "- foo") { t.Errorf("expected '- foo' bullet, got: %s", result) } - if !strings.Contains(result, " Type: bar") { - t.Errorf("expected indented 'Type: bar', got: %s", result) + if !strings.Contains(result, "- foo | Type: bar") { + t.Errorf("expected compact line entry, got: %s", result) } if !strings.Contains(result, "- longname") { t.Errorf("expected '- longname' bullet, got: %s", result) @@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) { for _, group := range []string{ "schedule", "mcp", "settings", "model", "memory", "search", "browser", "usage", - "email", "heartbeat", "skill", "fs", + "email", "heartbeat", "skill", "fs", "access", } { if !strings.Contains(help, "/"+group) { t.Errorf("missing /%s in global help", group) @@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) { } } +func TestExecuteWithInput_Access(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{ + BotID: "bot-1", + ChannelIdentityID: "channel-id-1", + UserID: "user-id-1", + Text: "/access", + ChannelType: "discord", + ConversationType: "thread", + ConversationID: "conv-1", + ThreadID: "thread-1", + RouteID: "route-1", + SessionID: "session-1", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "- Channel Identity: channel-id-1") { + t.Errorf("expected channel identity in access output, got: %s", result) + } + if !strings.Contains(result, "- Write Commands: yes") { + t.Errorf("expected write access in access output, got: %s", result) + } +} + +func TestExecute_StatusLatest(t *testing.T) { + t.Parallel() + sessionUUID := pgtype.UUID{} + copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + sessionUUID.Valid = true + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{ + latestSessionID: sessionUUID, + messageCount: 42, + latestUsage: 1200, + cacheRow: dbsqlc.GetSessionCacheStatsRow{ + CacheReadTokens: 300, + CacheWriteTokens: 150, + TotalInputTokens: 1200, + }, + skills: []string{"search", "browser"}, + }) + result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "- Scope: latest bot session") { + t.Errorf("expected latest scope, got: %s", result) + } + if !strings.Contains(result, "- Messages: 42") { + t.Errorf("expected message count, got: %s", result) + } +} + +func TestExecute_StatusLatestNoRows(t *testing.T) { + t.Parallel() + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{ + latestSessionErr: pgx.ErrNoRows, + }) + result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "No session found for this bot.") { + t.Errorf("expected no session message, got: %s", result) + } +} + +func TestExecute_StatusShowWithoutSession(t *testing.T) { + t.Parallel() + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "No active session found for this conversation.") { + t.Errorf("expected route-aware no session message, got: %s", result) + } +} + // Verify write commands are tagged with [owner] in usage. func TestUsage_OwnerTag(t *testing.T) { t.Parallel() diff --git a/internal/command/heartbeat_cmd.go b/internal/command/heartbeat_cmd.go index e57cf7b7..c49ce2bd 100644 --- a/internal/command/heartbeat_cmd.go +++ b/internal/command/heartbeat_cmd.go @@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup { } records = append(records, rec) } - return formatItems(records), nil + return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil }, }) return g diff --git a/internal/command/interfaces.go b/internal/command/interfaces.go index 3f3db905..bb2e0395 100644 --- a/internal/command/interfaces.go +++ b/internal/command/interfaces.go @@ -1,6 +1,13 @@ package command -import "context" +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/acl" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" +) // Skill represents a single skill loaded from a bot's container. type Skill struct { @@ -25,3 +32,20 @@ type ContainerFS interface { ListDir(ctx context.Context, botID, path string) ([]FSEntry, error) ReadFile(ctx context.Context, botID, path string) (string, error) } + +// CommandQueries captures the sqlc methods used by slash commands. +// *dbsqlc.Queries satisfies this interface directly. +type CommandQueries interface { + GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error) + CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error) + GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error) + GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) + GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error) + GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) + GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) +} + +// AccessEvaluator checks whether the current channel context may trigger chat. +type AccessEvaluator interface { + Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error) +} diff --git a/internal/command/mcp.go b/internal/command/mcp.go index dad32c40..40d56a03 100644 --- a/internal/command/mcp.go +++ b/internal/command/mcp.go @@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup { {"Status", item.Status}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /mcp get for full details."), nil }, }) g.Register(SubCommand{ diff --git a/internal/command/memory.go b/internal/command/memory.go index fce3d0a1..4965b685 100644 --- a/internal/command/memory.go +++ b/internal/command/memory.go @@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { Name: "list", Usage: "list - List all memory providers", Handler: func(cc CommandContext) (string, error) { + if h.memProvService == nil { + return "Memory provider service is not available.", nil + } items, err := h.memProvService.List(cc.Ctx) if err != nil { return "", err @@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if len(items) == 0 { return "No memory providers found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { def := "" if item.IsDefault { def = " (default)" } - records = append(records, []kv{ - {"Name", item.Name + def}, + label := item.Name + def + record := []kv{ + {"Name", label}, {"Provider", item.Provider}, - }) + } + if item.ID == settingsResp.MemoryProviderID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current memory provider", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil }, }) g.Register(SubCommand{ @@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /memory set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.memProvService.List(cc.Ctx) if err != nil { return "", err @@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Memory provider set to %q.", item.Name), nil + return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil } } return fmt.Sprintf("Memory provider %q not found.", name), nil diff --git a/internal/command/model.go b/internal/command/model.go index 53bcc644..856ed721 100644 --- a/internal/command/model.go +++ b/internal/command/model.go @@ -1,7 +1,9 @@ package command import ( + "errors" "fmt" + "sort" "strings" "github.com/memohai/memoh/internal/models" @@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup { g := newCommandGroup("model", "Manage bot models") g.Register(SubCommand{ Name: "list", - Usage: "list - List all available chat models", + Usage: "list [provider_name] - List available chat models", Handler: func(cc CommandContext) (string, error) { + if h.modelsService == nil { + return "Model service is not available.", nil + } items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat) if err != nil { return "", err } + filterProvider := "" + if len(cc.Args) > 0 { + filterProvider = strings.TrimSpace(strings.Join(cc.Args, " ")) + } + items = h.filterModelsByProvider(cc, items, filterProvider) if len(items) == 0 { + if filterProvider != "" { + return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil + } return "No chat models found.", nil } + settingsResp, _ := h.getBotSettings(cc) + sort.SliceStable(items, func(i, j int) bool { + return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp) + }) records := make([][]kv, 0, len(items)) for _, item := range items { provName := h.resolveProviderName(cc, item.ProviderID) + label := item.Name + markers := modelMarkers(item.ID, settingsResp) + if len(markers) > 0 { + label += " [" + strings.Join(markers, ", ") + "]" + } records = append(records, []kv{ - {"Model", item.Name}, + {"Model", label}, {"Provider", provName}, {"Model ID", item.ModelID}, }) } - return formatItems(records), nil + hint := "Use /model current to inspect active selections." + if filterProvider == "" { + hint = "Use /model list to narrow results." + } + return formatLimitedItems(records, defaultListLimit, hint), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show current chat and heartbeat models", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{ + {"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)}, + {"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)}, + }), nil }, }) g.Register(SubCommand{ Name: "set", - Usage: "set - Set the chat model", + Usage: "set | - Set the chat model", IsWrite: true, Handler: func(cc CommandContext) (string, error) { - if len(cc.Args) < 2 { - return "Usage: /model set ", nil + if len(cc.Args) < 1 { + return "Usage: /model set | ", nil } - modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) + if h.settingsService == nil { + return "Settings service is not available.", nil + } + before, _ := h.getBotSettings(cc) + modelResp, err := h.findModelForSelection(cc, cc.Args) if err != nil { return "", err } @@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil + return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil }, }) g.Register(SubCommand{ Name: "set-heartbeat", - Usage: "set-heartbeat - Set the heartbeat model", + Usage: "set-heartbeat | - Set the heartbeat model", IsWrite: true, Handler: func(cc CommandContext) (string, error) { - if len(cc.Args) < 2 { - return "Usage: /model set-heartbeat ", nil + if len(cc.Args) < 1 { + return "Usage: /model set-heartbeat | ", nil } - modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) + if h.settingsService == nil { + return "Settings service is not available.", nil + } + before, _ := h.getBotSettings(cc) + modelResp, err := h.findModelForSelection(cc, cc.Args) if err != nil { return "", err } @@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil + return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil }, }) return g @@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo } return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName) } + +func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) { + if h.modelsService == nil { + return models.GetResponse{}, errors.New("model service is not available") + } + if len(args) == 0 { + return models.GetResponse{}, errors.New("model identifier is required") + } + if len(args) == 1 { + return h.findModelByIDOrName(cc, args[0]) + } + return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " ")) +} + +func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) { + items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat) + if err != nil { + return models.GetResponse{}, err + } + target = strings.TrimSpace(target) + if target == "" { + return models.GetResponse{}, errors.New("model identifier is required") + } + for _, item := range items { + if strings.EqualFold(item.ModelID, target) { + return item, nil + } + } + matches := make([]models.GetResponse, 0, 4) + for _, item := range items { + if strings.EqualFold(item.Name, target) { + matches = append(matches, item) + } + } + switch len(matches) { + case 0: + return models.GetResponse{}, fmt.Errorf("model %q not found", target) + case 1: + return matches[0], nil + default: + choices := make([]string, 0, len(matches)) + for _, item := range matches { + choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID)) + } + return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", ")) + } +} + +func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse { + providerName = strings.TrimSpace(providerName) + if providerName == "" { + return items + } + filtered := make([]models.GetResponse, 0, len(items)) + for _, item := range items { + if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) { + filtered = append(filtered, item) + } + } + return filtered +} + +func modelMarkers(modelID string, settingsResp settings.Settings) []string { + var markers []string + if modelID == settingsResp.ChatModelID { + markers = append(markers, "chat") + } + if modelID == settingsResp.HeartbeatModelID { + markers = append(markers, "heartbeat") + } + return markers +} + +func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int { + switch len(modelMarkers(model.ID, settingsResp)) { + case 2: + return 0 + case 1: + return 1 + default: + return 2 + } +} diff --git a/internal/command/parser_test.go b/internal/command/parser_test.go index 3956c1b3..5a16ea06 100644 --- a/internal/command/parser_test.go +++ b/internal/command/parser_test.go @@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) { args []string }{ {"/help", "help", "", nil}, + {"/help model", "help", "model", nil}, + {"/help model set", "help", "model", []string{"set"}}, {"/subagent list", "subagent", "list", nil}, {"/subagent get mybot", "subagent", "get", []string{"mybot"}}, {"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}}, diff --git a/internal/command/registry.go b/internal/command/registry.go index 80de168a..d6d10d53 100644 --- a/internal/command/registry.go +++ b/internal/command/registry.go @@ -8,10 +8,18 @@ import ( // CommandContext carries execution context for a sub-command. type CommandContext struct { - Ctx context.Context - BotID string - Role string // "owner", "admin", "member", or "" (guest) - Args []string + Ctx context.Context + BotID string + Role string // "owner", "admin", "member", or "" (guest) + Args []string + ChannelIdentityID string + UserID string + ChannelType string + ConversationType string + ConversationID string + ThreadID string + RouteID string + SessionID string } // SubCommand describes a single sub-command within a resource group. @@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string { fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description) for _, name := range g.order { sub := g.commands[name] - perm := "" + desc := subSummary(sub) if sub.IsWrite { - perm = " [owner]" + desc += " [owner]" } - fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm) + fmt.Fprintf(&b, "- %s\n", desc) } + fmt.Fprintf(&b, "\nUse /help %s for details.", g.Name) return b.String() } +func (g *CommandGroup) ActionHelp(action string) string { + sub, ok := g.commands[action] + if !ok { + return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage()) + } + usage, summary := splitUsage(sub.Usage) + var b strings.Builder + fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name) + if summary != "" { + fmt.Fprintf(&b, "- Summary: %s\n", summary) + } + if usage == "" { + usage = sub.Name + } + fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage) + if sub.IsWrite { + b.WriteString("- Access: owner only\n") + } + fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name) + return strings.TrimRight(b.String(), "\n") +} + // Registry holds all registered command groups. type Registry struct { groups map[string]*CommandGroup @@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string { b.WriteString("/help - Show this help message\n") b.WriteString("/new - Start a new conversation (resets session context)\n") b.WriteString("/stop - Stop the current generation\n\n") - for i, name := range r.order { - if i > 0 { - b.WriteByte('\n') - } + for _, name := range r.order { group := r.groups[name] - b.WriteString(group.Usage()) + fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description) } + b.WriteString("\nUse /help to view actions, e.g. /help model") return strings.TrimRight(b.String(), "\n") } + +func (r *Registry) GroupHelp(name string) string { + group, ok := r.groups[name] + if !ok { + return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp()) + } + return group.Usage() +} + +func (r *Registry) ActionHelp(groupName, action string) string { + group, ok := r.groups[groupName] + if !ok { + return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp()) + } + return group.ActionHelp(action) +} + +func splitUsage(usage string) (commandUsage string, summary string) { + usage = strings.TrimSpace(usage) + if usage == "" { + return "", "" + } + parts := strings.SplitN(usage, " - ", 2) + commandUsage = strings.TrimSpace(parts[0]) + if len(parts) > 1 { + summary = strings.TrimSpace(parts[1]) + } + return commandUsage, summary +} + +func subSummary(sub SubCommand) string { + usage, summary := splitUsage(sub.Usage) + if summary == "" { + return usage + } + return fmt.Sprintf("%s - %s", sub.Name, summary) +} diff --git a/internal/command/search.go b/internal/command/search.go index a18432c5..19ad41f3 100644 --- a/internal/command/search.go +++ b/internal/command/search.go @@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup { Name: "list", Usage: "list - List all search providers", Handler: func(cc CommandContext) (string, error) { + if h.searchProvService == nil { + return "Search provider service is not available.", nil + } items, err := h.searchProvService.List(cc.Ctx, "") if err != nil { return "", err @@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if len(items) == 0 { return "No search providers found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { - records = append(records, []kv{ - {"Name", item.Name}, + label := item.Name + record := []kv{ + {"Name", label}, {"Provider", item.Provider}, - }) + } + if item.ID == settingsResp.SearchProviderID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current search provider", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil }, }) g.Register(SubCommand{ @@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /search set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.searchProvService.List(cc.Ctx, "") if err != nil { return "", err @@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Search provider set to %q.", item.Name), nil + return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil } } return fmt.Sprintf("Search provider %q not found.", name), nil diff --git a/internal/command/settings.go b/internal/command/settings.go index b88d9be5..f8ad64d4 100644 --- a/internal/command/settings.go +++ b/internal/command/settings.go @@ -1,6 +1,7 @@ package command import ( + "errors" "fmt" "strconv" "strings" @@ -106,6 +107,13 @@ func settingsUpdateUsage() string { "- --heartbeat_model_id " } +func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) { + if h.settingsService == nil { + return settings.Settings{}, errors.New("settings service is not available") + } + return h.settingsService.GetBot(cc.Ctx, cc.BotID) +} + // resolveModelName resolves a model UUID to "model_name (provider_name)". func (h *Handler) resolveModelName(cc CommandContext, modelID string) string { if modelID == "" { diff --git a/internal/command/status.go b/internal/command/status.go index b68e1a84..1d19774c 100644 --- a/internal/command/status.go +++ b/internal/command/status.go @@ -6,7 +6,9 @@ import ( "strconv" "strings" + "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" ) func (h *Handler) buildStatusGroup() *CommandGroup { @@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup { Name: "show", Usage: "show - Show current session status", Handler: func(cc CommandContext) (string, error) { + if strings.TrimSpace(cc.SessionID) == "" { + return "No active session found for this conversation.", nil + } + return h.renderSessionStatus(cc, cc.SessionID, "current conversation") + }, + }) + g.Register(SubCommand{ + Name: "latest", + Usage: "latest - Show the latest session status for this bot", + Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Session info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err } - sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return "No active session found.", nil + return "No session found for this bot.", nil } return "", err } - - msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID) - if err != nil { - return "", fmt.Errorf("count messages: %w", err) - } - - var usedTokens int64 - latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return "", fmt.Errorf("get usage: %w", err) - } - if err == nil { - usedTokens = latestUsage - } - - cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID) - if err != nil { - return "", fmt.Errorf("get cache: %w", err) - } - - var contextWindowStr string - if h.settingsService != nil { - s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID) - if sErr == nil && s.ChatModelID != "" && h.modelsService != nil { - m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID) - if mErr == nil && m.Config.ContextWindow != nil { - contextWindowStr = formatTokens(int64(*m.Config.ContextWindow)) - } - } - } - - var cacheHitRate float64 - if cacheRow.TotalInputTokens > 0 { - cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100 - } - - skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID) - - var b strings.Builder - b.WriteString("Session Status:\n\n") - fmt.Fprintf(&b, "- Messages: %d\n", msgCount) - if contextWindowStr != "" { - fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr) - } else { - fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens)) - } - fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate) - fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens)) - fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens)) - - if len(skills) > 0 { - fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", ")) - } - - return strings.TrimRight(b.String(), "\n"), nil + return h.renderSessionStatus(cc, sessionID.String(), "latest bot session") }, }) return g } +func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) { + if h.queries == nil { + return "Session info is not available.", nil + } + pgSessionID, err := parseCommandUUID(sessionID) + if err != nil { + return "", err + } + msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID) + if err != nil { + return "", fmt.Errorf("count messages: %w", err) + } + + var usedTokens int64 + latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return "", fmt.Errorf("get usage: %w", err) + } + if err == nil { + usedTokens = latestUsage + } + + cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID) + if err != nil { + return "", fmt.Errorf("get cache: %w", err) + } + + var cacheHitRate float64 + if cacheRow.TotalInputTokens > 0 { + cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100 + } + + skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID) + + contextUsage := formatTokens(usedTokens) + if contextWindow := h.resolveContextWindow(cc); contextWindow != "" { + contextUsage = contextUsage + " / " + contextWindow + } + + pairs := []kv{ + {"Scope", scope}, + {"Session ID", sessionID}, + {"Messages", strconv.FormatInt(msgCount, 10)}, + {"Context", contextUsage}, + {"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)}, + {"Cache Read", formatTokens(cacheRow.CacheReadTokens)}, + {"Cache Write", formatTokens(cacheRow.CacheWriteTokens)}, + } + if len(skills) > 0 { + pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")}) + } + return formatKV(pairs), nil +} + +func (h *Handler) resolveContextWindow(cc CommandContext) string { + if h.settingsService == nil || h.modelsService == nil { + return "" + } + s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID) + if err != nil || s.ChatModelID == "" { + return "" + } + m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID) + if err != nil || m.Config.ContextWindow == nil { + return "" + } + return formatTokens(int64(*m.Config.ContextWindow)) +} + +func parseCommandUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err) + } + return pgtype.UUID{Bytes: parsed, Valid: true}, nil +} + func formatTokens(n int64) string { if n >= 1_000_000 { return fmt.Sprintf("%.1fM", float64(n)/1_000_000) diff --git a/internal/command/usage.go b/internal/command/usage.go index a08792a8..4ed5b40a 100644 --- a/internal/command/usage.go +++ b/internal/command/usage.go @@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup { Name: "summary", Usage: "summary - Token usage summary (last 7 days)", Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Usage info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err @@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup { Name: "by-model", Usage: "by-model - Token usage grouped by model", Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Usage info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err